yifehuang97 commited on
Commit
5a0ba26
·
1 Parent(s): 64a43aa

(feat) semantic post processing

Browse files
Files changed (2) hide show
  1. app.py +50 -69
  2. utils.py +32 -0
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from PIL import Image, ImageDraw
5
  from transformers import GroundingDinoProcessor
6
  from hf_model import CountEX
7
- from utils import post_process_grounded_object_detection
8
 
9
  # Global variables for model and processor
10
  model = None
@@ -75,25 +75,27 @@ def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5)
75
 
76
  return filtered_points, filtered_indices
77
 
 
 
 
78
  def discriminative_point_suppression(
79
  points,
80
  neg_points,
81
- pos_queries,
82
- neg_queries,
83
  image_size,
84
  pixel_threshold=5,
85
- similarity_threshold=0.5,
86
- mode="and"
87
  ):
88
  """
89
  Discriminative Point Suppression (DPS):
90
- Suppress positive predictions that are both spatially close to
91
- AND semantically similar with negative predictions.
92
 
93
- Motivation: Spatial proximity alone may cause false suppression when
94
- positive and negative queries represent different semantic concepts.
95
- By jointly verifying spatial AND semantic alignment, we ensure
96
- suppression only occurs for true conflicts.
 
 
97
 
98
  Args:
99
  points: List of [x, y] positive points (normalized, 0-1)
@@ -102,13 +104,12 @@ def discriminative_point_suppression(
102
  neg_queries: (M, D) query embeddings for negative predictions
103
  image_size: (width, height) in pixels
104
  pixel_threshold: spatial distance threshold in pixels
105
- similarity_threshold: cosine similarity threshold for semantic match
106
- mode: "and" for hard joint condition, "weighted" for soft combination
107
 
108
  Returns:
109
  filtered_points: points after suppression
110
  filtered_indices: indices of kept points
111
- suppression_info: dict with detailed suppression decisions (for analysis)
112
  """
113
  if not neg_points or not points:
114
  return points, list(range(len(points))), {}
@@ -116,74 +117,53 @@ def discriminative_point_suppression(
116
  width, height = image_size
117
  N, M = len(points), len(neg_points)
118
 
119
- # === Spatial Distance ===
120
  points_arr = np.array(points) * np.array([width, height]) # (N, 2)
121
  neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
122
 
 
123
  spatial_dist = np.linalg.norm(
124
  points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
125
  ) # (N, M)
126
 
127
- # === Query Similarity (Cosine) ===
 
 
 
 
 
 
 
128
  # Normalize queries
129
  pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
130
  neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
131
 
132
- query_sim = np.dot(pos_q, neg_q.T) # (N, M), range [-1, 1]
 
 
133
 
134
- # === Joint Suppression Decision ===
135
- if mode == "and":
136
- # Hard condition: suppress only if BOTH spatially close AND semantically similar
137
- spatial_close = spatial_dist < pixel_threshold # (N, M)
138
- semantic_similar = query_sim > similarity_threshold # (N, M)
139
-
140
- # A positive is suppressed if ANY negative satisfies both conditions
141
- should_suppress = (spatial_close & semantic_similar).any(axis=1) # (N,)
142
-
143
- elif mode == "weighted":
144
- # Soft combination: weighted score
145
- # Convert distance to proximity score (0-1, higher = closer)
146
- spatial_proximity = np.exp(-spatial_dist / pixel_threshold) # (N, M)
147
-
148
- # Normalize similarity to [0, 1]
149
- semantic_score = (query_sim + 1) / 2 # (N, M)
150
-
151
- # Combined suppression score
152
- suppression_score = spatial_proximity * semantic_score # (N, M)
153
- max_suppression = suppression_score.max(axis=1) # (N,)
154
-
155
- should_suppress = max_suppression > similarity_threshold
156
 
157
- else:
158
- raise ValueError(f"Unknown mode: {mode}")
 
159
 
160
  # === Filter ===
161
  keep_mask = ~should_suppress
162
  filtered_points = np.array(points)[keep_mask].tolist()
163
  filtered_indices = np.where(keep_mask)[0].tolist()
164
 
165
- # === Suppression Info (for analysis/visualization) ===
166
  suppression_info = {
167
- "spatial_dist": spatial_dist,
168
- "query_similarity": query_sim,
 
 
 
169
  "suppressed_indices": np.where(should_suppress)[0].tolist(),
170
- "suppressed_reasons": []
171
  }
172
 
173
- # Record why each point was suppressed
174
- for i in np.where(should_suppress)[0]:
175
- if mode == "and":
176
- matching_negs = np.where(spatial_close[i] & semantic_similar[i])[0]
177
- else:
178
- matching_negs = [suppression_score[i].argmax()]
179
-
180
- suppression_info["suppressed_reasons"].append({
181
- "pos_idx": int(i),
182
- "matched_neg_idx": matching_negs.tolist() if isinstance(matching_negs, np.ndarray) else matching_negs,
183
- "min_spatial_dist": float(spatial_dist[i].min()),
184
- "max_query_sim": float(query_sim[i].max())
185
- })
186
-
187
  return filtered_points, filtered_indices, suppression_info
188
 
189
  def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
@@ -259,7 +239,13 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
259
  outputs["pred_logits"] = outputs["logits"]
260
 
261
  threshold = box_threshold if box_threshold > 0 else model.box_threshold
262
- results = post_process_grounded_object_detection(outputs, box_threshold=threshold)[0]
 
 
 
 
 
 
263
 
264
  boxes = results["boxes"]
265
  boxes = [box.tolist() for box in boxes]
@@ -273,17 +259,13 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
273
  neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
274
  neg_outputs["pred_logits"] = outputs["neg_logits"]
275
 
276
- neg_results = post_process_grounded_object_detection(neg_outputs, box_threshold=threshold)[0]
277
  neg_boxes = neg_results["boxes"]
278
  neg_boxes = [box.tolist() for box in neg_boxes]
279
  neg_points = [[box[0], box[1]] for box in neg_boxes]
280
 
281
- pos_queries = outputs["pos_queries"].squeeze(0).float()
282
- neg_queries = outputs["neg_queries"].squeeze(0).float()
283
- pos_queries = pos_queries[-1].squeeze(0)
284
- neg_queries = neg_queries[-1].squeeze(0)
285
- pos_queries = pos_queries.cpu().numpy()
286
- neg_queries = neg_queries.cpu().numpy()
287
 
288
  img_size = image.size
289
  # filtered_points, kept_indices = filter_points_by_negative(
@@ -299,8 +281,7 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
299
  neg_queries,
300
  image_size=img_size,
301
  pixel_threshold=5,
302
- similarity_threshold=0.5,
303
- mode="and"
304
  )
