Nikhil Pravin Pise commited on
Commit
fd5543a
Β·
1 Parent(s): ca20dc7

fix: all 15 bugs + full agentic RAG pipeline in HF Space

Browse files

Critical fixes:
- OpenSearch client param names aligned (query_text/query_vector)
- _execute_search preserves _source as nested dict
- health.py: SQLAlchemy text() wrapper + FAISS import path
- llm_config: module-level get_synthesizer() function
- Redis cache: simplified get(key)/set(key,value) API

High priority:
- Infinite rewrite loop prevented via retrieval_attempts counter
- analyze.py: handles both dict and object results

HF Space full pipeline integration:
- retrieve_node: retriever-agnostic (FAISS/OpenSearch/BM25)
- AgenticContext: retriever field added
- grade_documents_node + generate_answer_node: unified doc format
- app.py: full guardrail->retrieve->grade->rewrite->generate pipeline
- asyncio.get_event_loop() -> get_running_loop()

huggingface/app.py CHANGED
@@ -582,106 +582,149 @@ def format_summary(response: dict, elapsed: float) -> str:
582
 
583
 
584
  # ---------------------------------------------------------------------------
585
- # Q&A Chat Functions - Streaming Support
586
  # ---------------------------------------------------------------------------
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  def answer_medical_question(
589
- question: str,
590
  context: str = "",
591
  chat_history: list = None
592
  ) -> tuple[str, list]:
593
- """
594
- Answer a free-form medical question using retriever + LLM directly.
595
-
596
- Args:
597
- question: The user's medical question
598
- context: Optional biomarker/patient context
599
- chat_history: Previous conversation history
600
-
601
- Returns:
602
- Tuple of (formatted_answer, updated_chat_history)
603
  """
604
  if not question.strip():
605
  return "", chat_history or []
606
-
607
- # Check API key dynamically
608
  groq_key, google_key = get_api_keys()
609
  if not groq_key and not google_key:
610
  error_msg = "❌ Please add your GROQ_API_KEY or GOOGLE_API_KEY in Space Settings β†’ Secrets."
611
  history = (chat_history or []) + [(question, error_msg)]
612
  return error_msg, history
613
-
614
- # Setup provider
615
  provider = setup_llm_provider()
616
  logger.info(f"Q&A using provider: {provider}")
617
-
618
  try:
619
  start_time = time.time()
620
-
621
- # Import retriever and LLM
622
- from src.services.retrieval import make_retriever
623
- from src.llm_config import get_synthesizer
624
-
625
- # Initialize retriever
626
- retriever = make_retriever()
627
-
628
- # Build search query with context
629
- search_query = question
630
- if context.strip():
631
- search_query = f"{context} {question}"
632
-
633
- # Retrieve relevant documents
634
- docs = retriever.search(search_query, top_k=5)
635
-
636
- # Format context from retrieved docs
637
- doc_context = ""
638
- if docs:
639
- doc_texts = []
640
- for doc in docs[:5]:
641
- if hasattr(doc, 'content'):
642
- doc_texts.append(doc.content[:500])
643
- elif isinstance(doc, dict) and 'content' in doc:
644
- doc_texts.append(doc['content'][:500])
645
- doc_context = "\n\n---\n\n".join(doc_texts)
646
-
647
- # Get LLM
648
- llm = get_synthesizer()
649
-
650
- # Build prompt
651
- prompt = f"""You are a medical AI assistant. Answer the following medical question based on the provided context.
652
- Be helpful, accurate, and include relevant medical information. Always recommend consulting a healthcare professional for personal medical advice.
653
-
654
- Context from medical knowledge base:
655
- {doc_context if doc_context else "No specific context available - using general medical knowledge."}
656
-
657
- Patient Context: {context if context else "Not provided"}
658
 
659
- Question: {question}
 
 
 
 
 
 
 
 
 
 
 
 
660
 
661
- Answer:"""
662
-
663
-
664
- # Generate response
665
- response = llm.invoke(prompt)
666
- answer = response.content if hasattr(response, 'content') else str(response)
667
-
668
  if not answer:
669
  answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
670
-
671
  elapsed = time.time() - start_time
672
-
673
- # Format response with metadata
 
 
 
 
 
 
 
674
  formatted_answer = f"""{answer}
675
 
676
  ---
677
- *⏱️ Response time: {elapsed:.1f}s | πŸ€– Powered by RAG*
678
  """
679
-
680
- # Update chat history
681
  history = (chat_history or []) + [(question, formatted_answer)]
682
-
683
  return formatted_answer, history
684
-
685
  except Exception as exc:
686
  logger.exception(f"Q&A error: {exc}")
687
  error_msg = f"❌ Error: {str(exc)}"
