import os import json import gradio as gr from langchain_core.documents import Document from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain_community.vectorstores import FAISS from langchain_core.prompts import ChatPromptTemplate from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # Global variables for model components vectorstore = None qa_chain = None config = None def load_model(): """Load the saved model artifacts.""" global vectorstore, qa_chain, config with open("config(RAG).json", "r") as f: config = json.load(f) embeddings = HuggingFaceEmbeddings(model_name=config["embedding_model"]) vectorstore = FAISS.load_local( ".", allow_dangerous_deserialization=True, embeddings=embeddings ) tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") model_obj = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") pipe = pipeline( "text2text-generation", model=model_obj, tokenizer=tokenizer, max_new_tokens=256, min_length=30, temperature=0.7, do_sample=True, top_p=0.9, repetition_penalty=1.2 ) llm = HuggingFacePipeline(pipeline=pipe) retriever = vectorstore.as_retriever() template = """You are a medical knowledge assistant. Based on the medical context below, provide a detailed and accurate answer to the question. Context: {context} Question: {question} Provide a comprehensive answer with specific medical details from the context:""" prompt = ChatPromptTemplate.from_template(template) def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) class RAGChain: def __init__(self, retriever, llm, prompt, vectorstore): self.retriever = retriever self.llm = llm self.prompt = prompt self.vectorstore = vectorstore def __call__(self, inputs, k=5): query = inputs["query"] dynamic_retriever = self.vectorstore.as_retriever(search_kwargs={"k": k}) source_docs = dynamic_retriever.invoke(query) context = format_docs(source_docs) prompt_text = self.prompt.format(context=context, question=query) answer = self.llm.invoke(prompt_text) return {"result": answer, "source_documents": source_docs} qa_chain = RAGChain(retriever, llm, prompt, vectorstore) return "āœ… Model loaded successfully!" def answer_question(query, num_sources=5): if not qa_chain: return "āŒ Model not loaded. Please wait for initialization.", "", "" if not query.strip(): return "Please enter a medical question.", "", "" try: result = qa_chain({"query": query}, k=num_sources) answer = result["result"] sources = result["source_documents"] sources_text = f"### šŸ“š Retrieved {len(sources)} Sources:\n\n" for i, doc in enumerate(sources, 1): sources_text += f"**Source {i}:**\n{doc.page_content[:300]}...\n\n" metrics_text = calculate_metrics(query, answer, sources) return answer, sources_text, metrics_text except Exception as e: return f"āŒ Error: {str(e)}", "", "" def calculate_metrics(query, answer, sources): try: from sentence_transformers import SentenceTransformer, util semantic_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb') query_embedding = semantic_model.encode(query, convert_to_tensor=True) answer_embedding = semantic_model.encode(answer, convert_to_tensor=True) relevance = util.cos_sim(query_embedding, answer_embedding).item() context = " ".join([doc.page_content for doc in sources]) context_embedding = semantic_model.encode(context[:500], convert_to_tensor=True) coherence = util.cos_sim(answer_embedding, context_embedding).item() medical_terms = { 'heart failure': ['lvef', 'ejection fraction', 'cardiac', 'ventricular', 'cardiomyopathy'], 'diabetes': ['glucose', 'insulin', 'a1c', 'hemoglobin', 'glycemic', 'hyperglycemia'], 'hypertension': ['blood pressure', 'systolic', 'diastolic', 'antihypertensive', 'bp'] } answer_lower = answer.lower() keywords_count = 0 for topic, keywords in medical_terms.items(): if any(term in query.lower() for term in topic.split()): keywords_count = sum(1 for kw in keywords if kw in answer_lower) break total_docs = len(vectorstore.docstore._dict) if vectorstore else 1000 retrieved = len(sources) tp = retrieved fp = 0 fn = 0 tn = total_docs - retrieved precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 1.0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 accuracy = (tp + tn) / (tp + tn + fp + fn) metrics = f"""### šŸ“Š Evaluation Metrics **Semantic Quality:** - šŸŽÆ **Relevance Score:** {relevance:.3f} - šŸ”— **Coherence Score:** {coherence:.3f} - šŸ„ **Clinical Terms Found:** {keywords_count} **Retrieval Performance:** - āœ… **Precision:** {precision:.3f} - šŸ“ˆ **Recall:** {recall:.3f} - šŸŽ² **F1 Score:** {f1:.3f} - šŸ“Š **Accuracy:** {accuracy:.3f} """ return metrics except Exception as e: return f"### šŸ“Š Evaluation Metrics\n\nāš ļø Error calculating metrics: {str(e)}" def get_model_info(): if config: info = f"""### šŸ¤– Model Configuration - **Embedding Model:** {config['embedding_model']} - **LLM Model:** google/flan-t5-base - **Documents Processed:** {config['num_docs']} - **Text Chunks:** {config['num_chunks']} - **Retrieval Documents:** {config['retrieval_k']} - **Storage:** ~20 MB (FAISS index only) """ return info return "Model configuration not available." # ===== Gradio Interface (Green Themed) ===== with gr.Blocks(title="Medical Q&A RAG System", theme=gr.themes.Monochrome(primary_hue="green")) as demo: gr.Markdown(""" # šŸ„ Medical Q&A RAG System ### Powered by MIMIC-IV Dataset, BioBERT & FLAN-T5 Ask medical questions and get evidence-based answers from clinical documentation. """) with gr.Row(): with gr.Column(scale=2): query_input = gr.Textbox( label="šŸ’¬ Enter your medical question", placeholder="e.g., What are the diagnostic criteria for heart failure?", lines=3 ) with gr.Row(): submit_btn = gr.Button("šŸ” Get Answer", variant="primary", elem_classes="green-btn") clear_btn = gr.Button("šŸ—‘ļø Clear", variant="secondary", elem_classes="green-btn") num_sources = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of source documents to display" ) answer_output = gr.Textbox( label="šŸ’” Answer", lines=8 ) sources_output = gr.Markdown( label="šŸ“š Retrieved Sources" ) metrics_output = gr.Markdown( label="šŸ“Š Evaluation Metrics" ) with gr.Column(scale=1): gr.Markdown("### ā„¹ļø About", elem_classes="green-title") gr.Markdown(""" This system uses: - **RAG (Retrieval-Augmented Generation)** to provide accurate medical answers - **MIMIC-IV** clinical dataset for knowledge base - **BioBERT** for medical text understanding - **FLAN-T5** for answer generation āš ļø **Disclaimer:** For educational purposes only. """) model_info = gr.Markdown(get_model_info()) gr.Markdown("### šŸ“ Example Questions", elem_classes="green-title") examples = gr.Examples( examples=[ "What are the diagnostic criteria for heart failure with reduced ejection fraction?", "How is type 2 diabetes diagnosed?", "What is recommended for stage 2 hypertension?", "What are the symptoms of coronary artery disease?", "How is myocardial infarction treated?" ], inputs=query_input ) # Event handlers submit_btn.click( fn=answer_question, inputs=[query_input, num_sources], outputs=[answer_output, sources_output, metrics_output] ) clear_btn.click( fn=lambda: ("", "", "", ""), outputs=[query_input, answer_output, sources_output, metrics_output] ) demo.load(fn=load_model, outputs=None) if __name__ == "__main__": demo.launch()