Hebaelsayed commited on
Commit
989d169
Β·
verified Β·
1 Parent(s): f75aa89

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +418 -364
src/streamlit_app.py CHANGED
@@ -6,7 +6,7 @@ from qdrant_client.models import Distance, VectorParams, PointStruct
6
  from sentence_transformers import SentenceTransformer
7
 
8
  # ============================================================================
9
- # PHASE 2: DATABASE + TWO UPLOAD METHODS
10
  # ============================================================================
11
 
12
  st.set_page_config(
@@ -15,14 +15,76 @@ st.set_page_config(
15
  layout="wide"
16
  )
17
 
18
- st.title("πŸ—„οΈ Phase 2: Vector Database Setup")
19
- st.markdown("**Database creation + Upload for custom notes AND public datasets**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Initialize session state
22
  if 'db_created' not in st.session_state:
23
  st.session_state.db_created = False
24
- if 'embedder_loaded' not in st.session_state:
25
- st.session_state.embedder_loaded = False
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # ============================================================================
28
  # STEP 1: API Keys Check
@@ -37,370 +99,357 @@ qdrant_api_key = os.getenv("QDRANT_API_KEY")
37
  col1, col2, col3 = st.columns(3)
38
 
39
  with col1:
40
- if anthropic_key:
41
- st.success("βœ… Claude API")
42
- else:
43
- st.error("❌ Claude API")
44
 
45
  with col2:
 
46
  if qdrant_url:
47
- st.success(f"βœ… Qdrant URL")
48
  st.caption(qdrant_url[:30] + "...")
49
- else:
50
- st.error("❌ Qdrant URL")
51
 
52
  with col3:
53
- if qdrant_api_key:
54
- st.success("βœ… Qdrant API Key")
55
- else:
56
- st.error("❌ Qdrant API Key")
57
 
58
  if not all([anthropic_key, qdrant_url, qdrant_api_key]):
59
- st.warning("⚠️ Missing secrets! Add them in Settings β†’ Repository Secrets")
60
  st.stop()
61
 
62
- st.markdown("---")
63
 
64
- # ============================================================================
65
- # STEP 2: Connect to Qdrant
66
- # ============================================================================
67
-
68
- st.header("Step 2: Connect to Qdrant Database")
69
-
70
- col1, col2 = st.columns([2, 1])
71
-
72
- with col1:
73
- st.info("**Platform:** Qdrant Cloud (https://cloud.qdrant.io)")
74
- st.caption("This tests connection to your cloud database cluster")
75
 
76
- with col2:
77
- if st.button("πŸ”Œ Test Connection"):
78
- try:
79
- with st.spinner("Connecting..."):
80
- client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
81
- collections = client.get_collections()
82
-
83
- st.success("βœ… Connected!")
84
- st.metric("Collections", len(collections.collections))
85
-
86
- st.session_state.qdrant_client = client
87
-
88
- except Exception as e:
89
- st.error(f"❌ Failed: {str(e)}")
90
 
91
  st.markdown("---")
92
 
93
  # ============================================================================
94
- # STEP 3: Create Collection
95
  # ============================================================================
96
 
97
- st.header("πŸ—οΈ Step 3: Create Vector Database Collection")
98
 
99
- st.info("""
100
- **πŸ–₯️ Where this happens:**
101
- - You click button HERE in your HF Space app
102
- - App creates collection in Qdrant Cloud
103
- - You can verify in Qdrant dashboard
 
 
 
 
 
 
 
 
104
 
105
- **What gets created:**
106
- - Collection name: `math_knowledge_base`
107
- - Vector dimensions: 384 (matches embedding model)
108
- - Distance metric: COSINE similarity
109
- """)
110
 
111
- collection_name = st.text_input(
112
- "Collection Name:",
113
- value="math_knowledge_base",
114
- help="This is your database name"
115
- )
116
 
117
- col1, col2 = st.columns([3, 1])
118
 
119
- with col1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if st.button("πŸ—οΈ CREATE DATABASE COLLECTION", type="primary"):
121
  try:
122
- client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
123
-
124
- collections = client.get_collections().collections
125
- exists = any(c.name == collection_name for c in collections)
126
-
127
- if exists:
128
- st.warning(f"Collection '{collection_name}' already exists!")
129
- if st.button("βœ… Use Existing"):
130
- st.session_state.db_created = True
131
- else:
132
- with st.spinner("Creating..."):
133
- client.create_collection(
134
- collection_name=collection_name,
135
- vectors_config=VectorParams(size=384, distance=Distance.COSINE)
136
  )
137
-
138
- st.success(f"πŸŽ‰ Created: **{collection_name}**")
139
- st.balloons()
140
- st.session_state.db_created = True
141
-
 
 
 
 
 
 
 
142
  except Exception as e:
143
  st.error(f"❌ Failed: {str(e)}")
144
 
145
- with col2:
146
- st.markdown("**Verify in:**")
147
- st.link_button("Open Qdrant", "https://cloud.qdrant.io", use_container_width=True)
148
-
149
  st.markdown("---")
150
 
151
  # ============================================================================
152
- # STEP 4: Load Embedding Model
153
  # ============================================================================
154
 
155
  st.header("πŸ€– Step 4: Load Embedding Model")
156
 
157
- st.info("""
158
- **πŸ–₯️ Where this happens:**
159
- - Downloads from Hugging Face Model Hub
160
- - Loads into YOUR HF Space's memory
161
- - Takes 30-60 seconds first time
162
-
163
- **Model:** `sentence-transformers/all-MiniLM-L6-v2`
164
- - Size: ~90MB
165
- - Output: 384 dimensions
166
- - Purpose: Convert text β†’ vectors
167
- """)
168
 
169
- if st.button("πŸ“₯ LOAD EMBEDDING MODEL", type="primary"):
170
- try:
171
- with st.spinner("⏳ Loading... (30-60 sec first time)"):
172
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
173
- st.session_state.embedder = model
174
- st.session_state.embedder_loaded = True
175
-
176
- st.success("βœ… Model loaded!")
177
-
178
- # Test
179
- test_text = "Pythagorean theorem: aΒ² + bΒ² = cΒ²"
180
- test_embedding = model.encode(test_text)
181
-
182
- st.write(f"**Test embedding shape:** {test_embedding.shape}")
183
- st.caption(f"First 5 values: {test_embedding[:5]}")
 
 
 
 
184
 
185
- except Exception as e:
186
- st.error(f"❌ Failed: {str(e)}")
 
 
187
 
188
  st.markdown("---")
189
 
190
  # ============================================================================
191
- # STEP 5A: Upload Custom Notes (Manual Text)
192
  # ============================================================================
193
 
194
- st.header("πŸ“ Step 5A: Upload Custom Math Notes (Text)")
195
-
196
- st.success("**For:** Your handwritten notes (converted to text) or typed notes")
197
 
198
- st.info("""
199
- **πŸ–₯️ Where this happens:**
200
- 1. You paste text HERE in HF Space app
201
- 2. App chunks it into pieces
202
- 3. App converts to vectors (using model from Step 4)
203
- 4. App uploads to Qdrant Cloud database
204
- """)
205
-
206
- with st.expander("πŸ“„ Paste your custom math notes here", expanded=True):
207
-
208
- custom_text = st.text_area(
209
- "Your math content:",
210
- value="""Pythagorean Theorem:
211
- For right triangle: aΒ² + bΒ² = cΒ²
212
- Example: a=3, b=4 β†’ c=5
213
 
214
- Quadratic Formula:
215
- axΒ² + bx + c = 0
216
- x = (-b ± √(b²-4ac))/2a
217
 
218
  Derivatives:
219
  d/dx(xⁿ) = nxⁿ⁻¹
220
- d/dx(sin x) = cos x
221
- d/dx(eΛ£) = eΛ£""",
222
- height=200,
223
- key="custom_notes"
224
- )
225
-
226
- source_name = st.text_input("Note name:", value="my_notes.txt")
227
-
228
- if st.button("πŸš€ UPLOAD CUSTOM NOTES", type="primary"):
229
-
230
- if not st.session_state.get('embedder_loaded'):
231
- st.error("⚠️ Load embedding model first (Step 4)")
232
- st.stop()
233
 
234
- if not st.session_state.get('db_created'):
235
- st.error("⚠️ Create collection first (Step 3)")
236
- st.stop()
237
 
238
- try:
239
- client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
240
- embedder = st.session_state.embedder
241
 
242
- with st.spinner("Processing..."):
243
-
244
- # Chunk
245
- words = custom_text.split()
246
- chunk_size = 50
247
- chunks = []
248
- for i in range(0, len(words), chunk_size-10):
249
- chunk = ' '.join(words[i:i + chunk_size])
250
- if chunk:
251
- chunks.append(chunk)
252
-
253
- st.write(f"βœ… Created {len(chunks)} chunks")
254
-
255
- # Embed
256
- embeddings = embedder.encode(chunks, show_progress_bar=False)
257
- st.write(f"βœ… Generated {len(embeddings)} embeddings")
258
-
259
- # Upload
260
- points = []
261
- for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
262
- points.append(PointStruct(
263
- id=abs(hash(f"{source_name}_{idx}")) % (2**63),
264
- vector=embedding.tolist(),
265
- payload={
266
- "content": chunk,
267
- "source_name": source_name,
268
- "source_type": "custom_notes",
269
- "chunk_index": idx
270
- }
271
- ))
272
-
273
- client.upsert(collection_name=collection_name, points=points)
274
-
275
- st.success(f"πŸŽ‰ Uploaded {len(points)} vectors!")
276
-
277
- # Show total
278
- info = client.get_collection(collection_name)
279
- st.info(f"πŸ“Š Total vectors in database: {info.vectors_count:,}")
280
-
281
- except Exception as e:
282
- st.error(f"❌ Failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  st.markdown("---")
285
 
286
  # ============================================================================
287
- # STEP 5B: Load Public Datasets (MATH, RACE, GSM8K)
288
  # ============================================================================
289
 
290
  st.header("πŸ“š Step 5B: Load Public Datasets")
291
 
292
- st.success("**For:** MATH, RACE, GSM8K datasets from Hugging Face")
293
-
294
- st.info("""
295
- **πŸ–₯️ Where this happens:**
296
- 1. Select dataset HERE in HF Space app
297
- 2. App downloads from Hugging Face Datasets
298
- 3. App processes problems/solutions
299
- 4. App uploads to Qdrant Cloud database
300
-
301
- **Note:** These datasets are large! Start with small samples.
302
- """)
303
-
304
- with st.expander("πŸ“Š Load public math datasets", expanded=False):
305
-
306
- dataset_choice = st.selectbox(
307
- "Choose dataset:",
308
- ["GSM8K (8.5K problems)", "MATH (12.5K problems)", "RACE (28K questions)"]
309
- )
310
-
311
- sample_size = st.slider("Number of problems to load:", 10, 500, 50)
312
-
313
- st.warning(f"⚠️ Loading {sample_size} problems. Larger numbers take longer!")
314
-
315
- if st.button("πŸ“₯ LOAD PUBLIC DATASET", key="load_dataset"):
316
 
317
- if not st.session_state.get('embedder_loaded'):
318
- st.error("⚠️ Load embedding model first (Step 4)")
319
- st.stop()
 
 
 
 
 
320
 
321
- if not st.session_state.get('db_created'):
322
- st.error("⚠️ Create collection first (Step 3)")
323
- st.stop()
324
 
325
- try:
326
- from datasets import load_dataset
327
-
328
- client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
329
- embedder = st.session_state.embedder
330
 
331
- with st.spinner(f"Loading {dataset_choice}..."):
 
332
 
333
- # Load appropriate dataset
334
- if "GSM8K" in dataset_choice:
335
- dataset = load_dataset("gsm8k", "main", split="train")
336
- dataset_name = "GSM8K"
337
 
338
- # Format data
339
- texts = []
340
- for i in range(min(sample_size, len(dataset))):
341
- item = dataset[i]
342
- text = f"Problem: {item['question']}\nSolution: {item['answer']}"
343
- texts.append(text)
344
-
345
- elif "MATH" in dataset_choice:
346
- dataset = load_dataset("hendrycks/competition_math", split="train")
347
- dataset_name = "MATH"
348
 
349
- texts = []
350
- for i in range(min(sample_size, len(dataset))):
351
- item = dataset[i]
352
- text = f"Problem ({item['type']}): {item['problem']}\nSolution: {item['solution']}"
353
- texts.append(text)
354
-
355
- else: # RACE
356
- dataset = load_dataset("race", "all", split="train")
357
- dataset_name = "RACE"
358
 
359
- texts = []
360
- for i in range(min(sample_size, len(dataset))):
361
- item = dataset[i]
362
- text = f"Article: {item['article']}\nQuestion: {item['question']}\nAnswer: {item['answer']}"
363
- texts.append(text)
364
-
365
- st.write(f"βœ… Loaded {len(texts)} items from {dataset_name}")
366
-
367
- # Generate embeddings
368
- progress_bar = st.progress(0)
369
- embeddings = []
370
-
371
- for idx, text in enumerate(texts):
372
- embedding = embedder.encode(text)
373
- embeddings.append(embedding)
374
- progress_bar.progress((idx + 1) / len(texts))
375
-
376
- st.write(f"βœ… Generated {len(embeddings)} embeddings")
377
-
378
- # Upload to Qdrant
379
- points = []
380
- for idx, (text, embedding) in enumerate(zip(texts, embeddings)):
381
- points.append(PointStruct(
382
- id=abs(hash(f"{dataset_name}_{idx}")) % (2**63),
383
- vector=embedding.tolist(),
384
- payload={
385
- "content": text[:1000], # Truncate if too long
386
- "source_name": dataset_name,
387
- "source_type": "public_dataset",
388
- "dataset": dataset_name,
389
- "index": idx
390
- }
391
- ))
392
-
393
- client.upsert(collection_name=collection_name, points=points)
394
-
395
- st.success(f"πŸŽ‰ Uploaded {len(points)} vectors from {dataset_name}!")
396
-
397
- # Show total
398
- info = client.get_collection(collection_name)
399
- st.info(f"πŸ“Š Total vectors in database: {info.vectors_count:,}")
400
-
401
- except Exception as e:
402
- st.error(f"❌ Failed: {str(e)}")
403
- st.exception(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  st.markdown("---")
406
 
@@ -410,64 +459,55 @@ st.markdown("---")
410
 
411
  st.header("πŸ” Step 6: Test Search")
412
 
413
- st.info("""
414
- **πŸ–₯️ Search happens:**
415
- 1. You enter question HERE
416
- 2. App converts to vector
417
- 3. App searches Qdrant Cloud
418
- 4. Returns most similar chunks
419
- """)
420
-
421
- search_query = st.text_input(
422
- "Ask a question:",
423
- placeholder="What is the Pythagorean theorem?"
424
- )
425
-
426
- top_k = st.slider("Results:", 1, 10, 3)
427
-
428
- if st.button("πŸ” SEARCH", type="primary") and search_query:
429
 
430
- if not st.session_state.get('embedder_loaded'):
431
- st.error("⚠️ Load embedding model first")
432
- st.stop()
433
 
434
- try:
435
- client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
436
- embedder = st.session_state.embedder
437
 
438
- with st.spinner("Searching..."):
439
- query_embedding = embedder.encode(search_query)
440
-
441
- results = client.search(
442
- collection_name=collection_name,
443
- query_vector=query_embedding.tolist(),
444
- limit=top_k
445
- )
446
-
447
- if results:
448
- st.success(f"βœ… Found {len(results)} results!")
449
 
450
- for i, result in enumerate(results, 1):
451
- similarity_pct = result.score * 100
 
 
 
 
 
 
 
452
 
453
- with st.expander(f"Result {i} - {similarity_pct:.1f}% match", expanded=(i==1)):
454
- st.info(result.payload['content'])
455
 
456
- col1, col2 = st.columns(2)
457
- with col1:
458
- st.caption(f"Source: {result.payload['source_name']}")
459
- with col2:
460
- st.caption(f"Type: {result.payload['source_type']}")
461
- else:
462
- st.warning("No results found")
463
-
464
- except Exception as e:
465
- st.error(f"❌ Search failed: {str(e)}")
 
 
 
466
 
467
  st.markdown("---")
468
 
469
  # ============================================================================
470
- # Progress Dashboard
471
  # ============================================================================
472
 
473
  st.header("βœ… Progress Dashboard")
@@ -475,21 +515,35 @@ st.header("βœ… Progress Dashboard")
475
  col1, col2, col3 = st.columns(3)
476
 
477
  with col1:
478
- st.metric("Database", "βœ…" if st.session_state.get('db_created') else "❌")
479
 
480
  with col2:
481
- st.metric("Embedder", "βœ…" if st.session_state.get('embedder_loaded') else "❌")
482
 
483
  with col3:
484
- try:
485
- if st.session_state.get('db_created'):
486
- client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
487
- info = client.get_collection(collection_name)
488
- st.metric("Vectors", f"{info.vectors_count:,}")
489
- else:
490
- st.metric("Vectors", "N/A")
491
- except:
492
- st.metric("Vectors", "?")
493
-
494
- if st.session_state.get('db_created') and st.session_state.get('embedder_loaded'):
495
- st.success("πŸŽ‰ Phase 2 Complete! Ready for Phase 3: PDF Upload + Full RAG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from sentence_transformers import SentenceTransformer
7
 
8
  # ============================================================================
9
+ # CONFIGURATION - RUNS ONCE
10
  # ============================================================================
11
 
12
  st.set_page_config(
 
15
  layout="wide"
16
  )
17
 
18
+ # Collection name - centralized
19
+ COLLECTION_NAME = "math_knowledge_base"
20
+
21
+ # ============================================================================
22
+ # CACHED FUNCTIONS - LOAD ONCE, REUSE FOREVER
23
+ # ============================================================================
24
+
25
+ @st.cache_resource
26
+ def get_qdrant_client():
27
+ """Cache Qdrant client - only connects once"""
28
+ qdrant_url = os.getenv("QDRANT_URL")
29
+ qdrant_api_key = os.getenv("QDRANT_API_KEY")
30
+
31
+ if not qdrant_url or not qdrant_api_key:
32
+ return None
33
+
34
+ return QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
35
+
36
+ @st.cache_resource
37
+ def get_embedding_model():
38
+ """Cache embedding model - only loads once"""
39
+ try:
40
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
41
+ return model
42
+ except Exception as e:
43
+ st.error(f"Failed to load model: {e}")
44
+ return None
45
+
46
+ @st.cache_data(ttl=10) # Cache for 10 seconds
47
+ def get_vector_count(_client, collection_name):
48
+ """Get vector count with caching"""
49
+ try:
50
+ info = _client.get_collection(collection_name)
51
+ # Handle both old and new Qdrant API versions
52
+ if hasattr(info, 'vectors_count'):
53
+ return info.vectors_count
54
+ elif hasattr(info, 'points_count'):
55
+ return info.points_count
56
+ else:
57
+ return 0
58
+ except:
59
+ return 0
60
+
61
+ def check_collection_exists(client, collection_name):
62
+ """Check if collection exists"""
63
+ try:
64
+ collections = client.get_collections().collections
65
+ return any(c.name == collection_name for c in collections)
66
+ except:
67
+ return False
68
+
69
+ # ============================================================================
70
+ # INITIALIZE SESSION STATE
71
+ # ============================================================================
72
 
 
73
  if 'db_created' not in st.session_state:
74
  st.session_state.db_created = False
75
+
76
+ if 'embedder_ready' not in st.session_state:
77
+ st.session_state.embedder_ready = False
78
+
79
+ if 'manual_db_check' not in st.session_state:
80
+ st.session_state.manual_db_check = False
81
+
82
+ # ============================================================================
83
+ # MAIN APP
84
+ # ============================================================================
85
+
86
+ st.title("πŸ—„οΈ Phase 2: Vector Database Setup")
87
+ st.markdown("**Optimized: Components load once and stay cached!**")
88
 
89
  # ============================================================================
90
  # STEP 1: API Keys Check
 
99
  col1, col2, col3 = st.columns(3)
100
 
101
  with col1:
102
+ st.metric("Claude API", "βœ…" if anthropic_key else "❌")
 
 
 
103
 
104
  with col2:
105
+ st.metric("Qdrant URL", "βœ…" if qdrant_url else "❌")
106
  if qdrant_url:
 
107
  st.caption(qdrant_url[:30] + "...")
 
 
108
 
109
  with col3:
110
+ st.metric("Qdrant Key", "βœ…" if qdrant_api_key else "❌")
 
 
 
111
 
112
  if not all([anthropic_key, qdrant_url, qdrant_api_key]):
113
+ st.error("⚠️ Missing secrets! Add in Settings β†’ Repository Secrets")
114
  st.stop()
115
 
116
+ st.success("βœ… All API keys configured!")
117
 
118
+ # Get cached client
119
+ client = get_qdrant_client()
 
 
 
 
 
 
 
 
 
120
 
121
+ if not client:
122
+ st.error("Failed to create Qdrant client")
123
+ st.stop()
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  st.markdown("---")
126
 
127
  # ============================================================================
128
+ # STEP 2: Auto-check Connection
129
  # ============================================================================
130
 
131
+ st.header("Step 2: Qdrant Connection Status")
132
 
133
+ try:
134
+ collections = client.get_collections()
135
+ st.success(f"βœ… Connected to Qdrant! Found {len(collections.collections)} collections")
136
+
137
+ # Auto-check if our collection exists
138
+ if check_collection_exists(client, COLLECTION_NAME):
139
+ st.info(f"πŸ“Š Collection '{COLLECTION_NAME}' exists!")
140
+ st.session_state.db_created = True
141
+ st.session_state.manual_db_check = True
142
+
143
+ except Exception as e:
144
+ st.error(f"❌ Connection failed: {str(e)}")
145
+ st.stop()
146
 
147
+ st.markdown("---")
 
 
 
 
148
 
149
+ # ============================================================================
150
+ # STEP 3: Create Collection (FIXED)
151
+ # ============================================================================
 
 
152
 
153
+ st.header("πŸ—οΈ Step 3: Create Database Collection")
154
 
155
+ # Show current status
156
+ if st.session_state.db_created:
157
+ st.success(f"βœ… Collection '{COLLECTION_NAME}' is ready to use!")
158
+
159
+ col1, col2 = st.columns(2)
160
+ with col1:
161
+ if st.button("πŸ”„ Recreate Collection (Delete & Rebuild)"):
162
+ try:
163
+ client.delete_collection(COLLECTION_NAME)
164
+ st.session_state.db_created = False
165
+ st.session_state.manual_db_check = False
166
+ st.rerun()
167
+ except Exception as e:
168
+ st.error(f"Delete failed: {e}")
169
+
170
+ with col2:
171
+ if st.button("ℹ️ Show Collection Info"):
172
+ try:
173
+ info = client.get_collection(COLLECTION_NAME)
174
+ st.json({
175
+ "name": COLLECTION_NAME,
176
+ "vectors": get_vector_count(client, COLLECTION_NAME),
177
+ "status": "Ready"
178
+ })
179
+ except Exception as e:
180
+ st.error(f"Error: {e}")
181
+
182
+ else:
183
+ # Collection doesn't exist - show create button
184
+ st.info(f"Collection '{COLLECTION_NAME}' does not exist yet.")
185
+
186
  if st.button("πŸ—οΈ CREATE DATABASE COLLECTION", type="primary"):
187
  try:
188
+ with st.spinner("Creating collection..."):
189
+ client.create_collection(
190
+ collection_name=COLLECTION_NAME,
191
+ vectors_config=VectorParams(
192
+ size=384,
193
+ distance=Distance.COSINE
 
 
 
 
 
 
 
 
194
  )
195
+ )
196
+
197
+ st.success(f"πŸŽ‰ Created collection: {COLLECTION_NAME}")
198
+ st.balloons()
199
+
200
+ # Update state
201
+ st.session_state.db_created = True
202
+ st.session_state.manual_db_check = True
203
+
204
+ # Force reload
205
+ st.rerun()
206
+
207
  except Exception as e:
208
  st.error(f"❌ Failed: {str(e)}")
209
 
 
 
 
 
210
  st.markdown("---")
211
 
212
  # ============================================================================
213
+ # STEP 4: Load Embedding Model (CACHED - LOADS ONCE)
214
  # ============================================================================
215
 
216
  st.header("πŸ€– Step 4: Load Embedding Model")
217
 
218
+ # Try to get cached model
219
+ embedder = get_embedding_model()
 
 
 
 
 
 
 
 
 
220
 
221
+ if embedder is not None:
222
+ st.success("βœ… Embedding model loaded and cached!")
223
+ st.session_state.embedder_ready = True
224
+
225
+ # Show test
226
+ with st.expander("πŸ§ͺ Model Test"):
227
+ test_text = "Pythagorean theorem: aΒ² + bΒ² = cΒ²"
228
+ test_embedding = embedder.encode(test_text)
229
+ st.write(f"**Shape:** {test_embedding.shape}")
230
+ st.write(f"**Sample values:** {test_embedding[:5]}")
231
+ else:
232
+ st.warning("⚠️ Model not loaded yet")
233
+
234
+ if st.button("πŸ“₯ LOAD EMBEDDING MODEL", type="primary"):
235
+ st.info("Loading model... (30-60 seconds first time)")
236
+ with st.spinner("Loading..."):
237
+ # Clear cache and reload
238
+ get_embedding_model.clear()
239
+ embedder = get_embedding_model()
240
 
241
+ if embedder:
242
+ st.success("βœ… Model loaded!")
243
+ st.session_state.embedder_ready = True
244
+ st.rerun()
245
 
246
  st.markdown("---")
247
 
248
  # ============================================================================
249
+ # STEP 5A: Upload Custom Text (FIXED)
250
  # ============================================================================
251
 
252
+ st.header("πŸ“ Step 5A: Upload Custom Math Notes")
 
 
253
 
254
+ # Check prerequisites
255
+ if not st.session_state.db_created:
256
+ st.warning("⚠️ Please create collection first (Step 3)")
257
+ elif not st.session_state.embedder_ready:
258
+ st.warning("⚠️ Please load embedding model first (Step 4)")
259
+ else:
260
+ with st.expander("✍️ Upload text", expanded=True):
261
+
262
+ custom_text = st.text_area(
263
+ "Paste your math notes:",
264
+ value="""Pythagorean Theorem: aΒ² + bΒ² = cΒ²
265
+ Example: If a=3, b=4, then c=5
 
 
 
266
 
267
+ Quadratic Formula: x = (-b ± √(b²-4ac))/2a
268
+ For axΒ² + bx + c = 0
 
269
 
270
  Derivatives:
271
  d/dx(xⁿ) = nxⁿ⁻¹
272
+ d/dx(sin x) = cos x""",
273
+ height=150
274
+ )
 
 
 
 
 
 
 
 
 
 
275
 
276
+ source_name = st.text_input("Note name:", value="my_math_notes.txt")
 
 
277
 
278
+ if st.button("πŸš€ UPLOAD TEXT", type="primary", key="upload_text"):
 
 
279
 
280
+ if not custom_text.strip():
281
+ st.error("Please enter some text!")
282
+ else:
283
+ try:
284
+ with st.spinner("Processing..."):
285
+
286
+ # Chunk text
287
+ words = custom_text.split()
288
+ chunk_size = 50
289
+ overlap = 10
290
+ chunks = []
291
+
292
+ for i in range(0, len(words), chunk_size - overlap):
293
+ chunk = ' '.join(words[i:i + chunk_size])
294
+ if chunk.strip():
295
+ chunks.append(chunk)
296
+
297
+ st.write(f"πŸ“„ Created {len(chunks)} chunks")
298
+
299
+ # Generate embeddings
300
+ embeddings = embedder.encode(chunks, show_progress_bar=False)
301
+ st.write(f"πŸ”’ Generated {len(embeddings)} embeddings")
302
+
303
+ # Upload to Qdrant
304
+ points = []
305
+ for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
306
+ points.append(PointStruct(
307
+ id=abs(hash(f"{source_name}_{idx}_{custom_text[:20]}")) % (2**63),
308
+ vector=embedding.tolist(),
309
+ payload={
310
+ "content": chunk,
311
+ "source_name": source_name,
312
+ "source_type": "custom_notes",
313
+ "chunk_index": idx
314
+ }
315
+ ))
316
+
317
+ client.upsert(
318
+ collection_name=COLLECTION_NAME,
319
+ points=points
320
+ )
321
+
322
+ st.success(f"πŸŽ‰ Uploaded {len(points)} vectors!")
323
+
324
+ # Show updated count
325
+ total = get_vector_count(client, COLLECTION_NAME)
326
+ st.info(f"πŸ“Š Total vectors in database: {total}")
327
+
328
+ # Clear cache to refresh count
329
+ get_vector_count.clear()
330
+
331
+ except Exception as e:
332
+ st.error(f"❌ Upload failed: {str(e)}")
333
+ st.exception(e)
334
 
335
  st.markdown("---")
336
 
337
  # ============================================================================
338
+ # STEP 5B: Load Public Datasets (FIXED)
339
  # ============================================================================
340
 
341
  st.header("πŸ“š Step 5B: Load Public Datasets")
342
 
343
+ if not st.session_state.db_created:
344
+ st.warning("⚠️ Please create collection first (Step 3)")
345
+ elif not st.session_state.embedder_ready:
346
+ st.warning("⚠️ Please load embedding model first (Step 4)")
347
+ else:
348
+ with st.expander("πŸ“Š Load datasets from Hugging Face", expanded=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
+ dataset_choice = st.selectbox(
351
+ "Choose dataset:",
352
+ [
353
+ "GSM8K - Grade School Math (8.5K problems)",
354
+ "MATH - Competition Math (12.5K problems)",
355
+ "RACE - Reading Comprehension (28K passages)"
356
+ ]
357
+ )
358
 
359
+ sample_size = st.slider("Number of items to load:", 10, 500, 50)
 
 
360
 
361
+ st.warning(f"⚠️ Loading {sample_size} items. First time takes longer!")
362
+
363
+ if st.button("πŸ“₯ LOAD DATASET", type="primary", key="load_dataset"):
 
 
364
 
365
+ try:
366
+ from datasets import load_dataset
367
 
368
+ with st.spinner(f"Loading {dataset_choice.split('-')[0].strip()}..."):
 
 
 
369
 
370
+ # Determine dataset
371
+ if "GSM8K" in dataset_choice:
372
+ dataset = load_dataset("openai/gsm8k", "main", split="train", trust_remote_code=True)
373
+ dataset_name = "GSM8K"
374
+
375
+ texts = []
376
+ for i in range(min(sample_size, len(dataset))):
377
+ item = dataset[i]
378
+ text = f"Problem: {item['question']}\n\nSolution: {item['answer']}"
379
+ texts.append(text)
380
 
381
+ elif "MATH" in dataset_choice:
382
+ dataset = load_dataset("hendrycks/competition_math", split="train", trust_remote_code=True)
383
+ dataset_name = "MATH"
384
+
385
+ texts = []
386
+ for i in range(min(sample_size, len(dataset))):
387
+ item = dataset[i]
388
+ text = f"Problem ({item['type']}): {item['problem']}\n\nSolution: {item['solution']}"
389
+ texts.append(text)
390
 
391
+ else: # RACE
392
+ dataset = load_dataset("ehovy/race", "all", split="train", trust_remote_code=True)
393
+ dataset_name = "RACE"
394
+
395
+ texts = []
396
+ for i in range(min(sample_size, len(dataset))):
397
+ item = dataset[i]
398
+ text = f"Article: {item['article']}\n\nQuestion: {item['question']}\n\nAnswer: {item['answer']}"
399
+ texts.append(text)
400
+
401
+ st.write(f"βœ… Loaded {len(texts)} items from {dataset_name}")
402
+
403
+ # Generate embeddings with progress
404
+ progress_bar = st.progress(0)
405
+ status_text = st.empty()
406
+
407
+ embeddings = []
408
+ for idx, text in enumerate(texts):
409
+ embedding = embedder.encode(text)
410
+ embeddings.append(embedding)
411
+
412
+ progress_bar.progress((idx + 1) / len(texts))
413
+ status_text.text(f"Embedding {idx + 1}/{len(texts)}")
414
+
415
+ status_text.empty()
416
+ st.write(f"βœ… Generated {len(embeddings)} embeddings")
417
+
418
+ # Upload to Qdrant
419
+ points = []
420
+ for idx, (text, embedding) in enumerate(zip(texts, embeddings)):
421
+ # Truncate long texts
422
+ content = text[:2000] if len(text) > 2000 else text
423
+
424
+ points.append(PointStruct(
425
+ id=abs(hash(f"{dataset_name}_{idx}")) % (2**63),
426
+ vector=embedding.tolist(),
427
+ payload={
428
+ "content": content,
429
+ "source_name": dataset_name,
430
+ "source_type": "public_dataset",
431
+ "dataset": dataset_name,
432
+ "index": idx
433
+ }
434
+ ))
435
+
436
+ client.upsert(
437
+ collection_name=COLLECTION_NAME,
438
+ points=points
439
+ )
440
+
441
+ st.success(f"πŸŽ‰ Uploaded {len(points)} vectors from {dataset_name}!")
442
+
443
+ # Show updated count (FIXED)
444
+ get_vector_count.clear() # Clear cache
445
+ total = get_vector_count(client, COLLECTION_NAME)
446
+ st.info(f"πŸ“Š Total vectors in database: {total}")
447
+
448
+ except ImportError:
449
+ st.error("❌ 'datasets' library not installed. Add 'datasets' to requirements.txt")
450
+ except Exception as e:
451
+ st.error(f"❌ Failed: {str(e)}")
452
+ st.exception(e)
453
 
454
  st.markdown("---")
455
 
 
459
 
460
  st.header("πŸ” Step 6: Test Search")
461
 
462
+ if not st.session_state.db_created or not st.session_state.embedder_ready:
463
+ st.warning("⚠️ Complete Steps 3 & 4 first")
464
+ else:
465
+ search_query = st.text_input(
466
+ "Ask a question:",
467
+ placeholder="What is the Pythagorean theorem?"
468
+ )
 
 
 
 
 
 
 
 
 
469
 
470
+ top_k = st.slider("Number of results:", 1, 10, 3)
 
 
471
 
472
+ if st.button("πŸ” SEARCH", type="primary") and search_query:
 
 
473
 
474
+ try:
475
+ with st.spinner("Searching..."):
476
+
477
+ # Generate query embedding
478
+ query_embedding = embedder.encode(search_query)
 
 
 
 
 
 
479
 
480
+ # Search Qdrant
481
+ results = client.search(
482
+ collection_name=COLLECTION_NAME,
483
+ query_vector=query_embedding.tolist(),
484
+ limit=top_k
485
+ )
486
+
487
+ if results:
488
+ st.success(f"βœ… Found {len(results)} results!")
489
 
490
+ for i, result in enumerate(results, 1):
491
+ similarity_pct = result.score * 100
492
 
493
+ with st.expander(f"πŸ“„ Result {i} - {similarity_pct:.1f}% match", expanded=(i==1)):
494
+ st.info(result.payload['content'])
495
+
496
+ col1, col2 = st.columns(2)
497
+ with col1:
498
+ st.caption(f"**Source:** {result.payload['source_name']}")
499
+ with col2:
500
+ st.caption(f"**Type:** {result.payload['source_type']}")
501
+ else:
502
+ st.warning("No results found. Upload more data!")
503
+
504
+ except Exception as e:
505
+ st.error(f"❌ Search failed: {str(e)}")
506
 
507
  st.markdown("---")
508
 
509
  # ============================================================================
510
+ # PROGRESS DASHBOARD (FIXED)
511
  # ============================================================================
512
 
513
  st.header("βœ… Progress Dashboard")
 
515
  col1, col2, col3 = st.columns(3)
516
 
517
  with col1:
518
+ st.metric("Database", "βœ… Ready" if st.session_state.db_created else "❌ Not Created")
519
 
520
  with col2:
521
+ st.metric("Embedder", "βœ… Ready" if st.session_state.embedder_ready else "❌ Not Loaded")
522
 
523
  with col3:
524
+ if st.session_state.db_created:
525
+ vector_count = get_vector_count(client, COLLECTION_NAME)
526
+ st.metric("Vectors", f"{vector_count:,}" if vector_count else "0")
527
+ else:
528
+ st.metric("Vectors", "N/A")
529
+
530
+ # Success message
531
+ if st.session_state.db_created and st.session_state.embedder_ready:
532
+ st.success("πŸŽ‰ Phase 2 Complete! Ready for Phase 3: PDF Upload + Full RAG")
533
+
534
+ # Debug panel
535
+ with st.expander("πŸ”§ Debug Info"):
536
+ st.json({
537
+ "db_created": st.session_state.db_created,
538
+ "embedder_ready": st.session_state.embedder_ready,
539
+ "collection_name": COLLECTION_NAME,
540
+ "cached_client": client is not None,
541
+ "cached_embedder": embedder is not None
542
+ })
543
+
544
+ if st.button("πŸ”„ Clear All Caches & Restart"):
545
+ get_qdrant_client.clear()
546
+ get_embedding_model.clear()
547
+ get_vector_count.clear()
548
+ st.session_state.clear()
549
+ st.rerun()