@@ -690,83 +733,48 @@ Answer:"""
690
 
691
 
692
  def streaming_answer(question: str, context: str = ""):
693
- """
694
- Stream answer tokens for real-time response.
695
- Uses retriever + LLM directly (not the guild).
696
  """
697
  if not question.strip():
698
  yield ""
699
  return
700
-
701
- # Check API key
702
  groq_key, google_key = get_api_keys()
703
  if not groq_key and not google_key:
704
  yield "❌ Please add your GROQ_API_KEY or GOOGLE_API_KEY in Space Settings β†’ Secrets."
705
  return
706
-
707
- # Setup provider
708
- setup_llm_provider()
709
-
710
- try:
711
- yield "πŸ” Searching medical knowledge base...\n\n"
712
-
713
- from src.services.retrieval import make_retriever
714
- from src.llm_config import get_synthesizer
715
-
716
- # Initialize retriever
717
- retriever = make_retriever()
718
-
719
- # Build search query
720
- search_query = question
721
- if context.strip():
722
- search_query = f"{context} {question}"
723
-
724
- yield "πŸ” Searching medical knowledge base...\nπŸ“š Retrieving relevant documents...\n\n"
725
-
726
- # Retrieve docs
727
- docs = retriever.search(search_query, top_k=5)
728
-
729
- # Format context
730
- doc_context = ""
731
- if docs:
732
- doc_texts = []
733
- for doc in docs[:5]:
734
- if hasattr(doc, 'content'):
735
- doc_texts.append(doc.content[:500])
736
- elif isinstance(doc, dict) and 'content' in doc:
737
- doc_texts.append(doc['content'][:500])
738
- doc_context = "\n\n---\n\n".join(doc_texts)
739
-
740
- yield "πŸ” Searching medical knowledge base...\nπŸ“š Retrieving relevant documents...\nπŸ’­ Generating response...\n\n"
741
-
742
- # Get LLM
743
- llm = get_synthesizer()
744
-
745
- start_time = time.time()
746
-
747
- # Build prompt
748
- prompt = f"""You are a medical AI assistant. Answer the following medical question based on the provided context.
749
- Be helpful, accurate, and include relevant medical information. Always recommend consulting a healthcare professional for personal medical advice.
750
 
751
- Context from medical knowledge base:
752
- {doc_context if doc_context else "No specific context available - using general medical knowledge."}
753
 
754
- Patient Context: {context if context else "Not provided"}
 
755
 
756
- Question: {question}
757
 
758
- Answer:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
759
 
760
- # Generate response
761
- response = llm.invoke(prompt)
762
- answer = response.content if hasattr(response, 'content') else str(response)
763
-
764
  if not answer:
765
  answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
766
-
 
 
767
  elapsed = time.time() - start_time
768
-
769
- # Simulate streaming by revealing text progressively
770
  words = answer.split()
771
  accumulated = ""
772
  for i, word in enumerate(words):
@@ -774,20 +782,27 @@ Answer:"""
774
  if i % 5 == 0:
775
  yield accumulated
776
  time.sleep(0.02)
777
-
778
- # Final complete response
 
 
 
 
 
 
 
 
779
  yield f"""{answer}
780
 
781
  ---
