import os import gradio as gr import requests from dotenv import load_dotenv from PIL import ImageDraw, ImageFont from ultralytics import YOLO YOLO_WEIGHTS = "best.pt" GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions" WINDOWS_XP_COLORS = { "bg": "#ece9d8", "title": "#0053e1", "status": "#f3f3f3", "border": "#808080", } load_dotenv() GROQ_API_KEY = os.getenv("GROQ_API_KEY", "") GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile") custom_css = f""" body {{ background: {WINDOWS_XP_COLORS["bg"]}; font-family: Tahoma, Verdana, sans-serif; }} .gradio-container {{ border: 2px solid {WINDOWS_XP_COLORS["border"]}; background: {WINDOWS_XP_COLORS["bg"]}; border-radius: 6px; max-width: 700px; margin: 32px auto; box-shadow: 0 4px 16px #bbb; }} .gradio-title {{ background: {WINDOWS_XP_COLORS["title"]}; color: #fff; padding: 10px 16px; font-size: 20px; border-top-left-radius: 6px; border-top-right-radius: 6px; margin-bottom: 0; }} .status-bar {{ background: {WINDOWS_XP_COLORS["status"]}; color: #333; padding: 6px 16px; font-size: 13px; border-bottom-left-radius: 6px; border-bottom-right-radius: 6px; border-top: 1px solid {WINDOWS_XP_COLORS["border"]}; margin-top: 0; }} """ class DetectionModule: def __init__(self, weights_path): if not os.path.exists(weights_path): raise FileNotFoundError(f"YOLO weights not found: {weights_path}") self.model = YOLO(weights_path) def run(self, image): if image is None: return [] results = self.model(image, verbose=False) detections = [] for result in results: names = result.names for box in result.boxes: cls_idx = int(box.cls.item()) conf = float(box.conf.item()) x1, y1, x2, y2 = box.xyxy[0].tolist() detections.append( { "class": names.get(cls_idx, str(cls_idx)), "conf": conf, "box": [x1, y1, x2, y2], } ) return detections class ExplanationModule: def __init__(self, api_key, api_url=GROQ_API_URL): self.api_key = api_key self.api_url = api_url def generate(self, detections): if not self.api_key: return "[Groq API key not set. Cannot generate explanation.]" if not detections: return "No tumor detected with sufficient confidence." det_lines = [f"- Tumor type: {d['class']}, Confidence: {d['conf']:.2f}" for d in detections] prompt = ( "You are a medical AI assistant.\n" "Input:\n" f"Detection count: {len(detections)}\n" + "\n".join(det_lines) + "\nExplain in simple terms:\n" "- What was detected\n" "- What confidence means\n" "- Avoid medical diagnosis\n" "- Add disclaimer\n" ) headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } data = { "model": GROQ_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 256, "temperature": 0.2, } try: response = requests.post(self.api_url, headers=headers, json=data, timeout=10) response.raise_for_status() payload = response.json() return payload["choices"][0]["message"]["content"].strip() except Exception as exc: return f"[Groq API error: {exc}]" class VisualizationPipeline: def __init__(self): self.font = ImageFont.load_default() self.box_color = (0, 83, 225) self.text_color = (0, 0, 0) def draw(self, image, detections): rendered = image.convert("RGB").copy() draw = ImageDraw.Draw(rendered) for detection in detections: x1, y1, x2, y2 = map(int, detection["box"]) label = f"{detection['class']} ({detection['conf']:.2f})" draw.rectangle([x1, y1, x2, y2], outline=self.box_color, width=3) draw.text((x1, max(0, y1 - 16)), label, fill=self.text_color, font=self.font) return rendered class InferenceOrchestrator: def __init__(self, detection_module, explanation_module, visualization): self.detection = detection_module self.explanation = explanation_module self.visualization = visualization def predict(self, image): detections = self.detection.run(image) visual = self.visualization.draw(image, detections) explanation = self.explanation.generate(detections) if detections: top = max(detections, key=lambda item: item["conf"]) return visual, top["class"], top["conf"], explanation return visual, "no tumor", 0.0, explanation detection_module = DetectionModule(YOLO_WEIGHTS) explanation_module = ExplanationModule(GROQ_API_KEY) visualization = VisualizationPipeline() orchestrator = InferenceOrchestrator(detection_module, explanation_module, visualization) def set_ready(): return "Ready" def analyze(image): if image is None: return "Upload an MRI image to analyze.", None, "", 0.0, "" visual, tumor, conf, expl = orchestrator.predict(image) return "Analysis complete.", visual, tumor, conf, expl with gr.Blocks(title="Neuro-Oncology MRI Inference Console") as demo: gr.Markdown( "
Neuro-Oncology MRI Inference Console
" "
YOLO-based lesion localization with structured LLM-assisted explanation for research workflows.
" ) with gr.Row(): with gr.Column(): image_in = gr.Image(type="pil", label="Upload MRI Image", elem_id="img-in") status = gr.Markdown("Initializing inference pipeline...", elem_id="status-bar") with gr.Column(): image_out = gr.Image(type="pil", label="Annotated MRI Output", elem_id="img-out") tumor_type = gr.Textbox(label="Predicted Finding", interactive=False) confidence = gr.Number(label="Detection Confidence", interactive=False) explanation = gr.Textbox(label="Structured Interpretation Summary", lines=6, interactive=False) demo.load(set_ready, None, status) analyze_btn = gr.Button("Run Inference", elem_id="analyze-btn", interactive=True) analyze_btn.click( analyze, inputs=[image_in], outputs=[status, image_out, tumor_type, confidence, explanation], ) gr.Markdown("
For research use only. Not for clinical diagnosis.
") if __name__ == "__main__": launch_kwargs = { "theme": gr.themes.Base(), "css": custom_css, "show_error": True, } if os.getenv("SPACE_ID"): launch_kwargs["server_name"] = "0.0.0.0" port = os.getenv("PORT") if port: launch_kwargs["server_port"] = int(port) demo.launch(**launch_kwargs)