Nadun102 commited on
Commit
83db32b
·
verified ·
1 Parent(s): 4bd4f2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -50
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import torch
2
  import gradio as gr
 
3
  import cv2
4
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
5
 
6
- # Device
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
- # Load model
10
  model = Owlv2ForObjectDetection.from_pretrained(
11
  "google/owlv2-base-patch16-ensemble"
12
  ).to(device)
@@ -15,85 +17,118 @@ processor = Owlv2Processor.from_pretrained(
15
  "google/owlv2-base-patch16-ensemble"
16
  )
17
 
18
- # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # MAIN FUNCTION
20
- # ------------------------------
21
  def query_image(img, text_queries, score_threshold):
22
 
23
- # Convert text input
24
- queries = [q.strip() for q in text_queries.split(",")]
25
 
26
- # Get image size
27
- h, w = img.shape[:2]
28
- target_sizes = torch.tensor([[h, w]])
29
 
30
- # Preprocess
31
  inputs = processor(
32
- text=queries,
33
  images=img,
34
  return_tensors="pt"
35
  ).to(device)
36
 
37
- # Inference
38
  with torch.no_grad():
39
  outputs = model(**inputs)
40
 
41
- # Move to CPU
42
- outputs.logits = outputs.logits.cpu()
43
- outputs.pred_boxes = outputs.pred_boxes.cpu()
44
 
45
- # ✅ CORRECT FUNCTION
46
  results = processor.post_process_grounded_object_detection(
47
  outputs=outputs,
48
- target_sizes=target_sizes,
49
- threshold=score_threshold
50
- )
51
 
52
- boxes = results[0]["boxes"]
53
- scores = results[0]["scores"]
54
- labels = results[0]["labels"]
55
 
56
- annotated_labels = []
57
 
58
  # Draw boxes
59
  for box, score, label in zip(boxes, scores, labels):
 
 
60
 
61
- x1, y1, x2, y2 = [int(i) for i in box.tolist()]
62
-
63
- class_name = queries[label.item()]
64
- confidence = float(score)
65
 
66
- # Label text
67
- text = f"{class_name} ({confidence:.2f})"
 
 
 
 
68
 
69
  # Draw on image
70
  cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)
71
- cv2.putText(img, text, (x1, y1-10),
72
- cv2.FONT_HERSHEY_SIMPLEX, 0.5,
73
- (0,255,0), 2)
74
-
75
- # ✅ IMPORTANT: Only (box, label)
76
- annotated_labels.append((
77
- [x1, y1, x2, y2],
78
- text
79
- ))
80
-
81
- return img, annotated_labels
82
-
83
-
84
- # ------------------------------
85
- # UI
86
- # ------------------------------
87
  demo = gr.Interface(
88
  fn=query_image,
89
  inputs=[
90
  gr.Image(type="numpy"),
91
- gr.Textbox(label="Objects (comma separated)"),
92
- gr.Slider(0, 1, value=0.2, label="Confidence Threshold")
 
 
 
 
93
  ],
94
- outputs=gr.AnnotatedImage(),
95
- title="OWLv2 Object Detection (Fixed)",
96
  )
97
 
98
- # Launch
99
  demo.launch()
 
1
  import torch
2
  import gradio as gr
3
+ import numpy as np
4
  import cv2
5
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
 
7
+ # ===============================
8
+ # DEVICE
9
+ # ===============================
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
 
12
  model = Owlv2ForObjectDetection.from_pretrained(
13
  "google/owlv2-base-patch16-ensemble"
14
  ).to(device)
 
17
  "google/owlv2-base-patch16-ensemble"
18
  )
19
 
20
+ # ===============================
21
+ # YOUR PREPROCESSING
22
+ # ===============================
23
+ def advanced_preprocessing(img_array: np.ndarray,
24
+ crop_ratio=(0.25, 0.75),
25
+ target_size=(512, 512),
26
+ grayscale=True,
27
+ tile=(1,1)):
28
+
29
+ h, w = img_array.shape[:2]
30
+
31
+ x1, x2 = int(crop_ratio[0]*w), int(crop_ratio[1]*w)
32
+ y1, y2 = int(crop_ratio[0]*h), int(crop_ratio[1]*h)
33
+
34
+ img_cropped = img_array[y1:y2, x1:x2]
35
+ img_resized = cv2.resize(img_cropped, target_size)
36
+
37
+ if grayscale:
38
+ gray = cv2.cvtColor(img_resized, cv2.COLOR_RGB2GRAY)
39
+ img_resized = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
40
+
41
+ img_stretch = np.zeros_like(img_resized)
42
+ for c in range(3):
43
+ img_stretch[:,:,c] = cv2.normalize(
44
+ img_resized[:,:,c], None, 0, 255, cv2.NORM_MINMAX
45
+ )
46
+
47
+ if tile != (1,1):
48
+ img_stretch = np.tile(img_stretch, (tile[0], tile[1], 1))
49
+
50
+ return img_stretch
51
+
52
+
53
+ # ===============================
54
  # MAIN FUNCTION
55
+ # ===============================
56
  def query_image(img, text_queries, score_threshold):
57
 
58
+ # preprocess
59
+ img = advanced_preprocessing(img)
60
 
61
+ text_queries = [q.strip() for q in text_queries.split(",")]
 
 
62
 
 
63
  inputs = processor(
64
+ text=text_queries,
65
  images=img,
66
  return_tensors="pt"
67
  ).to(device)
68
 
 
69
  with torch.no_grad():
70
  outputs = model(**inputs)
71
 
72
+ # IMPORTANT FIX
73
+ target_sizes = torch.tensor([img.shape[:2]])
 
74
 
 
75
  results = processor.post_process_grounded_object_detection(
76
  outputs=outputs,
77
+ target_sizes=target_sizes
78
+ )[0]
 
79
 
80
+ boxes = results["boxes"]
81
+ scores = results["scores"]
82
+ labels = results["labels"]
83
 
84
+ output_data = []
85
 
86
  # Draw boxes
87
  for box, score, label in zip(boxes, scores, labels):
88
+ if score < score_threshold:
89
+ continue
90
 
91
+ x1, y1, x2, y2 = map(int, box.tolist())
92
+ class_name = text_queries[label.item()]
93
+ conf = float(score)
 
94
 
95
+ # Save structured output
96
+ output_data.append({
97
+ "box": [x1, y1, x2, y2],
98
+ "label": class_name,
99
+ "score": round(conf, 3)
100
+ })
101
 
102
  # Draw on image
103
  cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)
104
+ cv2.putText(
105
+ img,
106
+ f"{class_name} {conf:.2f}",
107
+ (x1, y1-5),
108
+ cv2.FONT_HERSHEY_SIMPLEX,
109
+ 0.5,
110
+ (0,255,0),
111
+ 2
112
+ )
113
+
114
+ return img, output_data
115
+
116
+
117
+ # ===============================
118
+ # GRADIO UI
119
+ # ===============================
120
  demo = gr.Interface(
121
  fn=query_image,
122
  inputs=[
123
  gr.Image(type="numpy"),
124
+ gr.Textbox(label="Classes (comma separated)"),
125
+ gr.Slider(0, 1, value=0.2)
126
+ ],
127
+ outputs=[
128
+ gr.Image(label="Result"),
129
+ gr.JSON(label="Detections")
130
  ],
131
+ title="Correct Bounding Box Detection (OWLv2)"
 
132
  )
133
 
 
134
  demo.launch()