782
- *⏱️ Response time: {elapsed:.1f}s | πŸ€– Powered by RAG*
783
  """
784
-
785
  except Exception as exc:
786
  logger.exception(f"Streaming Q&A error: {exc}")
787
  yield f"❌ Error: {str(exc)}"
788
 
789
 
790
-
791
  # ---------------------------------------------------------------------------
792
  # Gradio Interface
793
  # ---------------------------------------------------------------------------
 
582
 
583
 
584
  # ---------------------------------------------------------------------------
585
+ # Q&A Chat Functions β€” Full Agentic RAG Pipeline
586
  # ---------------------------------------------------------------------------
587
 
588
+ _rag_service = None
589
+ _rag_service_error = None
590
+
591
+
592
+ def _get_rag_service():
593
+ """Lazily initialize the full agentic RAG service for Q&A.
594
+
595
+ Uses a FAISS-backed retriever wrapped in an AgenticContext so the
596
+ guardrail β†’ retrieve β†’ grade β†’ rewrite β†’ generate pipeline runs
597
+ identically to the production API.
598
+ """
599
+ global _rag_service, _rag_service_error
600
+
601
+ if _rag_service is not None:
602
+ return _rag_service
603
+
604
+ if _rag_service_error is not None:
605
+ logger.warning("Previous RAG service init failed, retrying...")
606
+ _rag_service_error = None
607
+
608
+ try:
609
+ from src.services.agents.agentic_rag import AgenticRAGService
610
+ from src.services.agents.context import AgenticContext
611
+ from src.services.retrieval.factory import make_retriever
612
+ from src.llm_config import get_synthesizer
613
+
614
+ llm = get_synthesizer()
615
+ retriever = make_retriever() # auto-detects FAISS
616
+
617
+ # HF Space: skip OpenSearch, Redis, Langfuse
618
+ # but still get guardrail, grading, rewriting, generation
619
+ context = AgenticContext(
620
+ llm=llm,
621
+ embedding_service=None,
622
+ opensearch_client=None,
623
+ cache=None,
624
+ tracer=None,
625
+ retriever=retriever,
626
+ )
627
+
628
+ _rag_service = AgenticRAGService(context)
629
+ logger.info("Agentic RAG service initialized for Q&A")
630
+ return _rag_service
631
+
632
+ except Exception as exc:
633
+ logger.error(f"Failed to init agentic RAG service: {exc}")
634
+ _rag_service_error = exc
635
+ return None
636
+
637
+
638
+ def _fallback_qa(question: str, context_text: str = "") -> str:
639
+ """Direct retriever+LLM fallback when agentic pipeline is unavailable."""
640
+ from src.services.retrieval.factory import make_retriever
641
+ from src.llm_config import get_synthesizer
642
+
643
+ retriever = make_retriever()
644
+ search_query = f"{context_text} {question}" if context_text.strip() else question
645
+ docs = retriever.retrieve(search_query, top_k=5)
646
+
647
+ doc_context = ""
648
+ if docs:
649
+ doc_texts = [d.content[:500] for d in docs[:5]]
650
+ doc_context = "\n\n---\n\n".join(doc_texts)
651
+
652
+ llm = get_synthesizer()
653
+ prompt = f"""You are a medical AI assistant. Answer the following medical question based on the provided context.
654
+ Be helpful, accurate, and include relevant medical information. Always recommend consulting a healthcare professional.
655
+
656
+ Context from medical knowledge base:
657
+ {doc_context if doc_context else "No specific context available."}
658
+
659
+ Patient Context: {context_text if context_text else "Not provided"}
660
+
661
+ Question: {question}
662
+
663
+ Answer:"""
664
+ response = llm.invoke(prompt)
665
+ return response.content if hasattr(response, 'content') else str(response)
666
+
667
+
668
  def answer_medical_question(
669
+ question: str,
670
  context: str = "",
671
  chat_history: list = None
672
  ) -> tuple[str, list]:
673
+ """Answer a medical question using the full agentic RAG pipeline.
674
+
675
+ Pipeline: guardrail β†’ retrieve β†’ grade β†’ rewrite β†’ generate.
676
+ Falls back to direct retriever+LLM if the pipeline is unavailable.
 
 
 
 
 
 
677
  """
678
  if not question.strip():
679
  return "", chat_history or []
680
+
 
681
  groq_key, google_key = get_api_keys()
682
  if not groq_key and not google_key:
683
  error_msg = "❌ Please add your GROQ_API_KEY or GOOGLE_API_KEY in Space Settings β†’ Secrets."
684
  history = (chat_history or []) + [(question, error_msg)]
685
  return error_msg, history
686
+
 
687
  provider = setup_llm_provider()
688
  logger.info(f"Q&A using provider: {provider}")
689
+
690
  try:
691
  start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
+ rag_service = _get_rag_service()
694
+ if rag_service is not None:
695
+ result = rag_service.ask(query=question, patient_context=context)
696
+ answer = result.get("final_answer", "")
697
+ guardrail = result.get("guardrail_score")
698
+ docs_retrieved = len(result.get("retrieved_documents", []))
699
+ docs_relevant = len(result.get("relevant_documents", []))
700
+ else:
701
+ logger.warning("Using fallback Q&A (agentic pipeline unavailable)")
702
+ answer = _fallback_qa(question, context)
703
+ guardrail = None
704
+ docs_retrieved = 0
705
+ docs_relevant = 0
706
 
 
 
 
 
 
 
 
707
  if not answer:
708
  answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
709
+
710
  elapsed = time.time() - start_time
711
+
712
+ meta_parts = [f"⏱️ {elapsed:.1f}s"]
713
+ if guardrail is not None:
714
+ meta_parts.append(f"πŸ›‘οΈ Guardrail: {guardrail:.0f}/100")
715
+ if docs_retrieved > 0:
716
+ meta_parts.append(f"πŸ“š {docs_relevant}/{docs_retrieved} relevant docs")
717
+ meta_parts.append("πŸ€– Agentic RAG" if rag_service else "πŸ€– RAG")
718
+ meta_line = " | ".join(meta_parts)
719
+
720
  formatted_answer = f"""{answer}
721
 
722
  ---
723
+ *{meta_line}*
724
  """
 
 
725
  history = (chat_history or []) + [(question, formatted_answer)]
 
726
  return formatted_answer, history
727
+
728
  except Exception as exc:
729
  logger.exception(f"Q&A error: {exc}")
730
  error_msg = f"❌ Error: {str(exc)}"
 
733
 
734
 
735
  def streaming_answer(question: str, context: str = ""):
736
+ """Stream answer using the full agentic RAG pipeline.
737
+ Falls back to direct retriever+LLM if the pipeline is unavailable.
 
