mnhat19 commited on
Commit
b59fc2c
·
1 Parent(s): 1149349

feat: full optimization - Groq LLM, disease cache, deploy configs

Browse files

- Switch backend LLM from Gemini to Groq (llama-3.3-70b-versatile)
- Add GroqKeyManager: round-robin key rotation with 429 cooldown
- Add DiseaseCache: SQLite 7-day TTL cache - 0 LLM calls on cache HIT
- SessionStore: WAL mode + cleanup_expired() on startup
- Fix nhan_xet_tong_quan field in evaluation JSON (was 'nhan_xet')
- Restore proper Vietnamese diacritics in all LLM prompts
- Migrate Tailwind from CDN to PostCSS pipeline
- Add tenacity retry on rate limit errors
- ragService.ts: use VITE_RAG_API_URL env var (not hardcoded localhost)
- Add render.yaml (Render.com free tier backend deploy)
- Add vercel.json (Vercel frontend deploy)
- Add rag_project/Dockerfile (HuggingFace Spaces alternative)
- Update .gitignore: exclude secrets, DB files, training data

.env.example ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend Environment Variables
2
+ # Copy this file to .env and fill in your values
3
+ # DO NOT COMMIT the actual .env file — it contains secrets!
4
+
5
+ # ── Groq API Keys (get free keys at console.groq.com) ──
6
+ # Two keys for round-robin rotation to stay within rate limits
7
+ GROQ_API_KEY_1=gsk_your_first_key_here
8
+ GROQ_API_KEY_2=gsk_your_second_key_here
9
+ GROQ_MODEL=llama-3.3-70b-versatile
10
+
11
+ # ── Google Gemini key (used by frontend only — optional for backend) ──
12
+ GOOGLE_API_KEY=your_google_api_key_here
13
+
14
+ # ── CORS: comma-separated list of allowed frontend origins ──
15
+ # Local dev: http://localhost:3000,http://localhost:5173
16
+ # Production: https://your-app.vercel.app
17
+ ALLOWED_ORIGINS=http://localhost:3000,http://localhost:5173
18
+
19
+ # ── Optional API protection (leave blank to disable) ──
20
+ API_SECRET_KEY=
21
+ # For production: https://yourdomain.com
22
+ ALLOWED_ORIGINS=http://localhost:3000,http://localhost:5173
23
+
24
+ # Optional: Protect API endpoints with a static secret key
25
+ # Leave empty to disable authentication (development only)
26
+ API_SECRET_KEY=
.gitignore CHANGED
@@ -2,6 +2,13 @@
2
  .env
3
  *.env
4
 
 
 
 
 
 
 
 
5
  # python
6
  __pycache__/
7
  *.pyc
 
2
  .env
3
  *.env
4
 
5
+ # SQLite databases (generated at runtime)
6
+ sessions.db
7
+ disease_cache.db
8
+
9
+ # Training data (not needed for API)
10
+ pediatric_finetune_15k_vietnamese.jsonl
11
+
12
  # python
13
  __pycache__/
14
  *.pyc
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── HuggingFace Spaces compatible Dockerfile ──────────────────────────────
2
+ # Port MUST be 7860 for HF Spaces.
3
+ # Runs as user 1000 (HF requirement).
4
+
5
+ FROM python:3.10-slim
6
+
7
+ # System deps
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential curl git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Create non-root user (required by HF Spaces)
13
+ RUN useradd -m -u 1000 appuser
14
+
15
+ WORKDIR /app
16
+
17
+ # Install Python deps first (better layer caching)
18
+ COPY requirements_api.txt .
19
+ RUN pip install --no-cache-dir -r requirements_api.txt
20
+
21
+ # Copy application files
22
+ COPY . .
23
+
24
+ # Set ownership
25
+ RUN chown -R appuser:appuser /app
26
+
27
+ USER appuser
28
+
29
+ # Pre-build FAISS index at image build time (bakes the 4.4 MB index in)
30
+ # Sentence-transformer model is downloaded here and cached in /home/appuser/.cache
31
+ RUN python src/build_faiss.py
32
+
33
+ # Environment defaults (override via HF Space secrets)
34
+ ENV GROQ_API_KEY_1=""
35
+ ENV GROQ_API_KEY_2=""
36
+ ENV GROQ_MODEL="llama-3.3-70b-versatile"
37
+ ENV ALLOWED_ORIGINS="*"
38
+ ENV PORT=7860
39
+
40
+ # Expose HF Spaces port
41
+ EXPOSE 7860
42
+
43
+ CMD ["python", "api_server_fastapi.py"]
Dockerfile.backend ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+ WORKDIR /app
3
+
4
+ # Install dependencies
5
+ COPY requirements_api.txt .
6
+ RUN pip install --no-cache-dir -r requirements_api.txt
7
+
8
+ # Copy source
9
+ COPY . .
10
+
11
+ EXPOSE 8001
12
+ CMD ["python", "api_server_fastapi.py"]
api_server_fastapi.py CHANGED
@@ -5,7 +5,7 @@ Endpoints:
5
  - GET /api/diseases - Lấy danh sách bệnh từ JSON
6
  - POST /api/start-case - Nhận bệnh, tạo case với triệu chứng
7
  - POST /api/evaluate - Nhận đáp án user, trả về kết quả so sánh
8
- - Docs: http://localhost:5000/docs (Swagger UI)
9
  """
10
  import sys
11
  import io
@@ -14,8 +14,11 @@ import io
14
  sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
15
  sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
16
 
17
- from fastapi import FastAPI, HTTPException
 
 
18
  from fastapi.middleware.cors import CORSMiddleware
 
19
  from pydantic import BaseModel
20
  from typing import Optional, List, Dict, Any
21
  import json
@@ -31,6 +34,8 @@ from config import Config
31
  from doctor_evaluator import DoctorEvaluator
32
  from vector_store import VectorStoreManager
33
  from rag_chain import RAGChain
 
 
34
 
35
  app = FastAPI(
36
  title="Medical RAG API",
@@ -40,15 +45,31 @@ app = FastAPI(
40
  redoc_url="/redoc"
41
  )
42
 
43
- # Configure CORS
 
 
 
 
 
44
  app.add_middleware(
45
  CORSMiddleware,
46
- allow_origins=["*"], # In production, specify exact origins
47
  allow_credentials=True,
48
  allow_methods=["*"],
49
  allow_headers=["*"],
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Initialize RAG system
53
  print("[*] Initializing RAG system...")
54
  vs_manager = VectorStoreManager()
@@ -60,8 +81,12 @@ rag = RAGChain(vs_manager)
60
  evaluator = DoctorEvaluator(rag)
61
  print("[OK] RAG system ready!")
62
 
63
- # Store active sessions
64
- active_sessions: Dict[str, Dict[str, Any]] = {}
 
 
 
 
65
 
66
 
67
  # Pydantic models for request/response
@@ -189,64 +214,85 @@ async def get_diseases(
189
  raise HTTPException(status_code=500, detail=str(e))
190
 
191
 
192
- @app.post("/api/start-case", response_model=StartCaseResponse)
193
  async def start_case(request: StartCaseRequest):
194
  """
195
- Gọi các hàm SẴN như trong main.py
196
- 1. find_symptoms() - RAG tìm triệu chứng
197
- 2. generate_case() - Gemini tạo case
198
- 3. get_detailed_standard_knowledge() - RAG lấy đáp án chuẩn
199
  """
200
  try:
201
  disease = request.disease.strip()
202
  session_id = request.sessionId
203
-
204
  if not disease:
205
  raise HTTPException(status_code=400, detail="Disease name is required")
206
-
207
  print(f"[INFO] Starting case for disease: {disease}")
208
  print(f"[INFO] Session ID: {session_id}")
209
-
210
- # 1. RAG tìm TRIỆU CHỨNG (như main.py)
211
- print("[INFO] Step 1: Finding symptoms...")
212
- symptoms, symptom_sources = evaluator.find_symptoms(disease)
213
- print(f"[INFO] Found symptoms (first 200 chars): {symptoms[:200]}...")
214
-
215
- # 2. GEMINI tạo CASE (như main.py)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  print("[INFO] Step 2: Generating patient case...")
217
- patient_case = evaluator.generate_case(disease, symptoms)
 
 
218
  print(f"[INFO] Generated case (first 200 chars): {patient_case[:200]}...")
219
-
220
- # 3. RAG lấy đáp án chuẩn (như main.py)
221
- print("[INFO] Step 3: Getting standard knowledge...")
222
- standard_data, all_sources = evaluator.get_detailed_standard_knowledge(disease)
223
- print(f"[INFO] Standard data retrieved (length: {len(standard_data)} chars)")
224
-
225
- # Lưu vào session để /evaluate dùng
226
- session_data = {
 
 
 
 
 
 
 
227
  'disease': disease,
228
  'case': patient_case,
229
  'symptoms': symptoms,
230
  'standard': standard_data,
231
- 'sources': all_sources
232
- }
233
- active_sessions[session_id] = session_data
234
-
235
  return StartCaseResponse(
236
  success=True,
237
  sessionId=session_id,
238
  case=patient_case,
239
  symptoms=symptoms[:300] + "...",
240
- sources=[
241
- {
242
- 'file': doc.metadata.get('source_file', ''),
243
- 'title': doc.metadata.get('chunk_title', ''),
244
- 'section': doc.metadata.get('section_title', '')
245
- }
246
- for doc in all_sources[:3]
247
- ]
248
  )
249
-
250
  except HTTPException:
251
  raise
252
  except Exception as e:
@@ -256,7 +302,7 @@ async def start_case(request: StartCaseRequest):
256
  raise HTTPException(status_code=500, detail=str(e))
257
 
258
 
259
- @app.post("/api/evaluate", response_model=EvaluateResponse)
260
  async def evaluate_diagnosis(request: EvaluateRequest):
261
  """
