Spaces:
Sleeping
Sleeping
| import os | |
| import Stemmer | |
| import requests | |
| from utils import get_and_chunk_documents, llm, embed_model, get_index | |
| from utils import Settings | |
| from llama_index.retrievers.bm25 import BM25Retriever | |
| from llama_index.core.postprocessor import SentenceTransformerRerank | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| from llama_index.core.response_synthesizers import get_response_synthesizer | |
| from llama_index.core.settings import Settings | |
| from llama_index.core import VectorStoreIndex | |
| from llama_index.core.llms import ChatMessage | |
| from llama_index.core.retrievers import QueryFusionRetriever | |
| import json | |
| Settings.llm = llm | |
| Settings.embed_model = embed_model | |
| index = get_index() | |
| hybrid_retriever = None | |
| vector_retriever = None | |
| bm25_retriever = None | |
| if index: | |
| try: | |
| vector_retriever = index.as_retriever(similarity_top_k=15) | |
| print("β Vector retriever initialized successfully") | |
| all_nodes = index.docstore.docs | |
| if len(all_nodes) == 0: | |
| print("β οΈ Warning: No documents found in index, skipping BM25 retriever") | |
| hybrid_retriever = vector_retriever | |
| else: | |
| has_text_content = False | |
| for node_id, node in all_nodes.items(): | |
| if hasattr(node, 'text') and node.text and node.text.strip(): | |
| has_text_content = True | |
| break | |
| if not has_text_content: | |
| print("β οΈ Warning: No text content found in documents, skipping BM25 retriever") | |
| hybrid_retriever = vector_retriever | |
| else: | |
| try: | |
| print("π Creating BM25 retriever...") | |
| bm25_retriever = BM25Retriever.from_defaults( | |
| docstore=index.docstore, | |
| similarity_top_k=15, | |
| verbose=False | |
| ) | |
| print("β BM25 retriever initialized successfully") | |
| hybrid_retriever = QueryFusionRetriever( | |
| retrievers=[vector_retriever, bm25_retriever], | |
| similarity_top_k=20, | |
| num_queries=1, | |
| mode="reciprocal_rerank", | |
| use_async=False, | |
| ) | |
| print("β Hybrid retriever initialized successfully") | |
| except Exception as e: | |
| print(f"β Warning: Could not initialize BM25 retriever: {e}") | |
| print("π Falling back to vector-only retrieval") | |
| hybrid_retriever = vector_retriever | |
| except Exception as e: | |
| print(f"β Warning: Could not initialize retrievers: {e}") | |
| hybrid_retriever = None | |
| vector_retriever = None | |
| bm25_retriever = None | |
| else: | |
| print("β Warning: Could not initialize retrievers - index is None") | |
| def call_groq_api(prompt): | |
| """Call Groq API instead of LM Studio""" | |
| try: | |
| response = Settings.llm.complete(prompt) | |
| return str(response) | |
| except Exception as e: | |
| print(f"β Groq API call failed: {e}") | |
| raise e | |
| def get_direct_answer(question, symptom_summary, conversation_context="", max_context_nodes=8, is_risk_assessment=True): | |
| """Get answer using hybrid retriever with retrieved context""" | |
| print(f"π― Processing question: {question}") | |
| if not hybrid_retriever: | |
| return "Error: Retriever not available. Please check if documents are properly loaded in the index." | |
| try: | |
| print("π Retrieving with available retrieval method...") | |
| retrieved_nodes = hybrid_retriever.retrieve(question) | |
| print(f"π Retrieved {len(retrieved_nodes)} nodes") | |
| except Exception as e: | |
| print(f"β Retrieval failed: {e}") | |
| return f"Error during document retrieval: {e}. Please check your document index." | |
| if not retrieved_nodes: | |
| return "No relevant documents found for this question. Please ensure your medical knowledge base is properly loaded and consult your healthcare provider for medical advice." | |
| try: | |
| reranker = SentenceTransformerRerank( | |
| model='cross-encoder/ms-marco-MiniLM-L-2-v2', | |
| top_n=max_context_nodes, | |
| ) | |
| reranked_nodes = reranker.postprocess_nodes(retrieved_nodes, query_str=question) | |
| print(f"π― After reranking: {len(reranked_nodes)} nodes") | |
| except Exception as e: | |
| print(f"β Reranking failed: {e}, using original nodes") | |
| reranked_nodes = retrieved_nodes[:max_context_nodes] | |
| filtered_nodes = [] | |
| pregnancy_keywords = ['pregnancy', 'preeclampsia', 'gestational', 'trimester', 'fetal', 'bleeding', 'contractions', 'prenatal'] | |
| for node in reranked_nodes: | |
| node_text = node.get_text().lower() | |
| if any(keyword in node_text for keyword in pregnancy_keywords): | |
| filtered_nodes.append(node) | |
| if filtered_nodes: | |
| reranked_nodes = filtered_nodes[:max_context_nodes] | |
| print(f"π After pregnancy keyword filtering: {len(reranked_nodes)} nodes") | |
| else: | |
| print("β οΈ No pregnancy-related content found, using original nodes") | |
| context_chunks = [] | |
| total_chars = 0 | |
| max_context_chars = 6000 | |
| for node in reranked_nodes: | |
| node_text = node.get_text() | |
| if total_chars + len(node_text) <= max_context_chars: | |
| context_chunks.append(node_text) | |
| total_chars += len(node_text) | |
| else: | |
| remaining_chars = max_context_chars - total_chars | |
| if remaining_chars > 100: | |
| context_chunks.append(node_text[:remaining_chars] + "...") | |
| break | |
| context_text = "\n\n---\n\n".join(context_chunks) | |
| if is_risk_assessment: | |
| prompt = f"""You are the GraviLog Pregnancy Risk Assessment Agent. Use ONLY the context belowβdo not invent or add any new medical facts. | |
| SYMPTOM RESPONSES: | |
| {symptom_summary} | |
| MEDICAL KNOWLEDGE: | |
| {context_text} | |
| Respond ONLY in this exact format (no extra text): | |
| π₯ Risk Assessment Complete | |
| **Risk Level:** <Low/Medium/High> | |
| **Recommended Action:** <from KB's Risk Output Labels> | |
| π¬ Rationale: | |
| <One or two sentences citing which bullet(s) from the KB triggered your risk level.>""" | |
| else: | |
| prompt = f"""You are a pregnancy health assistant. Based on the medical knowledge below, answer the user's question about pregnancy symptoms and conditions. | |
| USER QUESTION: {question} | |
| CONVERSATION CONTEXT: | |
| {conversation_context} | |
| CURRENT SYMPTOMS REPORTED: | |
| {symptom_summary} | |
| MEDICAL KNOWLEDGE: | |
| {context_text} | |
| Provide a clear, informative answer based on the medical knowledge. Always mention if symptoms require medical attention and provide risk level (Low/Medium/High) when relevant.""" | |
| try: | |
| print("π€ Generating response with Groq API...") | |
| response_text = call_groq_api(prompt) | |
| return response_text | |
| except Exception as e: | |
| print(f"β LLM response failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error generating response: {e}" | |
| def get_answer_with_query_engine(question): | |
| """Alternative approach using LlamaIndex query engine with hybrid retrieval""" | |
| try: | |
| print(f"π― Processing question with query engine: {question}") | |
| if index is None: | |
| return "Error: Could not load index" | |
| if hybrid_retriever: | |
| query_engine = RetrieverQueryEngine.from_args( | |
| retriever=hybrid_retriever, | |
| response_synthesizer=get_response_synthesizer( | |
| response_mode="compact", | |
| use_async=False | |
| ), | |
| node_postprocessors=[ | |
| SentenceTransformerRerank( | |
| model='cross-encoder/ms-marco-MiniLM-L-2-v2', | |
| top_n=5 | |
| ) | |
| ] | |
| ) | |
| else: | |
| query_engine = index.as_query_engine( | |
| similarity_top_k=10, | |
| response_mode="compact" | |
| ) | |
| print("π€ Querying with engine...") | |
| response = query_engine.query(question) | |
| return str(response) | |
| except Exception as e: | |
| print(f"β Query engine failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error with query engine: {e}. Please check your setup and try again." |