738
  """
739
  if not question.strip():
740
  yield ""
741
  return
742
+
 
743
  groq_key, google_key = get_api_keys()
744
  if not groq_key and not google_key:
745
  yield "❌ Please add your GROQ_API_KEY or GOOGLE_API_KEY in Space Settings β†’ Secrets."
746
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
 
748
+ setup_llm_provider()
 
749
 
750
+ try:
751
+ yield "πŸ›‘οΈ Checking medical domain relevance...\n\n"
752
 
753
+ start_time = time.time()
754
 
755
+ rag_service = _get_rag_service()
756
+ if rag_service is not None:
757
+ yield "πŸ›‘οΈ Checking medical domain relevance...\nπŸ” Retrieving medical documents...\n\n"
758
+ result = rag_service.ask(query=question, patient_context=context)
759
+ answer = result.get("final_answer", "")
760
+ guardrail = result.get("guardrail_score")
761
+ docs_relevant = len(result.get("relevant_documents", []))
762
+ docs_retrieved = len(result.get("retrieved_documents", []))
763
+ else:
764
+ yield "πŸ” Searching medical knowledge base...\nπŸ“š Retrieving relevant documents...\n\n"
765
+ answer = _fallback_qa(question, context)
766
+ guardrail = None
767
+ docs_relevant = 0
768
+ docs_retrieved = 0
769
 
 
 
 
 
770
  if not answer:
771
  answer = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
772
+
773
+ yield "πŸ›‘οΈ Guardrail βœ“\nπŸ” Retrieved βœ“\nπŸ“Š Graded βœ“\nπŸ’­ Generating response...\n\n"
774
+
775
  elapsed = time.time() - start_time
776
+
777
+ # Progressive reveal
778
  words = answer.split()
779
  accumulated = ""
780
  for i, word in enumerate(words):
 
782
  if i % 5 == 0:
783
  yield accumulated
784
  time.sleep(0.02)
785
+
786
+ # Final response with metadata
787
+ meta_parts = [f"⏱️ {elapsed:.1f}s"]
788
+ if guardrail is not None:
789
+ meta_parts.append(f"πŸ›‘οΈ Guardrail: {guardrail:.0f}/100")
790
+ if docs_retrieved > 0:
791
+ meta_parts.append(f"πŸ“š {docs_relevant}/{docs_retrieved} relevant docs")
792
+ meta_parts.append("πŸ€– Agentic RAG" if rag_service else "πŸ€– RAG")
793
+ meta_line = " | ".join(meta_parts)
794
+
795
  yield f"""{answer}
796
 
797
  ---
