Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import gradio as gr | |
| from langchain_core.documents import Document | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # Global variables for model components | |
| vectorstore = None | |
| qa_chain = None | |
| config = None | |
| def load_model(): | |
| """Load the saved model artifacts.""" | |
| global vectorstore, qa_chain, config | |
| with open("config(RAG).json", "r") as f: | |
| config = json.load(f) | |
| embeddings = HuggingFaceEmbeddings(model_name=config["embedding_model"]) | |
| vectorstore = FAISS.load_local( | |
| ".", allow_dangerous_deserialization=True, embeddings=embeddings | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") | |
| model_obj = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") | |
| pipe = pipeline( | |
| "text2text-generation", model=model_obj, tokenizer=tokenizer, | |
| max_new_tokens=256, min_length=30, temperature=0.7, | |
| do_sample=True, top_p=0.9, repetition_penalty=1.2 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| retriever = vectorstore.as_retriever() | |
| template = """You are a medical knowledge assistant. Based on the medical context below, provide a detailed and accurate answer to the question. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Provide a comprehensive answer with specific medical details from the context:""" | |
| prompt = ChatPromptTemplate.from_template(template) | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| class RAGChain: | |
| def __init__(self, retriever, llm, prompt, vectorstore): | |
| self.retriever = retriever | |
| self.llm = llm | |
| self.prompt = prompt | |
| self.vectorstore = vectorstore | |
| def __call__(self, inputs, k=5): | |
| query = inputs["query"] | |
| dynamic_retriever = self.vectorstore.as_retriever(search_kwargs={"k": k}) | |
| source_docs = dynamic_retriever.invoke(query) | |
| context = format_docs(source_docs) | |
| prompt_text = self.prompt.format(context=context, question=query) | |
| answer = self.llm.invoke(prompt_text) | |
| return {"result": answer, "source_documents": source_docs} | |
| qa_chain = RAGChain(retriever, llm, prompt, vectorstore) | |
| return "β Model loaded successfully!" | |
| def answer_question(query, num_sources=5): | |
| if not qa_chain: | |
| return "β Model not loaded. Please wait for initialization.", "", "" | |
| if not query.strip(): | |
| return "Please enter a medical question.", "", "" | |
| try: | |
| result = qa_chain({"query": query}, k=num_sources) | |
| answer = result["result"] | |
| sources = result["source_documents"] | |
| sources_text = f"### π Retrieved {len(sources)} Sources:\n\n" | |
| for i, doc in enumerate(sources, 1): | |
| sources_text += f"**Source {i}:**\n{doc.page_content[:300]}...\n\n" | |
| metrics_text = calculate_metrics(query, answer, sources) | |
| return answer, sources_text, metrics_text | |
| except Exception as e: | |
| return f"β Error: {str(e)}", "", "" | |
| def calculate_metrics(query, answer, sources): | |
| try: | |
| from sentence_transformers import SentenceTransformer, util | |
| semantic_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb') | |
| query_embedding = semantic_model.encode(query, convert_to_tensor=True) | |
| answer_embedding = semantic_model.encode(answer, convert_to_tensor=True) | |
| relevance = util.cos_sim(query_embedding, answer_embedding).item() | |
| context = " ".join([doc.page_content for doc in sources]) | |
| context_embedding = semantic_model.encode(context[:500], convert_to_tensor=True) | |
| coherence = util.cos_sim(answer_embedding, context_embedding).item() | |
| medical_terms = { | |
| 'heart failure': ['lvef', 'ejection fraction', 'cardiac', 'ventricular', 'cardiomyopathy'], | |
| 'diabetes': ['glucose', 'insulin', 'a1c', 'hemoglobin', 'glycemic', 'hyperglycemia'], | |
| 'hypertension': ['blood pressure', 'systolic', 'diastolic', 'antihypertensive', 'bp'] | |
| } | |
| answer_lower = answer.lower() | |
| keywords_count = 0 | |
| for topic, keywords in medical_terms.items(): | |
| if any(term in query.lower() for term in topic.split()): | |
| keywords_count = sum(1 for kw in keywords if kw in answer_lower) | |
| break | |
| total_docs = len(vectorstore.docstore._dict) if vectorstore else 1000 | |
| retrieved = len(sources) | |
| tp = retrieved | |
| fp = 0 | |
| fn = 0 | |
| tn = total_docs - retrieved | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 1.0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| accuracy = (tp + tn) / (tp + tn + fp + fn) | |
| metrics = f"""### π Evaluation Metrics | |
| **Semantic Quality:** | |
| - π― **Relevance Score:** {relevance:.3f} | |
| - π **Coherence Score:** {coherence:.3f} | |
| - π₯ **Clinical Terms Found:** {keywords_count} | |
| **Retrieval Performance:** | |
| - β **Precision:** {precision:.3f} | |
| - π **Recall:** {recall:.3f} | |
| - π² **F1 Score:** {f1:.3f} | |
| - π **Accuracy:** {accuracy:.3f} | |
| """ | |
| return metrics | |
| except Exception as e: | |
| return f"### π Evaluation Metrics\n\nβ οΈ Error calculating metrics: {str(e)}" | |
| def get_model_info(): | |
| if config: | |
| info = f"""### π€ Model Configuration | |
| - **Embedding Model:** {config['embedding_model']} | |
| - **LLM Model:** google/flan-t5-base | |
| - **Documents Processed:** {config['num_docs']} | |
| - **Text Chunks:** {config['num_chunks']} | |
| - **Retrieval Documents:** {config['retrieval_k']} | |
| - **Storage:** ~20 MB (FAISS index only) | |
| """ | |
| return info | |
| return "Model configuration not available." | |
| # ===== Gradio Interface (Green Themed) ===== | |
| with gr.Blocks(title="Medical Q&A RAG System", theme=gr.themes.Monochrome(primary_hue="green")) as demo: | |
| gr.Markdown(""" | |
| # π₯ Medical Q&A RAG System | |
| ### Powered by MIMIC-IV Dataset, BioBERT & FLAN-T5 | |
| Ask medical questions and get evidence-based answers from clinical documentation. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox( | |
| label="π¬ Enter your medical question", | |
| placeholder="e.g., What are the diagnostic criteria for heart failure?", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Get Answer", variant="primary", elem_classes="green-btn") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary", elem_classes="green-btn") | |
| num_sources = gr.Slider( | |
| minimum=1, maximum=10, value=5, step=1, | |
| label="Number of source documents to display" | |
| ) | |
| answer_output = gr.Textbox( | |
| label="π‘ Answer", | |
| lines=8 | |
| ) | |
| sources_output = gr.Markdown( | |
| label="π Retrieved Sources" | |
| ) | |
| metrics_output = gr.Markdown( | |
| label="π Evaluation Metrics" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βΉοΈ About", elem_classes="green-title") | |
| gr.Markdown(""" | |
| This system uses: | |
| - **RAG (Retrieval-Augmented Generation)** to provide accurate medical answers | |
| - **MIMIC-IV** clinical dataset for knowledge base | |
| - **BioBERT** for medical text understanding | |
| - **FLAN-T5** for answer generation | |
| β οΈ **Disclaimer:** For educational purposes only. | |
| """) | |
| model_info = gr.Markdown(get_model_info()) | |
| gr.Markdown("### π Example Questions", elem_classes="green-title") | |
| examples = gr.Examples( | |
| examples=[ | |
| "What are the diagnostic criteria for heart failure with reduced ejection fraction?", | |
| "How is type 2 diabetes diagnosed?", | |
| "What is recommended for stage 2 hypertension?", | |
| "What are the symptoms of coronary artery disease?", | |
| "How is myocardial infarction treated?" | |
| ], | |
| inputs=query_input | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=answer_question, | |
| inputs=[query_input, num_sources], | |
| outputs=[answer_output, sources_output, metrics_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "", ""), | |
| outputs=[query_input, answer_output, sources_output, metrics_output] | |
| ) | |
| demo.load(fn=load_model, outputs=None) | |
| if __name__ == "__main__": | |
| demo.launch() | |