import os os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" import zipfile import gradio as gr from PIL import Image import numpy as np import torch from torchvision import transforms import tensorflow as tf from tensorflow.keras.preprocessing.image import img_to_array # LangChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.llms import CTransformers from langchain_core.prompts import PromptTemplate from langchain_classic.chains import RetrievalQA # ---------------- CONFIG ---------------- CLASSIFICATION_MODEL_PATH = "health_resnet101_lite.ptl" SEGMENTATION_MODEL_PATH = "segmentation.tflite" DB_FAISS_PATH = 'vectorstores/db_faiss' ZIP_PATH = "vectorstores.zip" LLM_MODEL_NAME = "TheBloke/Llama-2-7B-Chat-GGML" EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" # ---------------- GLOBAL STATE ---------------- clf_model, device, seg_model, rag_chain = None, None, None, None last_prediction = None # ---------------- UNZIP ---------------- def ensure_vectorstore(): if os.path.exists(DB_FAISS_PATH): return if not os.path.exists(ZIP_PATH): raise FileNotFoundError("vectorstores.zip not found") os.makedirs("vectorstores", exist_ok=True) with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref: zip_ref.extractall("vectorstores") # Fix structure if nested possible_paths = [ "vectorstores/db_faiss", "vectorstores/vectorstores/db_faiss", "vectorstores/db_faiss/db_faiss" ] for path in possible_paths: if os.path.exists(path): if path != DB_FAISS_PATH: os.rename(path, DB_FAISS_PATH) return raise FileNotFoundError("db_faiss not found after extraction") # ---------------- CLASSES ---------------- CLASSES = [ 'Astrocitoma T1', 'Astrocitoma T1C+', 'Astrocitoma T2', 'BC - Benign', 'BC - Early', 'BC - Pre', 'BC - Pro', 'Carcinoma T1', 'Carcinoma T1C+', 'Carcinoma T2', 'Ependimoma T1', 'Ependimoma T1C+', 'Ependimoma T2', 'Ganglioglioma T1', 'Ganglioglioma T1C+', 'Ganglioglioma T2', 'Germinoma T1', 'Germinoma T1C+', 'Germinoma T2', 'Glioblastoma T1', 'Glioblastoma T1C+', 'Glioblastoma T2', 'Granuloma T1', 'Granuloma T1C+', 'Granuloma T2', 'Meduloblastoma T1', 'Meduloblastoma T1C+', 'Meduloblastoma T2', 'Meningioma T1', 'Meningioma T1C+', 'Meningioma T2', 'Neurocitoma T1', 'Neurocitoma T1C+', 'Neurocitoma T2', 'Oligodendroglioma T1', 'Oligodendroglioma T1C+', 'Oligodendroglioma T2', 'Papiloma T1', 'Papiloma T1C+', 'Papiloma T2', 'Schwannoma T1', 'Schwannoma T1C+', 'Schwannoma T2', 'Tuberculoma T1', 'Tuberculoma T1C+', 'Tuberculoma T2', '_NORMAL T1', '_NORMAL T2' ] TUMOR_KEYWORDS = [ 'Astrocitoma', 'Carcinoma', 'Ependimoma', 'Ganglioglioma', 'Germinoma', 'Glioblastoma', 'Granuloma', 'Meduloblastoma', 'Meningioma', 'Neurocitoma', 'Oligodendroglioma', 'Papiloma', 'Schwannoma', 'Tuberculoma' ] # ---------------- TRANSFORMS ---------------- data_transforms = transforms.Compose([ transforms.Resize(256), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ---------------- LOAD MODELS ---------------- def load_models(): device = "cpu" ensure_vectorstore() clf_model = torch.jit.load(CLASSIFICATION_MODEL_PATH, map_location=device) clf_model.eval() interpreter = tf.lite.Interpreter(model_path=SEGMENTATION_MODEL_PATH) interpreter.allocate_tensors() embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) llm = CTransformers( model=LLM_MODEL_NAME, model_type="llama", max_new_tokens=256, temperature=0.5 ) # ✅ IMPROVED PROMPT prompt = PromptTemplate( template=""" You are a medical assistant. Use ONLY relevant and clean information from the context. - Ignore broken sentences, MCQs, or random fragments - Do NOT repeat the context - Give a clear, structured explanation Context: {context} Question: {question} Answer: """, input_variables=["context", "question"] ) retriever = db.as_retriever(search_kwargs={"k": 2}) qa_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, chain_type_kwargs={'prompt': prompt} ) return clf_model, device, interpreter, qa_chain # ---------------- CLEANING ---------------- def clean_text(text): lines = text.split("\n") clean_lines = [] for line in lines: line = line.strip() # remove short / noisy lines if len(line) < 25: continue # remove MCQ-style junk if any(x in line.lower() for x in ["question", "option", "choose", "correct answer"]): continue clean_lines.append(line) return "\n".join(clean_lines) def ensure_models_loaded(): global clf_model, device, seg_model, rag_chain if clf_model is None: clf_model, device, seg_model, rag_chain = load_models() # ---------------- CUSTOM RAG ---------------- def get_clean_rag_response(query): ensure_models_loaded() if rag_chain is None: return "Model not loaded properly." # Step 1: retrieve docs docs = rag_chain.retriever.invoke(query) # Step 2: clean docs cleaned_docs = [] for doc in docs: cleaned_text = clean_text(doc.page_content) if cleaned_text.strip(): doc.page_content = cleaned_text cleaned_docs.append(doc) # Step 3: call chain correctly response = rag_chain.invoke({ "query": query, "input_documents": cleaned_docs }) return response["result"] # ---------------- LLM ON DEMAND ---------------- def generate_explanation(): ensure_models_loaded() if last_prediction is None: return "Please analyze an image first." query = f"Explain {last_prediction}" return get_clean_rag_response(query) def ask_question(q): ensure_models_loaded() if not q: return "Ask something..." return get_clean_rag_response(q) # ---------------- FUNCTIONS ---------------- def classify_image(image): input_tensor = data_transforms(image).unsqueeze(0) with torch.no_grad(): output = clf_model(input_tensor) _, pred = torch.max(output, 1) return CLASSES[pred.item()] def segment_image(image): input_details = seg_model.get_input_details() output_details = seg_model.get_output_details() img = image.convert('L').resize((128, 128)) arr = img_to_array(img) / 255.0 arr = np.expand_dims(arr, axis=0).astype(np.float32) seg_model.set_tensor(input_details[0]['index'], arr) seg_model.invoke() mask = seg_model.get_tensor(output_details[0]['index'])[0] mask = (mask > 0.5).astype(np.uint8) * 255 return Image.fromarray(mask.squeeze(), 'L') def overlay(image, mask): mask = mask.resize(image.size) img_np = np.array(image.convert("RGB")) mask_np = np.array(mask) colored = np.zeros_like(img_np) colored[mask_np > 128] = [255, 0, 0] return Image.fromarray((img_np * 0.5 + colored * 0.5).astype(np.uint8)) # ---------------- FAST ANALYSIS ---------------- def analyze(image): global last_prediction ensure_models_loaded() if image is None: return "No image uploaded", None pred = classify_image(image) last_prediction = pred if any(k in pred for k in TUMOR_KEYWORDS): mask = segment_image(image) overlay_img = overlay(image, mask) else: overlay_img = None return pred, overlay_img # ---------------- UI ---------------- with gr.Blocks() as demo: gr.Markdown("# 🩺 Medical AI Assistant") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") analyze_btn = gr.Button("Analyze") with gr.Column(): prediction = gr.Textbox(label="Prediction") segmented_output = gr.Image(label="Segmentation") explain_btn = gr.Button("🧠 Generate Explanation") rag_output = gr.Textbox(label="Detailed Explanation") gr.Markdown("## ❓ Ask Questions") question = gr.Textbox() ask_btn = gr.Button("Ask") answer = gr.Textbox() analyze_btn.click( analyze, inputs=image_input, outputs=[prediction, segmented_output] ) explain_btn.click( generate_explanation, inputs=[], outputs=rag_output ) ask_btn.click( ask_question, inputs=question, outputs=answer ) demo.launch()