798
+ *{meta_line}*
799
  """
800
+
801
  except Exception as exc:
802
  logger.exception(f"Streaming Q&A error: {exc}")
803
  yield f"❌ Error: {str(exc)}"
804
 
805
 
 
806
  # ---------------------------------------------------------------------------
807
  # Gradio Interface
808
  # ---------------------------------------------------------------------------
src/llm_config.py CHANGED
@@ -387,6 +387,11 @@ class LLMConfig:
387
  llm_config = LLMConfig()
388
 
389
 
 
 
 
 
 
390
  def check_api_connection():
391
  """Verify API connection and keys are configured"""
392
  provider = DEFAULT_LLM_PROVIDER
 
387
  llm_config = LLMConfig()
388
 
389
 
390
+ def get_synthesizer(model_name: Optional[str] = None):
391
+ """Module-level convenience: get a synthesizer LLM instance."""
392
+ return llm_config.get_synthesizer(model_name)
393
+
394
+
395
  def check_api_connection():
396
  """Verify API connection and keys are configured"""
397
  provider = DEFAULT_LLM_PROVIDER
src/routers/analyze.py CHANGED
@@ -122,7 +122,7 @@ async def _run_guild_analysis(
122
 
123
  try:
124
  # Run sync function in thread pool
125
- loop = asyncio.get_event_loop()
126
  result = await loop.run_in_executor(
127
  _executor,
128
  lambda: ragbot.analyze(
@@ -142,6 +142,16 @@ async def _run_guild_analysis(
142
  elapsed = (time.time() - t0) * 1000
143
 
144
  # Build response from result
 
 
 
 
 
 
 
 
 
 
145
  return AnalysisResponse(
146
  status="success",
147
  request_id=request_id,
@@ -150,9 +160,9 @@ async def _run_guild_analysis(
150
  input_biomarkers=biomarkers,
151
  patient_context=patient_ctx,
152
  processing_time_ms=round(elapsed, 1),
153
- prediction=result.prediction if hasattr(result, 'prediction') else None,
154
- analysis=result.analysis if hasattr(result, 'analysis') else None,
155
- conversational_summary=result.conversational_summary if hasattr(result, 'conversational_summary') else None,
156
  )
157
 
158
 
 
122
 
123
  try:
124
  # Run sync function in thread pool
125
+ loop = asyncio.get_running_loop()
126
  result = await loop.run_in_executor(
127
  _executor,
128
  lambda: ragbot.analyze(
 
142
  elapsed = (time.time() - t0) * 1000
143
 
144
  # Build response from result
145
+ # Guild workflow returns a dict; ragbot.analyze() may return dict or object
146
+ if isinstance(result, dict):
147
+ prediction = result.get('prediction')
148
+ analysis = result.get('analysis')
149
+ conversational_summary = result.get('conversational_summary')
150
+ else:
151
+ prediction = getattr(result, 'prediction', None)
152
+ analysis = getattr(result, 'analysis', None)
153
+ conversational_summary = getattr(result, 'conversational_summary', None)
154
+
155
  return AnalysisResponse(
156
  status="success",
157
  request_id=request_id,
 
160
  input_biomarkers=biomarkers,
161
  patient_context=patient_ctx,
162
  processing_time_ms=round(elapsed, 1),
163
+ prediction=prediction,
164
+ analysis=analysis,
165
+ conversational_summary=conversational_summary,
166
  )
167
 
168
 
src/routers/ask.py CHANGED
@@ -88,7 +88,7 @@ async def _stream_rag_response(
88
  await asyncio.sleep(0) # Allow event loop to flush
89
 
90
  # Run the RAG pipeline (synchronous, but we yield progress)
91
- loop = asyncio.get_event_loop()
92
  result = await loop.run_in_executor(
93
  None,
94
  lambda: rag_service.ask(
 
88
  await asyncio.sleep(0) # Allow event loop to flush
89
 
90
  # Run the RAG pipeline (synchronous, but we yield progress)
91
+ loop = asyncio.get_running_loop()
92
  result = await loop.run_in_executor(
93
  None,
94
  lambda: rag_service.ask(
src/routers/health.py CHANGED
@@ -40,11 +40,12 @@ async def readiness_check(request: Request) -> HealthResponse:
40
  # --- PostgreSQL ---
41
  try:
42
  from src.database import get_engine
 
43
  engine = get_engine()
44
  if engine is not None:
45
  t0 = time.time()
46
  with engine.connect() as conn:
47
- conn.execute("SELECT 1")
48
  latency = (time.time() - t0) * 1000
49
  services.append(ServiceHealth(name="postgresql", status="ok", latency_ms=round(latency, 1)))
50
  else:
@@ -106,8 +107,8 @@ async def readiness_check(request: Request) -> HealthResponse:
106
 
107
  # --- FAISS (local retriever) ---
108
  try:
109
- from src.services.retrieval import make_retriever
110
- retriever = make_retriever("faiss")
111
  if retriever is not None:
112
  doc_count = retriever.doc_count()
113
  services.append(ServiceHealth(name="faiss", status="ok", detail=f"{doc_count} docs indexed"))
 
40
  # --- PostgreSQL ---
41
  try:
42
  from src.database import get_engine
43
+ from sqlalchemy import text
44
  engine = get_engine()
45
  if engine is not None:
46
  t0 = time.time()
47
  with engine.connect() as conn:
48
+ conn.execute(text("SELECT 1"))
49
  latency = (time.time() - t0) * 1000
50
  services.append(ServiceHealth(name="postgresql", status="ok", latency_ms=round(latency, 1)))
51
  else:
 
107
 
108
  # --- FAISS (local retriever) ---
109
  try:
110
+ from src.services.retrieval.factory import make_retriever
111
+ retriever = make_retriever(backend="faiss")
112
  if retriever is not None:
113
  doc_count = retriever.doc_count()
114
  services.append(ServiceHealth(name="faiss", status="ok", detail=f"{doc_count} docs indexed"))
src/services/agents/context.py CHANGED
@@ -21,3 +21,4 @@ class AgenticContext:
21
  cache: Any # RedisCache
22
  tracer: Any # LangfuseTracer
23
  guild: Optional[Any] = None # ClinicalInsightGuild (original workflow)
 
 
21
  cache: Any # RedisCache
22
  tracer: Any # LangfuseTracer
23
  guild: Optional[Any] = None # ClinicalInsightGuild (original workflow)
24
+ retriever: Optional[Any] = None # BaseRetriever (FAISS or OpenSearch)
src/services/agents/nodes/generate_answer_node.py CHANGED
@@ -24,9 +24,10 @@ def generate_answer_node(state: dict, *, context: Any) -> dict:
24
  # Build evidence block
25
  evidence_parts: list[str] = []
26
  for i, doc in enumerate(documents, 1):
27
- title = doc.get("title", "Unknown")
28
- section = doc.get("section", "")
29
- text = doc.get("text", "")[:2000]
 
30
  header = f"[{i}] {title}"
31
  if section:
32
  header += f" β€” {section}"
 
24
  # Build evidence block
25
  evidence_parts: list[str] = []
26
  for i, doc in enumerate(documents, 1):
27
+ meta = doc.get("metadata", {})
28
+ title = meta.get("title", doc.get("title", "Unknown"))
29
+ section = meta.get("section_title", doc.get("section", ""))
30
+ text = (doc.get("content") or doc.get("text", ""))[:2000]
31
  header = f"[{i}] {title}"
32
  if section:
33
  header += f" β€” {section}"
src/services/agents/nodes/grade_documents_node.py CHANGED
@@ -31,7 +31,7 @@ def grade_documents_node(state: dict, *, context: Any) -> dict:
31
  grading_results: list[dict] = []
32
 
33
  for doc in documents:
34
- text = doc.get("text", "")
35
  user_msg = f"Query: {query}\n\nDocument:\n{text[:2000]}"
36
  try:
37
  response = context.llm.invoke(
@@ -51,11 +51,13 @@ def grade_documents_node(state: dict, *, context: Any) -> dict:
51
  logger.warning("Grading LLM failed for doc %s: %s β€” marking relevant", doc.get("id"), exc)
52
  is_relevant = True # benefit of the doubt
53
 
54
- grading_results.append({"doc_id": doc.get("id"), "relevant": is_relevant})
55
  if is_relevant:
56
  relevant.append(doc)
57
 
58
- needs_rewrite = len(relevant) < 2 and not state.get("rewritten_query")
 
 
59
 
60
  return {
61
  "grading_results": grading_results,
 
31
  grading_results: list[dict] = []
32
 
33
  for doc in documents:
34
+ text = doc.get("content") or doc.get("text", "")
35
  user_msg = f"Query: {query}\n\nDocument:\n{text[:2000]}"
36
  try:
37
  response = context.llm.invoke(
 
51
  logger.warning("Grading LLM failed for doc %s: %s β€” marking relevant", doc.get("id"), exc)
52
  is_relevant = True # benefit of the doubt
53
 
54
+ grading_results.append({"doc_id": doc.get("id", doc.get("_id")), "relevant": is_relevant})
55
  if is_relevant:
56
  relevant.append(doc)
57
 
58
+ attempts = state.get("retrieval_attempts", 1)
59
+ max_attempts = state.get("max_retrieval_attempts", 2)
60
+ needs_rewrite = len(relevant) < 2 and attempts < max_attempts
61
 
62
  return {
63
  "grading_results": grading_results,
src/services/agents/nodes/retrieve_node.py CHANGED
@@ -1,7 +1,10 @@
1
  """