262
  Nhận câu trả lời user, so sánh với đáp án chuẩn đã có trong session
@@ -264,21 +310,23 @@ async def evaluate_diagnosis(request: EvaluateRequest):
264
  try:
265
  session_id = request.sessionId
266
  diagnosis = request.diagnosis
267
-
268
- if not session_id or session_id not in active_sessions:
 
 
 
 
269
  raise HTTPException(status_code=400, detail="Invalid or expired session")
270
-
271
- session_data = active_sessions[session_id]
272
  disease = session_data['disease']
273
  patient_case = session_data['case']
274
  standard_answer = session_data['standard']
275
- sources = session_data['sources']
276
-
277
  print(f"[INFO] Evaluating diagnosis for: {disease}")
278
  print(f"[INFO] Session ID: {session_id}")
279
  print(f"[INFO] User diagnosis: {diagnosis.dict()}")
280
-
281
- # Format user's answer
282
  user_answer = f"""
283
  CHẨN ĐOÁN:
284
  - Lâm sàng: {diagnosis.clinical or 'Không có'}
@@ -291,29 +339,29 @@ KẾ HOẠCH ĐIỀU TRỊ:
291
  - Thuốc: {diagnosis.medication or 'Không có'}
292
  """
293
  print(f"[INFO] Formatted user answer (first 300 chars): {user_answer[:300]}...")
294
-
295
- print("[INFO] Step 1: Evaluating with Gemini...")
296
- # Gemini đánh giá (dùng hàm CÓ SẴN: detailed_evaluation)
297
- evaluation_result = evaluator.detailed_evaluation(user_answer, standard_answer)
298
- print(f"[INFO] Step 2: Evaluation result (first 500 chars): {evaluation_result[:500]}...")
299
-
 
 
 
300
  # Parse JSON from evaluation
301
- print("[INFO] Step 3: Parsing JSON evaluation...")
302
  try:
303
- import json
304
- # Remove markdown code blocks if present
305
  eval_text = evaluation_result.strip()
306
  if eval_text.startswith('```'):
307
  lines = eval_text.split('\n')
308
  eval_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else eval_text
309
  if eval_text.startswith('json'):
310
  eval_text = eval_text[4:].strip()
311
- evaluation_obj = json.loads(eval_text)
312
- print(f"[INFO] Successfully parsed JSON: {json.dumps(evaluation_obj, ensure_ascii=False, indent=2)[:500]}...")
313
  except Exception as parse_error:
314
  print(f"[ERROR] Failed to parse JSON: {parse_error}")
315
- print(f"[ERROR] Raw evaluation text: {evaluation_result[:500]}...")
316
- # If parsing fails, return as text
317
  evaluation_obj = {
318
  'evaluation_text': evaluation_result,
319
  'diem_so': 'N/A',
@@ -324,22 +372,11 @@ KẾ HOẠCH ĐIỀU TRỊ:
324
  'dien_giai': evaluation_result,
325
  'nhan_xet_tong_quan': 'Lỗi parse JSON'
326
  }
327
-
328
- # Format sources
329
- formatted_sources = [
330
- {
331
- 'file': doc.metadata.get('source_file', ''),
332
- 'title': doc.metadata.get('chunk_title', ''),
333
- 'section': doc.metadata.get('section_title', '')
334
- }
335
- for doc in sources[:3]
336
- ]
337
-
338
- print("[INFO] Step 4: Formatting response...")
339
- print(f"[INFO] Evaluation object keys: {list(evaluation_obj.keys())}")
340
- print(f"[INFO] Standard answer length: {len(standard_answer)} chars")
341
- print(f"[INFO] Number of sources: {len(formatted_sources)}")
342
-
343
  return EvaluateResponse(
344
  success=True,
345
  case=patient_case,
@@ -350,7 +387,7 @@ KẾ HOẠCH ĐIỀU TRỊ:
350
  evaluation=evaluation_obj,
351
  sources=formatted_sources
352
  )
353
-
354
  except HTTPException:
355
  raise
356
  except Exception as e:
@@ -363,13 +400,18 @@ KẾ HOẠCH ĐIỀU TRỊ:
363
  if __name__ == '__main__':
364
  print("[*] Starting FastAPI Server...")
365
  print(f"[*] Server: http://localhost:8001")
366
- print(f"[*] Docs: http://localhost:8001/docs")
367
- print(f"[*] Using API Key: {Config.GOOGLE_API_KEY[:20]}...")
 
 
 
 
 
368
 
369
  uvicorn.run(
370
  app,
371
  host="0.0.0.0",
372
- port=8001,
373
  log_level="info",
374
- reload=False # Set to True for development
375
  )
 
5
  - GET /api/diseases - Lấy danh sách bệnh từ JSON
6
  - POST /api/start-case - Nhận bệnh, tạo case với triệu chứng
7
  - POST /api/evaluate - Nhận đáp án user, trả về kết quả so sánh
8
+ - Docs: http://localhost:8001/docs (Swagger UI)
9
  """
10
  import sys
11
  import io
 
14
  sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
15
  sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
16
 
17
+ import asyncio
18
+
19
+ from fastapi import FastAPI, HTTPException, Depends, Security
20
  from fastapi.middleware.cors import CORSMiddleware
21
+ from fastapi.security.api_key import APIKeyHeader
22
  from pydantic import BaseModel
23
  from typing import Optional, List, Dict, Any
24
  import json
 
34
  from doctor_evaluator import DoctorEvaluator
35
  from vector_store import VectorStoreManager
36
  from rag_chain import RAGChain
37
+ from session_store import SessionStore
38
+ from disease_cache import DiseaseCache
39
 
40
  app = FastAPI(
41
  title="Medical RAG API",
 
45
  redoc_url="/redoc"
46
  )
47
 
48
+ # Configure CORS — restrict to known frontend origins via ALLOWED_ORIGINS env var
49
+ _allowed_origins_env = os.getenv(
50
+ "ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:5173"
51
+ )
52
+ ALLOWED_ORIGINS = [o.strip() for o in _allowed_origins_env.split(",") if o.strip()]
53
+
54
  app.add_middleware(
55
  CORSMiddleware,
56
+ allow_origins=ALLOWED_ORIGINS,
57
  allow_credentials=True,
58
  allow_methods=["*"],
59
  allow_headers=["*"],
60
  )
61
 
62
+ # Optional API key authentication (set API_SECRET_KEY env var to enable)
63
+ _API_SECRET_KEY = os.getenv("API_SECRET_KEY", "")
64
+ _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
65
+
66
+
67
+ async def verify_api_key(api_key: str = Security(_api_key_header)):
68
+ """If API_SECRET_KEY is configured, require matching X-API-Key header."""
69
+ if _API_SECRET_KEY and api_key != _API_SECRET_KEY:
70
+ raise HTTPException(status_code=403, detail="Invalid or missing API key")
71
+ return api_key
72
+
73
  # Initialize RAG system
74
  print("[*] Initializing RAG system...")
75
  vs_manager = VectorStoreManager()
 
81
  evaluator = DoctorEvaluator(rag)
82
  print("[OK] RAG system ready!")
83
 
84
+ # Persistent session store (SQLite)
85
+ session_store = SessionStore()
86
+ session_store.cleanup_expired() # remove stale sessions from previous runs
87
+
88
+ # Disease-level result cache (7-day TTL, avoids repeating RAG queries for same disease)
89
+ disease_cache = DiseaseCache()
90
 
91
 
92
  # Pydantic models for request/response
 
214
  raise HTTPException(status_code=500, detail=str(e))
215
 
216
 
217
+ @app.post("/api/start-case", response_model=StartCaseResponse, dependencies=[Depends(verify_api_key)])
218
  async def start_case(request: StartCaseRequest):
219
  """
