Nadun102 commited on
Commit
f94528a
·
verified ·
1 Parent(s): 64fb345

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -32
app.py CHANGED
@@ -4,12 +4,12 @@ from transformers import Owlv2Processor, Owlv2ForObjectDetection
4
  import spaces
5
 
6
  # --------------------------
7
- # Device setup
8
  # --------------------------
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # --------------------------
12
- # Load OWLv2 model
13
  # --------------------------
14
  model = Owlv2ForObjectDetection.from_pretrained(
15
  "google/owlv2-base-patch16-ensemble"
@@ -20,35 +20,31 @@ processor = Owlv2Processor.from_pretrained(
20
  )
21
 
22
  # --------------------------
23
- # Detection function
24
  # --------------------------
25
  @spaces.GPU
26
  def query_image(img, text_queries, score_threshold):
27
 
28
- # Convert query string to list
29
  text_queries = [q.strip() for q in text_queries.split(",")]
30
 
31
- # FIX: Use actual image size
32
  h, w = img.shape[:2]
33
  target_sizes = torch.tensor([[h, w]])
34
 
35
- # Prepare inputs
36
  inputs = processor(
37
  text=text_queries,
38
  images=img,
39
  return_tensors="pt"
40
  ).to(device)
41
 
42
- # Run model
43
  with torch.no_grad():
44
  outputs = model(**inputs)
45
 
46
- # Move outputs to CPU
47
  outputs.logits = outputs.logits.cpu()
48
  outputs.pred_boxes = outputs.pred_boxes.cpu()
49
 
50
- # Post-process predictions
51
- results = processor.post_process_object_detection(
52
  outputs=outputs,
53
  target_sizes=target_sizes
54
  )
@@ -67,7 +63,7 @@ def query_image(img, text_queries, score_threshold):
67
  x1, y1, x2, y2 = box.tolist()
68
 
69
  detections.append({
70
- "box": [round(x1, 2), round(y1, 2), round(x2, 2), round(y2, 2)],
71
  "label": text_queries[label.item()],
72
  "score": round(float(score), 3)
73
  })
@@ -76,35 +72,22 @@ def query_image(img, text_queries, score_threshold):
76
 
77
 
78
  # --------------------------
79
- # Gradio UI
80
  # --------------------------
81
  demo = gr.Interface(
82
  fn=query_image,
83
  inputs=[
84
- gr.Image(type="numpy", label="Upload Image"),
85
- gr.Textbox(
86
- label="Enter objects (comma separated)",
87
- value="person, car, dog"
88
- ),
89
- gr.Slider(
90
- minimum=0,
91
- maximum=1,
92
- value=0.2,
93
- step=0.01,
94
- label="Score Threshold"
95
- )
96
  ],
97
- outputs=gr.AnnotatedImage(label="Detection Results"),
98
- title="OWLv2 Zero-Shot Object Detection",
99
- description=(
100
- "Upload an image and type objects to detect.\n\n"
101
- "Example: 'person, car, dog'\n\n"
102
- "Tip: Use natural phrases like 'photo of a car' for better results."
103
- )
104
  )
105
 
106
  # --------------------------
107
- # Run app
108
  # --------------------------
109
  if __name__ == "__main__":
110
  demo.launch()
 
4
  import spaces
5
 
6
  # --------------------------
7
+ # Device
8
  # --------------------------
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # --------------------------
12
+ # Load model
13
  # --------------------------
14
  model = Owlv2ForObjectDetection.from_pretrained(
15
  "google/owlv2-base-patch16-ensemble"
 
20
  )
21
 
22
  # --------------------------
23
+ # Detection
24
  # --------------------------
25
  @spaces.GPU
26
  def query_image(img, text_queries, score_threshold):
27
 
 
28
  text_queries = [q.strip() for q in text_queries.split(",")]
29
 
30
+ # Correct size
31
  h, w = img.shape[:2]
32
  target_sizes = torch.tensor([[h, w]])
33
 
 
34
  inputs = processor(
35
  text=text_queries,
36
  images=img,
37
  return_tensors="pt"
38
  ).to(device)
39
 
 
40
  with torch.no_grad():
41
  outputs = model(**inputs)
42
 
 
43
  outputs.logits = outputs.logits.cpu()
44
  outputs.pred_boxes = outputs.pred_boxes.cpu()
45
 
46
+ # FIXED FUNCTION NAME
47
+ results = processor.post_process_grounded_object_detection(
48
  outputs=outputs,
49
  target_sizes=target_sizes
50
  )
 
63
  x1, y1, x2, y2 = box.tolist()
64
 
65
  detections.append({
66
+ "box": [round(x1,2), round(y1,2), round(x2,2), round(y2,2)],
67
  "label": text_queries[label.item()],
68
  "score": round(float(score), 3)
69
  })
 
72
 
73
 
74
  # --------------------------
75
+ # UI
76
  # --------------------------
77
  demo = gr.Interface(
78
  fn=query_image,
79
  inputs=[
80
+ gr.Image(type="numpy"),
81
+ gr.Textbox(value="person, car, dog"),
82
+ gr.Slider(0, 1, value=0.2)
 
 
 
 
 
 
 
 
 
83
  ],
84
+ outputs=gr.AnnotatedImage(),
85
+ title="OWLv2 Detection",
86
+ description="Enter objects like: person, car, dog"
 
 
 
 
87
  )
88
 
89
  # --------------------------
90
+ # Run
91
  # --------------------------
92
  if __name__ == "__main__":
93
  demo.launch()