Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import time | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct | |
| from sentence_transformers import SentenceTransformer | |
| import PyPDF2 | |
| import io | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| st.set_page_config( | |
| page_title="Math AI - Phase 2.5: Database + PDF", | |
| page_icon="ποΈ", | |
| layout="wide" | |
| ) | |
| COLLECTION_NAME = "math_knowledge_base" | |
| # ============================================================================ | |
| # CACHED FUNCTIONS | |
| # ============================================================================ | |
| def get_qdrant_client(): | |
| qdrant_url = os.getenv("QDRANT_URL") | |
| qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
| if not qdrant_url or not qdrant_api_key: | |
| return None | |
| return QdrantClient(url=qdrant_url, api_key=qdrant_api_key) | |
| def get_embedding_model(): | |
| try: | |
| model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| return model | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| return None | |
| def get_vector_count_reliable(client, collection_name): | |
| try: | |
| count = 0 | |
| offset = None | |
| max_iterations = 1000 | |
| for _ in range(max_iterations): | |
| result = client.scroll( | |
| collection_name=collection_name, | |
| limit=100, | |
| offset=offset, | |
| with_payload=False, | |
| with_vectors=False | |
| ) | |
| if result is None or result[0] is None or len(result[0]) == 0: | |
| break | |
| count += len(result[0]) | |
| offset = result[1] | |
| if offset is None: | |
| break | |
| return count | |
| except: | |
| return 0 | |
| def check_collection_exists(client, collection_name): | |
| try: | |
| collections = client.get_collections().collections | |
| return any(c.name == collection_name for c in collections) | |
| except: | |
| return False | |
| def extract_text_from_pdf(pdf_file): | |
| """Extract text from PDF file""" | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| text = "" | |
| for page_num, page in enumerate(pdf_reader.pages): | |
| page_text = page.extract_text() | |
| text += f"\n\n--- Page {page_num + 1} ---\n\n{page_text}" | |
| return text | |
| except Exception as e: | |
| st.error(f"PDF extraction error: {str(e)}") | |
| return None | |
| # ============================================================================ | |
| # SESSION STATE | |
| # ============================================================================ | |
| if 'db_created' not in st.session_state: | |
| st.session_state.db_created = False | |
| if 'embedder_ready' not in st.session_state: | |
| st.session_state.embedder_ready = False | |
| if 'show_step' not in st.session_state: | |
| st.session_state.show_step = 'all' | |
| # ============================================================================ | |
| # MAIN APP | |
| # ============================================================================ | |
| st.title("ποΈ Phase 2.5: Database Setup + PDF Upload") | |
| client = get_qdrant_client() | |
| embedder = get_embedding_model() | |
| # ============================================================================ | |
| # SIDEBAR | |
| # ============================================================================ | |
| with st.sidebar: | |
| st.header("β‘ Quick Navigation") | |
| if st.button("π Show All Steps", use_container_width=True): | |
| st.session_state.show_step = 'all' | |
| if st.button("π Skip to Upload", use_container_width=True): | |
| st.session_state.show_step = 'upload' | |
| if st.button("π Skip to Search", use_container_width=True): | |
| st.session_state.show_step = 'search' | |
| st.markdown("---") | |
| st.subheader("π System Status") | |
| if client and check_collection_exists(client, COLLECTION_NAME): | |
| st.success("β Database Ready") | |
| st.session_state.db_created = True | |
| else: | |
| st.warning("β οΈ Database Not Ready") | |
| if embedder: | |
| st.success("β Model Loaded") | |
| st.session_state.embedder_ready = True | |
| else: | |
| st.warning("β οΈ Model Not Loaded") | |
| if client and st.session_state.db_created: | |
| count = get_vector_count_reliable(client, COLLECTION_NAME) | |
| st.metric("Vectors in DB", f"{count:,}") | |
| show_all = st.session_state.show_step == 'all' | |
| show_upload = st.session_state.show_step in ['all', 'upload'] | |
| show_search = st.session_state.show_step in ['all', 'search'] | |
| # ============================================================================ | |
| # STEP 1-2: Quick Status | |
| # ============================================================================ | |
| if show_all: | |
| st.header("Step 1-2: System Check") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Claude API", "β " if os.getenv("ANTHROPIC_API_KEY") else "β") | |
| with col2: | |
| st.metric("Qdrant", "β Connected" if client else "β") | |
| with col3: | |
| st.metric("Embedder", "β Cached" if embedder else "β") | |
| if not client: | |
| st.error("β οΈ Check Qdrant secrets!") | |
| st.stop() | |
| st.markdown("---") | |
| # ============================================================================ | |
| # STEP 3: Collection Management | |
| # ============================================================================ | |
| if show_all: | |
| st.header("ποΈ Step 3: Database Collection") | |
| if st.session_state.db_created: | |
| st.success(f"β Collection '{COLLECTION_NAME}' ready!") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("π Recreate Collection"): | |
| try: | |
| client.delete_collection(COLLECTION_NAME) | |
| st.session_state.db_created = False | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| with col2: | |
| if st.button("βΉοΈ Collection Info"): | |
| count = get_vector_count_reliable(client, COLLECTION_NAME) | |
| st.json({"name": COLLECTION_NAME, "vectors": count, "status": "Ready"}) | |
| else: | |
| if st.button("ποΈ CREATE COLLECTION", type="primary"): | |
| try: | |
| client.create_collection( | |
| collection_name=COLLECTION_NAME, | |
| vectors_config=VectorParams(size=384, distance=Distance.COSINE) | |
| ) | |
| st.success(f"π Created: {COLLECTION_NAME}") | |
| st.session_state.db_created = True | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"β Failed: {str(e)}") | |
| st.markdown("---") | |
| # ============================================================================ | |
| # STEP 4: Embedding Model | |
| # ============================================================================ | |
| if show_all: | |
| st.header("π€ Step 4: Embedding Model") | |
| if embedder: | |
| st.success("β Model loaded and cached!") | |
| st.session_state.embedder_ready = True | |
| else: | |
| st.warning("β οΈ Model loading failed. Refresh page.") | |
| st.markdown("---") | |
| # ============================================================================ | |
| # STEP 5A: Upload Custom Text | |
| # ============================================================================ | |
| if show_upload: | |
| st.header("π Step 5A: Upload Custom Notes") | |
| if not st.session_state.db_created or not st.session_state.embedder_ready: | |
| st.error("β οΈ Complete Steps 3 & 4 first") | |
| else: | |
| # Choose upload method | |
| upload_method = st.radio( | |
| "Upload method:", | |
| ["π Paste Text", "π Upload PDF File"], | |
| horizontal=True | |
| ) | |
| if upload_method == "π Paste Text": | |
| with st.expander("βοΈ Paste text", expanded=True): | |
| custom_text = st.text_area( | |
| "Math notes:", | |
| value="""Linear Equations: ax + b = 0, solution is x = -b/a | |
| Quadratic Equations: axΒ² + bx + c = 0 | |
| Solution: x = (-b Β± β(bΒ²-4ac)) / 2a | |
| Pythagorean Theorem: aΒ² + bΒ² = cΒ² | |
| Derivatives: | |
| d/dx(xβΏ) = nxβΏβ»ΒΉ | |
| d/dx(sin x) = cos x""", | |
| height=200 | |
| ) | |
| source_name = st.text_input("Source name:", value="math_notes.txt") | |
| if st.button("π UPLOAD TEXT", type="primary"): | |
| if not custom_text.strip(): | |
| st.error("Please enter text!") | |
| else: | |
| try: | |
| progress = st.progress(0) | |
| status = st.empty() | |
| status.text("π Chunking text...") | |
| progress.progress(0.2) | |
| words = custom_text.split() | |
| chunks = [] | |
| chunk_size = 50 | |
| for i in range(0, len(words), 40): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if chunk.strip(): | |
| chunks.append(chunk) | |
| st.write(f"β Created {len(chunks)} chunks") | |
| status.text("π’ Generating embeddings...") | |
| progress.progress(0.5) | |
| embeddings = embedder.encode(chunks, show_progress_bar=False) | |
| st.write(f"β Generated {len(embeddings)} embeddings") | |
| status.text("βοΈ Uploading...") | |
| progress.progress(0.8) | |
| points = [] | |
| for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)): | |
| points.append(PointStruct( | |
| id=abs(hash(f"{source_name}_{idx}_{custom_text[:50]}_{time.time()}")) % (2**63), | |
| vector=embedding.tolist(), | |
| payload={ | |
| "content": chunk, | |
| "source_name": source_name, | |
| "source_type": "custom_notes", | |
| "chunk_index": idx | |
| } | |
| )) | |
| client.upsert(collection_name=COLLECTION_NAME, points=points) | |
| progress.progress(1.0) | |
| status.empty() | |
| st.success(f"π Uploaded {len(points)} vectors!") | |
| count = get_vector_count_reliable(client, COLLECTION_NAME) | |
| st.info(f"π **Total vectors: {count:,}**") | |
| except Exception as e: | |
| st.error(f"β Failed: {str(e)}") | |
| st.exception(e) | |
| else: # PDF Upload | |
| with st.expander("π Upload PDF", expanded=True): | |
| st.info("π **NEW!** Upload your math PDFs directly") | |
| uploaded_file = st.file_uploader( | |
| "Choose PDF file:", | |
| type=['pdf'], | |
| help="Upload a PDF with math content" | |
| ) | |
| if uploaded_file: | |
| st.write(f"π File: {uploaded_file.name} ({uploaded_file.size / 1024:.1f} KB)") | |
| source_name = st.text_input( | |
| "Source name:", | |
| value=uploaded_file.name.replace('.pdf', '') | |
| ) | |
| if st.button("π UPLOAD PDF", type="primary"): | |
| try: | |
| progress = st.progress(0) | |
| status = st.empty() | |
| # Extract text | |
| status.text("π Extracting text from PDF...") | |
| progress.progress(0.1) | |
| extracted_text = extract_text_from_pdf(uploaded_file) | |
| if not extracted_text: | |
| st.error("β Failed to extract text from PDF") | |
| st.stop() | |
| st.write(f"β Extracted {len(extracted_text)} characters") | |
| # Show preview | |
| with st.expander("ποΈ Preview extracted text"): | |
| st.text(extracted_text[:500] + "..." if len(extracted_text) > 500 else extracted_text) | |
| # Chunk | |
| status.text("π Chunking text...") | |
| progress.progress(0.3) | |
| words = extracted_text.split() | |
| chunks = [] | |
| chunk_size = 100 # Larger chunks for PDFs | |
| overlap = 20 | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if chunk.strip(): | |
| chunks.append(chunk) | |
| st.write(f"β Created {len(chunks)} chunks") | |
| # Embed | |
| status.text("π’ Generating embeddings...") | |
| progress.progress(0.5) | |
| embeddings = [] | |
| for idx, chunk in enumerate(chunks): | |
| embedding = embedder.encode(chunk) | |
| embeddings.append(embedding) | |
| if idx % 20 == 0: | |
| progress.progress(0.5 + (0.3 * idx / len(chunks))) | |
| st.write(f"β Generated {len(embeddings)} embeddings") | |
| # Upload | |
| status.text("βοΈ Uploading to database...") | |
| progress.progress(0.9) | |
| points = [] | |
| for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)): | |
| points.append(PointStruct( | |
| id=abs(hash(f"pdf_{source_name}_{idx}_{time.time()}")) % (2**63), | |
| vector=embedding.tolist(), | |
| payload={ | |
| "content": chunk, | |
| "source_name": source_name, | |
| "source_type": "pdf_upload", | |
| "chunk_index": idx, | |
| "file_name": uploaded_file.name | |
| } | |
| )) | |
| client.upsert(collection_name=COLLECTION_NAME, points=points) | |
| progress.progress(1.0) | |
| status.empty() | |
| st.success(f"π Uploaded {len(points)} vectors from PDF!") | |
| st.balloons() | |
| count = get_vector_count_reliable(client, COLLECTION_NAME) | |
| st.info(f"π **Total vectors: {count:,}**") | |
| except Exception as e: | |
| st.error(f"β Upload failed: {str(e)}") | |
| st.exception(e) | |
| st.markdown("---") | |
| # ============================================================================ | |
| # STEP 5B: Load Public Datasets (FIXED - No DeepMind) | |
| # ============================================================================ | |
| if show_upload: | |
| st.header("π Step 5B: Load Public Datasets") | |
| if not st.session_state.db_created or not st.session_state.embedder_ready: | |
| st.error("β οΈ Complete Steps 3 & 4 first") | |
| else: | |
| with st.expander("π Load from Hugging Face", expanded=False): | |
| dataset_choice = st.selectbox( | |
| "Dataset:", | |
| [ | |
| "GSM8K - Grade School Math (8.5K problems)", | |
| "MATH - Competition Math (12.5K problems) β¨", | |
| "MathQA - Math Word Problems (37K problems) π", | |
| "CAMEL-AI Math - GPT-4 Generated (50K problems)", | |
| "RACE - Reading Comprehension (28K passages)" | |
| ] | |
| ) | |
| # INCREASED LIMIT FROM 500 TO 2000! | |
| sample_size = st.slider("Items to load:", 10, 2000, 50) | |
| st.warning(f"β οΈ Loading {sample_size} items. Large numbers take 5-15 minutes!") | |
| if st.button("π₯ LOAD DATASET", type="primary"): | |
| try: | |
| from datasets import load_dataset | |
| progress = st.progress(0) | |
| status = st.empty() | |
| # GSM8K | |
| if "GSM8K" in dataset_choice: | |
| status.text("π₯ Downloading GSM8K...") | |
| progress.progress(0.1) | |
| dataset = load_dataset("openai/gsm8k", "main", split="train", trust_remote_code=True) | |
| dataset_name = "GSM8K" | |
| texts = [] | |
| for i in range(min(sample_size, len(dataset))): | |
| item = dataset[i] | |
| text = f"Problem: {item['question']}\n\nSolution: {item['answer']}" | |
| texts.append(text) | |
| # MATH | |
| elif "MATH" in dataset_choice and "Competition" in dataset_choice: | |
| status.text("π₯ Downloading MATH...") | |
| progress.progress(0.1) | |
| dataset = None | |
| dataset_name = "MATH" | |
| # Try multiple sources | |
| for source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "EleutherAI/hendrycks_math"]: | |
| try: | |
| dataset = load_dataset(source, split="train", trust_remote_code=True) | |
| st.success(f"β Using {source}") | |
| break | |
| except: | |
| continue | |
| if dataset is None: | |
| st.error("β All MATH sources failed") | |
| st.stop() | |
| texts = [] | |
| for i in range(min(sample_size, len(dataset))): | |
| item = dataset[i] | |
| problem = item.get('problem', item.get('question', '')) | |
| solution = item.get('solution', item.get('answer', '')) | |
| problem_type = item.get('type', item.get('level', 'general')) | |
| text = f"Problem ({problem_type}): {problem}\n\nSolution: {solution}" | |
| texts.append(text) | |
| # MathQA (REPLACES DEEPMIND) | |
| elif "MathQA" in dataset_choice: | |
| status.text("π₯ Downloading MathQA...") | |
| progress.progress(0.1) | |
| st.info("π MathQA: 37K math word problems with detailed solutions") | |
| dataset = load_dataset("allenai/math_qa", split="train", trust_remote_code=True) | |
| dataset_name = "MathQA" | |
| texts = [] | |
| for i in range(min(sample_size, len(dataset))): | |
| item = dataset[i] | |
| text = f"Problem: {item['Problem']}\n\nRationale: {item['Rationale']}\n\nAnswer: {item['correct']}" | |
| texts.append(text) | |
| # CAMEL-AI | |
| elif "CAMEL" in dataset_choice: | |
| status.text("π₯ Downloading CAMEL-AI...") | |
| progress.progress(0.1) | |
| dataset = load_dataset("camel-ai/math", split="train", trust_remote_code=True) | |
| dataset_name = "CAMEL-Math" | |
| texts = [] | |
| for i in range(min(sample_size, len(dataset))): | |
| item = dataset[i] | |
| text = f"Problem: {item['message']}" | |
| texts.append(text) | |
| # RACE | |
| else: | |
| status.text("π₯ Downloading RACE...") | |
| progress.progress(0.1) | |
| dataset = load_dataset("ehovy/race", "all", split="train", trust_remote_code=True) | |
| dataset_name = "RACE" | |
| texts = [] | |
| for i in range(min(sample_size, len(dataset))): | |
| item = dataset[i] | |
| text = f"Article: {item['article'][:500]}\n\nQuestion: {item['question']}\n\nAnswer: {item['answer']}" | |
| texts.append(text) | |
| # Common processing | |
| st.write(f"β Loaded {len(texts)} items from {dataset_name}") | |
| progress.progress(0.3) | |
| status.text("π’ Generating embeddings...") | |
| embeddings = [] | |
| for idx, text in enumerate(texts): | |
| embedding = embedder.encode(text) | |
| embeddings.append(embedding) | |
| if idx % 50 == 0: | |
| progress.progress(0.3 + (0.5 * idx / len(texts))) | |
| status.text(f"π’ Embedding {idx+1}/{len(texts)}") | |
| st.write(f"β Generated {len(embeddings)} embeddings") | |
| progress.progress(0.8) | |
| status.text("βοΈ Uploading...") | |
| points = [] | |
| for idx, (text, embedding) in enumerate(zip(texts, embeddings)): | |
| content = text[:2000] if len(text) > 2000 else text | |
| points.append(PointStruct( | |
| id=abs(hash(f"{dataset_name}_{idx}_{time.time()}")) % (2**63), | |
| vector=embedding.tolist(), | |
| payload={ | |
| "content": content, | |
| "source_name": dataset_name, | |
| "source_type": "public_dataset", | |
| "dataset": dataset_name, | |
| "index": idx | |
| } | |
| )) | |
| client.upsert(collection_name=COLLECTION_NAME, points=points) | |
| progress.progress(1.0) | |
| status.empty() | |
| st.success(f"π Uploaded {len(points)} vectors from {dataset_name}!") | |
| count = get_vector_count_reliable(client, COLLECTION_NAME) | |
| st.info(f"π **Total vectors: {count:,}**") | |
| except ImportError: | |
| st.error("β Add 'datasets' to requirements.txt") | |
| except Exception as e: | |
| st.error(f"β Failed: {str(e)}") | |
| st.exception(e) | |
| st.markdown("---") | |
| # ============================================================================ | |
| # STEP 6: Search | |
| # ============================================================================ | |
| if show_search: | |
| st.header("π Step 6: Test Search") | |
| if not st.session_state.db_created or not st.session_state.embedder_ready: | |
| st.error("β οΈ Database and embedder must be ready") | |
| else: | |
| search_query = st.text_input( | |
| "Question:", | |
| placeholder="Solve xΒ² + 5x - 4 = 0" | |
| ) | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| top_k = st.slider("Results:", 1, 10, 5) | |
| with col2: | |
| st.metric("DB Vectors", get_vector_count_reliable(client, COLLECTION_NAME)) | |
| if st.button("π SEARCH", type="primary") and search_query: | |
| try: | |
| with st.spinner("Searching..."): | |
| query_embedding = embedder.encode(search_query) | |
| results = client.search( | |
| collection_name=COLLECTION_NAME, | |
| query_vector=query_embedding.tolist(), | |
| limit=top_k | |
| ) | |
| if results: | |
| st.success(f"β Found {len(results)} results!") | |
| for i, result in enumerate(results, 1): | |
| similarity_pct = result.score * 100 | |
| if similarity_pct > 50: | |
| color = "π’" | |
| elif similarity_pct > 30: | |
| color = "π‘" | |
| else: | |
| color = "π΄" | |
| with st.expander(f"{color} Result {i} - {similarity_pct:.1f}% match", expanded=(i<=2)): | |
| st.info(result.payload['content']) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.caption(f"**Source:** {result.payload['source_name']}") | |
| with col2: | |
| st.caption(f"**Type:** {result.payload['source_type']}") | |
| with col3: | |
| st.caption(f"**Score:** {result.score:.4f}") | |
| else: | |
| st.warning("No results found!") | |
| except Exception as e: | |
| st.error(f"β Search failed: {str(e)}") | |
| st.markdown("---") | |
| st.success("π Phase 2.5 Complete! You now have: Text, PDF upload, and 4 working datasets!") |