220
+ 1. find_symptoms() + get_detailed_standard_knowledge() run IN PARALLEL
221
+ 2. generate_case() runs after symptoms are ready
222
+ 3. Session persisted to SQLite
 
223
  """
224
  try:
225
  disease = request.disease.strip()
226
  session_id = request.sessionId
227
+
228
  if not disease:
229
  raise HTTPException(status_code=400, detail="Disease name is required")
230
+
231
  print(f"[INFO] Starting case for disease: {disease}")
232
  print(f"[INFO] Session ID: {session_id}")
233
+
234
+ loop = asyncio.get_running_loop()
235
+
236
+ # Check disease-level cache first (0 LLM calls if HIT)
237
+ cached = disease_cache.get(disease)
238
+ if cached:
239
+ print(f"[INFO] Disease cache HIT for '{disease}' — skipping RAG queries")
240
+ symptoms = cached["symptoms"]
241
+ standard_data = cached["standard"]
242
+ all_sources_raw = cached["sources"]
243
+ else:
244
+ # Cache MISS — run symptoms + standard in parallel, then cache results
245
+ print("[INFO] Disease cache MISS — running RAG queries in parallel...")
246
+ (symptoms, symptom_sources), (standard_data, std_sources) = await asyncio.gather(
247
+ loop.run_in_executor(None, evaluator.find_symptoms, disease),
248
+ loop.run_in_executor(None, evaluator.get_detailed_standard_knowledge, disease),
249
+ )
250
+ all_sources_raw = symptom_sources + std_sources
251
+ # Cache for future requests
252
+ disease_cache.set(disease, symptoms, standard_data, [
253
+ {"file": d.metadata.get("source_file",""), "title": d.metadata.get("chunk_title",""),
254
+ "section": d.metadata.get("section_title","")} for d in all_sources_raw[:5]
255
+ ])
256
+
257
+ print(f"[INFO] Symptoms (first 200 chars): {symptoms[:200]}...")
258
+ print(f"[INFO] Standard data length: {len(standard_data)} chars")
259
+
260
+ # Step 2: generate case (depends on symptoms output)
261
  print("[INFO] Step 2: Generating patient case...")
262
+ patient_case = await loop.run_in_executor(
263
+ None, evaluator.generate_case, disease, symptoms
264
+ )
265
  print(f"[INFO] Generated case (first 200 chars): {patient_case[:200]}...")
266
+
267
+ # Pre-format sources (Document objects or plain dicts -> plain dicts for JSON storage)
268
+ formatted_sources = []
269
+ for src in (all_sources_raw if not cached else all_sources_raw)[:5]:
270
+ if isinstance(src, dict):
271
+ formatted_sources.append(src)
272
+ else:
273
+ formatted_sources.append({
274
+ 'file': src.metadata.get('source_file', ''),
275
+ 'title': src.metadata.get('chunk_title', ''),
276
+ 'section': src.metadata.get('section_title', ''),
277
+ })
278
+
279
+ # Persist session to SQLite
280
+ session_store.set(session_id, {
281
  'disease': disease,
282
  'case': patient_case,
283
  'symptoms': symptoms,
284
  'standard': standard_data,
285
+ 'sources': formatted_sources,
286
+ })
287
+
 
288
  return StartCaseResponse(
289
  success=True,
290
  sessionId=session_id,
291
  case=patient_case,
292
  symptoms=symptoms[:300] + "...",
293
+ sources=formatted_sources[:3],
 
 
 
 
 
 
 
294
  )
295
+
296
  except HTTPException:
297
  raise
298
  except Exception as e:
 
302
  raise HTTPException(status_code=500, detail=str(e))
303
 
304
 
305
+ @app.post("/api/evaluate", response_model=EvaluateResponse, dependencies=[Depends(verify_api_key)])
306
  async def evaluate_diagnosis(request: EvaluateRequest):
307
  """
308
  Nhận câu trả lời user, so sánh với đáp án chuẩn đã có trong session
 
310
  try:
311
  session_id = request.sessionId
312
  diagnosis = request.diagnosis
313
+
314
+ if not session_id:
315
+ raise HTTPException(status_code=400, detail="Session ID required")
316
+
317
+ session_data = session_store.get(session_id)
318
+ if session_data is None:
319
  raise HTTPException(status_code=400, detail="Invalid or expired session")
320
+
 
321
  disease = session_data['disease']
322
  patient_case = session_data['case']
323
  standard_answer = session_data['standard']
324
+
 
325
  print(f"[INFO] Evaluating diagnosis for: {disease}")
326
  print(f"[INFO] Session ID: {session_id}")
327
  print(f"[INFO] User diagnosis: {diagnosis.dict()}")
328
+
329
+ # Format users answer
330
  user_answer = f"""
331
  CHẨN ĐOÁN:
332
  - Lâm sàng: {diagnosis.clinical or 'Không có'}
 
339
  - Thuốc: {diagnosis.medication or 'Không có'}
340
  """
341
  print(f"[INFO] Formatted user answer (first 300 chars): {user_answer[:300]}...")
342
+
343
+ # Run Groq evaluation (blocking I/O executed off the event loop)
344
+ print("[INFO] Step 1: Evaluating with Groq...")
345
+ loop = asyncio.get_running_loop()
346
+ evaluation_result = await loop.run_in_executor(
347
+ None, evaluator.detailed_evaluation, user_answer, standard_answer
348
+ )
349
+ print(f"[INFO] Evaluation result (first 500 chars): {evaluation_result[:500]}...")
350
+
351
  # Parse JSON from evaluation
352
+ print("[INFO] Step 2: Parsing JSON evaluation...")
353
  try:
354
+ import json as _json
 
355
  eval_text = evaluation_result.strip()
356
  if eval_text.startswith('```'):
357
  lines = eval_text.split('\n')
358
  eval_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else eval_text
359
  if eval_text.startswith('json'):
360
  eval_text = eval_text[4:].strip()
361
+ evaluation_obj = _json.loads(eval_text)
362
+ print(f"[INFO] Successfully parsed JSON")
363
  except Exception as parse_error:
364
  print(f"[ERROR] Failed to parse JSON: {parse_error}")
 
 
365
  evaluation_obj = {
366
  'evaluation_text': evaluation_result,
367
  'diem_so': 'N/A',
 
372
  'dien_giai': evaluation_result,
373
  'nhan_xet_tong_quan': 'Lỗi parse JSON'
374
  }
375
+
376
+ # Sources are already pre-formatted plain dicts (stored in session)
377
+ formatted_sources = session_data.get('sources', [])[:3]
378
+
379
+ print("[INFO] Step 3: Formatting response...")
 
 
 
 
 
 
 
 
 
 
 
380
  return EvaluateResponse(
381
  success=True,
382
  case=patient_case,
 
387
  evaluation=evaluation_obj,
388
  sources=formatted_sources
389
  )
390
+
391
  except HTTPException:
392
  raise
393
  except Exception as e:
 
400
  if __name__ == '__main__':
401
  print("[*] Starting FastAPI Server...")
402
  print(f"[*] Server: http://localhost:8001")
403
+ print(f"[*] Docs: http://localhost:8001/docs")
404
+ api_key_status = "configured" if Config.GROQ_API_KEY_1 else "NOT SET (set GROQ_API_KEY_1 in .env)"
405
+ print(f"[*] Groq Key status: {api_key_status}")
406
+ cors_status = "restricted" if ALLOWED_ORIGINS != ["*"] else "OPEN (*)"
407
+ print(f"[*] CORS origins ({cors_status}): {ALLOWED_ORIGINS}")
408
+ auth_status = "enabled" if _API_SECRET_KEY else "disabled (set API_SECRET_KEY to enable)"
409
+ print(f"[*] API auth: {auth_status}")
410
 
411
  uvicorn.run(
412
  app,
413
  host="0.0.0.0",
414
+ port=int(os.getenv("PORT", "8001")),
415
  log_level="info",
416
+ reload=False
417
  )
