| import os |
| import json |
| import gradio as gr |
| from langchain_core.documents import Document |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline |
| from langchain_community.vectorstores import FAISS |
| from langchain_core.prompts import ChatPromptTemplate |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
| |
| vectorstore = None |
| qa_chain = None |
| config = None |
|
|
| def load_model(): |
| """Load the saved model artifacts.""" |
| global vectorstore, qa_chain, config |
| |
| |
| with open("config(RAG).json", "r") as f: |
| config = json.load(f) |
| |
| |
| embeddings = HuggingFaceEmbeddings(model_name=config["embedding_model"]) |
| vectorstore = FAISS.load_local( |
| ".", |
| embeddings, |
| allow_dangerous_deserialization=True |
| ) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") |
| model_obj = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") |
| |
| pipe = pipeline( |
| "text2text-generation", model=model_obj, tokenizer=tokenizer, |
| max_new_tokens=256, min_length=30, temperature=0.7, |
| do_sample=True, top_p=0.9, repetition_penalty=1.2 |
| ) |
| |
| llm = HuggingFacePipeline(pipeline=pipe) |
| |
| retriever = vectorstore.as_retriever() |
| |
| template = """You are a medical knowledge assistant. Based on the medical context below, provide a detailed and accurate answer to the question. |
| |
| Context: |
| {context} |
| |
| Question: {question} |
| |
| Provide a comprehensive answer with specific medical details from the context:""" |
| prompt = ChatPromptTemplate.from_template(template) |
| |
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
| |
| class RAGChain: |
| def __init__(self, retriever, llm, prompt, vectorstore): |
| self.retriever = retriever |
| self.llm = llm |
| self.prompt = prompt |
| self.vectorstore = vectorstore |
| |
| def __call__(self, inputs, k=5): |
| query = inputs["query"] |
| |
| dynamic_retriever = self.vectorstore.as_retriever(search_kwargs={"k": k}) |
| source_docs = dynamic_retriever.invoke(query) |
| context = format_docs(source_docs) |
| prompt_text = self.prompt.format(context=context, question=query) |
| answer = self.llm.invoke(prompt_text) |
| return {"result": answer, "source_documents": source_docs} |
| |
| qa_chain = RAGChain(retriever, llm, prompt, vectorstore) |
| return "β
Model loaded successfully!" |
|
|
| def answer_question(query, num_sources=5): |
| """Process a medical query and return answer with sources.""" |
| if not qa_chain: |
| return "β Model not loaded. Please wait for initialization.", "", "" |
| |
| if not query.strip(): |
| return "Please enter a medical question.", "", "" |
| |
| try: |
| |
| result = qa_chain({"query": query}, k=num_sources) |
| answer = result["result"] |
| sources = result["source_documents"] |
| |
| |
| sources_text = f"### π Retrieved {len(sources)} Sources:\n\n" |
| for i, doc in enumerate(sources, 1): |
| sources_text += f"**Source {i}:**\n{doc.page_content[:300]}...\n\n" |
| |
| |
| metrics_text = calculate_metrics(query, answer, sources) |
| |
| return answer, sources_text, metrics_text |
| |
| except Exception as e: |
| return f"β Error: {str(e)}", "", "" |
|
|
| def calculate_metrics(query, answer, sources): |
| """Calculate evaluation metrics for the query-answer pair.""" |
| try: |
| from sentence_transformers import SentenceTransformer, util |
| import numpy as np |
| |
| |
| semantic_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb') |
| |
| |
| query_embedding = semantic_model.encode(query, convert_to_tensor=True) |
| answer_embedding = semantic_model.encode(answer, convert_to_tensor=True) |
| relevance = util.cos_sim(query_embedding, answer_embedding).item() |
| |
| |
| context = " ".join([doc.page_content for doc in sources]) |
| context_embedding = semantic_model.encode(context[:500], convert_to_tensor=True) |
| coherence = util.cos_sim(answer_embedding, context_embedding).item() |
| |
| |
| medical_terms = { |
| 'heart failure': ['lvef', 'ejection fraction', 'cardiac', 'ventricular', 'cardiomyopathy'], |
| 'diabetes': ['glucose', 'insulin', 'a1c', 'hemoglobin', 'glycemic', 'hyperglycemia'], |
| 'hypertension': ['blood pressure', 'systolic', 'diastolic', 'antihypertensive', 'bp'] |
| } |
| answer_lower = answer.lower() |
| keywords_count = 0 |
| for topic, keywords in medical_terms.items(): |
| if any(term in query.lower() for term in topic.split()): |
| keywords_count = sum(1 for kw in keywords if kw in answer_lower) |
| break |
| |
| |
| |
| |
| total_docs = len(vectorstore.docstore._dict) if vectorstore else 1000 |
| retrieved = len(sources) |
| tp = retrieved |
| fp = 0 |
| fn = 0 |
| tn = total_docs - retrieved |
| |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 1.0 |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 |
| accuracy = (tp + tn) / (tp + tn + fp + fn) |
| |
| |
| metrics = f"""### π Evaluation Metrics |
| |
| **Semantic Quality:** |
| - π― **Relevance Score:** {relevance:.3f} (Query-Answer alignment) |
| - π **Coherence Score:** {coherence:.3f} (Answer-Context consistency) |
| - π₯ **Clinical Terms Found:** {keywords_count} medical keywords |
| |
| **Retrieval Performance:** |
| - β
**Precision:** {precision:.3f} |
| - π **Recall:** {recall:.3f} |
| - π² **F1 Score:** {f1:.3f} |
| - π **Accuracy:** {accuracy:.3f} |
| |
| **Confusion Matrix:** |
| ``` |
| Relevant Not Relevant |
| Retrieved {tp:5d} {fp:5d} |
| Not Retrieved {fn:5d} {tn:5d} |
| ``` |
| |
| **Interpretation:** |
| - Relevance > 0.5: β
Answer addresses query |
| - Coherence > 0.6: β
Answer grounded in context |
| - Clinical Terms > 2: β
Domain-specific vocabulary |
| """ |
| return metrics |
| |
| except Exception as e: |
| return f"### π Evaluation Metrics\n\nβ οΈ Error calculating metrics: {str(e)}" |
|
|
| def get_model_info(): |
| """Return model configuration information.""" |
| if config: |
| info = f"""### π€ Model Configuration |
| - **Embedding Model:** {config['embedding_model']} |
| - **LLM Model:** google/flan-t5-base (loaded from HF Hub) |
| - **Documents Processed:** {config['num_docs']} |
| - **Text Chunks:** {config['num_chunks']} |
| - **Retrieval Documents:** {config['retrieval_k']} |
| - **Storage:** ~20 MB (FAISS index only) |
| """ |
| return info |
| return "Model configuration not available." |
|
|
| |
| with gr.Blocks(title="Medical Q&A RAG System") as demo: |
| gr.Markdown(""" |
| # π₯ Medical Q&A RAG System |
| ### Powered by MIMIC-IV Dataset, BioBERT & FLAN-T5 |
| |
| Ask medical questions and get evidence-based answers from clinical documentation. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| query_input = gr.Textbox( |
| label="π¬ Enter your medical question", |
| placeholder="e.g., What are the diagnostic criteria for heart failure?", |
| lines=3 |
| ) |
| |
| with gr.Row(): |
| submit_btn = gr.Button("π Get Answer", variant="primary") |
| clear_btn = gr.Button("ποΈ Clear") |
| |
| num_sources = gr.Slider( |
| minimum=1, maximum=10, value=5, step=1, |
| label="Number of source documents to display" |
| ) |
| |
| answer_output = gr.Textbox( |
| label="π‘ Answer", |
| lines=8 |
| ) |
| |
| sources_output = gr.Markdown( |
| label="π Retrieved Sources" |
| ) |
| |
| metrics_output = gr.Markdown( |
| label="π Evaluation Metrics" |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### βΉοΈ About") |
| gr.Markdown(""" |
| This system uses: |
| - **RAG (Retrieval-Augmented Generation)** to provide accurate medical answers |
| - **MIMIC-IV** clinical dataset for knowledge base |
| - **BioBERT** for medical text understanding |
| - **FLAN-T5** for answer generation |
| |
| β οΈ **Disclaimer:** This is for educational purposes only. Not for clinical use. |
| """) |
| |
| model_info = gr.Markdown(get_model_info()) |
| |
| gr.Markdown("### π Example Questions") |
| examples = gr.Examples( |
| examples=[ |
| "What are the diagnostic criteria for heart failure with reduced ejection fraction?", |
| "How is type 2 diabetes diagnosed?", |
| "What is recommended for stage 2 hypertension?", |
| "What are the symptoms of coronary artery disease?", |
| "How is myocardial infarction treated?" |
| ], |
| inputs=query_input |
| ) |
| |
| |
| submit_btn.click( |
| fn=answer_question, |
| inputs=[query_input, num_sources], |
| outputs=[answer_output, sources_output, metrics_output] |
| ) |
| |
| clear_btn.click( |
| fn=lambda: ("", "", "", ""), |
| outputs=[query_input, answer_output, sources_output, metrics_output] |
| ) |
| |
| |
| demo.load(fn=load_model, outputs=None) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |
|
|