305
 
306
  filtered_boxes = [boxes[i] for i in kept_indices]
 
4
  from PIL import Image, ImageDraw
5
  from transformers import GroundingDinoProcessor
6
  from hf_model import CountEX
7
+ from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
8
 
9
  # Global variables for model and processor
10
  model = None
 
75
 
76
  return filtered_points, filtered_indices
77
 
78
+
79
+ import numpy as np
80
+
81
  def discriminative_point_suppression(
82
  points,
83
  neg_points,
84
+ pos_queries, # (N, D) numpy array
85
+ neg_queries, # (M, D) numpy array
86
  image_size,
87
  pixel_threshold=5,
88
+ similarity_threshold=0.3,
 
89
  ):
90
  """
91
  Discriminative Point Suppression (DPS):
 
 
92
 
93
+ Step 1: Find spatially closest negative point for each positive point
94
+ Step 2: If distance < pixel_threshold, check query similarity
95
+ Step 3: Suppress only if query similarity > similarity_threshold
96
+
97
+ This two-stage design ensures suppression only when predictions are
98
+ both spatially overlapping AND semantically conflicting.
99
 
100
  Args:
101
  points: List of [x, y] positive points (normalized, 0-1)
 
104
  neg_queries: (M, D) query embeddings for negative predictions
105
  image_size: (width, height) in pixels
106
  pixel_threshold: spatial distance threshold in pixels
107
+ similarity_threshold: cosine similarity threshold for semantic conflict
 
108
 
109
  Returns:
110
  filtered_points: points after suppression
111
  filtered_indices: indices of kept points
112
+ suppression_info: dict with detailed suppression decisions
113
  """
114
  if not neg_points or not points:
115
  return points, list(range(len(points))), {}
 
117
  width, height = image_size
118
  N, M = len(points), len(neg_points)
119
 
120
+ # === Step 1: Spatial Matching ===
121
  points_arr = np.array(points) * np.array([width, height]) # (N, 2)
122
  neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
123
 
124
+ # Compute pairwise distances
125
  spatial_dist = np.linalg.norm(
126
  points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
127
  ) # (N, M)
128
 
