|
|
import os |
|
|
import io |
|
|
import cv2 |
|
|
import json |
|
|
import time |
|
|
import math |
|
|
import base64 |
|
|
import queue |
|
|
import shutil |
|
|
import numpy as np |
|
|
import requests |
|
|
import onnxruntime as ort |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
MODEL_URL = "https://github.com/mdciri/YOLOv7-Bone-Fracture-Detection/releases/download/trained-models/yolov7-p6-bonefracture.onnx" |
|
|
MODEL_DIR = os.path.join(os.path.dirname(__file__), "models") |
|
|
MODEL_PATH = os.path.join(MODEL_DIR, "yolov7-p6-bonefracture.onnx") |
|
|
INPUT_SIZE = 640 |
|
|
CONF_THRES_DEFAULT = 0.25 |
|
|
IOU_THRES_DEFAULT = 0.45 |
|
|
|
|
|
|
|
|
CLASSES = [ |
|
|
"boneanomaly", |
|
|
"bonelesion", |
|
|
"foreignbody", |
|
|
"fracture", |
|
|
"metal", |
|
|
"periostealreaction", |
|
|
"pronatorsign", |
|
|
"softtissue", |
|
|
"text", |
|
|
] |
|
|
|
|
|
_session = None |
|
|
_input_name = None |
|
|
_output_name = None |
|
|
|
|
|
|
|
|
def ensure_model_available(): |
|
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
try: |
|
|
with requests.get(MODEL_URL, stream=True, timeout=120) as r: |
|
|
r.raise_for_status() |
|
|
tmp_path = MODEL_PATH + ".downloading" |
|
|
with open(tmp_path, "wb") as f: |
|
|
for chunk in r.iter_content(chunk_size=1 << 20): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
os.replace(tmp_path, MODEL_PATH) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Téléchargement du modèle échoué. Activez Internet dans les paramètres du Space ou réessayez plus tard. Détails: " |
|
|
+ str(e) |
|
|
) |
|
|
|
|
|
|
|
|
def load_session(): |
|
|
global _session, _input_name, _output_name |
|
|
if _session is None: |
|
|
ensure_model_available() |
|
|
providers = ["CPUExecutionProvider"] |
|
|
_session = ort.InferenceSession(MODEL_PATH, providers=providers) |
|
|
_input_name = _session.get_inputs()[0].name |
|
|
_output_name = _session.get_outputs()[0].name |
|
|
return _session |
|
|
|
|
|
|
|
|
def ensure_rgb(image: np.ndarray) -> np.ndarray: |
|
|
"""Ensure input image is 3-channel RGB.""" |
|
|
if image is None: |
|
|
return image |
|
|
if image.ndim == 2: |
|
|
|
|
|
return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
|
if image.ndim == 3 and image.shape[2] == 4: |
|
|
|
|
|
return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
|
|
return image |
|
|
|
|
|
|
|
|
def letterbox(im, new_shape=(INPUT_SIZE, INPUT_SIZE), color=(114, 114, 114)): |
|
|
shape = im.shape[:2] |
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) |
|
|
nh, nw = int(round(shape[0] * r)), int(round(shape[1] * r)) |
|
|
im_resized = cv2.resize(im, (nw, nh), interpolation=cv2.INTER_LINEAR) |
|
|
top = (new_shape[0] - nh) // 2 |
|
|
bottom = new_shape[0] - nh - top |
|
|
left = (new_shape[1] - nw) // 2 |
|
|
right = new_shape[1] - nw - left |
|
|
im_padded = cv2.copyMakeBorder(im_resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) |
|
|
return im_padded, r, (left, top) |
|
|
|
|
|
|
|
|
def xywh2xyxy(x): |
|
|
y = x.copy() |
|
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 |
|
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 |
|
|
y[:, 2] = x[:, 0] + x[:, 2] / 2 |
|
|
y[:, 3] = x[:, 1] + x[:, 3] / 2 |
|
|
return y |
|
|
|
|
|
|
|
|
def nms(boxes, scores, iou_thres=0.45): |
|
|
idxs = scores.argsort()[::-1] |
|
|
keep = [] |
|
|
while idxs.size > 0: |
|
|
i = idxs[0] |
|
|
keep.append(i) |
|
|
if idxs.size == 1: |
|
|
break |
|
|
ious = iou(boxes[i], boxes[idxs[1:]]) |
|
|
idxs = idxs[1:][ious < iou_thres] |
|
|
return keep |
|
|
|
|
|
|
|
|
def iou(box, boxes): |
|
|
x1 = np.maximum(box[0], boxes[:, 0]) |
|
|
y1 = np.maximum(box[1], boxes[:, 1]) |
|
|
x2 = np.minimum(box[2], boxes[:, 2]) |
|
|
y2 = np.minimum(box[3], boxes[:, 3]) |
|
|
inter = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1) |
|
|
area1 = (box[2] - box[0]) * (box[3] - box[1]) |
|
|
area2 = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
|
|
union = area1 + area2 - inter + 1e-16 |
|
|
return inter / union |
|
|
|
|
|
|
|
|
def scale_boxes(boxes, gain, pad): |
|
|
boxes[:, [0, 2]] -= pad[0] |
|
|
boxes[:, [1, 3]] -= pad[1] |
|
|
boxes[:, :4] /= gain |
|
|
return boxes |
|
|
|
|
|
|
|
|
def infer_yolov7(image_rgb, conf_thres=0.25, iou_thres=0.45, only_fracture=True): |
|
|
h0, w0 = image_rgb.shape[:2] |
|
|
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
img = cv2.resize(image_bgr, (INPUT_SIZE, INPUT_SIZE), interpolation=cv2.INTER_LINEAR) |
|
|
img = img.astype(np.float32) / 255.0 |
|
|
img = np.transpose(img, (2, 0, 1)) |
|
|
img = np.expand_dims(img, 0) |
|
|
|
|
|
session = load_session() |
|
|
pred = session.run([_output_name], {_input_name: img})[0] |
|
|
if pred.ndim == 3: |
|
|
pred = pred[0] |
|
|
|
|
|
if pred.size == 0: |
|
|
return [] |
|
|
boxes_xyxy = pred[:, 0:4].astype(np.float32) |
|
|
scores = pred[:, 4].astype(np.float32) |
|
|
labels = pred[:, 5].astype(np.int32) |
|
|
|
|
|
|
|
|
mask = scores >= conf_thres |
|
|
boxes_xyxy = boxes_xyxy[mask] |
|
|
scores = scores[mask] |
|
|
labels = labels[mask] |
|
|
|
|
|
if boxes_xyxy.shape[0] == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
sx = w0 / float(INPUT_SIZE) |
|
|
sy = h0 / float(INPUT_SIZE) |
|
|
boxes_xyxy[:, [0, 2]] *= sx |
|
|
boxes_xyxy[:, [1, 3]] *= sy |
|
|
|
|
|
dets = [] |
|
|
for b, c, s in zip(boxes_xyxy, labels, scores): |
|
|
x1, y1, x2, y2 = b.tolist() |
|
|
x1 = max(0, min(w0 - 1, x1)) |
|
|
y1 = max(0, min(h0 - 1, y1)) |
|
|
x2 = max(0, min(w0 - 1, x2)) |
|
|
y2 = max(0, min(h0 - 1, y2)) |
|
|
name = CLASSES[c] if 0 <= c < len(CLASSES) else str(int(c)) |
|
|
if only_fracture and name != "fracture": |
|
|
continue |
|
|
dets.append({ |
|
|
"box": [float(x1), float(y1), float(x2), float(y2)], |
|
|
"score": float(s), |
|
|
"class_id": int(c), |
|
|
"class_name": name, |
|
|
}) |
|
|
return dets |
|
|
|
|
|
|
|
|
def draw_detections(image_rgb, dets): |
|
|
img = image_rgb.copy() |
|
|
for d in dets: |
|
|
x1, y1, x2, y2 = map(int, d["box"]) |
|
|
name = d["class_name"] |
|
|
score = d["score"] |
|
|
color = (255, 0, 0) if name == "fracture" else (0, 150, 255) |
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3) |
|
|
label = f"{name}:{score:.2f}" |
|
|
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) |
|
|
y1_text = max(0, y1 - 8) |
|
|
cv2.rectangle(img, (x1, y1_text - th - 6), (x1 + tw + 6, y1_text + 2), color, -1) |
|
|
cv2.putText(img, label, (x1 + 3, y1_text), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) |
|
|
return img |
|
|
|
|
|
|
|
|
def predict(image, region, conf_thres, iou_thres, show_non_fracture): |
|
|
if image is None: |
|
|
return None, json.dumps({"error": "Aucune image fournie."}, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
image = ensure_rgb(image) |
|
|
|
|
|
only_fracture = not show_non_fracture |
|
|
|
|
|
start = time.time() |
|
|
try: |
|
|
dets = infer_yolov7(image, conf_thres=conf_thres, iou_thres=iou_thres, only_fracture=only_fracture) |
|
|
except Exception as e: |
|
|
msg = str(e) |
|
|
return None, json.dumps({"error": msg}, ensure_ascii=False, indent=2) |
|
|
elapsed = time.time() - start |
|
|
|
|
|
annotated = draw_detections(image, dets) |
|
|
resp = { |
|
|
"region": region, |
|
|
"detections": dets, |
|
|
"count": len(dets), |
|
|
"time_s": round(elapsed, 3), |
|
|
"note": "Modèle entraîné sur le poignet (GRAZPEDWRI-DX). Les autres régions sont exploratoires.", |
|
|
"medical_warning": "Cet outil n’est pas un dispositif médical. Il ne remplace pas l’avis d’un(e) radiologue/médecin.", |
|
|
} |
|
|
return annotated, json.dumps(resp, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
def build_ui(): |
|
|
with gr.Blocks(title="Détection de fracture (Radiographie)") as demo: |
|
|
gr.Markdown(""" |
|
|
# Détection de fracture (Radiographie) — Prototype |
|
|
- Interface en français, fonctionnement 100% en ligne. |
|
|
- Téléversez une radiographie, puis lancez l’analyse. |
|
|
- Modèle détection (boîtes) entraîné sur le poignet; autres régions = usage exploratoire. |
|
|
- N’est pas un dispositif médical. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
inp = gr.Image(type="numpy", label="Téléverser une radiographie") |
|
|
with gr.Column(scale=1): |
|
|
region = gr.Dropdown( |
|
|
choices=[ |
|
|
"Poignet (modèle entraîné)", |
|
|
"Autre (exploratoire)", |
|
|
], |
|
|
value="Poignet (modèle entraîné)", |
|
|
label="Région anatomique", |
|
|
) |
|
|
conf = gr.Slider(0.05, 0.9, value=CONF_THRES_DEFAULT, step=0.01, label="Seuil de confiance") |
|
|
iou = gr.Slider(0.1, 0.9, value=IOU_THRES_DEFAULT, step=0.01, label="Seuil NMS (IoU)") |
|
|
show_non_frac = gr.Checkbox(False, label="Afficher aussi les autres classes (non-fracture)") |
|
|
btn = gr.Button("Analyser", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
out_img = gr.Image(type="numpy", label="Résultat annoté") |
|
|
out_json = gr.Code(language="json", label="Détails des détections") |
|
|
|
|
|
btn.click(predict, inputs=[inp, region, conf, iou, show_non_frac], outputs=[out_img, out_json]) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Avertissement |
|
|
Cet outil sert d’aide et ne remplace pas un avis médical professionnel. |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = build_ui() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|