2
  MediGuard AI β€” Retrieve Node
3
 
4
- Performs hybrid search (BM25 + vector KNN) and merges results.
 
 
 
5
  """
6
 
7
  from __future__ import annotations
@@ -13,56 +16,90 @@ logger = logging.getLogger(__name__)
13
 
14
 
15
  def retrieve_node(state: dict, *, context: Any) -> dict:
16
- """Retrieve documents from OpenSearch via hybrid search."""
17
- query = state.get("rewritten_query") or state.get("query", "")
18
 
19
- # 1. Try cache first
 
 
 
 
 
 
20
  cache_key = f"retrieve:{query}"
 
 
21
  if context.cache:
22
  cached = context.cache.get(cache_key)
23
  if cached is not None:
24
- logger.debug("Cache hit for retrieve query")
25
- return {"retrieved_documents": cached}
 
26
 
27
- # 2. Embed the query
28
- try:
29
- query_embedding = context.embedding_service.embed_query(query)
30
- except Exception as exc:
31
- logger.error("Embedding failed: %s", exc)
32
- return {"retrieved_documents": [], "errors": [str(exc)]}
33
 
34
- # 3. Hybrid search
35
- try:
36
- results = context.opensearch_client.search_hybrid(
37
- query_text=query,
38
- query_vector=query_embedding,
39
- top_k=10,
40
- )
41
- except Exception as exc:
42
- logger.error("OpenSearch hybrid search failed: %s β€” falling back to BM25", exc)
43
  try:
44
- results = context.opensearch_client.search_bm25(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  query_text=query,
46
- top_k=10,
 
47
  )
48
- except Exception as exc2:
49
- logger.error("BM25 fallback also failed: %s", exc2)
50
- return {"retrieved_documents": [], "errors": [str(exc), str(exc2)]}
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- documents = [
53
- {
54
- "id": hit.get("_id", ""),
55
- "score": hit.get("_score", 0.0),
56
- "text": hit.get("_source", {}).get("chunk_text", ""),
57
- "title": hit.get("_source", {}).get("title", ""),
58
- "section": hit.get("_source", {}).get("section_title", ""),
59
- "metadata": hit.get("_source", {}),
60
- }
61
- for hit in results
62
- ]
 
 
 
 
 
 
 
63
 
64
- # 4. Store in cache (5 min TTL)
65
- if context.cache:
66
  context.cache.set(cache_key, documents, ttl=300)
67
 
68
- return {"retrieved_documents": documents}
 
 
1
  """
2
  MediGuard AI β€” Retrieve Node
3
 
4
+ Performs document retrieval using the best available backend:
5
+ 1. Generic retriever (FAISS, OpenSearch wrapper, etc.)
6
+ 2. OpenSearch hybrid search (BM25 + KNN)
7
+ 3. BM25 keyword fallback
8
  """
9
 
10
  from __future__ import annotations
 
16
 
17
 
18
  def retrieve_node(state: dict, *, context: Any) -> dict:
19
+ """Retrieve documents using the best available backend.
 
20
 