render.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Render.com deployment config
2
+ # Free tier: 750 h/month — enough for 24/7 with UptimeRobot keep-alive
3
+ # RAM: 512 MB (tight for PhoBERT — recommend HF Spaces if RAM issues occur)
4
+
5
+ services:
6
+ - type: web
7
+ name: medchat-backend
8
+ env: python
9
+ region: singapore # closest to Vietnam
10
+ plan: free
11
+
12
+ # Build: install deps AND pre-build FAISS index
13
+ buildCommand: |
14
+ pip install -r requirements_api.txt
15
+ python src/build_faiss.py
16
+
17
+ startCommand: python api_server_fastapi.py
18
+
19
+ healthCheckPath: /api/health
20
+
21
+ envVars:
22
+ - key: GROQ_API_KEY_1
23
+ sync: false # set manually in Render dashboard
24
+ - key: GROQ_API_KEY_2
25
+ sync: false
26
+ - key: GROQ_MODEL
27
+ value: llama-3.3-70b-versatile
28
+ - key: ALLOWED_ORIGINS
29
+ sync: false # set to your Vercel URL, e.g. https://medchat.vercel.app
30
+ - key: PORT
31
+ value: "8001"
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  langchain-core>=0.3.17,<0.4.0
2
  langchain-community>=0.3.7,<0.4.0
3
  langchain-huggingface
@@ -8,9 +9,5 @@ sentence-transformers
8
  google-generativeai
9
  python-dotenv
10
 
11
-
12
- unsloth
13
- transformers
14
- torch
15
- accelerate
16
- bitsandbytes
 
1
+ # Core RAG serving dependencies
2
  langchain-core>=0.3.17,<0.4.0
3
  langchain-community>=0.3.7,<0.4.0
4
  langchain-huggingface
 
9
  google-generativeai
10
  python-dotenv
11
 
12
+ # Fine-tuning dependencies are in requirements_finetune.txt
13
+ # Do NOT install those for serving/inference deployments
 
 
 
 
requirements_api.txt CHANGED
@@ -1,3 +1,19 @@
 
1
  fastapi==0.109.0
2
  uvicorn[standard]==0.27.0
3
  pydantic==2.5.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI server & runtime dependencies
2
  fastapi==0.109.0
3
  uvicorn[standard]==0.27.0
4
  pydantic==2.5.3
5
+
6
+ # LangChain stack
7
+ langchain-core>=0.3.17,<0.4.0
8
+ langchain-community>=0.3.7,<0.4.0
9
+ langchain-huggingface
10
+ langchain-groq>=0.3.0,<1.0.0 # must stay on 0.x to share langchain-core 0.3.x
11
+ langchain-google-genai # kept for optional fallback / frontend
12
+
13
+ # Vector store & embeddings
14
+ faiss-cpu
15
+ sentence-transformers
16
+
17
+ # Utilities
18
+ python-dotenv
19
+ tenacity
requirements_finetune.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning dependencies — NOT needed for serving/inference
2
+ # Install separately only for model training:
3
+ # pip install -r requirements_finetune.txt
4
+ unsloth
5
+ transformers
6
+ torch
7
+ accelerate
8
+ bitsandbytes
src/config.py CHANGED
@@ -1,29 +1,31 @@
1
  import os
 
2
  from dotenv import load_dotenv
3
 
4
  load_dotenv()
5
 
6
  class Config:
7
- # SỬA ĐƯỜNG DẪN LẠI NHA
8
- BASE_DIR = r"D:\Storage\rag_project"
9
- DATA_DIR = f"{BASE_DIR}/data"
10
  CHUNK_FILES = [
11
- f"{DATA_DIR}/BoYTe200_v3.json",
12
- f"{DATA_DIR}/NHIKHOA2.json",
13
- f"{DATA_DIR}/PHACDODIEUTRI_2016.json"
14
- ]
15
- # EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Model embedding
 
16
  # EMBEDDING_MODEL = "bkai-foundation-models/vietnamese-bi-encoder"
17
  EMBEDDING_MODEL = "VoVanPhuc/sup-SimCSE-VietNamese-phobert-base"
18
-
19
- # GOOGLE API KEY - Thay bằng key của bạn từ https://makersuite.google.com/app/apikey
20
- GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY', 'YOUR_API_KEY_HERE')
21
-
22
- LLM_MODEL = "gemini-2.5-flash"
23
- K_RETRIEVE = 3 # Số Document muốn truy
24
- TEMPERATURE = 0
25
 
26
- """
27
- AIzaSyABvC8mPrwa0Kgy08mFFzkyeh2_N-Bb3lY
28
- AIzaSyDJqr4nKDrcfmmuKOdDCHkXRvKA48htD6o
29
- """
 
 
 
 
 
 
 
 
1
  import os
2
+ from pathlib import Path
3
  from dotenv import load_dotenv
4
 
5
  load_dotenv()
6
 
7
  class Config:
8
+ # Project root directory resolved relative to this file (rag_project/)
9
+ BASE_DIR = Path(__file__).parent.parent
10
+ DATA_DIR = BASE_DIR / "data"
11
  CHUNK_FILES = [
12
+ str(DATA_DIR / "BoYTe200_v3.json"),
13
+ str(DATA_DIR / "NHIKHOA2.json"),
14
+ str(DATA_DIR / "PHACDODIEUTRI_2016.json"),
15
+ ]
16
+
17
+ # EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
18
  # EMBEDDING_MODEL = "bkai-foundation-models/vietnamese-bi-encoder"
19
  EMBEDDING_MODEL = "VoVanPhuc/sup-SimCSE-VietNamese-phobert-base"
 
 
 
 
 
 
 
20
 
