Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" | |
| import zipfile | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| # LangChain | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.llms import CTransformers | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_classic.chains import RetrievalQA | |
| # ---------------- CONFIG ---------------- | |
| CLASSIFICATION_MODEL_PATH = "health_resnet101_lite.ptl" | |
| SEGMENTATION_MODEL_PATH = "segmentation.tflite" | |
| DB_FAISS_PATH = 'vectorstores/db_faiss' | |
| ZIP_PATH = "vectorstores.zip" | |
| LLM_MODEL_NAME = "TheBloke/Llama-2-7B-Chat-GGML" | |
| EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| # ---------------- GLOBAL STATE ---------------- | |
| clf_model, device, seg_model, rag_chain = None, None, None, None | |
| last_prediction = None | |
| # ---------------- UNZIP ---------------- | |
| def ensure_vectorstore(): | |
| if os.path.exists(DB_FAISS_PATH): | |
| return | |
| if not os.path.exists(ZIP_PATH): | |
| raise FileNotFoundError("vectorstores.zip not found") | |
| os.makedirs("vectorstores", exist_ok=True) | |
| with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref: | |
| zip_ref.extractall("vectorstores") | |
| # Fix structure if nested | |
| possible_paths = [ | |
| "vectorstores/db_faiss", | |
| "vectorstores/vectorstores/db_faiss", | |
| "vectorstores/db_faiss/db_faiss" | |
| ] | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| if path != DB_FAISS_PATH: | |
| os.rename(path, DB_FAISS_PATH) | |
| return | |
| raise FileNotFoundError("db_faiss not found after extraction") | |
| # ---------------- CLASSES ---------------- | |
| CLASSES = [ 'Astrocitoma T1', 'Astrocitoma T1C+', 'Astrocitoma T2', 'BC - Benign', 'BC - Early', 'BC - Pre', 'BC - Pro', 'Carcinoma T1', 'Carcinoma T1C+', 'Carcinoma T2', 'Ependimoma T1', 'Ependimoma T1C+', 'Ependimoma T2', 'Ganglioglioma T1', 'Ganglioglioma T1C+', 'Ganglioglioma T2', 'Germinoma T1', 'Germinoma T1C+', 'Germinoma T2', 'Glioblastoma T1', 'Glioblastoma T1C+', 'Glioblastoma T2', 'Granuloma T1', 'Granuloma T1C+', 'Granuloma T2', 'Meduloblastoma T1', 'Meduloblastoma T1C+', 'Meduloblastoma T2', 'Meningioma T1', 'Meningioma T1C+', 'Meningioma T2', 'Neurocitoma T1', 'Neurocitoma T1C+', 'Neurocitoma T2', 'Oligodendroglioma T1', 'Oligodendroglioma T1C+', 'Oligodendroglioma T2', 'Papiloma T1', 'Papiloma T1C+', 'Papiloma T2', 'Schwannoma T1', 'Schwannoma T1C+', 'Schwannoma T2', 'Tuberculoma T1', 'Tuberculoma T1C+', 'Tuberculoma T2', '_NORMAL T1', '_NORMAL T2' ] | |
| TUMOR_KEYWORDS = [ 'Astrocitoma', 'Carcinoma', 'Ependimoma', 'Ganglioglioma', 'Germinoma', 'Glioblastoma', 'Granuloma', 'Meduloblastoma', 'Meningioma', 'Neurocitoma', 'Oligodendroglioma', 'Papiloma', 'Schwannoma', 'Tuberculoma' ] | |
| # ---------------- TRANSFORMS ---------------- | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.Grayscale(num_output_channels=3), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| # ---------------- LOAD MODELS ---------------- | |
| def load_models(): | |
| device = "cpu" | |
| ensure_vectorstore() | |
| clf_model = torch.jit.load(CLASSIFICATION_MODEL_PATH, map_location=device) | |
| clf_model.eval() | |
| interpreter = tf.lite.Interpreter(model_path=SEGMENTATION_MODEL_PATH) | |
| interpreter.allocate_tensors() | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
| db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) | |
| llm = CTransformers( | |
| model=LLM_MODEL_NAME, | |
| model_type="llama", | |
| max_new_tokens=256, | |
| temperature=0.5 | |
| ) | |
| # ✅ IMPROVED PROMPT | |
| prompt = PromptTemplate( | |
| template=""" | |
| You are a medical assistant. | |
| Use ONLY relevant and clean information from the context. | |
| - Ignore broken sentences, MCQs, or random fragments | |
| - Do NOT repeat the context | |
| - Give a clear, structured explanation | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer: | |
| """, | |
| input_variables=["context", "question"] | |
| ) | |
| retriever = db.as_retriever(search_kwargs={"k": 2}) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=retriever, | |
| chain_type_kwargs={'prompt': prompt} | |
| ) | |
| return clf_model, device, interpreter, qa_chain | |
| # ---------------- CLEANING ---------------- | |
| def clean_text(text): | |
| lines = text.split("\n") | |
| clean_lines = [] | |
| for line in lines: | |
| line = line.strip() | |
| # remove short / noisy lines | |
| if len(line) < 25: | |
| continue | |
| # remove MCQ-style junk | |
| if any(x in line.lower() for x in ["question", "option", "choose", "correct answer"]): | |
| continue | |
| clean_lines.append(line) | |
| return "\n".join(clean_lines) | |
| def ensure_models_loaded(): | |
| global clf_model, device, seg_model, rag_chain | |
| if clf_model is None: | |
| clf_model, device, seg_model, rag_chain = load_models() | |
| # ---------------- CUSTOM RAG ---------------- | |
| def get_clean_rag_response(query): | |
| ensure_models_loaded() | |
| if rag_chain is None: | |
| return "Model not loaded properly." | |
| # Step 1: retrieve docs | |
| docs = rag_chain.retriever.invoke(query) | |
| # Step 2: clean docs | |
| cleaned_docs = [] | |
| for doc in docs: | |
| cleaned_text = clean_text(doc.page_content) | |
| if cleaned_text.strip(): | |
| doc.page_content = cleaned_text | |
| cleaned_docs.append(doc) | |
| # Step 3: call chain correctly | |
| response = rag_chain.invoke({ | |
| "query": query, | |
| "input_documents": cleaned_docs | |
| }) | |
| return response["result"] | |
| # ---------------- LLM ON DEMAND ---------------- | |
| def generate_explanation(): | |
| ensure_models_loaded() | |
| if last_prediction is None: | |
| return "Please analyze an image first." | |
| query = f"Explain {last_prediction}" | |
| return get_clean_rag_response(query) | |
| def ask_question(q): | |
| ensure_models_loaded() | |
| if not q: | |
| return "Ask something..." | |
| return get_clean_rag_response(q) | |
| # ---------------- FUNCTIONS ---------------- | |
| def classify_image(image): | |
| input_tensor = data_transforms(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = clf_model(input_tensor) | |
| _, pred = torch.max(output, 1) | |
| return CLASSES[pred.item()] | |
| def segment_image(image): | |
| input_details = seg_model.get_input_details() | |
| output_details = seg_model.get_output_details() | |
| img = image.convert('L').resize((128, 128)) | |
| arr = img_to_array(img) / 255.0 | |
| arr = np.expand_dims(arr, axis=0).astype(np.float32) | |
| seg_model.set_tensor(input_details[0]['index'], arr) | |
| seg_model.invoke() | |
| mask = seg_model.get_tensor(output_details[0]['index'])[0] | |
| mask = (mask > 0.5).astype(np.uint8) * 255 | |
| return Image.fromarray(mask.squeeze(), 'L') | |
| def overlay(image, mask): | |
| mask = mask.resize(image.size) | |
| img_np = np.array(image.convert("RGB")) | |
| mask_np = np.array(mask) | |
| colored = np.zeros_like(img_np) | |
| colored[mask_np > 128] = [255, 0, 0] | |
| return Image.fromarray((img_np * 0.5 + colored * 0.5).astype(np.uint8)) | |
| # ---------------- FAST ANALYSIS ---------------- | |
| def analyze(image): | |
| global last_prediction | |
| ensure_models_loaded() | |
| if image is None: | |
| return "No image uploaded", None | |
| pred = classify_image(image) | |
| last_prediction = pred | |
| if any(k in pred for k in TUMOR_KEYWORDS): | |
| mask = segment_image(image) | |
| overlay_img = overlay(image, mask) | |
| else: | |
| overlay_img = None | |
| return pred, overlay_img | |
| # ---------------- UI ---------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🩺 Medical AI Assistant") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| analyze_btn = gr.Button("Analyze") | |
| with gr.Column(): | |
| prediction = gr.Textbox(label="Prediction") | |
| segmented_output = gr.Image(label="Segmentation") | |
| explain_btn = gr.Button("🧠 Generate Explanation") | |
| rag_output = gr.Textbox(label="Detailed Explanation") | |
| gr.Markdown("## ❓ Ask Questions") | |
| question = gr.Textbox() | |
| ask_btn = gr.Button("Ask") | |
| answer = gr.Textbox() | |
| analyze_btn.click( | |
| analyze, | |
| inputs=image_input, | |
| outputs=[prediction, segmented_output] | |
| ) | |
| explain_btn.click( | |
| generate_explanation, | |
| inputs=[], | |
| outputs=rag_output | |
| ) | |
| ask_btn.click( | |
| ask_question, | |
| inputs=question, | |
| outputs=answer | |
| ) | |
| demo.launch() |