amine5970 commited on
Commit
a00909b
·
verified ·
1 Parent(s): e25198b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import cv2
4
+ import json
5
+ import time
6
+ import math
7
+ import base64
8
+ import queue
9
+ import shutil
10
+ import numpy as np
11
+ import requests
12
+ import onnxruntime as ort
13
+ from PIL import Image
14
+ import gradio as gr
15
+
16
+ # Configs
17
+ MODEL_URL = "https://github.com/mdciri/YOLOv7-Bone-Fracture-Detection/releases/download/trained-models/yolov7-p6-bonefracture.onnx"
18
+ MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
19
+ MODEL_PATH = os.path.join(MODEL_DIR, "yolov7-p6-bonefracture.onnx")
20
+ INPUT_SIZE = 640 # yolov7-p6 typical size
21
+ CONF_THRES_DEFAULT = 0.25
22
+ IOU_THRES_DEFAULT = 0.45
23
+
24
+ # Classes from GRAZPEDWRI-DX training
25
+ CLASSES = [
26
+ "boneanomaly",
27
+ "bonelesion",
28
+ "foreignbody",
29
+ "fracture",
30
+ "metal",
31
+ "periostealreaction",
32
+ "pronatorsign",
33
+ "softtissue",
34
+ "text",
35
+ ]
36
+
37
+ _session = None
38
+ _input_name = None
39
+ _output_name = None
40
+
41
+
42
+ def ensure_model_available():
43
+ os.makedirs(MODEL_DIR, exist_ok=True)
44
+ if not os.path.exists(MODEL_PATH):
45
+ try:
46
+ with requests.get(MODEL_URL, stream=True, timeout=120) as r:
47
+ r.raise_for_status()
48
+ tmp_path = MODEL_PATH + ".downloading"
49
+ with open(tmp_path, "wb") as f:
50
+ for chunk in r.iter_content(chunk_size=1 << 20):
51
+ if chunk:
52
+ f.write(chunk)
53
+ os.replace(tmp_path, MODEL_PATH)
54
+ except Exception as e:
55
+ raise RuntimeError(
56
+ "Téléchargement du modèle échoué. Activez Internet dans les paramètres du Space ou réessayez plus tard. Détails: "
57
+ + str(e)
58
+ )
59
+
60
+
61
+ def load_session():
62
+ global _session, _input_name, _output_name
63
+ if _session is None:
64
+ ensure_model_available()
65
+ providers = ["CPUExecutionProvider"]
66
+ _session = ort.InferenceSession(MODEL_PATH, providers=providers)
67
+ _input_name = _session.get_inputs()[0].name
68
+ _output_name = _session.get_outputs()[0].name
69
+ return _session
70
+
71
+
72
+ def ensure_rgb(image: np.ndarray) -> np.ndarray:
73
+ """Ensure input image is 3-channel RGB."""
74
+ if image is None:
75
+ return image
76
+ if image.ndim == 2:
77
+ # Grayscale -> RGB
78
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
79
+ if image.ndim == 3 and image.shape[2] == 4:
80
+ # RGBA -> RGB
81
+ return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
82
+ return image
83
+
84
+
85
+ def letterbox(im, new_shape=(INPUT_SIZE, INPUT_SIZE), color=(114, 114, 114)):
86
+ shape = im.shape[:2] # h, w
87
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
88
+ nh, nw = int(round(shape[0] * r)), int(round(shape[1] * r))
89
+ im_resized = cv2.resize(im, (nw, nh), interpolation=cv2.INTER_LINEAR)
90
+ top = (new_shape[0] - nh) // 2
91
+ bottom = new_shape[0] - nh - top
92
+ left = (new_shape[1] - nw) // 2
93
+ right = new_shape[1] - nw - left
94
+ im_padded = cv2.copyMakeBorder(im_resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
95
+ return im_padded, r, (left, top)
96
+
97
+
98
+ def xywh2xyxy(x):
99
+ y = x.copy()
100
+ y[:, 0] = x[:, 0] - x[:, 2] / 2
101
+ y[:, 1] = x[:, 1] - x[:, 3] / 2
102
+ y[:, 2] = x[:, 0] + x[:, 2] / 2
103
+ y[:, 3] = x[:, 1] + x[:, 3] / 2
104
+ return y
105
+
106
+
107
+ def nms(boxes, scores, iou_thres=0.45):
108
+ idxs = scores.argsort()[::-1]
109
+ keep = []
110
+ while idxs.size > 0:
111
+ i = idxs[0]
112
+ keep.append(i)
113
+ if idxs.size == 1:
114
+ break
115
+ ious = iou(boxes[i], boxes[idxs[1:]])
116
+ idxs = idxs[1:][ious < iou_thres]
117
+ return keep
118
+
119
+
120
+ def iou(box, boxes):
121
+ x1 = np.maximum(box[0], boxes[:, 0])
122
+ y1 = np.maximum(box[1], boxes[:, 1])
123
+ x2 = np.minimum(box[2], boxes[:, 2])
124
+ y2 = np.minimum(box[3], boxes[:, 3])
125
+ inter = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
126
+ area1 = (box[2] - box[0]) * (box[3] - box[1])
127
+ area2 = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
128
+ union = area1 + area2 - inter + 1e-16
129
+ return inter / union
130
+
131
+
132
+ def scale_boxes(boxes, gain, pad):
133
+ boxes[:, [0, 2]] -= pad[0]
134
+ boxes[:, [1, 3]] -= pad[1]
135
+ boxes[:, :4] /= gain
136
+ return boxes
137
+
138
+
139
+ def infer_yolov7(image_rgb, conf_thres=0.25, iou_thres=0.45, only_fracture=True):
140
+ h0, w0 = image_rgb.shape[:2]
141
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
142
+ # ONNX model expects 640x640 input as per reference script
143
+ img = cv2.resize(image_bgr, (INPUT_SIZE, INPUT_SIZE), interpolation=cv2.INTER_LINEAR)
144
+ img = img.astype(np.float32) / 255.0
145
+ img = np.transpose(img, (2, 0, 1))
146
+ img = np.expand_dims(img, 0)
147
+
148
+ session = load_session()
149
+ pred = session.run([_output_name], {_input_name: img})[0]
150
+ if pred.ndim == 3:
151
+ pred = pred[0]
152
+ # pred expected shape: [N, 6] -> [x1, y1, x2, y2, score, label]
153
+ if pred.size == 0:
154
+ return []
155
+ boxes_xyxy = pred[:, 0:4].astype(np.float32)
156
+ scores = pred[:, 4].astype(np.float32)
157
+ labels = pred[:, 5].astype(np.int32)
158
+
159
+ # confidence filtering
160
+ mask = scores >= conf_thres
161
+ boxes_xyxy = boxes_xyxy[mask]
162
+ scores = scores[mask]
163
+ labels = labels[mask]
164
+
165
+ if boxes_xyxy.shape[0] == 0:
166
+ return []
167
+
168
+ # scale boxes back from 640x640 to original size
169
+ sx = w0 / float(INPUT_SIZE)
170
+ sy = h0 / float(INPUT_SIZE)
171
+ boxes_xyxy[:, [0, 2]] *= sx
172
+ boxes_xyxy[:, [1, 3]] *= sy
173
+
174
+ dets = []
175
+ for b, c, s in zip(boxes_xyxy, labels, scores):
176
+ x1, y1, x2, y2 = b.tolist()
177
+ x1 = max(0, min(w0 - 1, x1))
178
+ y1 = max(0, min(h0 - 1, y1))
179
+ x2 = max(0, min(w0 - 1, x2))
180
+ y2 = max(0, min(h0 - 1, y2))
181
+ name = CLASSES[c] if 0 <= c < len(CLASSES) else str(int(c))
182
+ if only_fracture and name != "fracture":
183
+ continue
184
+ dets.append({
185
+ "box": [float(x1), float(y1), float(x2), float(y2)],
186
+ "score": float(s),
187
+ "class_id": int(c),
188
+ "class_name": name,
189
+ })
190
+ return dets
191
+
192
+
193
+ def draw_detections(image_rgb, dets):
194
+ img = image_rgb.copy()
195
+ for d in dets:
196
+ x1, y1, x2, y2 = map(int, d["box"])
197
+ name = d["class_name"]
198
+ score = d["score"]
199
+ color = (255, 0, 0) if name == "fracture" else (0, 150, 255)
200
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
201
+ label = f"{name}:{score:.2f}"
202
+ (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)
203
+ y1_text = max(0, y1 - 8)
204
+ cv2.rectangle(img, (x1, y1_text - th - 6), (x1 + tw + 6, y1_text + 2), color, -1)
205
+ cv2.putText(img, label, (x1 + 3, y1_text), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
206
+ return img
207
+
208
+
209
+ def predict(image, region, conf_thres, iou_thres, show_non_fracture):
210
+ if image is None:
211
+ return None, json.dumps({"error": "Aucune image fournie."}, ensure_ascii=False, indent=2)
212
+
213
+ # Normalize channels to RGB
214
+ image = ensure_rgb(image)
215
+
216
+ only_fracture = not show_non_fracture
217
+
218
+ start = time.time()
219
+ try:
220
+ dets = infer_yolov7(image, conf_thres=conf_thres, iou_thres=iou_thres, only_fracture=only_fracture)
221
+ except Exception as e:
222
+ msg = str(e)
223
+ return None, json.dumps({"error": msg}, ensure_ascii=False, indent=2)
224
+ elapsed = time.time() - start
225
+
226
+ annotated = draw_detections(image, dets)
227
+ resp = {
228
+ "region": region,
229
+ "detections": dets,
230
+ "count": len(dets),
231
+ "time_s": round(elapsed, 3),
232
+ "note": "Modèle entraîné sur le poignet (GRAZPEDWRI-DX). Les autres régions sont exploratoires.",
233
+ "medical_warning": "Cet outil n’est pas un dispositif médical. Il ne remplace pas l’avis d’un(e) radiologue/médecin.",
234
+ }
235
+ return annotated, json.dumps(resp, ensure_ascii=False, indent=2)
236
+
237
+
238
+ def build_ui():
239
+ with gr.Blocks(title="Détection de fracture (Radiographie)") as demo:
240
+ gr.Markdown("""
241
+ # Détection de fracture (Radiographie) — Prototype
242
+ - Interface en français, fonctionnement 100% en ligne.
243
+ - Téléversez une radiographie, puis lancez l’analyse.
244
+ - Modèle détection (boîtes) entraîné sur le poignet; autres régions = usage exploratoire.
245
+ - N’est pas un dispositif médical.
246
+ """)
247
+
248
+ with gr.Row():
249
+ with gr.Column(scale=2):
250
+ inp = gr.Image(type="numpy", label="Téléverser une radiographie")
251
+ with gr.Column(scale=1):
252
+ region = gr.Dropdown(
253
+ choices=[
254
+ "Poignet (modèle entraîné)",
255
+ "Autre (exploratoire)",
256
+ ],
257
+ value="Poignet (modèle entraîné)",
258
+ label="Région anatomique",
259
+ )
260
+ conf = gr.Slider(0.05, 0.9, value=CONF_THRES_DEFAULT, step=0.01, label="Seuil de confiance")
261
+ iou = gr.Slider(0.1, 0.9, value=IOU_THRES_DEFAULT, step=0.01, label="Seuil NMS (IoU)")
262
+ show_non_frac = gr.Checkbox(False, label="Afficher aussi les autres classes (non-fracture)")
263
+ btn = gr.Button("Analyser", variant="primary")
264
+
265
+ with gr.Row():
266
+ out_img = gr.Image(type="numpy", label="Résultat annoté")
267
+ out_json = gr.Code(language="json", label="Détails des détections")
268
+
269
+ btn.click(predict, inputs=[inp, region, conf, iou, show_non_frac], outputs=[out_img, out_json])
270
+
271
+ gr.Markdown("""
272
+ ### Avertissement
273
+ Cet outil sert d’aide et ne remplace pas un avis médical professionnel.
274
+ """)
275
+
276
+ return demo
277
+
278
+
279
+ demo = build_ui()
280
+
281
+ if __name__ == "__main__":
282
+ demo.launch()