21
+ # Groq API keys (set in rag_project/.env)
22
+ GROQ_API_KEY_1 = os.getenv('GROQ_API_KEY_1', '')
23
+ GROQ_API_KEY_2 = os.getenv('GROQ_API_KEY_2', '')
24
+
25
+ # Keep Google key in case frontend/other services need it
26
+ GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY', '')
27
+
28
+ GROQ_MODEL = os.getenv('GROQ_MODEL', 'llama-3.3-70b-versatile')
29
+ LLM_MODEL = GROQ_MODEL # alias used in older code
30
+ K_RETRIEVE = 3
31
+ TEMPERATURE = 0
src/disease_cache.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Disease-level result cache — avoids repeating RAG+LLM queries for the same disease.
3
+
4
+ Stores (symptoms_text, standard_text, sources_json) per disease in SQLite.
5
+ TTL: 7 days (medical knowledge stable; rebuild only after data update).
6
+
7
+ Impact:
8
+ - Cold (first request for a disease): 3 LLM calls + FAISS search
9
+ - Warm (subsequent requests): 0 LLM calls, 0 FAISS search → ~5 s → <0.1 s
10
+ """
11
+ import sqlite3
12
+ import json
13
+ import time
14
+ import logging
15
+ from pathlib import Path
16
+ from typing import Optional, Tuple, List, Dict
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ TTL_SECONDS = 7 * 24 * 60 * 60 # 7 days
21
+
22
+
23
+ class DiseaseCache:
24
+ def __init__(self, db_path: Optional[str] = None):
25
+ if db_path is None:
26
+ db_path = str(Path(__file__).parent.parent / "disease_cache.db")
27
+ self.db_path = db_path
28
+ self._init_db()
29
+ logger.info(f"[DiseaseCache] SQLite cache at: {self.db_path}")
30
+
31
+ def _conn(self) -> sqlite3.Connection:
32
+ conn = sqlite3.connect(self.db_path, check_same_thread=False)
33
+ conn.execute("PRAGMA journal_mode=WAL")
34
+ conn.row_factory = sqlite3.Row
35
+ return conn
36
+
37
+ def _init_db(self):
38
+ with self._conn() as c:
39
+ c.execute("""
40
+ CREATE TABLE IF NOT EXISTS disease_cache (
41
+ disease TEXT PRIMARY KEY,
42
+ symptoms TEXT NOT NULL,
43
+ standard TEXT NOT NULL,
44
+ sources TEXT NOT NULL DEFAULT '[]',
45
+ cached_at REAL NOT NULL
46
+ )
47
+ """)
48
+ c.commit()
49
+
50
+ # ── read ──────────────────────────────────────────────────────────────────
51
+ def get(self, disease: str) -> Optional[Dict]:
52
+ """Return cached {symptoms, standard, sources} or None if missing/expired."""
53
+ with self._conn() as c:
54
+ row = c.execute(
55
+ "SELECT symptoms, standard, sources, cached_at FROM disease_cache WHERE disease = ?",
56
+ (disease,)
57
+ ).fetchone()
58
+
59
+ if row is None:
60
+ return None
61
+
62
+ age = time.time() - row["cached_at"]
63
+ if age > TTL_SECONDS:
64
+ self.invalidate(disease)
65
+ logger.info(f"[DiseaseCache] Cache expired for '{disease}' ({age/3600:.1f}h old)")
66
+ return None
67
+
68
+ logger.info(f"[DiseaseCache] Cache HIT for '{disease}' ({age/3600:.1f}h old)")
69
+ return {
70
+ "symptoms": row["symptoms"],
71
+ "standard": row["standard"],
72
+ "sources": json.loads(row["sources"]),
73
+ }
74
+
75
+ # ── write ─────────────────────────────────────────────────────────────────
76
+ def set(self, disease: str, symptoms: str, standard: str, sources: List[Dict]):
77
+ """Cache symptoms + standard for a disease."""
78
+ now = time.time()
79
+ sources_json = json.dumps(sources, ensure_ascii=False)
80
+ with self._conn() as c:
81
+ c.execute("""
82
+ INSERT INTO disease_cache (disease, symptoms, standard, sources, cached_at)
83
+ VALUES (?, ?, ?, ?, ?)
84
+ ON CONFLICT(disease) DO UPDATE
85
+ SET symptoms=excluded.symptoms,
86
+ standard=excluded.standard,
87
+ sources=excluded.sources,
88
+ cached_at=excluded.cached_at
89
+ """, (disease, symptoms, standard, sources_json, now))
90
+ c.commit()
91
+ logger.info(f"[DiseaseCache] Cached '{disease}'")
92
+
93
+ # ── management ────────────────────────────────────────────────────────────
94
+ def invalidate(self, disease: str):
95
+ with self._conn() as c:
96
+ c.execute("DELETE FROM disease_cache WHERE disease = ?", (disease,))
97
+ c.commit()
98
+
99
+ def invalidate_all(self):
100
+ with self._conn() as c:
101
+ c.execute("DELETE FROM disease_cache")
102
+ c.commit()
103
+ logger.info("[DiseaseCache] All entries invalidated")
104
+
105
+ def stats(self) -> Dict:
106
+ with self._conn() as c:
107
+ total = c.execute("SELECT COUNT(*) FROM disease_cache").fetchone()[0]
108
+ fresh = c.execute(
109
+ "SELECT COUNT(*) FROM disease_cache WHERE cached_at > ?",
110
+ (time.time() - TTL_SECONDS,)
111
+ ).fetchone()[0]
112
+ return {"total": total, "fresh": fresh, "expired": total - fresh}
src/doctor_evaluator.py CHANGED
@@ -1,166 +1,104 @@
1
- from rag_chain import RAGChain
2
- from vector_store import VectorStoreManager
3
- from data_loader import DataLoader
4
- from config import Config
5
- from langchain_google_genai import ChatGoogleGenerativeAI
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class DoctorEvaluator:
8
- def __init__(self, rag):
9
  self.rag = rag
10
- self.evaluator_llm = ChatGoogleGenerativeAI(
11
- model="gemini-2.5-flash",
12
- google_api_key=Config.GOOGLE_API_KEY,
13
- temperature=0.1
 
 
 
 
 
 
 
14
  )
15
- print("DoctorEvaluator: Ready (Gemini + RAG)!")
16
-
17
- def generate_case(self, disease: str, symptoms: str):
18
- """DÙNG GEMINI TẠO CASE - NHANH + ỔN ĐỊNH"""
19
- prompt = f"""
20
- Bạn là bác sĩ nhi khoa. Tạo một ca bệnh THỰC TẾ cho bệnh: {disease}
21
-
22
- TRIỆU CHỨNG TỪ TÀI LIỆU:
23
- {symptoms}
24
-
25
- YÊU CẦU:
26
- 1. Chỉ tạo lời thoại của mẹ bệnh nhân (3-4 câu)
27
- 2. PHẢI MÔ TẢ các triệu chứng CỤ THỂ của bệnh {disease} từ tài liệu trên
28
- 3. Dùng ngôn ngữ đời thường, tự nhiên
29
- 4. Format: "Bé [tên] nhà chị [tên mẹ] bữa nay bị [triệu chứng cụ thể]. Chị lo lắm! [thêm chi tiết triệu chứng]."
30
-
31
- DỤ TỐT:
32
- - Bệnh Viêm phổi → " An bị sốt cao 39 độ, ho có đờm, thở nhanh phì phò"
33
- - Bệnh Suy tim "Bé Minh thở nhanh, mệt lả, kém, chân tay lạnh"
34
-
35
- CASE BỆNH:
36
- """
37
- result = self.evaluator_llm.invoke([prompt])
38
- return result.content.strip()
39
-
40
-
41
- def evaluate_doctor(self, disease: str):
42
- print(f"\n ĐÁNH GIÁ: {disease}")
43
- print("=" * 80)
44
-
45
- # 1. RAG tìm TRIỆU CHỨNG
46
- print("Hệ thống đang TRUY TÌM TRIỆU CHỨNG:")
47
- symptoms, symptom_sources = self.find_symptoms(disease)
48
- print(f"Xác định triệu chứng: {symptoms[:100]}...")
49
-
50
- # 2. GEMINI tạo CASE
51
- print("Tiến hành tạo case...")
52
- patient_case = self.generate_case(disease, symptoms)
53
- print(f"Case hoàn chỉnh:\n{patient_case}")
54
-
55
- # 3. NHẬP TRẢ LỜI BS
56
- doctor_answer = input("\n NHẬP CÂU TRẢ LỜI CỦA BÁC SĨ:\n").strip()
57
-
58
- # 4. RAG chi tiết + Đánh giá (giữ nguyên)
59
- print("\n TRUY TÌM ĐÁP ÁN CHUẨN:")
60
- standard_data, all_sources = self.get_detailed_standard_knowledge(disease)
61
- evaluation = self.detailed_evaluation(doctor_answer, standard_data)
62
-
63
- return {
64
- 'case': patient_case,
65
- 'standard': standard_data,
66
- 'evaluation': evaluation,
67
- 'sources': all_sources
68
- }
69
-
70
- def find_symptoms(self, disease: str):
71
- """RAG tìm triệu chứng bệnh - CẢI THIỆN"""
72
- # Query chi tiết hơn để tìm đúng bệnh
73
- queries = [
74
- f"{disease} biểu hiện",
75
- f"{disease} triệu chứng",
76
- f"{disease} dấu hiệu"
77
  ]
78
-
79
- all_symptoms = []
80
- sources = []
81
- for q in queries:
82
- print(f" Query: {q}")
83
- answer, src = self.rag.query(q)
84
- if answer and len(answer.strip()) > 50: # Chỉ lấy answer có nội dung
85
- all_symptoms.append(answer)
86
- sources.extend(src)
87
-
88
- # Gom triệu chứng đầy đủ hơn (không cắt quá ngắn)
89
- if all_symptoms:
90
- # Lấy 2 answer tốt nhất, mỗi cái 500 ký tự
91
- symptoms_summary = "\n\n".join([s[:500] for s in all_symptoms[:2]])
92
- else:
93
- symptoms_summary = f"Không tìm thấy thông tin triệu chứng cho {disease}"
94
-
95
- print(f" Tìm thấy triệu chứng: {symptoms_summary[:200]}...")
96
- return symptoms_summary, sources
97
-
98
- def get_detailed_standard_knowledge(self, disease: str):
99
- """RAG CHẨN ĐOÁN CHI TIẾT + ĐIỀU TRỊ"""
100
- queries = {
101
- 'LAM_SANG': [f"{disease} lâm sàng"],
102
- 'CAN_LAM_SANG': [f"{disease} cận lâm sàng"],
103
- 'CHAN_DOAN_XAC_DINH': [f"{disease} chẩn đoán xác định"],
104
- 'CHAN_DOAN_PHAN_BIET': [f"{disease} chẩn đoán phân biệt"],
105
- 'DIEU_TRI': [f"{disease} điều trị", f"{disease} thuốc"]
106
- }
107
-
108
- results = {}
109
- all_sources = []
110
-
111
- for section, qlist in queries.items():
112
- print(f" {section}:")
113
- section_content = []
114
- for q in qlist:
115
- print(f" {q}")
116
- answer, sources = self.rag.query(q)
117
- section_content.append(answer)
118
- all_sources.extend(sources)
119
- results[section] = "\n".join(section_content[:2])
120
-
121
- # Format đẹp
122
- standard_text = f"""
123
- CHẨN ĐOÁN LÂM SÀNG:
124
- {results['LAM_SANG']}
125
-
126
- CHẨN ĐOÁN CẬN LÂM SÀNG:
127
- {results['CAN_LAM_SANG']}
128
-
129
- CHẨN ĐOÁN XÁC ĐỊNH:
130
- {results['CHAN_DOAN_XAC_DINH']}
131
-
132
- CHẨN ĐOÁN PHÂN BIỆT:
133
- {results['CHAN_DOAN_PHAN_BIET']}
134
 
