Hebaelsayed commited on
Commit
3331648
Β·
verified Β·
1 Parent(s): 6e1b1a8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +147 -370
src/streamlit_app.py CHANGED
@@ -14,20 +14,20 @@ from sentence_transformers import SentenceTransformer
14
  from huggingface_hub import hf_hub_download, list_repo_files
15
 
16
  # ============================================================================
17
- # PRODUCTION MATH AI SYSTEM - SMART PROCESSING
18
  # ============================================================================
19
 
20
  st.set_page_config(
21
- page_title="Math AI System - Production",
22
  page_icon="πŸŽ“",
23
  layout="wide"
24
  )
25
 
26
  COLLECTION_NAME = "math_knowledge_base"
27
- DATASET_REPO = "Hebaelsayed/math-ai-documents"
28
 
29
  # ============================================================================
30
- # AVAILABLE EMBEDDING MODELS
31
  # ============================================================================
32
 
33
  EMBEDDING_MODELS = {
@@ -43,7 +43,7 @@ EMBEDDING_MODELS = {
43
  "speed": "Medium",
44
  "quality": "Better"
45
  },
46
- "MPNet (Best Quality, 768D)": {
47
  "name": "sentence-transformers/all-mpnet-base-v2",
48
  "dimensions": 768,
49
  "speed": "Slower",
@@ -57,7 +57,6 @@ EMBEDDING_MODELS = {
57
 
58
  @st.cache_resource
59
  def get_qdrant_client():
60
- """Initialize Qdrant client"""
61
  return QdrantClient(
62
  url=os.getenv("QDRANT_URL"),
63
  api_key=os.getenv("QDRANT_API_KEY")
@@ -65,44 +64,34 @@ def get_qdrant_client():
65
 
66
  @st.cache_resource
67
  def get_claude_client():
68
- """Initialize Claude client"""
69
  return Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
70
 
71
  @st.cache_resource
72
  def get_embedding_model(model_name):
73
- """Load embedding model (cached per model)"""
74
  return SentenceTransformer(model_name)
75
 
76
  # ============================================================================
77
  # HELPER FUNCTIONS
78
  # ============================================================================
79
 
80
- def get_file_hash(file_path):
81
- """Generate unique hash for file to track if already processed"""
82
- return hashlib.md5(file_path.encode()).hexdigest()
83
-
84
  def check_if_processed(qdrant, file_name):
85
- """Check if file already processed in Qdrant"""
86
  try:
87
  results = qdrant.scroll(
88
  collection_name=COLLECTION_NAME,
89
  scroll_filter={
90
- "must": [
91
- {"key": "source_name", "match": {"value": file_name}}
92
- ]
93
  },
94
  limit=1,
95
  with_payload=True,
96
  with_vectors=False
97
  )
98
-
99
  return len(results[0]) > 0 if results and results[0] else False
100
-
101
  except:
102
  return False
103
 
104
  def list_dataset_files(folder_path):
105
- """List PDF files in HF Dataset folder"""
106
  try:
107
  hf_token = os.getenv("HF_TOKEN")
108
  all_files = list_repo_files(
@@ -110,38 +99,27 @@ def list_dataset_files(folder_path):
110
  repo_type="dataset",
111
  token=hf_token
112
  )
113
-
114
- pdf_files = [
115
- f for f in all_files
116
- if f.startswith(folder_path) and f.endswith('.pdf')
117
- ]
118
-
119
- return pdf_files
120
-
121
  except Exception as e:
122
- st.error(f"Error listing files: {e}")
123
  return []
124
 
125
  def download_from_dataset(file_path):
126
- """Download file from HF Dataset"""
127
  try:
128
  hf_token = os.getenv("HF_TOKEN")
129
-
130
- local_path = hf_hub_download(
131
  repo_id=DATASET_REPO,
132
  filename=file_path,
133
  repo_type="dataset",
134
  token=hf_token
135
  )
136
-
137
- return local_path
138
-
139
  except Exception as e:
140
  st.error(f"Download error: {e}")
141
  return None
142
 
143
  def extract_text_from_pdf(pdf_path):
144
- """Extract text from typed PDF"""
145
  try:
146
  with open(pdf_path, 'rb') as file:
147
  reader = PyPDF2.PdfReader(file)
@@ -150,74 +128,60 @@ def extract_text_from_pdf(pdf_path):
150
  text += f"\n\n=== Page {page_num + 1} ===\n\n{page.extract_text()}"
151
  return text
152
  except Exception as e:
153
- st.error(f"Text extraction error: {e}")
154
  return None
155
 
156
  def pdf_to_images(pdf_path):
157
- """Convert PDF to images for OCR"""
158
  try:
159
  images = convert_from_path(pdf_path, dpi=200)
160
  return images
161
  except Exception as e:
162
- st.error(f"PDF to image error: {e}")
163
- st.info("πŸ’‘ This requires poppler-utils. Add 'poppler-utils' to packages.txt file in your Space")
164
  return []
165
 
166
  def resize_image(image, max_size=(2048, 2048)):
167
- """Resize image for Claude Vision"""
168
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
169
  return image
170
 
171
  def image_to_base64(image):
172
- """Convert PIL Image to base64"""
173
  buffered = BytesIO()
174
  image.save(buffered, format="PNG")
175
  return base64.b64encode(buffered.getvalue()).decode()
176
 
177
  def ocr_with_claude(claude_client, image, context=""):
178
- """AI OCR with Claude Vision"""
179
-
180
  resized = resize_image(image.copy())
181
  img_b64 = image_to_base64(resized)
182
 
183
- prompt = f"""Transcribe handwritten math solution.
184
-
185
- STYLE: Italian cursive (connected letters)
186
  LANGUAGE: English
187
-
188
  CONTEXT: {context[:2000] if context else ""}
189
-
190
- INSTRUCTIONS:
191
- 1. Transcribe in English
192
- 2. Use proper math notation: ∫, βˆ‘, √, βˆ‚, etc.
193
- 3. Maintain structure
194
- 4. Mark unclear: [unclear: guess]
195
-
196
  OUTPUT: Transcription only."""
197
 
198
  try:
199
  message = claude_client.messages.create(
200
  model="claude-sonnet-4-20250514",
201
  max_tokens=4000,
202
- messages=[
203
- {
204
- "role": "user",
205
- "content": [
206
- {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img_b64}},
207
- {"type": "text", "text": prompt}
208
- ]
209
- }
210
- ]
211
  )
212
-
213
  return message.content[0].text, message.usage.input_tokens + message.usage.output_tokens
214
-
215
  except Exception as e:
216
  st.error(f"OCR error: {e}")
217
  return None, 0
218
 
219
  def chunk_text(text, chunk_size=150, overlap=30):
220
- """Split text into chunks"""
221
  words = text.split()
222
  chunks = []
223
  for i in range(0, len(words), chunk_size - overlap):
@@ -227,7 +191,7 @@ def chunk_text(text, chunk_size=150, overlap=30):
227
  return chunks
228
 
229
  def get_vector_count(qdrant):
230
- """Get total vectors in database"""
231
  try:
232
  count = 0
233
  offset = None
@@ -250,7 +214,7 @@ def get_vector_count(qdrant):
250
  return 0
251
 
252
  # ============================================================================
253
- # INITIALIZE CLIENTS
254
  # ============================================================================
255
 
256
  try:
@@ -258,357 +222,221 @@ try:
258
  claude = get_claude_client()
259
  st.sidebar.success("βœ… System Ready")
260
  except Exception as e:
261
- st.error(f"❌ Initialization failed: {e}")
262
- st.info("Add these secrets: QDRANT_URL, QDRANT_API_KEY, ANTHROPIC_API_KEY, HF_TOKEN")
263
  st.stop()
264
 
265
  # ============================================================================
266
  # SIDEBAR
267
  # ============================================================================
268
 
269
- st.sidebar.title("πŸŽ“ Math AI System")
270
- st.sidebar.caption("Production Version")
271
 
272
  try:
273
  vector_count = get_vector_count(qdrant)
274
- st.sidebar.metric("Total Vectors", f"{vector_count:,}")
275
-
276
- storage_mb = (vector_count * 384 * 4) / (1024 * 1024)
277
- st.sidebar.metric("Storage", f"{storage_mb:.1f} MB")
278
  except:
279
- st.sidebar.warning("Database unavailable")
280
 
281
  st.sidebar.markdown("---")
282
 
283
  # ============================================================================
284
- # MAIN TABS (Reordered as requested)
285
  # ============================================================================
286
 
287
- tab1, tab2, tab3 = st.tabs([
288
- "πŸ“Š Dataset Manager",
289
- "πŸ” Search & Solve",
290
- "πŸ“ˆ Statistics"
291
- ])
292
 
293
  # ============================================================================
294
- # TAB 1: DATASET MANAGER (Primary Interface)
295
  # ============================================================================
296
 
297
  with tab1:
298
-
299
  st.title("πŸ“Š Dataset Manager")
300
- st.markdown("*Manage all your data sources in one place*")
301
 
302
- # Check HF Token
303
  if not os.getenv("HF_TOKEN"):
304
- st.error("⚠️ Missing HF_TOKEN in secrets!")
305
- st.info("Add it in Settings β†’ Repository Secrets")
306
  st.stop()
307
 
308
  # Collection setup
309
- st.header("πŸ—οΈ Step 1: Database Setup")
310
-
311
- col1, col2 = st.columns([2, 1])
312
 
313
- with col1:
314
- try:
315
- collections = qdrant.get_collections().collections
316
- exists = any(c.name == COLLECTION_NAME for c in collections)
 
 
 
 
317
 
318
- if exists:
319
- st.success(f"βœ… Collection '{COLLECTION_NAME}' exists")
320
- else:
321
- st.warning(f"Collection '{COLLECTION_NAME}' doesn't exist")
322
-
323
- # Show embedding model choice for initial creation
324
- st.subheader("Choose Embedding Model")
325
-
326
- for model_name, specs in EMBEDDING_MODELS.items():
327
- with st.expander(f"{model_name} - {specs['quality']} quality, {specs['speed']} speed"):
328
- st.write(f"**Dimensions:** {specs['dimensions']}")
329
- st.write(f"**Model:** `{specs['name']}`")
330
-
331
- selected_model_key = st.selectbox(
332
- "Select embedding model:",
333
- list(EMBEDDING_MODELS.keys())
334
  )
335
-
336
- if st.button("πŸ—οΈ Create Collection", type="primary"):
337
- dimensions = EMBEDDING_MODELS[selected_model_key]["dimensions"]
338
-
339
- qdrant.create_collection(
340
- collection_name=COLLECTION_NAME,
341
- vectors_config=VectorParams(
342
- size=dimensions,
343
- distance=Distance.COSINE
344
- )
345
- )
346
-
347
- st.success(f"βœ… Created with {dimensions}D vectors!")
348
- st.session_state.embedding_model = EMBEDDING_MODELS[selected_model_key]["name"]
349
- st.rerun()
350
-
351
- except Exception as e:
352
- st.error(f"Error: {e}")
353
-
354
- with col2:
355
- st.info(f"""
356
- **Dataset:**
357
- `{DATASET_REPO}`
358
-
359
- **Collection:**
360
- `{COLLECTION_NAME}`
361
- """)
362
 
363
  st.markdown("---")
364
 
365
- # Processing options
366
- st.header("βš™οΈ Step 2: Processing Configuration")
367
 
368
- col1, col2, col3 = st.columns(3)
369
 
370
  with col1:
371
- st.subheader("Chunking Strategy")
372
- chunk_size = st.slider("Chunk size (words):", 50, 500, 150)
373
- chunk_overlap = st.slider("Overlap (words):", 0, 100, 30)
374
- st.caption(f"Overlap: {(chunk_overlap/chunk_size*100):.0f}%")
375
 
376
  with col2:
377
- st.subheader("Embedding Model")
378
- # Get current model from collection or use default
379
  current_model = st.session_state.get('embedding_model', EMBEDDING_MODELS["MiniLM-L6 (Fast, 384D)"]["name"])
380
-
381
- # Find which key this model belongs to
382
- current_model_key = "MiniLM-L6 (Fast, 384D)"
383
- for key, specs in EMBEDDING_MODELS.items():
384
- if specs["name"] == current_model:
385
- current_model_key = key
386
- break
387
-
388
- st.info(f"**Active:** {current_model_key}")
389
- st.caption(f"Model: `{current_model}`")
390
-
391
- with col3:
392
- st.subheader("OCR Settings")
393
- use_context_for_ocr = st.checkbox("Use book context", value=True, help="Better accuracy, higher cost")
394
- st.caption("Context helps Claude understand symbols")
395
 
396
  st.markdown("---")
397
 
398
  # Data sources
399
- st.header("πŸ“ Step 3: Data Sources")
400
-
401
- source_tabs = st.tabs([
402
- "πŸ“‚ Your Dataset Files",
403
- "🌐 Public Datasets (GSM8K, MATH, etc.)"
404
- ])
405
 
406
- # ========================================================================
407
- # SOURCE 1: HF Dataset Files
408
- # ========================================================================
409
 
410
  with source_tabs[0]:
411
-
412
- st.subheader("Files from Your HF Dataset")
413
-
414
  folder_type = st.radio(
415
- "Select folder:",
416
- ["πŸ“š Books (Typed PDFs)", "πŸ“ Exams (Typed PDFs)", "πŸ–ŠοΈ Answers (Handwritten - needs OCR)"],
417
  horizontal=True
418
  )
419
 
420
- # Determine folder path
421
  if "Books" in folder_type:
422
- folder_path = "books/"
423
- doc_type = "book"
424
  elif "Exams" in folder_type:
425
- folder_path = "exams/"
426
- doc_type = "exam"
427
  else:
428
- folder_path = "answers/"
429
- doc_type = "answer_handwritten"
430
 
431
- # List files
432
- if st.button(f"πŸ” Scan {folder_path} folder"):
433
- with st.spinner("Scanning dataset..."):
434
  files = list_dataset_files(folder_path)
435
 
436
  if files:
437
- # Check processing status for each file
438
  file_status = []
439
  for file in files:
440
- file_name = file.split('/')[-1]
441
- is_processed = check_if_processed(qdrant, file_name)
442
- file_status.append({
443
- "file": file,
444
- "name": file_name,
445
- "processed": is_processed
446
- })
447
 
448
  st.session_state.current_files = file_status
449
  st.session_state.current_folder = folder_path
450
  st.session_state.current_doc_type = doc_type
451
  else:
452
- st.warning(f"No files found in {folder_path}")
453
 
454
- # Display files with status
455
  if 'current_files' in st.session_state and st.session_state.current_folder == folder_path:
456
 
457
- st.write(f"**Found {len(st.session_state.current_files)} files:**")
458
-
459
- # Summary
460
  processed_count = sum(1 for f in st.session_state.current_files if f['processed'])
461
  pending_count = len(st.session_state.current_files) - processed_count
462
 
463
  col1, col2, col3 = st.columns(3)
464
- with col1:
465
- st.metric("Total", len(st.session_state.current_files))
466
- with col2:
467
- st.metric("βœ… Processed", processed_count)
468
- with col3:
469
- st.metric("⏳ Pending", pending_count)
470
 
471
- # File list with checkboxes
472
- st.subheader("Select files to process:")
473
 
474
  selected_files = []
475
-
476
  for file_info in st.session_state.current_files:
477
- col1, col2 = st.columns([3, 1])
478
-
479
- with col1:
480
- # Only allow selection if not processed
481
- if file_info['processed']:
482
- st.checkbox(
483
- f"βœ… {file_info['name']} (Already processed)",
484
- value=False,
485
- disabled=True,
486
- key=f"file_{file_info['name']}"
487
- )
488
- else:
489
- if st.checkbox(
490
- f"⏳ {file_info['name']}",
491
- value=True, # Auto-select pending files
492
- key=f"file_{file_info['name']}"
493
- ):
494
- selected_files.append(file_info)
495
-
496
- with col2:
497
- if file_info['processed']:
498
- st.caption("Skip")
499
- else:
500
- st.caption("Ready")
501
 
502
- # Process button
503
  if selected_files:
504
-
505
  st.markdown("---")
506
- st.write(f"**Ready to process {len(selected_files)} file(s)**")
507
 
508
- # Show cost estimate for OCR
509
  if doc_type == "answer_handwritten":
510
- est_pages = len(selected_files) * 5 # Assume 5 pages per PDF
511
- est_cost = est_pages * 0.08
512
- st.warning(f"⚠️ OCR Cost Estimate: ~${est_cost:.2f} ({est_pages} pages Γ— ~$0.08/page)")
513
 
514
- if st.button(f"πŸš€ PROCESS SELECTED FILES", type="primary"):
515
 
516
- # Load embedding model
517
  embedder = get_embedding_model(current_model)
518
 
519
- # Get context if needed
520
  context_books = ""
521
- if doc_type == "answer_handwritten" and use_context_for_ocr:
522
  try:
523
- book_samples = qdrant.scroll(
524
  collection_name=COLLECTION_NAME,
525
  limit=10,
526
  with_payload=True,
527
  with_vectors=False,
528
  scroll_filter={"must": [{"key": "source_type", "match": {"value": "book"}}]}
529
  )
530
-
531
- if book_samples and book_samples[0]:
532
- context_books = "\n".join([p.payload['content'] for p in book_samples[0][:5]])
533
- st.info("βœ… Using book context for OCR")
534
  except:
535
- st.caption("No books in database - OCR will work but may be less accurate")
536
 
537
- # Process each selected file
538
  total_tokens = 0
539
  total_vectors = 0
540
 
541
  for file_info in selected_files:
542
-
543
  with st.expander(f"Processing {file_info['name']}", expanded=True):
544
-
545
  try:
546
- # Download
547
  st.write("πŸ“₯ Downloading...")
548
  local_path = download_from_dataset(file_info['file'])
549
 
550
  if not local_path:
551
- st.error("Download failed")
552
  continue
553
 
554
- # Extract or OCR
555
  if doc_type == "answer_handwritten":
556
- # OCR path
557
- st.write("πŸ–ΌοΈ Converting to images...")
558
  images = pdf_to_images(local_path)
559
 
560
  if not images:
561
- st.error("Conversion failed - poppler-utils not installed?")
562
  continue
563
 
564
  st.write(f"βœ… {len(images)} pages")
565
 
566
- # OCR each page
567
- transcribed_pages = []
568
- page_tokens = 0
569
 
570
- for page_num, image in enumerate(images, 1):
571
- st.write(f"πŸ€– OCR page {page_num}/{len(images)}...")
572
-
573
- transcription, tokens = ocr_with_claude(
574
- claude,
575
- image,
576
- context=context_books
577
- )
578
-
579
- if transcription:
580
- transcribed_pages.append(f"\n=== Page {page_num} ===\n\n{transcription}")
581
- page_tokens += tokens
582
 
583
- if not transcribed_pages:
584
  st.error("OCR failed")
585
  continue
586
 
587
- text = "\n\n".join(transcribed_pages)
588
- total_tokens += page_tokens
589
-
590
- st.success(f"βœ… Transcribed {len(text):,} chars (${page_tokens * 0.000003:.3f})")
591
 
592
  else:
593
- # Text extraction
594
- st.write("πŸ“– Extracting text...")
595
  text = extract_text_from_pdf(local_path)
596
-
597
  if not text:
598
- st.error("Text extraction failed")
599
  continue
600
-
601
  st.write(f"βœ… {len(text):,} chars")
602
 
603
- # Chunk
604
  chunks = chunk_text(text, chunk_size, chunk_overlap)
605
  st.write(f"βœ‚οΈ {len(chunks)} chunks")
606
 
607
- # Embed
608
  st.write("πŸ”’ Embedding...")
609
  embeddings = embedder.encode(chunks, show_progress_bar=False)
610
 
611
- # Upload
612
  points = []
613
  for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
614
  points.append(PointStruct(
@@ -618,92 +446,60 @@ with tab1:
618
  "content": chunk,
619
  "source_name": file_info['name'],
620
  "source_type": doc_type,
621
- "chunk_index": i,
622
- "embedding_model": current_model
623
  }
624
  ))
625
 
626
  qdrant.upsert(collection_name=COLLECTION_NAME, points=points)
627
  total_vectors += len(points)
628
-
629
- st.success(f"βœ… Uploaded {len(points)} vectors!")
630
 
631
  except Exception as e:
632
  st.error(f"Error: {e}")
633
 
634
- # Summary
635
  st.balloons()
636
- st.success(f"""
637
- πŸŽ‰ Processing Complete!
638
-
639
- - Files processed: {len(selected_files)}
640
- - Vectors added: {total_vectors:,}
641
- - OCR tokens used: {total_tokens:,}
642
- - OCR cost: ${total_tokens * 0.000003:.2f}
643
- """)
644
-
645
- # Clear selection
646
  st.session_state.pop('current_files', None)
647
  st.rerun()
648
 
649
- # ========================================================================
650
- # SOURCE 2: Public Datasets
651
- # ========================================================================
652
-
653
  with source_tabs[1]:
654
-
655
- st.subheader("Public Math Datasets")
656
-
657
  dataset_choice = st.selectbox(
658
- "Select dataset:",
659
- [
660
- "GSM8K - Grade School Math (8.5K problems)",
661
- "MATH - Competition Math (12.5K problems)",
662
- "MathQA - Math Word Problems (37K problems)"
663
- ]
664
  )
665
 
666
- sample_size = st.slider("Number of samples:", 10, 2000, 100)
667
 
668
- # Check if already loaded
669
  dataset_name = dataset_choice.split(" - ")[0]
670
  already_loaded = check_if_processed(qdrant, dataset_name)
671
 
672
  if already_loaded:
673
- st.success(f"βœ… {dataset_name} already loaded!")
674
- st.info("Vectors from this dataset are already in your database.")
675
  else:
676
- if st.button(f"πŸ“₯ Load {dataset_name}", type="primary"):
677
-
678
  try:
679
  from datasets import load_dataset
680
 
681
  embedder = get_embedding_model(current_model)
682
 
683
- with st.spinner(f"Loading {dataset_name}..."):
684
-
685
  if "GSM8K" in dataset_choice:
686
  dataset = load_dataset("openai/gsm8k", "main", split="train", trust_remote_code=True)
687
  texts = [f"Problem: {dataset[i]['question']}\n\nSolution: {dataset[i]['answer']}"
688
  for i in range(min(sample_size, len(dataset)))]
689
-
690
  elif "MATH" in dataset_choice:
691
  dataset = load_dataset("lighteval/MATH", split="train", trust_remote_code=True)
692
  texts = [f"Problem: {dataset[i].get('problem', '')}\n\nSolution: {dataset[i].get('solution', '')}"
693
  for i in range(min(sample_size, len(dataset)))]
694
-
695
- else: # MathQA
696
  dataset = load_dataset("allenai/math_qa", split="train", trust_remote_code=True)
697
  texts = [f"Problem: {dataset[i]['Problem']}\n\nAnswer: {dataset[i]['correct']}"
698
  for i in range(min(sample_size, len(dataset)))]
699
 
700
- st.write(f"βœ… Loaded {len(texts)} problems")
701
 
702
- # Embed
703
- st.write("πŸ”’ Embedding...")
704
  embeddings = embedder.encode(texts, show_progress_bar=True)
705
 
706
- # Upload
707
  points = []
708
  for i, (text, emb) in enumerate(zip(texts, embeddings)):
709
  points.append(PointStruct(
@@ -713,13 +509,12 @@ with tab1:
713
  "content": text[:2000],
714
  "source_name": dataset_name,
715
  "source_type": "public_dataset",
716
- "index": i,
717
- "embedding_model": current_model
718
  }
719
  ))
720
 
721
  qdrant.upsert(collection_name=COLLECTION_NAME, points=points)
722
- st.success(f"βœ… Uploaded {len(points)} vectors!")
723
  st.balloons()
724
 
725
  except Exception as e:
@@ -730,30 +525,20 @@ with tab1:
730
  # ============================================================================
731
 
732
  with tab2:
733
-
734
  st.title("πŸ” Search & Solve")
735
 
736
  problem = st.text_area(
737
- "Enter math problem:",
738
- placeholder="Find the gradient of the loss function L(w) = (1/2)||Xw - y||Β²",
739
  height=150
740
  )
741
 
742
  col1, col2 = st.columns(2)
743
-
744
- with col1:
745
- top_k = st.slider("Retrieve top:", 3, 20, 5)
746
-
747
- with col2:
748
- detail = st.select_slider(
749
- "Detail level:",
750
- ["Concise", "Standard", "Detailed", "Exhaustive"],
751
- value="Detailed"
752
- )
753
 
754
  if st.button("πŸš€ SOLVE", type="primary") and problem:
755
 
756
- # Get embedding model
757
  current_model = st.session_state.get('embedding_model', EMBEDDING_MODELS["MiniLM-L6 (Fast, 384D)"]["name"])
758
  embedder = get_embedding_model(current_model)
759
 
@@ -764,7 +549,7 @@ with tab2:
764
  results = qdrant.search(
765
  collection_name=COLLECTION_NAME,
766
  query_vector=query_emb.tolist(),
767
- limit=top_k
768
  )
769
  except:
770
  results = []
@@ -776,31 +561,31 @@ with tab2:
776
 
777
  with st.expander("πŸ“š References"):
778
  for i, r in enumerate(results, 1):
779
- st.markdown(f"**{i}.** ({r.score*100:.0f}% match)")
780
  st.text(r.payload['content'][:200] + "...")
781
  st.caption(f"Source: {r.payload.get('source_name')}")
782
 
783
- with st.spinner("Generating solution..."):
784
 
785
  context = "\n\n".join([r.payload['content'] for r in results])
786
 
787
- prompt = f"""Solve this problem using references.
788
 
789
  PROBLEM: {problem}
790
 
791
  REFERENCES: {context}
792
 
793
- DETAIL: {detail}
794
 
795
  FORMAT:
796
  ## SOLUTION
797
  [Steps]
798
 
799
  ## REASONING
800
- [Why this approach]
801
 
802
  ## REFERENCES
803
- [Which sources helped]"""
804
 
805
  try:
806
  message = claude.messages.create(
@@ -813,7 +598,7 @@ FORMAT:
813
  st.markdown(message.content[0].text)
814
 
815
  st.download_button(
816
- "πŸ“₯ Download Solution",
817
  message.content[0].text,
818
  file_name=f"solution_{int(time.time())}.md"
819
  )
@@ -826,8 +611,7 @@ FORMAT:
826
  # ============================================================================
827
 
828
  with tab3:
829
-
830
- st.title("πŸ“ˆ Statistics & Analytics")
831
 
832
  try:
833
  sample = qdrant.scroll(
@@ -847,20 +631,14 @@ with tab3:
847
  sources.add(point.payload.get('source_name', 'Unknown'))
848
 
849
  col1, col2, col3 = st.columns(3)
 
 
 
850
 
851
- with col1:
852
- st.metric("Total Vectors", get_vector_count(qdrant))
853
-
854
- with col2:
855
- st.metric("Unique Sources", len(sources))
856
-
857
- with col3:
858
- st.metric("Document Types", len(types))
859
-
860
- st.subheader("Distribution by Type")
861
  for doc_type, count in sorted(types.items()):
862
  pct = count / sum(types.values()) * 100
863
- st.progress(count / sum(types.values()), text=f"{doc_type}: {count} ({pct:.1f}%)")
864
 
865
  st.subheader("All Sources")
866
  for src in sorted(sources):
@@ -869,5 +647,4 @@ with tab3:
869
  except Exception as e:
870
  st.error(f"Error: {e}")
871
 
872
- st.sidebar.markdown("---")
873
- st.sidebar.caption("v2.0 - Production")
 
14
  from huggingface_hub import hf_hub_download, list_repo_files
15
 
16
  # ============================================================================
17
+ # PRODUCTION MATH AI SYSTEM
18
  # ============================================================================
19
 
20
  st.set_page_config(
21
+ page_title="Math AI System",
22
  page_icon="πŸŽ“",
23
  layout="wide"
24
  )
25
 
26
  COLLECTION_NAME = "math_knowledge_base"
27
+ DATASET_REPO = "yourusername/math-ai-documents" # ← CHANGE THIS!
28
 
29
  # ============================================================================
30
+ # EMBEDDING MODELS
31
  # ============================================================================
32
 
33
  EMBEDDING_MODELS = {
 
43
  "speed": "Medium",
44
  "quality": "Better"
45
  },
46
+ "MPNet (Best, 768D)": {
47
  "name": "sentence-transformers/all-mpnet-base-v2",
48
  "dimensions": 768,
49
  "speed": "Slower",
 
57
 
58
  @st.cache_resource
59
  def get_qdrant_client():
 
60
  return QdrantClient(
61
  url=os.getenv("QDRANT_URL"),
62
  api_key=os.getenv("QDRANT_API_KEY")
 
64
 
65
  @st.cache_resource
66
  def get_claude_client():
 
67
  return Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
68
 
69
  @st.cache_resource
70
  def get_embedding_model(model_name):
 
71
  return SentenceTransformer(model_name)
72
 
73
  # ============================================================================
74
  # HELPER FUNCTIONS
75
  # ============================================================================
76
 
 
 
 
 
77
  def check_if_processed(qdrant, file_name):
78
+ """Check if file already in database"""
79
  try:
80
  results = qdrant.scroll(
81
  collection_name=COLLECTION_NAME,
82
  scroll_filter={
83
+ "must": [{"key": "source_name", "match": {"value": file_name}}]
 
 
84
  },
85
  limit=1,
86
  with_payload=True,
87
  with_vectors=False
88
  )
 
89
  return len(results[0]) > 0 if results and results[0] else False
 
90
  except:
91
  return False
92
 
93
  def list_dataset_files(folder_path):
94
+ """List PDFs in HF Dataset folder"""
95
  try:
96
  hf_token = os.getenv("HF_TOKEN")
97
  all_files = list_repo_files(
 
99
  repo_type="dataset",
100
  token=hf_token
101
  )
102
+ return [f for f in all_files if f.startswith(folder_path) and f.endswith('.pdf')]
 
 
 
 
 
 
 
103
  except Exception as e:
104
+ st.error(f"Error listing: {e}")
105
  return []
106
 
107
  def download_from_dataset(file_path):
108
+ """Download from HF Dataset"""
109
  try:
110
  hf_token = os.getenv("HF_TOKEN")
111
+ return hf_hub_download(
 
112
  repo_id=DATASET_REPO,
113
  filename=file_path,
114
  repo_type="dataset",
115
  token=hf_token
116
  )
 
 
 
117
  except Exception as e:
118
  st.error(f"Download error: {e}")
119
  return None
120
 
121
  def extract_text_from_pdf(pdf_path):
122
+ """Extract text from PDF"""
123
  try:
124
  with open(pdf_path, 'rb') as file:
125
  reader = PyPDF2.PdfReader(file)
 
128
  text += f"\n\n=== Page {page_num + 1} ===\n\n{page.extract_text()}"
129
  return text
130
  except Exception as e:
131
+ st.error(f"Extraction error: {e}")
132
  return None
133
 
134
  def pdf_to_images(pdf_path):
135
+ """Convert PDF to images"""
136
  try:
137
  images = convert_from_path(pdf_path, dpi=200)
138
  return images
139
  except Exception as e:
140
+ st.error(f"Conversion error: {e}")
141
+ st.info("πŸ’‘ Add 'poppler-utils' to packages.txt")
142
  return []
143
 
144
  def resize_image(image, max_size=(2048, 2048)):
145
+ """Resize for Claude"""
146
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
147
  return image
148
 
149
  def image_to_base64(image):
150
+ """Convert to base64"""
151
  buffered = BytesIO()
152
  image.save(buffered, format="PNG")
153
  return base64.b64encode(buffered.getvalue()).decode()
154
 
155
  def ocr_with_claude(claude_client, image, context=""):
156
+ """AI OCR"""
 
157
  resized = resize_image(image.copy())
158
  img_b64 = image_to_base64(resized)
159
 
160
+ prompt = f"""Transcribe handwritten math.
161
+ STYLE: Italian cursive
 
162
  LANGUAGE: English
 
163
  CONTEXT: {context[:2000] if context else ""}
 
 
 
 
 
 
 
164
  OUTPUT: Transcription only."""
165
 
166
  try:
167
  message = claude_client.messages.create(
168
  model="claude-sonnet-4-20250514",
169
  max_tokens=4000,
170
+ messages=[{
171
+ "role": "user",
172
+ "content": [
173
+ {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img_b64}},
174
+ {"type": "text", "text": prompt}
175
+ ]
176
+ }]
 
 
177
  )
 
178
  return message.content[0].text, message.usage.input_tokens + message.usage.output_tokens
 
179
  except Exception as e:
180
  st.error(f"OCR error: {e}")
181
  return None, 0
182
 
183
  def chunk_text(text, chunk_size=150, overlap=30):
184
+ """Split into chunks"""
185
  words = text.split()
186
  chunks = []
187
  for i in range(0, len(words), chunk_size - overlap):
 
191
  return chunks
192
 
193
  def get_vector_count(qdrant):
194
+ """Get total vectors"""
195
  try:
196
  count = 0
197
  offset = None
 
214
  return 0
215
 
216
  # ============================================================================
217
+ # INITIALIZE
218
  # ============================================================================
219
 
220
  try:
 
222
  claude = get_claude_client()
223
  st.sidebar.success("βœ… System Ready")
224
  except Exception as e:
225
+ st.error(f"❌ Init failed: {e}")
226
+ st.info("Add secrets: QDRANT_URL, QDRANT_API_KEY, ANTHROPIC_API_KEY, HF_TOKEN")
227
  st.stop()
228
 
229
  # ============================================================================
230
  # SIDEBAR
231
  # ============================================================================
232
 
233
+ st.sidebar.title("πŸŽ“ Math AI")
234
+ st.sidebar.caption("Production v2.0")
235
 
236
  try:
237
  vector_count = get_vector_count(qdrant)
238
+ st.sidebar.metric("Vectors", f"{vector_count:,}")
239
+ st.sidebar.metric("Storage", f"{(vector_count * 384 * 4) / (1024 * 1024):.1f} MB")
 
 
240
  except:
241
+ st.sidebar.warning("DB unavailable")
242
 
243
  st.sidebar.markdown("---")
244
 
245
  # ============================================================================
246
+ # TABS
247
  # ============================================================================
248
 
249
+ tab1, tab2, tab3 = st.tabs(["πŸ“Š Dataset Manager", "πŸ” Search & Solve", "πŸ“ˆ Statistics"])
 
 
 
 
250
 
251
  # ============================================================================
252
+ # TAB 1: DATASET MANAGER
253
  # ============================================================================
254
 
255
  with tab1:
 
256
  st.title("πŸ“Š Dataset Manager")
 
257
 
 
258
  if not os.getenv("HF_TOKEN"):
259
+ st.error("⚠️ Add HF_TOKEN in Settings β†’ Secrets")
 
260
  st.stop()
261
 
262
  # Collection setup
263
+ st.header("πŸ—οΈ Database Setup")
 
 
264
 
265
+ try:
266
+ collections = qdrant.get_collections().collections
267
+ exists = any(c.name == COLLECTION_NAME for c in collections)
268
+
269
+ if exists:
270
+ st.success(f"βœ… Collection '{COLLECTION_NAME}' ready")
271
+ else:
272
+ st.warning("Collection doesn't exist")
273
 
274
+ selected_model = st.selectbox("Embedding model:", list(EMBEDDING_MODELS.keys()))
275
+
276
+ if st.button("πŸ—οΈ Create Collection"):
277
+ dimensions = EMBEDDING_MODELS[selected_model]["dimensions"]
278
+ qdrant.create_collection(
279
+ collection_name=COLLECTION_NAME,
280
+ vectors_config=VectorParams(size=dimensions, distance=Distance.COSINE)
 
 
 
 
 
 
 
 
 
281
  )
282
+ st.success("Created!")
283
+ st.session_state.embedding_model = EMBEDDING_MODELS[selected_model]["name"]
284
+ st.rerun()
285
+ except Exception as e:
286
+ st.error(f"Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  st.markdown("---")
289
 
290
+ # Processing config
291
+ st.header("βš™οΈ Configuration")
292
 
293
+ col1, col2 = st.columns(2)
294
 
295
  with col1:
296
+ chunk_size = st.slider("Chunk size:", 50, 500, 150)
297
+ chunk_overlap = st.slider("Overlap:", 0, 100, 30)
 
 
298
 
299
  with col2:
 
 
300
  current_model = st.session_state.get('embedding_model', EMBEDDING_MODELS["MiniLM-L6 (Fast, 384D)"]["name"])
301
+ st.info(f"**Active Model:**\n{current_model}")
302
+ use_context = st.checkbox("Use context for OCR", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  st.markdown("---")
305
 
306
  # Data sources
307
+ st.header("πŸ“ Data Sources")
 
 
 
 
 
308
 
309
+ source_tabs = st.tabs(["πŸ“‚ Your Files", "🌐 Public Datasets"])
 
 
310
 
311
  with source_tabs[0]:
 
 
 
312
  folder_type = st.radio(
313
+ "Folder:",
314
+ ["πŸ“š Books", "πŸ“ Exams", "πŸ–ŠοΈ Answers (OCR)"],
315
  horizontal=True
316
  )
317
 
 
318
  if "Books" in folder_type:
319
+ folder_path, doc_type = "books/", "book"
 
320
  elif "Exams" in folder_type:
321
+ folder_path, doc_type = "exams/", "exam"
 
322
  else:
323
+ folder_path, doc_type = "answers/", "answer_handwritten"
 
324
 
325
+ if st.button(f"πŸ” Scan {folder_path}"):
326
+ with st.spinner("Scanning..."):
 
327
  files = list_dataset_files(folder_path)
328
 
329
  if files:
 
330
  file_status = []
331
  for file in files:
332
+ name = file.split('/')[-1]
333
+ processed = check_if_processed(qdrant, name)
334
+ file_status.append({"file": file, "name": name, "processed": processed})
 
 
 
 
335
 
336
  st.session_state.current_files = file_status
337
  st.session_state.current_folder = folder_path
338
  st.session_state.current_doc_type = doc_type
339
  else:
340
+ st.warning("No files found")
341
 
 
342
  if 'current_files' in st.session_state and st.session_state.current_folder == folder_path:
343
 
 
 
 
344
  processed_count = sum(1 for f in st.session_state.current_files if f['processed'])
345
  pending_count = len(st.session_state.current_files) - processed_count
346
 
347
  col1, col2, col3 = st.columns(3)
348
+ col1.metric("Total", len(st.session_state.current_files))
349
+ col2.metric("βœ… Done", processed_count)
350
+ col3.metric("⏳ Pending", pending_count)
 
 
 
351
 
352
+ st.subheader("Select files:")
 
353
 
354
  selected_files = []
 
355
  for file_info in st.session_state.current_files:
356
+ if file_info['processed']:
357
+ st.checkbox(f"βœ… {file_info['name']}", value=False, disabled=True, key=f"f_{file_info['name']}")
358
+ else:
359
+ if st.checkbox(f"⏳ {file_info['name']}", value=True, key=f"f_{file_info['name']}"):
360
+ selected_files.append(file_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
 
362
  if selected_files:
 
363
  st.markdown("---")
 
364
 
 
365
  if doc_type == "answer_handwritten":
366
+ est_cost = len(selected_files) * 5 * 0.08
367
+ st.warning(f"⚠️ OCR Cost: ~${est_cost:.2f}")
 
368
 
369
+ if st.button("πŸš€ PROCESS SELECTED", type="primary"):
370
 
 
371
  embedder = get_embedding_model(current_model)
372
 
 
373
  context_books = ""
374
+ if doc_type == "answer_handwritten" and use_context:
375
  try:
376
+ samples = qdrant.scroll(
377
  collection_name=COLLECTION_NAME,
378
  limit=10,
379
  with_payload=True,
380
  with_vectors=False,
381
  scroll_filter={"must": [{"key": "source_type", "match": {"value": "book"}}]}
382
  )
383
+ if samples and samples[0]:
384
+ context_books = "\n".join([p.payload['content'] for p in samples[0][:5]])
 
 
385
  except:
386
+ pass
387
 
 
388
  total_tokens = 0
389
  total_vectors = 0
390
 
391
  for file_info in selected_files:
 
392
  with st.expander(f"Processing {file_info['name']}", expanded=True):
 
393
  try:
 
394
  st.write("πŸ“₯ Downloading...")
395
  local_path = download_from_dataset(file_info['file'])
396
 
397
  if not local_path:
 
398
  continue
399
 
 
400
  if doc_type == "answer_handwritten":
401
+ st.write("πŸ–ΌοΈ Converting...")
 
402
  images = pdf_to_images(local_path)
403
 
404
  if not images:
 
405
  continue
406
 
407
  st.write(f"βœ… {len(images)} pages")
408
 
409
+ transcribed = []
410
+ tokens = 0
 
411
 
412
+ for i, img in enumerate(images, 1):
413
+ st.write(f"πŸ€– OCR {i}/{len(images)}...")
414
+ trans, tok = ocr_with_claude(claude, img, context_books)
415
+ if trans:
416
+ transcribed.append(f"\n=== Page {i} ===\n\n{trans}")
417
+ tokens += tok
 
 
 
 
 
 
418
 
419
+ if not transcribed:
420
  st.error("OCR failed")
421
  continue
422
 
423
+ text = "\n\n".join(transcribed)
424
+ total_tokens += tokens
425
+ st.success(f"βœ… {len(text):,} chars (${tokens * 0.000003:.3f})")
 
426
 
427
  else:
428
+ st.write("πŸ“– Extracting...")
 
429
  text = extract_text_from_pdf(local_path)
 
430
  if not text:
 
431
  continue
 
432
  st.write(f"βœ… {len(text):,} chars")
433
 
 
434
  chunks = chunk_text(text, chunk_size, chunk_overlap)
435
  st.write(f"βœ‚οΈ {len(chunks)} chunks")
436
 
 
437
  st.write("πŸ”’ Embedding...")
438
  embeddings = embedder.encode(chunks, show_progress_bar=False)
439
 
 
440
  points = []
441
  for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
442
  points.append(PointStruct(
 
446
  "content": chunk,
447
  "source_name": file_info['name'],
448
  "source_type": doc_type,
449
+ "chunk_index": i
 
450
  }
451
  ))
452
 
453
  qdrant.upsert(collection_name=COLLECTION_NAME, points=points)
454
  total_vectors += len(points)
455
+ st.success(f"βœ… {len(points)} vectors!")
 
456
 
457
  except Exception as e:
458
  st.error(f"Error: {e}")
459
 
 
460
  st.balloons()
461
+ st.success(f"Done! {total_vectors:,} vectors | ${total_tokens * 0.000003:.2f}")
 
 
 
 
 
 
 
 
 
462
  st.session_state.pop('current_files', None)
463
  st.rerun()
464
 
 
 
 
 
465
  with source_tabs[1]:
 
 
 
466
  dataset_choice = st.selectbox(
467
+ "Dataset:",
468
+ ["GSM8K - Grade School Math", "MATH - Competition Math", "MathQA - Word Problems"]
 
 
 
 
469
  )
470
 
471
+ sample_size = st.slider("Samples:", 10, 2000, 100)
472
 
 
473
  dataset_name = dataset_choice.split(" - ")[0]
474
  already_loaded = check_if_processed(qdrant, dataset_name)
475
 
476
  if already_loaded:
477
+ st.success(f"βœ… {dataset_name} loaded!")
 
478
  else:
479
+ if st.button(f"πŸ“₯ Load {dataset_name}"):
 
480
  try:
481
  from datasets import load_dataset
482
 
483
  embedder = get_embedding_model(current_model)
484
 
485
+ with st.spinner("Loading..."):
 
486
  if "GSM8K" in dataset_choice:
487
  dataset = load_dataset("openai/gsm8k", "main", split="train", trust_remote_code=True)
488
  texts = [f"Problem: {dataset[i]['question']}\n\nSolution: {dataset[i]['answer']}"
489
  for i in range(min(sample_size, len(dataset)))]
 
490
  elif "MATH" in dataset_choice:
491
  dataset = load_dataset("lighteval/MATH", split="train", trust_remote_code=True)
492
  texts = [f"Problem: {dataset[i].get('problem', '')}\n\nSolution: {dataset[i].get('solution', '')}"
493
  for i in range(min(sample_size, len(dataset)))]
494
+ else:
 
495
  dataset = load_dataset("allenai/math_qa", split="train", trust_remote_code=True)
496
  texts = [f"Problem: {dataset[i]['Problem']}\n\nAnswer: {dataset[i]['correct']}"
497
  for i in range(min(sample_size, len(dataset)))]
498
 
499
+ st.write(f"βœ… {len(texts)} problems")
500
 
 
 
501
  embeddings = embedder.encode(texts, show_progress_bar=True)
502
 
 
503
  points = []
504
  for i, (text, emb) in enumerate(zip(texts, embeddings)):
505
  points.append(PointStruct(
 
509
  "content": text[:2000],
510
  "source_name": dataset_name,
511
  "source_type": "public_dataset",
512
+ "index": i
 
513
  }
514
  ))
515
 
516
  qdrant.upsert(collection_name=COLLECTION_NAME, points=points)
517
+ st.success(f"βœ… {len(points)} vectors!")
518
  st.balloons()
519
 
520
  except Exception as e:
 
525
  # ============================================================================
526
 
527
  with tab2:
 
528
  st.title("πŸ” Search & Solve")
529
 
530
  problem = st.text_area(
531
+ "Problem:",
532
+ placeholder="Find gradient of L(w) = (1/2)||Xw - y||Β²",
533
  height=150
534
  )
535
 
536
  col1, col2 = st.columns(2)
537
+ col1.slider("Retrieve:", 3, 20, 5, key="top_k")
538
+ col2.select_slider("Detail:", ["Concise", "Standard", "Detailed", "Exhaustive"], value="Detailed", key="detail")
 
 
 
 
 
 
 
 
539
 
540
  if st.button("πŸš€ SOLVE", type="primary") and problem:
541
 
 
542
  current_model = st.session_state.get('embedding_model', EMBEDDING_MODELS["MiniLM-L6 (Fast, 384D)"]["name"])
543
  embedder = get_embedding_model(current_model)
544
 
 
549
  results = qdrant.search(
550
  collection_name=COLLECTION_NAME,
551
  query_vector=query_emb.tolist(),
552
+ limit=st.session_state.top_k
553
  )
554
  except:
555
  results = []
 
561
 
562
  with st.expander("πŸ“š References"):
563
  for i, r in enumerate(results, 1):
564
+ st.markdown(f"**{i}.** ({r.score*100:.0f}%)")
565
  st.text(r.payload['content'][:200] + "...")
566
  st.caption(f"Source: {r.payload.get('source_name')}")
567
 
568
+ with st.spinner("Generating..."):
569
 
570
  context = "\n\n".join([r.payload['content'] for r in results])
571
 
572
+ prompt = f"""Solve using references.
573
 
574
  PROBLEM: {problem}
575
 
576
  REFERENCES: {context}
577
 
578
+ DETAIL: {st.session_state.detail}
579
 
580
  FORMAT:
581
  ## SOLUTION
582
  [Steps]
583
 
584
  ## REASONING
585
+ [Why]
586
 
587
  ## REFERENCES
588
+ [Sources]"""
589
 
590
  try:
591
  message = claude.messages.create(
 
598
  st.markdown(message.content[0].text)
599
 
600
  st.download_button(
601
+ "πŸ“₯ Download",
602
  message.content[0].text,
603
  file_name=f"solution_{int(time.time())}.md"
604
  )
 
611
  # ============================================================================
612
 
613
  with tab3:
614
+ st.title("πŸ“ˆ Statistics")
 
615
 
616
  try:
617
  sample = qdrant.scroll(
 
631
  sources.add(point.payload.get('source_name', 'Unknown'))
632
 
633
  col1, col2, col3 = st.columns(3)
634
+ col1.metric("Vectors", get_vector_count(qdrant))
635
+ col2.metric("Sources", len(sources))
636
+ col3.metric("Types", len(types))
637
 
638
+ st.subheader("Distribution")
 
 
 
 
 
 
 
 
 
639
  for doc_type, count in sorted(types.items()):
640
  pct = count / sum(types.values()) * 100
641
+ st.progress(count / sum(types.values()), text=f"{doc_type}: {count} ({pct:.0f}%)")
642
 
643
  st.subheader("All Sources")
644
  for src in sorted(sources):
 
647
  except Exception as e:
648
  st.error(f"Error: {e}")
649
 
650
+ st.sidebar.caption("v2.0")