adityaverma977 commited on
Commit
4110b90
·
0 Parent(s):

Prepare Hugging Face Space

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. .gitignore +7 -0
  3. README.md +21 -0
  4. app.py +208 -0
  5. best.pt +3 -0
  6. requirements.txt +7 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ .gradio/
README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Neuro-Oncology MRI Inference Console
3
+ colorFrom: blue
4
+ colorTo: indigo
5
+ sdk: gradio
6
+ python_version: "3.10"
7
+ app_file: app.py
8
+ suggested_hardware: cpu-basic
9
+ ---
10
+
11
+ # Neuro-Oncology MRI Inference Console
12
+
13
+ Gradio app for YOLO-based MRI lesion localization with structured explanation generated through the Groq chat-completions API.
14
+
15
+ ## Required Space secret
16
+
17
+ - `GROQ_API_KEY`
18
+
19
+ ## Optional Space variable
20
+
21
+ - `GROQ_MODEL`
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import requests
5
+ from dotenv import load_dotenv
6
+ from PIL import ImageDraw, ImageFont
7
+ from ultralytics import YOLO
8
+
9
+ YOLO_WEIGHTS = "best.pt"
10
+ GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
11
+
12
+ WINDOWS_XP_COLORS = {
13
+ "bg": "#ece9d8",
14
+ "title": "#0053e1",
15
+ "status": "#f3f3f3",
16
+ "border": "#808080",
17
+ }
18
+
19
+ load_dotenv()
20
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
21
+ GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
22
+
23
+ custom_css = f"""
24
+ body {{ background: {WINDOWS_XP_COLORS["bg"]}; font-family: Tahoma, Verdana, sans-serif; }}
25
+ .gradio-container {{
26
+ border: 2px solid {WINDOWS_XP_COLORS["border"]};
27
+ background: {WINDOWS_XP_COLORS["bg"]};
28
+ border-radius: 6px;
29
+ max-width: 700px;
30
+ margin: 32px auto;
31
+ box-shadow: 0 4px 16px #bbb;
32
+ }}
33
+ .gradio-title {{
34
+ background: {WINDOWS_XP_COLORS["title"]};
35
+ color: #fff;
36
+ padding: 10px 16px;
37
+ font-size: 20px;
38
+ border-top-left-radius: 6px;
39
+ border-top-right-radius: 6px;
40
+ margin-bottom: 0;
41
+ }}
42
+ .status-bar {{
43
+ background: {WINDOWS_XP_COLORS["status"]};
44
+ color: #333;
45
+ padding: 6px 16px;
46
+ font-size: 13px;
47
+ border-bottom-left-radius: 6px;
48
+ border-bottom-right-radius: 6px;
49
+ border-top: 1px solid {WINDOWS_XP_COLORS["border"]};
50
+ margin-top: 0;
51
+ }}
52
+ """
53
+
54
+
55
+ class DetectionModule:
56
+ def __init__(self, weights_path):
57
+ if not os.path.exists(weights_path):
58
+ raise FileNotFoundError(f"YOLO weights not found: {weights_path}")
59
+ self.model = YOLO(weights_path)
60
+
61
+ def run(self, image):
62
+ if image is None:
63
+ return []
64
+ results = self.model(image, verbose=False)
65
+ detections = []
66
+ for result in results:
67
+ names = result.names
68
+ for box in result.boxes:
69
+ cls_idx = int(box.cls.item())
70
+ conf = float(box.conf.item())
71
+ x1, y1, x2, y2 = box.xyxy[0].tolist()
72
+ detections.append(
73
+ {
74
+ "class": names.get(cls_idx, str(cls_idx)),
75
+ "conf": conf,
76
+ "box": [x1, y1, x2, y2],
77
+ }
78
+ )
79
+ return detections
80
+
81
+
82
+ class ExplanationModule:
83
+ def __init__(self, api_key, api_url=GROQ_API_URL):
84
+ self.api_key = api_key
85
+ self.api_url = api_url
86
+
87
+ def generate(self, detections):
88
+ if not self.api_key:
89
+ return "[Groq API key not set. Cannot generate explanation.]"
90
+ if not detections:
91
+ return "No tumor detected with sufficient confidence."
92
+ det_lines = [f"- Tumor type: {d['class']}, Confidence: {d['conf']:.2f}" for d in detections]
93
+ prompt = (
94
+ "You are a medical AI assistant.\n"
95
+ "Input:\n"
96
+ f"Detection count: {len(detections)}\n"
97
+ + "\n".join(det_lines)
98
+ + "\nExplain in simple terms:\n"
99
+ "- What was detected\n"
100
+ "- What confidence means\n"
101
+ "- Avoid medical diagnosis\n"
102
+ "- Add disclaimer\n"
103
+ )
104
+ headers = {
105
+ "Authorization": f"Bearer {self.api_key}",
106
+ "Content-Type": "application/json",
107
+ }
108
+ data = {
109
+ "model": GROQ_MODEL,
110
+ "messages": [{"role": "user", "content": prompt}],
111
+ "max_tokens": 256,
112
+ "temperature": 0.2,
113
+ }
114
+ try:
115
+ response = requests.post(self.api_url, headers=headers, json=data, timeout=10)
116
+ response.raise_for_status()
117
+ payload = response.json()
118
+ return payload["choices"][0]["message"]["content"].strip()
119
+ except Exception as exc:
120
+ return f"[Groq API error: {exc}]"
121
+
122
+
123
+ class VisualizationPipeline:
124
+ def __init__(self):
125
+ self.font = ImageFont.load_default()
126
+ self.box_color = (0, 83, 225)
127
+ self.text_color = (0, 0, 0)
128
+
129
+ def draw(self, image, detections):
130
+ rendered = image.convert("RGB").copy()
131
+ draw = ImageDraw.Draw(rendered)
132
+ for detection in detections:
133
+ x1, y1, x2, y2 = map(int, detection["box"])
134
+ label = f"{detection['class']} ({detection['conf']:.2f})"
135
+ draw.rectangle([x1, y1, x2, y2], outline=self.box_color, width=3)
136
+ draw.text((x1, max(0, y1 - 16)), label, fill=self.text_color, font=self.font)
137
+ return rendered
138
+
139
+
140
+ class InferenceOrchestrator:
141
+ def __init__(self, detection_module, explanation_module, visualization):
142
+ self.detection = detection_module
143
+ self.explanation = explanation_module
144
+ self.visualization = visualization
145
+
146
+ def predict(self, image):
147
+ detections = self.detection.run(image)
148
+ visual = self.visualization.draw(image, detections)
149
+ explanation = self.explanation.generate(detections)
150
+ if detections:
151
+ top = max(detections, key=lambda item: item["conf"])
152
+ return visual, top["class"], top["conf"], explanation
153
+ return visual, "no tumor", 0.0, explanation
154
+
155
+
156
+ detection_module = DetectionModule(YOLO_WEIGHTS)
157
+ explanation_module = ExplanationModule(GROQ_API_KEY)
158
+ visualization = VisualizationPipeline()
159
+ orchestrator = InferenceOrchestrator(detection_module, explanation_module, visualization)
160
+
161
+
162
+ def set_ready():
163
+ return "Ready"
164
+
165
+
166
+ def analyze(image):
167
+ if image is None:
168
+ return "Upload an MRI image to analyze.", None, "", 0.0, ""
169
+ visual, tumor, conf, expl = orchestrator.predict(image)
170
+ return "Analysis complete.", visual, tumor, conf, expl
171
+
172
+
173
+ with gr.Blocks(title="Neuro-Oncology MRI Inference Console") as demo:
174
+ gr.Markdown(
175
+ "<div class='gradio-title'>Neuro-Oncology MRI Inference Console</div>"
176
+ "<div class='status-bar'>YOLO-based lesion localization with structured LLM-assisted explanation for research workflows.</div>"
177
+ )
178
+ with gr.Row():
179
+ with gr.Column():
180
+ image_in = gr.Image(type="pil", label="Upload MRI Image", elem_id="img-in")
181
+ status = gr.Markdown("Initializing inference pipeline...", elem_id="status-bar")
182
+ with gr.Column():
183
+ image_out = gr.Image(type="pil", label="Annotated MRI Output", elem_id="img-out")
184
+ tumor_type = gr.Textbox(label="Predicted Finding", interactive=False)
185
+ confidence = gr.Number(label="Detection Confidence", interactive=False)
186
+ explanation = gr.Textbox(label="Structured Interpretation Summary", lines=6, interactive=False)
187
+ demo.load(set_ready, None, status)
188
+ analyze_btn = gr.Button("Run Inference", elem_id="analyze-btn", interactive=True)
189
+ analyze_btn.click(
190
+ analyze,
191
+ inputs=[image_in],
192
+ outputs=[status, image_out, tumor_type, confidence, explanation],
193
+ )
194
+ gr.Markdown("<div class='status-bar'>For research use only. Not for clinical diagnosis.</div>")
195
+
196
+
197
+ if __name__ == "__main__":
198
+ launch_kwargs = {
199
+ "theme": gr.themes.Base(),
200
+ "css": custom_css,
201
+ "show_error": True,
202
+ }
203
+ if os.getenv("SPACE_ID"):
204
+ launch_kwargs["server_name"] = "0.0.0.0"
205
+ port = os.getenv("PORT")
206
+ if port:
207
+ launch_kwargs["server_port"] = int(port)
208
+ demo.launch(**launch_kwargs)
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bd1d5e3077a576bc6ec7e54e8d4d1949b78cede0340902a1b0155a66adf9f35
3
+ size 207411165
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==6.13.0
2
+ ultralytics==8.4.41
3
+ opencv-python-headless==4.13.0.92
4
+ pillow==10.4.0
5
+ requests==2.32.5
6
+ torch==2.5.1
7
+ python-dotenv==1.0.1