135
- CÁCH ĐIỀU TRỊ:
136
- {results['DIEU_TRI']}
137
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return standard_text, all_sources
139
-
140
- def detailed_evaluation(self, doctor_answer: str, standard_data: str):
141
- """ĐÁNH GIÁ CHI TIẾT + DIỄN GIẢI"""
142
- prompt = f"""
143
- BẠN LÀ CHUYÊN GIA Y KHOA ĐÁNH GIÁ BÁC SĨ
144
-
145
- CÂU TRẢ LỜI BÁC SĨ:
146
- {doctor_answer}
147
-
148
- KIẾN THỨC CHUẨN:
149
- {standard_data}
150
 
151
- PHÂN TÍCH CHI TIẾT (JSON):
152
- {{
153
- "diem_manh": ["..."],
154
- "diem_yeu": ["..."],
155
- "da_co": ["..."],
156
- "thieu": ["..."],
157
- "dien_giai": ["Giải thích sao đúng/thiếu..."],
158
- "diem_so": "85/100",
159
- "nhan_xet_tong_quan": "..."
160
- }}
161
-
162
- JSON PURE:
163
- """
164
-
165
- result = self.evaluator_llm.invoke([prompt])
166
- return result.content
 
1
+ """
2
+ DoctorEvaluator uses Groq LLM (via shared GroqKeyManager) for:
3
+ 1. generate_case() : 1 LLM call
4
+ 2. detailed_evaluation() : 1 LLM call (compact JSON, ~4 fields)
5
+
6
+ RAG queries reduced:
7
+ - find_symptoms : 3 → 1 combined query
8
+ - get_detailed_standard_knowledge : 6 → 2 combined queries
9
+ Total LLM calls per start-case: 1(symptoms RAG) + 2(standard RAG) + 1(case) = 4
10
+ Total LLM calls per evaluate : 2(standard RAG) + 1(eval) = 3
11
+ """
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ from typing import Dict, Tuple, List
14
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
15
+ from rag_chain import RAGChain, get_key_manager, _is_rate_limit
16
+
17
 
18
  class DoctorEvaluator:
19
+ def __init__(self, rag: RAGChain):
20
  self.rag = rag
21
+ self._km = get_key_manager()
22
+ print("DoctorEvaluator: Ready (Groq + RAG)!")
23
+
24
+ # ── internal helper ────────────────────────────────────────────────────────
25
+ def _llm_invoke(self, prompt: str, temperature: float = 0.1) -> str:
26
+ """Call Groq with retry + key rotation on 429."""
27
+ @retry(
28
+ retry=retry_if_exception(_is_rate_limit),
29
+ wait=wait_exponential(multiplier=1, min=5, max=30),
30
+ stop=stop_after_attempt(4),
31
+ reraise=True,
32
  )
