adityaverma977
Prepare Hugging Face Space
4110b90
import os
import gradio as gr
import requests
from dotenv import load_dotenv
from PIL import ImageDraw, ImageFont
from ultralytics import YOLO
YOLO_WEIGHTS = "best.pt"
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
WINDOWS_XP_COLORS = {
"bg": "#ece9d8",
"title": "#0053e1",
"status": "#f3f3f3",
"border": "#808080",
}
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
custom_css = f"""
body {{ background: {WINDOWS_XP_COLORS["bg"]}; font-family: Tahoma, Verdana, sans-serif; }}
.gradio-container {{
border: 2px solid {WINDOWS_XP_COLORS["border"]};
background: {WINDOWS_XP_COLORS["bg"]};
border-radius: 6px;
max-width: 700px;
margin: 32px auto;
box-shadow: 0 4px 16px #bbb;
}}
.gradio-title {{
background: {WINDOWS_XP_COLORS["title"]};
color: #fff;
padding: 10px 16px;
font-size: 20px;
border-top-left-radius: 6px;
border-top-right-radius: 6px;
margin-bottom: 0;
}}
.status-bar {{
background: {WINDOWS_XP_COLORS["status"]};
color: #333;
padding: 6px 16px;
font-size: 13px;
border-bottom-left-radius: 6px;
border-bottom-right-radius: 6px;
border-top: 1px solid {WINDOWS_XP_COLORS["border"]};
margin-top: 0;
}}
"""
class DetectionModule:
def __init__(self, weights_path):
if not os.path.exists(weights_path):
raise FileNotFoundError(f"YOLO weights not found: {weights_path}")
self.model = YOLO(weights_path)
def run(self, image):
if image is None:
return []
results = self.model(image, verbose=False)
detections = []
for result in results:
names = result.names
for box in result.boxes:
cls_idx = int(box.cls.item())
conf = float(box.conf.item())
x1, y1, x2, y2 = box.xyxy[0].tolist()
detections.append(
{
"class": names.get(cls_idx, str(cls_idx)),
"conf": conf,
"box": [x1, y1, x2, y2],
}
)
return detections
class ExplanationModule:
def __init__(self, api_key, api_url=GROQ_API_URL):
self.api_key = api_key
self.api_url = api_url
def generate(self, detections):
if not self.api_key:
return "[Groq API key not set. Cannot generate explanation.]"
if not detections:
return "No tumor detected with sufficient confidence."
det_lines = [f"- Tumor type: {d['class']}, Confidence: {d['conf']:.2f}" for d in detections]
prompt = (
"You are a medical AI assistant.\n"
"Input:\n"
f"Detection count: {len(detections)}\n"
+ "\n".join(det_lines)
+ "\nExplain in simple terms:\n"
"- What was detected\n"
"- What confidence means\n"
"- Avoid medical diagnosis\n"
"- Add disclaimer\n"
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
data = {
"model": GROQ_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 256,
"temperature": 0.2,
}
try:
response = requests.post(self.api_url, headers=headers, json=data, timeout=10)
response.raise_for_status()
payload = response.json()
return payload["choices"][0]["message"]["content"].strip()
except Exception as exc:
return f"[Groq API error: {exc}]"
class VisualizationPipeline:
def __init__(self):
self.font = ImageFont.load_default()
self.box_color = (0, 83, 225)
self.text_color = (0, 0, 0)
def draw(self, image, detections):
rendered = image.convert("RGB").copy()
draw = ImageDraw.Draw(rendered)
for detection in detections:
x1, y1, x2, y2 = map(int, detection["box"])
label = f"{detection['class']} ({detection['conf']:.2f})"
draw.rectangle([x1, y1, x2, y2], outline=self.box_color, width=3)
draw.text((x1, max(0, y1 - 16)), label, fill=self.text_color, font=self.font)
return rendered
class InferenceOrchestrator:
def __init__(self, detection_module, explanation_module, visualization):
self.detection = detection_module
self.explanation = explanation_module
self.visualization = visualization
def predict(self, image):
detections = self.detection.run(image)
visual = self.visualization.draw(image, detections)
explanation = self.explanation.generate(detections)
if detections:
top = max(detections, key=lambda item: item["conf"])
return visual, top["class"], top["conf"], explanation
return visual, "no tumor", 0.0, explanation
detection_module = DetectionModule(YOLO_WEIGHTS)
explanation_module = ExplanationModule(GROQ_API_KEY)
visualization = VisualizationPipeline()
orchestrator = InferenceOrchestrator(detection_module, explanation_module, visualization)
def set_ready():
return "Ready"
def analyze(image):
if image is None:
return "Upload an MRI image to analyze.", None, "", 0.0, ""
visual, tumor, conf, expl = orchestrator.predict(image)
return "Analysis complete.", visual, tumor, conf, expl
with gr.Blocks(title="Neuro-Oncology MRI Inference Console") as demo:
gr.Markdown(
"<div class='gradio-title'>Neuro-Oncology MRI Inference Console</div>"
"<div class='status-bar'>YOLO-based lesion localization with structured LLM-assisted explanation for research workflows.</div>"
)
with gr.Row():
with gr.Column():
image_in = gr.Image(type="pil", label="Upload MRI Image", elem_id="img-in")
status = gr.Markdown("Initializing inference pipeline...", elem_id="status-bar")
with gr.Column():
image_out = gr.Image(type="pil", label="Annotated MRI Output", elem_id="img-out")
tumor_type = gr.Textbox(label="Predicted Finding", interactive=False)
confidence = gr.Number(label="Detection Confidence", interactive=False)
explanation = gr.Textbox(label="Structured Interpretation Summary", lines=6, interactive=False)
demo.load(set_ready, None, status)
analyze_btn = gr.Button("Run Inference", elem_id="analyze-btn", interactive=True)
analyze_btn.click(
analyze,
inputs=[image_in],
outputs=[status, image_out, tumor_type, confidence, explanation],
)
gr.Markdown("<div class='status-bar'>For research use only. Not for clinical diagnosis.</div>")
if __name__ == "__main__":
launch_kwargs = {
"theme": gr.themes.Base(),
"css": custom_css,
"show_error": True,
}
if os.getenv("SPACE_ID"):
launch_kwargs["server_name"] = "0.0.0.0"
port = os.getenv("PORT")
if port:
launch_kwargs["server_port"] = int(port)
demo.launch(**launch_kwargs)