Nadun102 commited on
Commit
10f0dc2
·
verified ·
1 Parent(s): cff8b1e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
4
+ import spaces
5
+
6
+ # --------------------------
7
+ # Device setup
8
+ # --------------------------
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # --------------------------
12
+ # Load OWLv2 model
13
+ # --------------------------
14
+ model = Owlv2ForObjectDetection.from_pretrained(
15
+ "google/owlv2-base-patch16-ensemble"
16
+ ).to(device)
17
+
18
+ processor = Owlv2Processor.from_pretrained(
19
+ "google/owlv2-base-patch16-ensemble"
20
+ )
21
+
22
+ # --------------------------
23
+ # Detection function
24
+ # --------------------------
25
+ @spaces.GPU
26
+ def query_image(img, text_queries, score_threshold):
27
+
28
+ # Convert query string to list
29
+ text_queries = [q.strip() for q in text_queries.split(",")]
30
+
31
+ # ✅ FIX: Use actual image size
32
+ h, w = img.shape[:2]
33
+ target_sizes = torch.tensor([[h, w]])
34
+
35
+ # Prepare inputs
36
+ inputs = processor(
37
+ text=text_queries,
38
+ images=img,
39
+ return_tensors="pt"
40
+ ).to(device)
41
+
42
+ # Run model
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+
46
+ # Move outputs to CPU
47
+ outputs.logits = outputs.logits.cpu()
48
+ outputs.pred_boxes = outputs.pred_boxes.cpu()
49
+
50
+ # Post-process predictions
51
+ results = processor.post_process_object_detection(
52
+ outputs=outputs,
53
+ target_sizes=target_sizes
54
+ )
55
+
56
+ boxes = results[0]["boxes"]
57
+ scores = results[0]["scores"]
58
+ labels = results[0]["labels"]
59
+
60
+ detections = []
61
+
62
+ for box, score, label in zip(boxes, scores, labels):
63
+
64
+ if score < score_threshold:
65
+ continue
66
+
67
+ x1, y1, x2, y2 = box.tolist()
68
+
69
+ detections.append({
70
+ "box": [round(x1, 2), round(y1, 2), round(x2, 2), round(y2, 2)],
71
+ "label": text_queries[label.item()],
72
+ "score": round(float(score), 3)
73
+ })
74
+
75
+ return img, detections
76
+
77
+
78
+ # --------------------------
79
+ # Gradio UI
80
+ # --------------------------
81
+ demo = gr.Interface(
82
+ fn=query_image,
83
+ inputs=[
84
+ gr.Image(type="numpy", label="Upload Image"),
85
+ gr.Textbox(
86
+ label="Enter objects (comma separated)",
87
+ value="person, car, dog"
88
+ ),
89
+ gr.Slider(
90
+ minimum=0,
91
+ maximum=1,
92
+ value=0.2,
93
+ step=0.01,
94
+ label="Score Threshold"
95
+ )
96
+ ],
97
+ outputs=gr.AnnotatedImage(label="Detection Results"),
98
+ title="OWLv2 Zero-Shot Object Detection",
99
+ description=(
100
+ "Upload an image and type objects to detect.\n\n"
101
+ "Example: 'person, car, dog'\n\n"
102
+ "Tip: Use natural phrases like 'photo of a car' for better results."
103
+ )
104
+ )
105
+
106
+ # --------------------------
107
+ # Run app
108
+ # --------------------------
109
+ if __name__ == "__main__":
110
+ demo.launch()