Medical_Chatbot / app.py
vulcan2506's picture
Update app.py
46f72c5 verified
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import zipfile
import gradio as gr
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
import tensorflow as tf
from tensorflow.keras.preprocessing.image import img_to_array
# LangChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers
from langchain_core.prompts import PromptTemplate
from langchain_classic.chains import RetrievalQA
# ---------------- CONFIG ----------------
CLASSIFICATION_MODEL_PATH = "health_resnet101_lite.ptl"
SEGMENTATION_MODEL_PATH = "segmentation.tflite"
DB_FAISS_PATH = 'vectorstores/db_faiss'
ZIP_PATH = "vectorstores.zip"
LLM_MODEL_NAME = "TheBloke/Llama-2-7B-Chat-GGML"
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
# ---------------- GLOBAL STATE ----------------
clf_model, device, seg_model, rag_chain = None, None, None, None
last_prediction = None
# ---------------- UNZIP ----------------
def ensure_vectorstore():
if os.path.exists(DB_FAISS_PATH):
return
if not os.path.exists(ZIP_PATH):
raise FileNotFoundError("vectorstores.zip not found")
os.makedirs("vectorstores", exist_ok=True)
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
zip_ref.extractall("vectorstores")
# Fix structure if nested
possible_paths = [
"vectorstores/db_faiss",
"vectorstores/vectorstores/db_faiss",
"vectorstores/db_faiss/db_faiss"
]
for path in possible_paths:
if os.path.exists(path):
if path != DB_FAISS_PATH:
os.rename(path, DB_FAISS_PATH)
return
raise FileNotFoundError("db_faiss not found after extraction")
# ---------------- CLASSES ----------------
CLASSES = [ 'Astrocitoma T1', 'Astrocitoma T1C+', 'Astrocitoma T2', 'BC - Benign', 'BC - Early', 'BC - Pre', 'BC - Pro', 'Carcinoma T1', 'Carcinoma T1C+', 'Carcinoma T2', 'Ependimoma T1', 'Ependimoma T1C+', 'Ependimoma T2', 'Ganglioglioma T1', 'Ganglioglioma T1C+', 'Ganglioglioma T2', 'Germinoma T1', 'Germinoma T1C+', 'Germinoma T2', 'Glioblastoma T1', 'Glioblastoma T1C+', 'Glioblastoma T2', 'Granuloma T1', 'Granuloma T1C+', 'Granuloma T2', 'Meduloblastoma T1', 'Meduloblastoma T1C+', 'Meduloblastoma T2', 'Meningioma T1', 'Meningioma T1C+', 'Meningioma T2', 'Neurocitoma T1', 'Neurocitoma T1C+', 'Neurocitoma T2', 'Oligodendroglioma T1', 'Oligodendroglioma T1C+', 'Oligodendroglioma T2', 'Papiloma T1', 'Papiloma T1C+', 'Papiloma T2', 'Schwannoma T1', 'Schwannoma T1C+', 'Schwannoma T2', 'Tuberculoma T1', 'Tuberculoma T1C+', 'Tuberculoma T2', '_NORMAL T1', '_NORMAL T2' ]
TUMOR_KEYWORDS = [ 'Astrocitoma', 'Carcinoma', 'Ependimoma', 'Ganglioglioma', 'Germinoma', 'Glioblastoma', 'Granuloma', 'Meduloblastoma', 'Meningioma', 'Neurocitoma', 'Oligodendroglioma', 'Papiloma', 'Schwannoma', 'Tuberculoma' ]
# ---------------- TRANSFORMS ----------------
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# ---------------- LOAD MODELS ----------------
def load_models():
device = "cpu"
ensure_vectorstore()
clf_model = torch.jit.load(CLASSIFICATION_MODEL_PATH, map_location=device)
clf_model.eval()
interpreter = tf.lite.Interpreter(model_path=SEGMENTATION_MODEL_PATH)
interpreter.allocate_tensors()
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
llm = CTransformers(
model=LLM_MODEL_NAME,
model_type="llama",
max_new_tokens=256,
temperature=0.5
)
# ✅ IMPROVED PROMPT
prompt = PromptTemplate(
template="""
You are a medical assistant.
Use ONLY relevant and clean information from the context.
- Ignore broken sentences, MCQs, or random fragments
- Do NOT repeat the context
- Give a clear, structured explanation
Context:
{context}
Question:
{question}
Answer:
""",
input_variables=["context", "question"]
)
retriever = db.as_retriever(search_kwargs={"k": 2})
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type_kwargs={'prompt': prompt}
)
return clf_model, device, interpreter, qa_chain
# ---------------- CLEANING ----------------
def clean_text(text):
lines = text.split("\n")
clean_lines = []
for line in lines:
line = line.strip()
# remove short / noisy lines
if len(line) < 25:
continue
# remove MCQ-style junk
if any(x in line.lower() for x in ["question", "option", "choose", "correct answer"]):
continue
clean_lines.append(line)
return "\n".join(clean_lines)
def ensure_models_loaded():
global clf_model, device, seg_model, rag_chain
if clf_model is None:
clf_model, device, seg_model, rag_chain = load_models()
# ---------------- CUSTOM RAG ----------------
def get_clean_rag_response(query):
ensure_models_loaded()
if rag_chain is None:
return "Model not loaded properly."
# Step 1: retrieve docs
docs = rag_chain.retriever.invoke(query)
# Step 2: clean docs
cleaned_docs = []
for doc in docs:
cleaned_text = clean_text(doc.page_content)
if cleaned_text.strip():
doc.page_content = cleaned_text
cleaned_docs.append(doc)
# Step 3: call chain correctly
response = rag_chain.invoke({
"query": query,
"input_documents": cleaned_docs
})
return response["result"]
# ---------------- LLM ON DEMAND ----------------
def generate_explanation():
ensure_models_loaded()
if last_prediction is None:
return "Please analyze an image first."
query = f"Explain {last_prediction}"
return get_clean_rag_response(query)
def ask_question(q):
ensure_models_loaded()
if not q:
return "Ask something..."
return get_clean_rag_response(q)
# ---------------- FUNCTIONS ----------------
def classify_image(image):
input_tensor = data_transforms(image).unsqueeze(0)
with torch.no_grad():
output = clf_model(input_tensor)
_, pred = torch.max(output, 1)
return CLASSES[pred.item()]
def segment_image(image):
input_details = seg_model.get_input_details()
output_details = seg_model.get_output_details()
img = image.convert('L').resize((128, 128))
arr = img_to_array(img) / 255.0
arr = np.expand_dims(arr, axis=0).astype(np.float32)
seg_model.set_tensor(input_details[0]['index'], arr)
seg_model.invoke()
mask = seg_model.get_tensor(output_details[0]['index'])[0]
mask = (mask > 0.5).astype(np.uint8) * 255
return Image.fromarray(mask.squeeze(), 'L')
def overlay(image, mask):
mask = mask.resize(image.size)
img_np = np.array(image.convert("RGB"))
mask_np = np.array(mask)
colored = np.zeros_like(img_np)
colored[mask_np > 128] = [255, 0, 0]
return Image.fromarray((img_np * 0.5 + colored * 0.5).astype(np.uint8))
# ---------------- FAST ANALYSIS ----------------
def analyze(image):
global last_prediction
ensure_models_loaded()
if image is None:
return "No image uploaded", None
pred = classify_image(image)
last_prediction = pred
if any(k in pred for k in TUMOR_KEYWORDS):
mask = segment_image(image)
overlay_img = overlay(image, mask)
else:
overlay_img = None
return pred, overlay_img
# ---------------- UI ----------------
with gr.Blocks() as demo:
gr.Markdown("# 🩺 Medical AI Assistant")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
analyze_btn = gr.Button("Analyze")
with gr.Column():
prediction = gr.Textbox(label="Prediction")
segmented_output = gr.Image(label="Segmentation")
explain_btn = gr.Button("🧠 Generate Explanation")
rag_output = gr.Textbox(label="Detailed Explanation")
gr.Markdown("## ❓ Ask Questions")
question = gr.Textbox()
ask_btn = gr.Button("Ask")
answer = gr.Textbox()
analyze_btn.click(
analyze,
inputs=image_input,
outputs=[prediction, segmented_output]
)
explain_btn.click(
generate_explanation,
inputs=[],
outputs=rag_output
)
ask_btn.click(
ask_question,
inputs=question,
outputs=answer
)
demo.launch()