DiReCT_RAG / app.py
waleed-12's picture
Update app.py
64fa3f8 verified
import os
import json
import gradio as gr
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# Global variables for model components
vectorstore = None
qa_chain = None
config = None
def load_model():
"""Load the saved model artifacts."""
global vectorstore, qa_chain, config
with open("config(RAG).json", "r") as f:
config = json.load(f)
embeddings = HuggingFaceEmbeddings(model_name=config["embedding_model"])
vectorstore = FAISS.load_local(
".", allow_dangerous_deserialization=True, embeddings=embeddings
)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model_obj = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
pipe = pipeline(
"text2text-generation", model=model_obj, tokenizer=tokenizer,
max_new_tokens=256, min_length=30, temperature=0.7,
do_sample=True, top_p=0.9, repetition_penalty=1.2
)
llm = HuggingFacePipeline(pipeline=pipe)
retriever = vectorstore.as_retriever()
template = """You are a medical knowledge assistant. Based on the medical context below, provide a detailed and accurate answer to the question.
Context:
{context}
Question: {question}
Provide a comprehensive answer with specific medical details from the context:"""
prompt = ChatPromptTemplate.from_template(template)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
class RAGChain:
def __init__(self, retriever, llm, prompt, vectorstore):
self.retriever = retriever
self.llm = llm
self.prompt = prompt
self.vectorstore = vectorstore
def __call__(self, inputs, k=5):
query = inputs["query"]
dynamic_retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})
source_docs = dynamic_retriever.invoke(query)
context = format_docs(source_docs)
prompt_text = self.prompt.format(context=context, question=query)
answer = self.llm.invoke(prompt_text)
return {"result": answer, "source_documents": source_docs}
qa_chain = RAGChain(retriever, llm, prompt, vectorstore)
return "βœ… Model loaded successfully!"
def answer_question(query, num_sources=5):
if not qa_chain:
return "❌ Model not loaded. Please wait for initialization.", "", ""
if not query.strip():
return "Please enter a medical question.", "", ""
try:
result = qa_chain({"query": query}, k=num_sources)
answer = result["result"]
sources = result["source_documents"]
sources_text = f"### πŸ“š Retrieved {len(sources)} Sources:\n\n"
for i, doc in enumerate(sources, 1):
sources_text += f"**Source {i}:**\n{doc.page_content[:300]}...\n\n"
metrics_text = calculate_metrics(query, answer, sources)
return answer, sources_text, metrics_text
except Exception as e:
return f"❌ Error: {str(e)}", "", ""
def calculate_metrics(query, answer, sources):
try:
from sentence_transformers import SentenceTransformer, util
semantic_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
query_embedding = semantic_model.encode(query, convert_to_tensor=True)
answer_embedding = semantic_model.encode(answer, convert_to_tensor=True)
relevance = util.cos_sim(query_embedding, answer_embedding).item()
context = " ".join([doc.page_content for doc in sources])
context_embedding = semantic_model.encode(context[:500], convert_to_tensor=True)
coherence = util.cos_sim(answer_embedding, context_embedding).item()
medical_terms = {
'heart failure': ['lvef', 'ejection fraction', 'cardiac', 'ventricular', 'cardiomyopathy'],
'diabetes': ['glucose', 'insulin', 'a1c', 'hemoglobin', 'glycemic', 'hyperglycemia'],
'hypertension': ['blood pressure', 'systolic', 'diastolic', 'antihypertensive', 'bp']
}
answer_lower = answer.lower()
keywords_count = 0
for topic, keywords in medical_terms.items():
if any(term in query.lower() for term in topic.split()):
keywords_count = sum(1 for kw in keywords if kw in answer_lower)
break
total_docs = len(vectorstore.docstore._dict) if vectorstore else 1000
retrieved = len(sources)
tp = retrieved
fp = 0
fn = 0
tn = total_docs - retrieved
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 1.0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn)
metrics = f"""### πŸ“Š Evaluation Metrics
**Semantic Quality:**
- 🎯 **Relevance Score:** {relevance:.3f}
- πŸ”— **Coherence Score:** {coherence:.3f}
- πŸ₯ **Clinical Terms Found:** {keywords_count}
**Retrieval Performance:**
- βœ… **Precision:** {precision:.3f}
- πŸ“ˆ **Recall:** {recall:.3f}
- 🎲 **F1 Score:** {f1:.3f}
- πŸ“Š **Accuracy:** {accuracy:.3f}
"""
return metrics
except Exception as e:
return f"### πŸ“Š Evaluation Metrics\n\n⚠️ Error calculating metrics: {str(e)}"
def get_model_info():
if config:
info = f"""### πŸ€– Model Configuration
- **Embedding Model:** {config['embedding_model']}
- **LLM Model:** google/flan-t5-base
- **Documents Processed:** {config['num_docs']}
- **Text Chunks:** {config['num_chunks']}
- **Retrieval Documents:** {config['retrieval_k']}
- **Storage:** ~20 MB (FAISS index only)
"""
return info
return "Model configuration not available."
# ===== Gradio Interface (Green Themed) =====
with gr.Blocks(title="Medical Q&A RAG System", theme=gr.themes.Monochrome(primary_hue="green")) as demo:
gr.Markdown("""
# πŸ₯ Medical Q&A RAG System
### Powered by MIMIC-IV Dataset, BioBERT & FLAN-T5
Ask medical questions and get evidence-based answers from clinical documentation.
""")
with gr.Row():
with gr.Column(scale=2):
query_input = gr.Textbox(
label="πŸ’¬ Enter your medical question",
placeholder="e.g., What are the diagnostic criteria for heart failure?",
lines=3
)
with gr.Row():
submit_btn = gr.Button("πŸ” Get Answer", variant="primary", elem_classes="green-btn")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", elem_classes="green-btn")
num_sources = gr.Slider(
minimum=1, maximum=10, value=5, step=1,
label="Number of source documents to display"
)
answer_output = gr.Textbox(
label="πŸ’‘ Answer",
lines=8
)
sources_output = gr.Markdown(
label="πŸ“š Retrieved Sources"
)
metrics_output = gr.Markdown(
label="πŸ“Š Evaluation Metrics"
)
with gr.Column(scale=1):
gr.Markdown("### ℹ️ About", elem_classes="green-title")
gr.Markdown("""
This system uses:
- **RAG (Retrieval-Augmented Generation)** to provide accurate medical answers
- **MIMIC-IV** clinical dataset for knowledge base
- **BioBERT** for medical text understanding
- **FLAN-T5** for answer generation
⚠️ **Disclaimer:** For educational purposes only.
""")
model_info = gr.Markdown(get_model_info())
gr.Markdown("### πŸ“ Example Questions", elem_classes="green-title")
examples = gr.Examples(
examples=[
"What are the diagnostic criteria for heart failure with reduced ejection fraction?",
"How is type 2 diabetes diagnosed?",
"What is recommended for stage 2 hypertension?",
"What are the symptoms of coronary artery disease?",
"How is myocardial infarction treated?"
],
inputs=query_input
)
# Event handlers
submit_btn.click(
fn=answer_question,
inputs=[query_input, num_sources],
outputs=[answer_output, sources_output, metrics_output]
)
clear_btn.click(
fn=lambda: ("", "", "", ""),
outputs=[query_input, answer_output, sources_output, metrics_output]
)
demo.load(fn=load_model, outputs=None)
if __name__ == "__main__":
demo.launch()