129
+ # Find nearest negative for each positive
130
+ nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,)
131
+ nearest_neg_dist = spatial_dist.min(axis=1) # (N,)
132
+
133
+ # Check spatial condition
134
+ spatially_close = nearest_neg_dist < pixel_threshold # (N,)
135
+
136
+ # === Step 2: Query Similarity Check (only for spatially close pairs) ===
137
  # Normalize queries
138
  pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
139
  neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
140
 
141
+ # Compute similarity only for matched pairs
142
+ matched_neg_q = neg_q[nearest_neg_idx] # (N, D)
143
+ query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity
144
 
145
+ # Check semantic condition
146
+ semantically_similar = query_sim > similarity_threshold # (N,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ # === Step 3: Joint Decision ===
149
+ # Suppress only if BOTH conditions are met
150
+ should_suppress = spatially_close & semantically_similar # (N,)
151
 
152
  # === Filter ===
153
  keep_mask = ~should_suppress
154
  filtered_points = np.array(points)[keep_mask].tolist()
155
  filtered_indices = np.where(keep_mask)[0].tolist()
156
 
157
+ # === Suppression Info ===
158
  suppression_info = {
159
+ "nearest_neg_idx": nearest_neg_idx.tolist(),
160
+ "nearest_neg_dist": nearest_neg_dist.tolist(),
161
+ "query_similarity": query_sim.tolist(),
162
+ "spatially_close": spatially_close.tolist(),
163
+ "semantically_similar": semantically_similar.tolist(),
164
  "suppressed_indices": np.where(should_suppress)[0].tolist(),
 
165
  }
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  return filtered_points, filtered_indices, suppression_info
168
 
169
  def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
 
239
  outputs["pred_logits"] = outputs["logits"]
240
 
241
  threshold = box_threshold if box_threshold > 0 else model.box_threshold
242
+ pos_queries = outputs["pos_queries"].squeeze(0).float()
243
+ neg_queries = outputs["neg_queries"].squeeze(0).float()
244
+ pos_queries = pos_queries[-1].squeeze(0)
245
+ neg_queries = neg_queries[-1].squeeze(0)
246
+ pos_queries = pos_queries.cpu().numpy()
247
+ neg_queries = neg_queries.cpu().numpy()
248
+ results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
249
 
250
  boxes = results["boxes"]
251
  boxes = [box.tolist() for box in boxes]
 
259
  neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
260
  neg_outputs["pred_logits"] = outputs["neg_logits"]
261
 
262
+ neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
263
  neg_boxes = neg_results["boxes"]
264
  neg_boxes = [box.tolist() for box in neg_boxes]
265
  neg_points = [[box[0], box[1]] for box in neg_boxes]
266
 
267
+ pos_queries = results["queries"]
268
+ neg_queries = neg_results["queries"]
 
 
 
 
269
 
270
  img_size = image.size
271
  # filtered_points, kept_indices = filter_points_by_negative(
 
281
  neg_queries,
282
  image_size=img_size,
283
  pixel_threshold=5,
284
+ similarity_threshold=0.25,
 
285
  )
286
 
287
  filtered_boxes = [boxes[i] for i in kept_indices]
utils.py CHANGED
@@ -45,6 +45,38 @@ def post_process_grounded_object_detection(
45
 
46
  return results
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  class collator:
50
  def __init__(self, processor=None, use_negative=True):
 
45
 
46
  return results
47
 
48
+ def post_process_grounded_object_detection_with_queries(
49
+ outputs,
50
+ queries,
51
+ box_threshold: float = 0.4,
52
+ ):
53
+ """
54
+ Post-process grounded object detection outputs.
55
+ Now also returns the query embeddings for each kept prediction.
56
+ """
57
+ logits, boxes = outputs.logits, outputs.pred_boxes
58
+ assert len(logits) == queries.shape[0], "logits and queries must have the same batch size"
59
+
60
+ probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
61
+ scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
62
+
63
+ results = []
64
+ for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)):
65
+ mask = s > box_threshold
66
+ score = s[mask]
67
+ box = b[mask]
68
+ prob = p[mask]
69
+
70
+ result = {"scores": score, "boxes": box}
71
+
72
+ # 保存对应的 query embeddings
73
+ if queries is not None:
74
+ result["queries"] = queries[idx][mask] # (num_kept, D)
75
+
76
+ results.append(result)
77
+ assert len(results['scores']) == len(results['boxes']) == results['queries'].shape[0], "scores, boxes and queries must have the same length"
78
+ return results
79
+
80
 
81
  class collator:
82
  def __init__(self, processor=None, use_negative=True):