33
+ def _call():
34
+ try:
35
+ llm = self._km.build_llm(temperature=temperature)
36
+ return llm.invoke([prompt])
37
+ except Exception as exc:
38
+ if _is_rate_limit(exc):
39
+ self._km.mark_rate_limited(self._km.current())
40
+ self._km.rotate()
41
+ raise
42
+ return _call().content
43
+
44
+ # ── public methods ─────────────────────────────────────────────────────────
45
+ def generate_case(self, disease: str, symptoms: str) -> str:
46
+ """Tạo ca bệnh nhi bằng 1 LLM call, prompt ngắn gọn."""
47
+ prompt = (
48
+ f"Bạn là bác sĩ nhi khoa. Tạo 1 lời thoại của mẹ bệnh nhân (2-3 câu, "
49
+ f"ngôn ngữ đời thường) mô tả triệu chứng cụ thể của bệnh {disease}.\n"
50
+ f"Triệu chứng từ tài liệu: {symptoms[:400]}\n"
51
+ f"Format: 'Bé [tên] nhà chị [tên mẹ] bị [triệu chứng cụ thể]. [Thêm chi tiết].'\n"
52
+ f"CASE:"
53
+ )
54
+ return self._llm_invoke(prompt, temperature=0.3).strip()
55
+
56
+ def find_symptoms(self, disease: str) -> Tuple[str, list]:
57
+ """1 RAG query (thay cho 3 query trước đây)."""
58
+ answer, sources = self.rag.query(f"{disease} triệu chứng biểu hiện lâm sàng")
59
+ summary = answer[:600] if answer else f"Không tìm thấy thông tin triệu chứng cho {disease}"
60
+ return summary, sources
61
+
62
+ def get_detailed_standard_knowledge(self, disease: str) -> Tuple[str, list]:
63
+ """2 RAG queries thay cho 6 query trước đây."""
64
+ tasks = [
65
+ ("CHAN_DOAN", f"{disease} lâm sàng cận lâm sàng chẩn đoán xác định phân biệt"),
66
+ ("DIEU_TRI", f"{disease} điều trị thuốc"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ raw: Dict[str, Tuple] = {}
70
+ with ThreadPoolExecutor(max_workers=2) as pool:
71
+ futures = {pool.submit(self.rag.query, q): key for key, q in tasks}
72
+ for future in as_completed(futures):
73
+ key = futures[future]
74
+ try:
75
+ raw[key] = future.result()
76
+ except Exception as exc:
77
+ print(f"[WARN] {key} query failed: {exc}")
78
+ raw[key] = ("Khong tim thay thong tin", [])
79
+
80
+ all_sources: list = []
81
+ for key, _ in tasks:
82
+ all_sources.extend(raw.get(key, ("", []))[1])
83
+
84
+ def r(k): return raw.get(k, ("",))[0]
85
+
86
+ standard_text = (
87
+ f"CHAN DOAN:\n{r('CHAN_DOAN')}\n\n"
88
+ f"DIEU TRI:\n{r('DIEU_TRI')}"
89
+ )
90
  return standard_text, all_sources
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ def detailed_evaluation(self, doctor_answer: str, standard_data: str) -> str:
93
+ """Đánh giá ngắn gọn — JSON 4 trường, tối đa 300 token output."""
94
+ std = standard_data[:1200]
95
+ doc = doctor_answer[:600]
96
+ prompt = (
97
+ "Chuyên gia y khoa đánh giá câu trả lời bác sĩ. Trả về JSON thuần túy, KHÔNG giải thích thêm.\n\n"
98
+ f"CÂU TRẢ LỜI BÁC SĨ:\n{doc}\n\n"
99
+ f"KIẾN THỨC CHUẨN (tóm tắt):\n{std}\n\n"
100
+ "JSON format (ngắn gọn, mỗi mảng tối đa 3 phần tử):\n"
101
+ '{"diem_so":"85/100","nhan_xet_tong_quan":"2 câu tóm tắt","diem_manh":["...","..."],"thieu":["...","..."]}\n\n'
102
+ "JSON:"
103
+ )
104
+ return self._llm_invoke(prompt, temperature=0)
 
 
 
src/hybrid_retriever.py CHANGED
@@ -1,39 +1,55 @@
1
  from langchain_community.vectorstores import FAISS
2
  import re
3
 
 
 
 
 
 
 
 
 
 
 
 
4
  class HybridRetriever:
5
  def __init__(self, vectorstore):
6
  self.vs = vectorstore
7
-
8
  def keyword_search(self, query, k=5):
9
- """Exact keyword matching - PRIORITY 1"""
10
- keywords = re.findall(r'\b\w{3,}\b', query.lower())
 
 
 
 
 
 
11
  scored_docs = []
12
-
13
  for doc_id, doc in self.vs.docstore._dict.items():
14
  content_lower = doc.page_content.lower()
15
  title_lower = doc.metadata.get('chunk_title', '').lower()
16
-
17
- # Score cao nếu match title + content
18
- score = sum(2 if kw in title_lower else 1
19
- for kw in keywords if kw in content_lower or kw in title_lower)
20
-
 
 
21
  if score > 0:
22
  scored_docs.append((score, doc))
23
-
24
  scored_docs.sort(reverse=True, key=lambda x: x[0])
25
  return [doc for _, doc in scored_docs[:k]]
26
-
27
  def hybrid_search(self, query, k=3):
28
  """KEYWORD FIRST → Semantic backup"""
29
- # PRIORITY 1: Keyword exact match
30
- keyword_docs = self.keyword_search(query, k=k*2)
31
-
32
  if keyword_docs:
33
  print(f" KEYWORD HIT: {len(keyword_docs)} docs")
34
  return keyword_docs[:k]
35
-
36
- # PRIORITY 2: Semantic fallback
37
  print(" Semantic fallback...")
38
  semantic_docs = self.vs.similarity_search(query, k=k)
39
  return semantic_docs
 
1
  from langchain_community.vectorstores import FAISS
2
  import re
3
 
4
+ # Vietnamese stop words — high-frequency words that corrupt keyword ranking signals
5
+ VIETNAMESE_STOPWORDS = {
6
+ 'và', 'là', 'của', 'có', 'cho', 'với', 'các', 'được', 'trong',
7
+ 'đến', 'khi', 'này', 'bằng', 'theo', 'một', 'những', 'từ', 'hay',
8
+ 'như', 'hoặc', 'về', 'tại', 'trên', 'sau', 'trước', 'cùng', 'để',
9
+ 'không', 'cần', 'phải', 'nên', 'thể', 'vào', 'ra', 'đây', 'đó',
10
+ 'nào', 'mà', 'thì', 'sẽ', 'đã', 'còn', 'vẫn', 'rất', 'nhiều',
11
+ 'đặc', 'biệt', 'thêm', 'khác', 'tất', 'cả', 'nếu', 'bởi', 'vì',
12
+ }
13
+
14
+
15
  class HybridRetriever:
16
  def __init__(self, vectorstore):
17
  self.vs = vectorstore
18
+
19
  def keyword_search(self, query, k=5):
20
+ """Exact keyword matching with Vietnamese stop-word filtering - PRIORITY 1"""
21
+ keywords = [
22
+ w for w in re.findall(r'\b\w{3,}\b', query.lower())
23
+ if w not in VIETNAMESE_STOPWORDS
24
+ ]
25
+ if not keywords:
26
+ return []
27
+
28
  scored_docs = []
 
29
  for doc_id, doc in self.vs.docstore._dict.items():
30
  content_lower = doc.page_content.lower()
31
  title_lower = doc.metadata.get('chunk_title', '').lower()
32
+
33
+ # Title match scores 2x; content match scores 1x
34
+ score = sum(
35
+ 2 if kw in title_lower else 1
36
+ for kw in keywords
37
+ if kw in content_lower or kw in title_lower
38
+ )
39
  if score > 0:
40
  scored_docs.append((score, doc))
41
+
42
  scored_docs.sort(reverse=True, key=lambda x: x[0])
43
  return [doc for _, doc in scored_docs[:k]]
44
+
45
  def hybrid_search(self, query, k=3):
46
  """KEYWORD FIRST → Semantic backup"""
47
+ keyword_docs = self.keyword_search(query, k=k * 2)
48
+
 
49
  if keyword_docs:
50
  print(f" KEYWORD HIT: {len(keyword_docs)} docs")
51
  return keyword_docs[:k]
52
+
 
53
  print(" Semantic fallback...")
54
  semantic_docs = self.vs.similarity_search(query, k=k)
55
  return semantic_docs
src/key_manager.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GroqKeyManager — round-robin key rotation with immediate failover on 429.
3
+
4
+ Usage:
5
+ mgr = GroqKeyManager([KEY_1, KEY_2])
6
+ key = mgr.current() # get current key
7
+ key = mgr.rotate() # advance to next key (call on 429)
8
+ llm = mgr.build_llm(model) # ChatGroq with current key
9
+ """
10
+ import threading
11
+ import time
12
+ import logging
13
+ from typing import List
14
+
15
+ from langchain_groq import ChatGroq
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class GroqKeyManager:
21
+ """Thread-safe round-robin Groq API key manager."""
22
+
23
+ def __init__(self, keys: List[str], model: str = "llama-3.3-70b-versatile"):
24
+ self._keys = [k.strip() for k in keys if k and k.strip()]
25
+ if not self._keys:
26
+ raise ValueError("GroqKeyManager: no valid API keys provided")
27
+ self._model = model
28
+ self._idx = 0
29
+ self._lock = threading.Lock()
30
+ # per-key cooldown tracking: key → expiry timestamp
31
+ self._cooldown: dict[str, float] = {}
32
+ logger.info(f"[KeyManager] {len(self._keys)} Groq key(s) loaded, model={model}")
33
+
34
+ def current(self) -> str:
35
+ with self._lock:
36
+ return self._keys[self._idx % len(self._keys)]
37
+
38
+ def rotate(self) -> str:
39
+ """Advance to next available (non-cooled-down) key. Returns the new key."""
40
+ with self._lock:
41
+ now = time.time()
42
+ for _ in range(len(self._keys)):
43
+ self._idx = (self._idx + 1) % len(self._keys)
44
+ key = self._keys[self._idx]
45
+ if now >= self._cooldown.get(key, 0):
46
+ logger.warning(f"[KeyManager] Rotated to key index {self._idx}")
47
+ return key
48
+ # all keys on cooldown — return current and let tenacity wait
49
+ logger.warning("[KeyManager] All keys on cooldown, returning current key")
50
+ return self._keys[self._idx % len(self._keys)]
51
+
52
+ def mark_rate_limited(self, key: str, cooldown_secs: int = 62):
53
+ """Mark a key as rate-limited for cooldown_secs seconds."""
54
+ with self._lock:
55
+ self._cooldown[key] = time.time() + cooldown_secs
56
+ logger.warning(f"[KeyManager] Key ...{key[-6:]} cooled down for {cooldown_secs}s")
57
+
58
+ def build_llm(self, temperature: float = 0) -> ChatGroq:
59
+ """Return a ChatGroq instance using the current key."""
60
+ return ChatGroq(
61
+ model=self._model,
62
+ api_key=self.current(),
63
+ temperature=temperature,
64
+ max_tokens=800, # cap output tokens to save quota
65
+ )
src/rag_chain.py CHANGED
@@ -1,83 +1,75 @@
1
- from langchain_google_genai import ChatGoogleGenerativeAI
2
  from langchain_core.prompts import PromptTemplate
3
  from config import Config
 
4
  from hybrid_retriever import HybridRetriever
5
  from vector_store import VectorStoreManager
6
 
7
- class RAGChain:
8
- def __init__(self, vector_store_manager: VectorStoreManager):
9
- self.llm = ChatGoogleGenerativeAI(
10
- model=Config.LLM_MODEL,
11
- google_api_key=Config.GOOGLE_API_KEY,
12
- temperature=0 # 0 để deterministic
 
 
 
 
13
  )
14
-
15
- self.vectorstore = vector_store_manager.vector_store
16
- self.retriever = HybridRetriever(self.vectorstore) # FIX TYPO
17
-
18
- # PROMPT MỚI: TRẢ NỘI DUNG CHUNK + TÓM TẮT
19
- self.custom_prompt = PromptTemplate(
20
- input_variables=["context", "question"],
21
- template="""
22
- Bạn là bác sĩ y khoa. Dựa vào TÀI LIỆU sau:
23
 
24
- CONTEXT:
25
- {context}
26
 
27
- CÂU HỎI: {question}
 
 
28
 
29
- TRẢ LỜI:
30
- 1. TRÍCH DẪN ĐÚNG nội dung từ CONTEXT (giữ nguyên văn bản)
31
- 2. Tóm tắt ngắn gọn nếu cần
32
- 3. Luôn ưu tiên thông tin từ chunk chính xác nhất
33
 
34
- NỘI DUNG TÀI LIỆU:
35
- """
 
 
 
 
 
 
36
  )
37
-
38
- def query(self, question: str):
39
- """HYBRID RETRIEVAL + FULL CHUNK CONTENT"""
40
-
41
- # BƯỚC 1: HYBRID SEARCH - PRIORITY KEYWORD
42
- sources = self.retriever.hybrid_search(question, k=4)
43
-
44
- # BƯỚC 2: RE-RANK theo keyword match
45
- ranked_sources = self.rerank_sources(sources, question)
46
-
47
- # BƯỚC 3: Tạo context FULL CONTENT
48
- context = self.build_context(ranked_sources)
49
-
50
- # BƯỚC 4: Generate với prompt rõ ràng
51
- formatted_prompt = self.custom_prompt.format(
52
- context=context,
53
- question=question
54
  )
55
-
56
- result = self.llm.invoke([formatted_prompt])
57
- return result.content, ranked_sources
58
-
 
 
 
 
 
 
 
 
 
59
  def rerank_sources(self, sources, question):
60
- """RE-RANK: Keyword match > Semantic"""
61
  keywords = question.lower().split()
62
-
63
- def score_doc(doc):
64
- content = doc.page_content.lower()
65
- title = doc.metadata.get('chunk_title', '').lower()
66
- score = sum(1 for kw in keywords if kw in content or kw in title)
67
- return score
68
-
69
- return sorted(sources, key=score_doc, reverse=True)
70
-
71
  def build_context(self, sources):
72
- """FULL CHUNK CONTENT + METADATA"""
73
- context_parts = []
74
  for i, doc in enumerate(sources[:3]):
75
- file = doc.metadata.get('source_file', 'N/A')
76
- chunk_title = doc.metadata.get('chunk_title', 'N/A')
77
- section_title = doc.metadata.get('section_title', 'N/A')
78
-
79
- context_parts.append(
80
- f"[{i+1}] {file} | {chunk_title} | {section_title}\n"
81
- f"NỘI DUNG:\n{doc.page_content}\n{'='*80}"
82
- )
83
- return "\n\n".join(context_parts)
 
1
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
2
  from langchain_core.prompts import PromptTemplate
3
  from config import Config
4
+ from key_manager import GroqKeyManager
5
  from hybrid_retriever import HybridRetriever
6
  from vector_store import VectorStoreManager
7
 
8
+ # Shared key manager -- single instance reused across all RAGChain objects
9
+ _KEY_MANAGER = None
10
+
11
+
12
+ def get_key_manager():
13
+ global _KEY_MANAGER
14
+ if _KEY_MANAGER is None:
15
+ _KEY_MANAGER = GroqKeyManager(
16
+ keys=[Config.GROQ_API_KEY_1, Config.GROQ_API_KEY_2],
17
+ model=Config.GROQ_MODEL,
18
  )
19
+ return _KEY_MANAGER
 
 
 
 
 
 
 
 
20
 
 
 
21
 
22
+ def _is_rate_limit(exc):
23
+ msg = str(exc).lower()
24
+ return "429" in msg or "quota" in msg or "rate limit" in msg or "ratelimit" in msg
25
 
 
 
 
 
26
 
27
+ class RAGChain:
28
+ def __init__(self, vector_store_manager):
29
+ self._km = get_key_manager()
30
+ self.vectorstore = vector_store_manager.vector_store
31
+ self.retriever = HybridRetriever(self.vectorstore)
32
+ self.prompt_template = PromptTemplate(
33
+ input_variables=["context", "question"],
34
+ template="Tài liệu y khoa:\n{context}\n\nCâu hỏi: {question}\n\nTrả lời ngắn gọn, chọn lọc thông tin quan trọng nhất từ tài liệu (tối đa 200 từ):"
35
  )
36
+
37
+ def query(self, question):
38
+ sources = self.retriever.hybrid_search(question, k=3)
39
+ ranked = self.rerank_sources(sources, question)
40
+ context = self.build_context(ranked)
41
+ prompt = self.prompt_template.format(context=context, question=question)
42
+
43
+ @retry(
44
+ retry=retry_if_exception(_is_rate_limit),
45
+ wait=wait_exponential(multiplier=1, min=5, max=30),
46
+ stop=stop_after_attempt(4),
47
+ reraise=True,
 
 
 
 
 
48
  )
49
+ def _invoke():
50
+ try:
51
+ llm = self._km.build_llm(temperature=0)
52
+ return llm.invoke([prompt])
53
+ except Exception as exc:
54
+ if _is_rate_limit(exc):
55
+ self._km.mark_rate_limited(self._km.current())
56
+ self._km.rotate()
57
+ raise
58
+
59
+ result = _invoke()
60
+ return result.content, ranked
61
+
62
  def rerank_sources(self, sources, question):
 
63
  keywords = question.lower().split()
64
+ def score(doc):
65
+ text = doc.page_content.lower() + doc.metadata.get("chunk_title", "").lower()
66
+ return sum(1 for kw in keywords if kw in text)
67
+ return sorted(sources, key=score, reverse=True)
68
+
 
 
 
 
69
  def build_context(self, sources):
70
+ parts = []
 
71
  for i, doc in enumerate(sources[:3]):
72
+ meta = f"[{i+1}] {doc.metadata.get('source_file','?')} | {doc.metadata.get('chunk_title','?')}"
73
+ content = doc.page_content[:600]
74
+ parts.append(f"{meta}\n{content}")
75
+ return "\n\n".join(parts)
 
 
 
 
 
src/session_store.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite-backed session store with TTL for RAG sessions.
3
+ Replaces the in-memory dict to persist sessions across server restarts
4
+ and prevent unbounded memory growth.
5
+ """
6
+ import sqlite3
7
+ import json
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Optional, Dict, Any
11
+
12
+ SESSION_TTL_SECONDS = 24 * 60 * 60 # 24 hours
13
+
14
+
15
+ class SessionStore:
16
+ def __init__(self, db_path: Optional[str] = None):
17
+ if db_path is None:
18
+ db_path = str(Path(__file__).parent.parent / "sessions.db")
19
+ self.db_path = db_path
20
+ self._init_db()
21
+ print(f"[SessionStore] SQLite store at: {self.db_path}")
22
+
23
+ def _get_conn(self) -> sqlite3.Connection:
24
+ conn = sqlite3.connect(self.db_path, check_same_thread=False)
25
+ conn.execute("PRAGMA journal_mode=WAL") # Better concurrent read performance
26
+ conn.row_factory = sqlite3.Row
27
+ return conn
28
+
29
+ def _init_db(self):
30
+ with self._get_conn() as conn:
31
+ conn.execute(
32
+ """
33
+ CREATE TABLE IF NOT EXISTS sessions (
34
+ session_id TEXT PRIMARY KEY,
35
+ data TEXT NOT NULL,
36
+ created_at REAL NOT NULL,
37
+ updated_at REAL NOT NULL
38
+ )
39
+ """
40
+ )
41
+ conn.commit()
42
+
43
+ def get(self, session_id: str) -> Optional[Dict[str, Any]]:
44
+ """Return session data or None if not found / expired."""
45
+ with self._get_conn() as conn:
46
+ row = conn.execute(
47
+ "SELECT data, updated_at FROM sessions WHERE session_id = ?",
48
+ (session_id,),
49
+ ).fetchone()
50
+
51
+ if row is None:
52
+ return None
53
+
54
+ if time.time() - row["updated_at"] > SESSION_TTL_SECONDS:
55
+ self.delete(session_id)
56
+ return None
57
+
58
+ return json.loads(row["data"])
59
+
60
+ def set(self, session_id: str, data: Dict[str, Any]):
61
+ """Persist or update a session (all values must be JSON-serialisable)."""
62
+ now = time.time()
63
+ data_json = json.dumps(data, ensure_ascii=False)
64
+ with self._get_conn() as conn:
65
+ conn.execute(
66
+ """
67
+ INSERT INTO sessions (session_id, data, created_at, updated_at)
68
+ VALUES (?, ?, ?, ?)
69
+ ON CONFLICT(session_id) DO UPDATE
70
+ SET data = excluded.data, updated_at = excluded.updated_at
71
+ """,
72
+ (session_id, data_json, now, now),
73
+ )
74
+ conn.commit()
75
+
76
+ def delete(self, session_id: str):
77
+ with self._get_conn() as conn:
78
+ conn.execute(
79
+ "DELETE FROM sessions WHERE session_id = ?", (session_id,)
80
+ )
81
+ conn.commit()
82
+
83
+ def cleanup_expired(self) -> int:
84
+ """Delete all sessions older than SESSION_TTL_SECONDS. Returns count deleted."""
85
+ cutoff = time.time() - SESSION_TTL_SECONDS
86
+ with self._get_conn() as conn:
87
+ cur = conn.execute(
88
+ "DELETE FROM sessions WHERE updated_at < ?", (cutoff,)
89
+ )
90
+ conn.commit()
91
+ deleted = cur.rowcount
92
+ if deleted:
93
+ print(f"[SessionStore] Cleaned up {deleted} expired session(s)")
94
+ return deleted
95
+
96
+ def cleanup_expired(self) -> int:
97
+ """Remove sessions older than TTL. Returns number of rows deleted."""
98
+ cutoff = time.time() - SESSION_TTL_SECONDS
99
+ with self._get_conn() as conn:
100
+ deleted = conn.execute(
101
+ "DELETE FROM sessions WHERE updated_at < ?", (cutoff,)
102
+ ).rowcount
103
+ conn.commit()
104
+ if deleted:
105
+ print(f"[SessionStore] Cleaned up {deleted} expired session(s)")
106
+ return deleted
src/vector_store.py CHANGED
@@ -4,7 +4,7 @@ from embeddings import EmbeddingsManager
4
  from typing import List
5
  from pathlib import Path
6
  import json
7
- from pymongo import MongoClient
8
 
9
  from data_loader import DataLoader
10
 
@@ -72,8 +72,8 @@ class VectorStoreManager:
72
  return self.vector_store.as_retriever(search_kwargs={"k": k})
73
 
74
  def save_documents(self, docs):
75
-
76
- output_dir = Path(r"D:\Storage\rag_project\store")
77
  output_dir.mkdir(parents=True, exist_ok=True)
78
 
79
  records = []
 
4
  from typing import List
5
  from pathlib import Path
6
  import json
7
+ # pymongo is only needed for save_documents(); imported lazily below
8
 
9
  from data_loader import DataLoader
10
 
 
72
  return self.vector_store.as_retriever(search_kwargs={"k": k})
73
 
74
  def save_documents(self, docs):
75
+ from pymongo import MongoClient # optional dependency
76
+ output_dir = Path(__file__).parent.parent / "store"
77
  output_dir.mkdir(parents=True, exist_ok=True)
78
 
79
  records = []