magomerob commited on
Commit
301fa4e
·
verified ·
1 Parent(s): 061ddbb

Upload 2 files

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +117 -0
  3. raccoon-101.jpg +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ raccoon-101.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_fastai
2
+ import gradio as gr
3
+ from fastai.vision.all import *
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+
7
+ repo_id = "magomerob/yolo_finetuned_raccoons"
8
+ learner = from_pretrained_fastai(repo_id)
9
+ labels = learner.dls.vocab
10
+
11
+
12
+ def _to_pil(x):
13
+ if isinstance(x, Image.Image):
14
+ return x.convert("RGB")
15
+ if isinstance(x, np.ndarray):
16
+ return Image.fromarray(x).convert("RGB")
17
+ return PILImage.create(x).convert("RGB")
18
+
19
+
20
+ def _get_detections(pred, labels_vocab):
21
+ """
22
+ Returns: boxes_xyxy (list of [x1,y1,x2,y2]), names (list[str]), scores (list[float] or None)
23
+
24
+ Supports a few common formats:
25
+ - dict with keys like boxes/labels/scores
26
+ - tuple/list like (boxes, labels, scores)
27
+ - fastai-like (pred, pred_idx, probs) where pred holds detections
28
+ """
29
+ # fastai often returns (pred, pred_idx, probs)
30
+ if isinstance(pred, (tuple, list)) and len(pred) == 3:
31
+ pred = pred[0]
32
+
33
+ # dict-like output
34
+ if isinstance(pred, dict):
35
+ boxes = pred.get("boxes") or pred.get("bboxes") or pred.get("bbox")
36
+ lab = pred.get("labels") or pred.get("classes") or pred.get("label_ids")
37
+ scores = pred.get("scores") or pred.get("confs") or pred.get("confidences")
38
+ if boxes is None or lab is None:
39
+ raise ValueError(f"Unsupported dict output keys: {list(pred.keys())}")
40
+ boxes = np.asarray(boxes).tolist()
41
+ lab = np.asarray(lab).tolist()
42
+ scores = None if scores is None else np.asarray(scores).tolist()
43
+
44
+ # tuple/list like (boxes, labels, scores?)
45
+ elif isinstance(pred, (tuple, list)) and len(pred) >= 2:
46
+ boxes = np.asarray(pred[0]).tolist()
47
+ lab = np.asarray(pred[1]).tolist()
48
+ scores = None
49
+ if len(pred) >= 3 and pred[2] is not None:
50
+ try:
51
+ scores = np.asarray(pred[2]).tolist()
52
+ except Exception:
53
+ scores = None
54
+ else:
55
+ raise ValueError(f"Unsupported prediction type: {type(pred)}")
56
+
57
+ names = []
58
+ for x in lab:
59
+ try:
60
+ xi = int(x)
61
+ names.append(str(labels_vocab[xi]) if xi < len(labels_vocab) else str(xi))
62
+ except Exception:
63
+ names.append(str(x))
64
+
65
+ return boxes, names, scores
66
+
67
+
68
+ def show_preds(input_image, display_label=True, display_bbox=True, detection_threshold=0.5):
69
+ if detection_threshold == 0:
70
+ detection_threshold = 0.5
71
+
72
+ img = _to_pil(input_image)
73
+ pred = learner.predict(img)
74
+
75
+ boxes, names, scores = _get_detections(pred, labels)
76
+
77
+ draw = ImageDraw.Draw(img)
78
+ try:
79
+ font = ImageFont.truetype("DejaVuSans.ttf", 16)
80
+ except Exception:
81
+ font = ImageFont.load_default()
82
+
83
+ for i, box in enumerate(boxes):
84
+ score = None if scores is None or i >= len(scores) else float(scores[i])
85
+ if score is not None and score < detection_threshold:
86
+ continue
87
+
88
+ x1, y1, x2, y2 = [int(round(v)) for v in box]
89
+
90
+ if display_bbox:
91
+ draw.rectangle([x1, y1, x2, y2], width=3)
92
+
93
+ if display_label:
94
+ label = names[i]
95
+ text = label + (f" {score:.2f}" if score is not None else "")
96
+ # small filled background for readability
97
+ tw, th = draw.textbbox((0, 0), text, font=font)[2:]
98
+ pad = 3
99
+ draw.rectangle([x1, max(0, y1 - th - 2 * pad), x1 + tw + 2 * pad, y1], fill="black")
100
+ draw.text((x1 + pad, max(0, y1 - th - pad)), text, font=font, fill="white")
101
+
102
+ return img
103
+
104
+
105
+ display_chkbox_label = gr.Checkbox(label="Label", value=True)
106
+ display_chkbox_box = gr.Checkbox(label="Box", value=True)
107
+ detection_threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Detection Threshold")
108
+
109
+ demo = gr.Interface(
110
+ fn=show_preds,
111
+ inputs=[gr.Image(type="numpy"), display_chkbox_label, display_chkbox_box, detection_threshold_slider],
112
+ outputs=gr.Image(type="pil"),
113
+ examples=[["raccoon-101.jpg", True, True, 0.5]],
114
+ )
115
+
116
+ if __name__ == "__main__":
117
+ demo.launch()
raccoon-101.jpg ADDED

Git LFS Details

  • SHA256: 96de14aec2c74be4c5544c32a7dfc9902ad230d362afbce6b624fa44f8591aeb
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB