adcelis commited on
Commit
3955743
·
verified ·
1 Parent(s): 8efe00d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from torchvision.ops import nms
5
+ from PIL import Image, ImageDraw
6
+
7
+ from ultralytics import YOLO
8
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # ---- CAMBIA ESTO ----
14
+ YOLO_REPO_ID = "adcelis/TU_REPO_YOLOV8" # repo donde subiste best.pt
15
+ YOLO_FILENAME = "best.pt" # nombre del archivo en el repo
16
+
17
+ DETR_REPO_ID = "adcelis/detr_finetuned_raccoon" # tu repo DETR
18
+ # ----------------------
19
+
20
+ # Si tu repo es privado, crea un secret HF_TOKEN en el Space y descomenta:
21
+ # HF_TOKEN = os.environ.get("HF_TOKEN")
22
+
23
+ yolo_path = hf_hub_download(repo_id=YOLO_REPO_ID, filename=YOLO_FILENAME) # , token=HF_TOKEN
24
+ yolo_model = YOLO(yolo_path)
25
+
26
+ detr_processor = AutoImageProcessor.from_pretrained(DETR_REPO_ID) # , token=HF_TOKEN
27
+ detr_model = AutoModelForObjectDetection.from_pretrained(DETR_REPO_ID).to(DEVICE) # , token=HF_TOKEN
28
+ detr_model.eval()
29
+
30
+
31
+ def yolo_predict(pil_img, conf=0.25):
32
+ res = yolo_model.predict(pil_img, conf=conf, verbose=False)[0]
33
+ boxes = res.boxes.xyxy.cpu()
34
+ scores = res.boxes.conf.cpu()
35
+ labels = res.boxes.cls.cpu().long()
36
+ names = res.names # dict id->label
37
+ return boxes, scores, labels, names
38
+
39
+
40
+ @torch.no_grad()
41
+ def detr_predict(pil_img, conf=0.5):
42
+ inputs = detr_processor(images=[pil_img], return_tensors="pt").to(DEVICE)
43
+ outputs = detr_model(**inputs)
44
+ target_sizes = torch.tensor([[pil_img.size[1], pil_img.size[0]]], device=DEVICE)
45
+ results = detr_processor.post_process_object_detection(outputs, threshold=conf, target_sizes=target_sizes)[0]
46
+ return results["boxes"].cpu(), results["scores"].cpu(), results["labels"].cpu()
47
+
48
+
49
+ def ensemble_union_nms(boxes1, scores1, labels1, boxes2, scores2, labels2,
50
+ w2=0.8, iou_thr=0.5, score_thr=0.25):
51
+ boxes = torch.cat([boxes1, boxes2], dim=0)
52
+ scores = torch.cat([scores1, scores2 * w2], dim=0)
53
+ labels = torch.cat([labels1, labels2], dim=0)
54
+
55
+ keep = scores >= score_thr
56
+ boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
57
+
58
+ if boxes.numel() == 0:
59
+ return boxes, scores, labels
60
+
61
+ keep_all = []
62
+ for cls in labels.unique():
63
+ idx = torch.where(labels == cls)[0]
64
+ k = nms(boxes[idx], scores[idx], iou_thr)
65
+ keep_all.append(idx[k])
66
+ keep_all = torch.cat(keep_all)
67
+ keep_all = keep_all[scores[keep_all].argsort(descending=True)]
68
+ return boxes[keep_all], scores[keep_all], labels[keep_all]
69
+
70
+
71
+ def draw_boxes(pil_img, boxes, scores, labels, names):
72
+ img = pil_img.copy()
73
+ draw = ImageDraw.Draw(img)
74
+ for b, s, l in zip(boxes, scores, labels):
75
+ x1, y1, x2, y2 = [float(x) for x in b.tolist()]
76
+ draw.rectangle((x1, y1, x2, y2), outline="green", width=2)
77
+ label = names.get(int(l), str(int(l)))
78
+ draw.text((x1, y1), f"{label} {float(s):.2f}", fill="black")
79
+ return img
80
+
81
+
82
+ def run(pil_img, yolo_conf, detr_conf, w2, iou_thr, score_thr):
83
+ pil_img = pil_img.convert("RGB")
84
+
85
+ b1, s1, l1, names = yolo_predict(pil_img, conf=yolo_conf)
86
+ b2, s2, l2 = detr_predict(pil_img, conf=detr_conf)
87
+
88
+ be, se, le = ensemble_union_nms(b1, s1, l1, b2, s2, l2,
89
+ w2=w2, iou_thr=iou_thr, score_thr=score_thr)
90
+
91
+ out_img = draw_boxes(pil_img, be, se, le, names)
92
+
93
+ rows = []
94
+ for b, s, l in zip(be, se, le):
95
+ x1, y1, x2, y2 = [round(float(x), 2) for x in b.tolist()]
96
+ rows.append([names.get(int(l), str(int(l))), round(float(s), 3), x1, y1, x2, y2])
97
+
98
+ return out_img, rows
99
+
100
+
101
+ demo = gr.Interface(
102
+ fn=run,
103
+ inputs=[
104
+ gr.Image(type="pil", label="Imagen"),
105
+ gr.Slider(0.05, 0.9, value=0.25, step=0.05, label="YOLO conf"),
106
+ gr.Slider(0.05, 0.9, value=0.5, step=0.05, label="DETR conf"),
107
+ gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Peso DETR (w2)"),
108
+ gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="NMS IoU"),
109
+ gr.Slider(0.05, 0.9, value=0.25, step=0.05, label="Score mínimo (post-ensemble)"),
110
+ ],
111
+ outputs=[
112
+ gr.Image(type="pil", label="Ensemble (NMS)"),
113
+ gr.Dataframe(headers=["label", "score", "x1", "y1", "x2", "y2"], label="Detecciones"),
114
+ ],
115
+ title="Ensemble YOLOv8 + DETR con Non-Maximum Suppression",
116
+ )
117
+
118
+ if __name__ == "__main__":
119
+ demo.launch()