math-ai-system / src /streamlit_app.py
Hebaelsayed's picture
Update src/streamlit_app.py
f41b68e verified
import streamlit as st
import os
import time
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer
import PyPDF2
import io
# ============================================================================
# CONFIGURATION
# ============================================================================
st.set_page_config(
page_title="Math AI - Phase 2.5: Database + PDF",
page_icon="πŸ—„οΈ",
layout="wide"
)
COLLECTION_NAME = "math_knowledge_base"
# ============================================================================
# CACHED FUNCTIONS
# ============================================================================
@st.cache_resource(show_spinner="πŸ”Œ Connecting to Qdrant...")
def get_qdrant_client():
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
return None
return QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
@st.cache_resource(show_spinner="πŸ€– Loading embedding model (30-60s first time)...")
def get_embedding_model():
try:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
return model
except Exception as e:
st.error(f"Failed to load model: {e}")
return None
def get_vector_count_reliable(client, collection_name):
try:
count = 0
offset = None
max_iterations = 1000
for _ in range(max_iterations):
result = client.scroll(
collection_name=collection_name,
limit=100,
offset=offset,
with_payload=False,
with_vectors=False
)
if result is None or result[0] is None or len(result[0]) == 0:
break
count += len(result[0])
offset = result[1]
if offset is None:
break
return count
except:
return 0
def check_collection_exists(client, collection_name):
try:
collections = client.get_collections().collections
return any(c.name == collection_name for c in collections)
except:
return False
def extract_text_from_pdf(pdf_file):
"""Extract text from PDF file"""
try:
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ""
for page_num, page in enumerate(pdf_reader.pages):
page_text = page.extract_text()
text += f"\n\n--- Page {page_num + 1} ---\n\n{page_text}"
return text
except Exception as e:
st.error(f"PDF extraction error: {str(e)}")
return None
# ============================================================================
# SESSION STATE
# ============================================================================
if 'db_created' not in st.session_state:
st.session_state.db_created = False
if 'embedder_ready' not in st.session_state:
st.session_state.embedder_ready = False
if 'show_step' not in st.session_state:
st.session_state.show_step = 'all'
# ============================================================================
# MAIN APP
# ============================================================================
st.title("πŸ—„οΈ Phase 2.5: Database Setup + PDF Upload")
client = get_qdrant_client()
embedder = get_embedding_model()
# ============================================================================
# SIDEBAR
# ============================================================================
with st.sidebar:
st.header("⚑ Quick Navigation")
if st.button("πŸ“‹ Show All Steps", use_container_width=True):
st.session_state.show_step = 'all'
if st.button("πŸš€ Skip to Upload", use_container_width=True):
st.session_state.show_step = 'upload'
if st.button("πŸ” Skip to Search", use_container_width=True):
st.session_state.show_step = 'search'
st.markdown("---")
st.subheader("πŸ“Š System Status")
if client and check_collection_exists(client, COLLECTION_NAME):
st.success("βœ… Database Ready")
st.session_state.db_created = True
else:
st.warning("⚠️ Database Not Ready")
if embedder:
st.success("βœ… Model Loaded")
st.session_state.embedder_ready = True
else:
st.warning("⚠️ Model Not Loaded")
if client and st.session_state.db_created:
count = get_vector_count_reliable(client, COLLECTION_NAME)
st.metric("Vectors in DB", f"{count:,}")
show_all = st.session_state.show_step == 'all'
show_upload = st.session_state.show_step in ['all', 'upload']
show_search = st.session_state.show_step in ['all', 'search']
# ============================================================================
# STEP 1-2: Quick Status
# ============================================================================
if show_all:
st.header("Step 1-2: System Check")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Claude API", "βœ…" if os.getenv("ANTHROPIC_API_KEY") else "❌")
with col2:
st.metric("Qdrant", "βœ… Connected" if client else "❌")
with col3:
st.metric("Embedder", "βœ… Cached" if embedder else "❌")
if not client:
st.error("⚠️ Check Qdrant secrets!")
st.stop()
st.markdown("---")
# ============================================================================
# STEP 3: Collection Management
# ============================================================================
if show_all:
st.header("πŸ—οΈ Step 3: Database Collection")
if st.session_state.db_created:
st.success(f"βœ… Collection '{COLLECTION_NAME}' ready!")
col1, col2 = st.columns(2)
with col1:
if st.button("πŸ”„ Recreate Collection"):
try:
client.delete_collection(COLLECTION_NAME)
st.session_state.db_created = False
st.rerun()
except Exception as e:
st.error(f"Error: {e}")
with col2:
if st.button("ℹ️ Collection Info"):
count = get_vector_count_reliable(client, COLLECTION_NAME)
st.json({"name": COLLECTION_NAME, "vectors": count, "status": "Ready"})
else:
if st.button("πŸ—οΈ CREATE COLLECTION", type="primary"):
try:
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
st.success(f"πŸŽ‰ Created: {COLLECTION_NAME}")
st.session_state.db_created = True
st.rerun()
except Exception as e:
st.error(f"❌ Failed: {str(e)}")
st.markdown("---")
# ============================================================================
# STEP 4: Embedding Model
# ============================================================================
if show_all:
st.header("πŸ€– Step 4: Embedding Model")
if embedder:
st.success("βœ… Model loaded and cached!")
st.session_state.embedder_ready = True
else:
st.warning("⚠️ Model loading failed. Refresh page.")
st.markdown("---")
# ============================================================================
# STEP 5A: Upload Custom Text
# ============================================================================
if show_upload:
st.header("πŸ“ Step 5A: Upload Custom Notes")
if not st.session_state.db_created or not st.session_state.embedder_ready:
st.error("⚠️ Complete Steps 3 & 4 first")
else:
# Choose upload method
upload_method = st.radio(
"Upload method:",
["πŸ“ Paste Text", "πŸ“„ Upload PDF File"],
horizontal=True
)
if upload_method == "πŸ“ Paste Text":
with st.expander("✍️ Paste text", expanded=True):
custom_text = st.text_area(
"Math notes:",
value="""Linear Equations: ax + b = 0, solution is x = -b/a
Quadratic Equations: axΒ² + bx + c = 0
Solution: x = (-b ± √(b²-4ac)) / 2a
Pythagorean Theorem: aΒ² + bΒ² = cΒ²
Derivatives:
d/dx(xⁿ) = nxⁿ⁻¹
d/dx(sin x) = cos x""",
height=200
)
source_name = st.text_input("Source name:", value="math_notes.txt")
if st.button("πŸš€ UPLOAD TEXT", type="primary"):
if not custom_text.strip():
st.error("Please enter text!")
else:
try:
progress = st.progress(0)
status = st.empty()
status.text("πŸ“„ Chunking text...")
progress.progress(0.2)
words = custom_text.split()
chunks = []
chunk_size = 50
for i in range(0, len(words), 40):
chunk = ' '.join(words[i:i + chunk_size])
if chunk.strip():
chunks.append(chunk)
st.write(f"βœ… Created {len(chunks)} chunks")
status.text("πŸ”’ Generating embeddings...")
progress.progress(0.5)
embeddings = embedder.encode(chunks, show_progress_bar=False)
st.write(f"βœ… Generated {len(embeddings)} embeddings")
status.text("☁️ Uploading...")
progress.progress(0.8)
points = []
for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
points.append(PointStruct(
id=abs(hash(f"{source_name}_{idx}_{custom_text[:50]}_{time.time()}")) % (2**63),
vector=embedding.tolist(),
payload={
"content": chunk,
"source_name": source_name,
"source_type": "custom_notes",
"chunk_index": idx
}
))
client.upsert(collection_name=COLLECTION_NAME, points=points)
progress.progress(1.0)
status.empty()
st.success(f"πŸŽ‰ Uploaded {len(points)} vectors!")
count = get_vector_count_reliable(client, COLLECTION_NAME)
st.info(f"πŸ“Š **Total vectors: {count:,}**")
except Exception as e:
st.error(f"❌ Failed: {str(e)}")
st.exception(e)
else: # PDF Upload
with st.expander("πŸ“„ Upload PDF", expanded=True):
st.info("πŸŽ‰ **NEW!** Upload your math PDFs directly")
uploaded_file = st.file_uploader(
"Choose PDF file:",
type=['pdf'],
help="Upload a PDF with math content"
)
if uploaded_file:
st.write(f"πŸ“„ File: {uploaded_file.name} ({uploaded_file.size / 1024:.1f} KB)")
source_name = st.text_input(
"Source name:",
value=uploaded_file.name.replace('.pdf', '')
)
if st.button("πŸš€ UPLOAD PDF", type="primary"):
try:
progress = st.progress(0)
status = st.empty()
# Extract text
status.text("πŸ“– Extracting text from PDF...")
progress.progress(0.1)
extracted_text = extract_text_from_pdf(uploaded_file)
if not extracted_text:
st.error("❌ Failed to extract text from PDF")
st.stop()
st.write(f"βœ… Extracted {len(extracted_text)} characters")
# Show preview
with st.expander("πŸ‘οΈ Preview extracted text"):
st.text(extracted_text[:500] + "..." if len(extracted_text) > 500 else extracted_text)
# Chunk
status.text("πŸ“„ Chunking text...")
progress.progress(0.3)
words = extracted_text.split()
chunks = []
chunk_size = 100 # Larger chunks for PDFs
overlap = 20
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if chunk.strip():
chunks.append(chunk)
st.write(f"βœ… Created {len(chunks)} chunks")
# Embed
status.text("πŸ”’ Generating embeddings...")
progress.progress(0.5)
embeddings = []
for idx, chunk in enumerate(chunks):
embedding = embedder.encode(chunk)
embeddings.append(embedding)
if idx % 20 == 0:
progress.progress(0.5 + (0.3 * idx / len(chunks)))
st.write(f"βœ… Generated {len(embeddings)} embeddings")
# Upload
status.text("☁️ Uploading to database...")
progress.progress(0.9)
points = []
for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
points.append(PointStruct(
id=abs(hash(f"pdf_{source_name}_{idx}_{time.time()}")) % (2**63),
vector=embedding.tolist(),
payload={
"content": chunk,
"source_name": source_name,
"source_type": "pdf_upload",
"chunk_index": idx,
"file_name": uploaded_file.name
}
))
client.upsert(collection_name=COLLECTION_NAME, points=points)
progress.progress(1.0)
status.empty()
st.success(f"πŸŽ‰ Uploaded {len(points)} vectors from PDF!")
st.balloons()
count = get_vector_count_reliable(client, COLLECTION_NAME)
st.info(f"πŸ“Š **Total vectors: {count:,}**")
except Exception as e:
st.error(f"❌ Upload failed: {str(e)}")
st.exception(e)
st.markdown("---")
# ============================================================================
# STEP 5B: Load Public Datasets (FIXED - No DeepMind)
# ============================================================================
if show_upload:
st.header("πŸ“š Step 5B: Load Public Datasets")
if not st.session_state.db_created or not st.session_state.embedder_ready:
st.error("⚠️ Complete Steps 3 & 4 first")
else:
with st.expander("πŸ“Š Load from Hugging Face", expanded=False):
dataset_choice = st.selectbox(
"Dataset:",
[
"GSM8K - Grade School Math (8.5K problems)",
"MATH - Competition Math (12.5K problems) ✨",
"MathQA - Math Word Problems (37K problems) πŸ†•",
"CAMEL-AI Math - GPT-4 Generated (50K problems)",
"RACE - Reading Comprehension (28K passages)"
]
)
# INCREASED LIMIT FROM 500 TO 2000!
sample_size = st.slider("Items to load:", 10, 2000, 50)
st.warning(f"⚠️ Loading {sample_size} items. Large numbers take 5-15 minutes!")
if st.button("πŸ“₯ LOAD DATASET", type="primary"):
try:
from datasets import load_dataset
progress = st.progress(0)
status = st.empty()
# GSM8K
if "GSM8K" in dataset_choice:
status.text("πŸ“₯ Downloading GSM8K...")
progress.progress(0.1)
dataset = load_dataset("openai/gsm8k", "main", split="train", trust_remote_code=True)
dataset_name = "GSM8K"
texts = []
for i in range(min(sample_size, len(dataset))):
item = dataset[i]
text = f"Problem: {item['question']}\n\nSolution: {item['answer']}"
texts.append(text)
# MATH
elif "MATH" in dataset_choice and "Competition" in dataset_choice:
status.text("πŸ“₯ Downloading MATH...")
progress.progress(0.1)
dataset = None
dataset_name = "MATH"
# Try multiple sources
for source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "EleutherAI/hendrycks_math"]:
try:
dataset = load_dataset(source, split="train", trust_remote_code=True)
st.success(f"βœ… Using {source}")
break
except:
continue
if dataset is None:
st.error("❌ All MATH sources failed")
st.stop()
texts = []
for i in range(min(sample_size, len(dataset))):
item = dataset[i]
problem = item.get('problem', item.get('question', ''))
solution = item.get('solution', item.get('answer', ''))
problem_type = item.get('type', item.get('level', 'general'))
text = f"Problem ({problem_type}): {problem}\n\nSolution: {solution}"
texts.append(text)
# MathQA (REPLACES DEEPMIND)
elif "MathQA" in dataset_choice:
status.text("πŸ“₯ Downloading MathQA...")
progress.progress(0.1)
st.info("πŸ†• MathQA: 37K math word problems with detailed solutions")
dataset = load_dataset("allenai/math_qa", split="train", trust_remote_code=True)
dataset_name = "MathQA"
texts = []
for i in range(min(sample_size, len(dataset))):
item = dataset[i]
text = f"Problem: {item['Problem']}\n\nRationale: {item['Rationale']}\n\nAnswer: {item['correct']}"
texts.append(text)
# CAMEL-AI
elif "CAMEL" in dataset_choice:
status.text("πŸ“₯ Downloading CAMEL-AI...")
progress.progress(0.1)
dataset = load_dataset("camel-ai/math", split="train", trust_remote_code=True)
dataset_name = "CAMEL-Math"
texts = []
for i in range(min(sample_size, len(dataset))):
item = dataset[i]
text = f"Problem: {item['message']}"
texts.append(text)
# RACE
else:
status.text("πŸ“₯ Downloading RACE...")
progress.progress(0.1)
dataset = load_dataset("ehovy/race", "all", split="train", trust_remote_code=True)
dataset_name = "RACE"
texts = []
for i in range(min(sample_size, len(dataset))):
item = dataset[i]
text = f"Article: {item['article'][:500]}\n\nQuestion: {item['question']}\n\nAnswer: {item['answer']}"
texts.append(text)
# Common processing
st.write(f"βœ… Loaded {len(texts)} items from {dataset_name}")
progress.progress(0.3)
status.text("πŸ”’ Generating embeddings...")
embeddings = []
for idx, text in enumerate(texts):
embedding = embedder.encode(text)
embeddings.append(embedding)
if idx % 50 == 0:
progress.progress(0.3 + (0.5 * idx / len(texts)))
status.text(f"πŸ”’ Embedding {idx+1}/{len(texts)}")
st.write(f"βœ… Generated {len(embeddings)} embeddings")
progress.progress(0.8)
status.text("☁️ Uploading...")
points = []
for idx, (text, embedding) in enumerate(zip(texts, embeddings)):
content = text[:2000] if len(text) > 2000 else text
points.append(PointStruct(
id=abs(hash(f"{dataset_name}_{idx}_{time.time()}")) % (2**63),
vector=embedding.tolist(),
payload={
"content": content,
"source_name": dataset_name,
"source_type": "public_dataset",
"dataset": dataset_name,
"index": idx
}
))
client.upsert(collection_name=COLLECTION_NAME, points=points)
progress.progress(1.0)
status.empty()
st.success(f"πŸŽ‰ Uploaded {len(points)} vectors from {dataset_name}!")
count = get_vector_count_reliable(client, COLLECTION_NAME)
st.info(f"πŸ“Š **Total vectors: {count:,}**")
except ImportError:
st.error("❌ Add 'datasets' to requirements.txt")
except Exception as e:
st.error(f"❌ Failed: {str(e)}")
st.exception(e)
st.markdown("---")
# ============================================================================
# STEP 6: Search
# ============================================================================
if show_search:
st.header("πŸ” Step 6: Test Search")
if not st.session_state.db_created or not st.session_state.embedder_ready:
st.error("⚠️ Database and embedder must be ready")
else:
search_query = st.text_input(
"Question:",
placeholder="Solve xΒ² + 5x - 4 = 0"
)
col1, col2 = st.columns([3, 1])
with col1:
top_k = st.slider("Results:", 1, 10, 5)
with col2:
st.metric("DB Vectors", get_vector_count_reliable(client, COLLECTION_NAME))
if st.button("πŸ” SEARCH", type="primary") and search_query:
try:
with st.spinner("Searching..."):
query_embedding = embedder.encode(search_query)
results = client.search(
collection_name=COLLECTION_NAME,
query_vector=query_embedding.tolist(),
limit=top_k
)
if results:
st.success(f"βœ… Found {len(results)} results!")
for i, result in enumerate(results, 1):
similarity_pct = result.score * 100
if similarity_pct > 50:
color = "🟒"
elif similarity_pct > 30:
color = "🟑"
else:
color = "πŸ”΄"
with st.expander(f"{color} Result {i} - {similarity_pct:.1f}% match", expanded=(i<=2)):
st.info(result.payload['content'])
col1, col2, col3 = st.columns(3)
with col1:
st.caption(f"**Source:** {result.payload['source_name']}")
with col2:
st.caption(f"**Type:** {result.payload['source_type']}")
with col3:
st.caption(f"**Score:** {result.score:.4f}")
else:
st.warning("No results found!")
except Exception as e:
st.error(f"❌ Search failed: {str(e)}")
st.markdown("---")
st.success("πŸŽ‰ Phase 2.5 Complete! You now have: Text, PDF upload, and 4 working datasets!")