21
+ Priority:
22
+ 1. context.retriever (generic BaseRetriever β€” works with FAISS & OpenSearch)
23
+ 2. context.opensearch_client + context.embedding_service (hybrid search)
24
+ 3. BM25 keyword fallback
25
+ 4. Empty list
26
+ """
27
+ query = state.get("rewritten_query") or state.get("query", "")
28
  cache_key = f"retrieve:{query}"
29
+
30
+ # 1. Try cache
31
  if context.cache:
32
  cached = context.cache.get(cache_key)
33
  if cached is not None:
34
+ logger.info("Cache HIT for query: %s…", query[:50])
35
+ attempts = state.get("retrieval_attempts", 0) + 1
36
+ return {"retrieved_documents": cached, "retrieval_attempts": attempts}
37
 
38
+ documents: list = []
 
 
 
 
 
39
 
40
+ # 2. Generic retriever (FAISS, OpenSearch wrapper, etc.)
41
+ if getattr(context, "retriever", None) is not None:
 
 
 
 
 
 
 
42
  try:
43
+ results = context.retriever.retrieve(query, top_k=8)
44
+ documents = [
45
+ {
46
+ "content": getattr(r, "content", ""),
47
+ "metadata": getattr(r, "metadata", {}),
48
+ "score": getattr(r, "score", 0.0),
49
+ }
50
+ for r in results
51
+ ]
52
+ backend = getattr(context.retriever, "backend_name", "unknown")
53
+ logger.info("Retrieved %d docs via %s", len(documents), backend)
54
+ except Exception as exc:
55
+ logger.warning("Retriever failed (%s), trying OpenSearch fallback…", exc)
56
+
57
+ # 3. OpenSearch hybrid fallback
58
+ if not documents and context.opensearch_client and context.embedding_service:
59
+ try:
60
+ embedding = context.embedding_service.embed_query(query)
61
+ raw_hits = context.opensearch_client.search_hybrid(
62
  query_text=query,
63
+ query_vector=embedding,
64
+ top_k=8,
65
  )
66
+ documents = [
67
+ {
68
+ "content": h.get("_source", {}).get("chunk_text", ""),
69
+ "metadata": {
70
+ k: v for k, v in h.get("_source", {}).items()
71
+ if k != "chunk_text"
72
+ },
73
+ "score": h.get("_score", 0.0),
74
+ }
75
+ for h in raw_hits
76
+ ]
77
+ logger.info("Retrieved %d docs via OpenSearch hybrid", len(documents))
78
+ except Exception as exc:
79
+ logger.error("OpenSearch retrieval failed: %s", exc)
80
 
81
+ # 4. Optional BM25 fallback if still nothing
82
+ if not documents and context.opensearch_client:
83
+ try:
84
+ raw_hits = context.opensearch_client.search_bm25(query_text=query, top_k=8)
85
+ documents = [
86
+ {
87
+ "content": h.get("_source", {}).get("chunk_text", ""),
88
+ "metadata": {
89
+ k: v for k, v in h.get("_source", {}).items()
90
+ if k != "chunk_text"
91
+ },
92
+ "score": h.get("_score", 0.0),
93
+ }
94
+ for h in raw_hits
95
+ ]
96
+ logger.info("Retrieved %d docs via BM25 fallback", len(documents))
97
+ except Exception as exc:
98
+ logger.error("BM25 fallback also failed: %s", exc)
99
 
100
+ # 5. Store in cache (5 min TTL)
101
+ if context.cache and documents:
102
  context.cache.set(cache_key, documents, ttl=300)
103
 
104
+ attempts = state.get("retrieval_attempts", 0) + 1
105
+ return {"retrieved_documents": documents, "retrieval_attempts": attempts}
src/services/cache/redis_cache.py CHANGED
@@ -11,7 +11,7 @@ import hashlib
11
  import json
12
  import logging
13
  from functools import lru_cache
14
- from typing import Any, Dict, Optional
15
 
16
  from src.settings import get_settings
17
 
@@ -48,12 +48,13 @@ class RedisCache:
48
  raw = "|".join(parts)
49
  return f"mediguard:{hashlib.sha256(raw.encode()).hexdigest()}"
50
 
51
- def get(self, *key_parts: str) -> Optional[Dict[str, Any]]:
 
52
  if not self._enabled:
53
  return None
54
- key = self._make_key(*key_parts)
55
  try:
56
- value = self._client.get(key)
57
  if value is None:
58
  return None
59
  return json.loads(value)
@@ -61,23 +62,25 @@ class RedisCache:
61
  logger.warning("Cache GET failed: %s", exc)
62
  return None
63
 
64
- def set(self, value: Dict[str, Any], *key_parts: str, ttl: Optional[int] = None) -> bool:
 
65
  if not self._enabled:
66
  return False
67
- key = self._make_key(*key_parts)
68
  try:
69
- self._client.setex(key, ttl or self._default_ttl, json.dumps(value, default=str))
70
  return True
71
  except Exception as exc:
72
  logger.warning("Cache SET failed: %s", exc)
73
  return False
74
 
75
- def delete(self, *key_parts: str) -> bool:
 
76
  if not self._enabled:
77
  return False
78
- key = self._make_key(*key_parts)
79
  try:
80
- self._client.delete(key)
81
  return True
82
  except Exception as exc:
83
  logger.warning("Cache DELETE failed: %s", exc)
 
11
  import json
12
  import logging
13
  from functools import lru_cache
14
+ from typing import Any, Optional
15
 
16
  from src.settings import get_settings
17
 
 
48
  raw = "|".join(parts)
49
  return f"mediguard:{hashlib.sha256(raw.encode()).hexdigest()}"
50
 
51
+ def get(self, key: str) -> Optional[Any]:
52
+ """Get a cached value by key."""
53
  if not self._enabled:
54
  return None
55
+ cache_key = self._make_key(key)
56
  try:
57
+ value = self._client.get(cache_key)
58
  if value is None:
59
  return None
60
  return json.loads(value)
 
62
  logger.warning("Cache GET failed: %s", exc)
63
  return None
64
 
65
+ def set(self, key: str, value: Any, *, ttl: Optional[int] = None) -> bool:
66
+ """Set a cached value with optional TTL."""
67
  if not self._enabled:
68
  return False
69
+ cache_key = self._make_key(key)
70
  try:
71
+ self._client.setex(cache_key, ttl or self._default_ttl, json.dumps(value, default=str))
72
  return True
73
  except Exception as exc:
74
  logger.warning("Cache SET failed: %s", exc)
75
  return False
76
 
77
+ def delete(self, key: str) -> bool:
78
+ """Delete a cached value by key."""
79
  if not self._enabled:
80
  return False
81
+ cache_key = self._make_key(key)
82
  try:
83
+ self._client.delete(cache_key)
84
  return True
85
  except Exception as exc:
86
  logger.warning("Cache DELETE failed: %s", exc)
src/services/opensearch/client.py CHANGED
@@ -85,7 +85,7 @@ class OpenSearchClient:
85
 
86
  def search_bm25(
87
  self,
88
- query: str,
89
  *,
90
  top_k: int = 10,
91
  filters: Optional[Dict[str, Any]] = None,
@@ -97,7 +97,7 @@ class OpenSearchClient:
97
  "must": [
98
  {
99
  "multi_match": {
100
- "query": query,
101
  "fields": [
102
  "chunk_text^3",
103
  "title^2",
@@ -119,7 +119,7 @@ class OpenSearchClient:
119
 
120
  def search_vector(
121
  self,
122
- embedding: List[float],
123
  *,
124
  top_k: int = 10,
125
  filters: Optional[Dict[str, Any]] = None,
@@ -129,7 +129,7 @@ class OpenSearchClient:
129
  "query": {
130
  "knn": {
131
  "embedding": {
132
- "vector": embedding,
133
  "k": top_k,
134
  }
135
  }
@@ -141,8 +141,8 @@ class OpenSearchClient:
141
 
142
  def search_hybrid(
143
  self,
144
- query: str,
145
- embedding: List[float],
146
  *,
147
  top_k: int = 10,
148
  filters: Optional[Dict[str, Any]] = None,
@@ -150,8 +150,8 @@ class OpenSearchClient:
150
  vector_weight: float = 0.6,
151
  ) -> List[Dict[str, Any]]:
152
  """Reciprocal Rank Fusion of BM25 + KNN results."""
153
- bm25_results = self.search_bm25(query, top_k=top_k, filters=filters)
154
- vector_results = self.search_vector(embedding, top_k=top_k, filters=filters)
155
  return self._rrf_fuse(bm25_results, vector_results, top_k=top_k)
156
 
157
  # ── Internal helpers ─────────────────────────────────────────────────
@@ -166,7 +166,7 @@ class OpenSearchClient:
166
  {
167
  "_id": h["_id"],
168
  "_score": h.get("_score", 0.0),
169
- **h.get("_source", {}),
170
  }
171
  for h in hits
172
  ]
 
85
 
86
  def search_bm25(
87
  self,
88
+ query_text: str,
89
  *,
90
  top_k: int = 10,
91
  filters: Optional[Dict[str, Any]] = None,
 
97
  "must": [
98
  {
99
  "multi_match": {
100
+ "query": query_text,
101
  "fields": [
102
  "chunk_text^3",
103
  "title^2",
 
119
 
120
  def search_vector(
121
  self,
122
+ query_vector: List[float],
123
  *,
124
  top_k: int = 10,
125
  filters: Optional[Dict[str, Any]] = None,
 
129
  "query": {
130
  "knn": {
131
  "embedding": {
132
+ "vector": query_vector,
133
  "k": top_k,
134
  }
135
  }
 
141
 
142
  def search_hybrid(
143
  self,
144
+ query_text: str,
145
+ query_vector: List[float],
146
  *,
147
  top_k: int = 10,
148
  filters: Optional[Dict[str, Any]] = None,
 
150
  vector_weight: float = 0.6,
151
  ) -> List[Dict[str, Any]]:
152
  """Reciprocal Rank Fusion of BM25 + KNN results."""
153
+ bm25_results = self.search_bm25(query_text, top_k=top_k, filters=filters)
154
+ vector_results = self.search_vector(query_vector, top_k=top_k, filters=filters)
155
  return self._rrf_fuse(bm25_results, vector_results, top_k=top_k)
156
 
157
  # ── Internal helpers ─────────────────────────────────────────────────
 
166
  {
167
  "_id": h["_id"],
168
  "_score": h.get("_score", 0.0),
169
+ "_source": h.get("_source", {}),
170
  }
171
  for h in hits
172
  ]