Spaces:
Sleeping
Sleeping
Commit ·
5a0ba26
1
Parent(s): 64a43aa
(feat) semantic post processing
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
| 4 |
from PIL import Image, ImageDraw
|
| 5 |
from transformers import GroundingDinoProcessor
|
| 6 |
from hf_model import CountEX
|
| 7 |
-
from utils import post_process_grounded_object_detection
|
| 8 |
|
| 9 |
# Global variables for model and processor
|
| 10 |
model = None
|
|
@@ -75,25 +75,27 @@ def filter_points_by_negative(points, neg_points, image_size, pixel_threshold=5)
|
|
| 75 |
|
| 76 |
return filtered_points, filtered_indices
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
def discriminative_point_suppression(
|
| 79 |
points,
|
| 80 |
neg_points,
|
| 81 |
-
pos_queries,
|
| 82 |
-
neg_queries,
|
| 83 |
image_size,
|
| 84 |
pixel_threshold=5,
|
| 85 |
-
similarity_threshold=0.
|
| 86 |
-
mode="and"
|
| 87 |
):
|
| 88 |
"""
|
| 89 |
Discriminative Point Suppression (DPS):
|
| 90 |
-
Suppress positive predictions that are both spatially close to
|
| 91 |
-
AND semantically similar with negative predictions.
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
|
| 98 |
Args:
|
| 99 |
points: List of [x, y] positive points (normalized, 0-1)
|
|
@@ -102,13 +104,12 @@ def discriminative_point_suppression(
|
|
| 102 |
neg_queries: (M, D) query embeddings for negative predictions
|
| 103 |
image_size: (width, height) in pixels
|
| 104 |
pixel_threshold: spatial distance threshold in pixels
|
| 105 |
-
similarity_threshold: cosine similarity threshold for semantic
|
| 106 |
-
mode: "and" for hard joint condition, "weighted" for soft combination
|
| 107 |
|
| 108 |
Returns:
|
| 109 |
filtered_points: points after suppression
|
| 110 |
filtered_indices: indices of kept points
|
| 111 |
-
suppression_info: dict with detailed suppression decisions
|
| 112 |
"""
|
| 113 |
if not neg_points or not points:
|
| 114 |
return points, list(range(len(points))), {}
|
|
@@ -116,74 +117,53 @@ def discriminative_point_suppression(
|
|
| 116 |
width, height = image_size
|
| 117 |
N, M = len(points), len(neg_points)
|
| 118 |
|
| 119 |
-
# === Spatial
|
| 120 |
points_arr = np.array(points) * np.array([width, height]) # (N, 2)
|
| 121 |
neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
|
| 122 |
|
|
|
|
| 123 |
spatial_dist = np.linalg.norm(
|
| 124 |
points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
|
| 125 |
) # (N, M)
|
| 126 |
|
| 127 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# Normalize queries
|
| 129 |
pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
|
| 130 |
neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
# Hard condition: suppress only if BOTH spatially close AND semantically similar
|
| 137 |
-
spatial_close = spatial_dist < pixel_threshold # (N, M)
|
| 138 |
-
semantic_similar = query_sim > similarity_threshold # (N, M)
|
| 139 |
-
|
| 140 |
-
# A positive is suppressed if ANY negative satisfies both conditions
|
| 141 |
-
should_suppress = (spatial_close & semantic_similar).any(axis=1) # (N,)
|
| 142 |
-
|
| 143 |
-
elif mode == "weighted":
|
| 144 |
-
# Soft combination: weighted score
|
| 145 |
-
# Convert distance to proximity score (0-1, higher = closer)
|
| 146 |
-
spatial_proximity = np.exp(-spatial_dist / pixel_threshold) # (N, M)
|
| 147 |
-
|
| 148 |
-
# Normalize similarity to [0, 1]
|
| 149 |
-
semantic_score = (query_sim + 1) / 2 # (N, M)
|
| 150 |
-
|
| 151 |
-
# Combined suppression score
|
| 152 |
-
suppression_score = spatial_proximity * semantic_score # (N, M)
|
| 153 |
-
max_suppression = suppression_score.max(axis=1) # (N,)
|
| 154 |
-
|
| 155 |
-
should_suppress = max_suppression > similarity_threshold
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
|
|
|
| 159 |
|
| 160 |
# === Filter ===
|
| 161 |
keep_mask = ~should_suppress
|
| 162 |
filtered_points = np.array(points)[keep_mask].tolist()
|
| 163 |
filtered_indices = np.where(keep_mask)[0].tolist()
|
| 164 |
|
| 165 |
-
# === Suppression Info
|
| 166 |
suppression_info = {
|
| 167 |
-
"
|
| 168 |
-
"
|
|
|
|
|
|
|
|
|
|
| 169 |
"suppressed_indices": np.where(should_suppress)[0].tolist(),
|
| 170 |
-
"suppressed_reasons": []
|
| 171 |
}
|
| 172 |
|
| 173 |
-
# Record why each point was suppressed
|
| 174 |
-
for i in np.where(should_suppress)[0]:
|
| 175 |
-
if mode == "and":
|
| 176 |
-
matching_negs = np.where(spatial_close[i] & semantic_similar[i])[0]
|
| 177 |
-
else:
|
| 178 |
-
matching_negs = [suppression_score[i].argmax()]
|
| 179 |
-
|
| 180 |
-
suppression_info["suppressed_reasons"].append({
|
| 181 |
-
"pos_idx": int(i),
|
| 182 |
-
"matched_neg_idx": matching_negs.tolist() if isinstance(matching_negs, np.ndarray) else matching_negs,
|
| 183 |
-
"min_spatial_dist": float(spatial_dist[i].min()),
|
| 184 |
-
"max_query_sim": float(query_sim[i].max())
|
| 185 |
-
})
|
| 186 |
-
|
| 187 |
return filtered_points, filtered_indices, suppression_info
|
| 188 |
|
| 189 |
def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
|
|
@@ -259,7 +239,13 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
|
|
| 259 |
outputs["pred_logits"] = outputs["logits"]
|
| 260 |
|
| 261 |
threshold = box_threshold if box_threshold > 0 else model.box_threshold
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
boxes = results["boxes"]
|
| 265 |
boxes = [box.tolist() for box in boxes]
|
|
@@ -273,17 +259,13 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
|
|
| 273 |
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
|
| 274 |
neg_outputs["pred_logits"] = outputs["neg_logits"]
|
| 275 |
|
| 276 |
-
neg_results =
|
| 277 |
neg_boxes = neg_results["boxes"]
|
| 278 |
neg_boxes = [box.tolist() for box in neg_boxes]
|
| 279 |
neg_points = [[box[0], box[1]] for box in neg_boxes]
|
| 280 |
|
| 281 |
-
pos_queries =
|
| 282 |
-
neg_queries =
|
| 283 |
-
pos_queries = pos_queries[-1].squeeze(0)
|
| 284 |
-
neg_queries = neg_queries[-1].squeeze(0)
|
| 285 |
-
pos_queries = pos_queries.cpu().numpy()
|
| 286 |
-
neg_queries = neg_queries.cpu().numpy()
|
| 287 |
|
| 288 |
img_size = image.size
|
| 289 |
# filtered_points, kept_indices = filter_points_by_negative(
|
|
@@ -299,8 +281,7 @@ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius,
|
|
| 299 |
neg_queries,
|
| 300 |
image_size=img_size,
|
| 301 |
pixel_threshold=5,
|
| 302 |
-
similarity_threshold=0.
|
| 303 |
-
mode="and"
|
| 304 |
)
|
| 305 |
|
| 306 |
filtered_boxes = [boxes[i] for i in kept_indices]
|
|
|
|
| 4 |
from PIL import Image, ImageDraw
|
| 5 |
from transformers import GroundingDinoProcessor
|
| 6 |
from hf_model import CountEX
|
| 7 |
+
from utils import post_process_grounded_object_detection, post_process_grounded_object_detection_with_queries
|
| 8 |
|
| 9 |
# Global variables for model and processor
|
| 10 |
model = None
|
|
|
|
| 75 |
|
| 76 |
return filtered_points, filtered_indices
|
| 77 |
|
| 78 |
+
|
| 79 |
+
import numpy as np
|
| 80 |
+
|
| 81 |
def discriminative_point_suppression(
|
| 82 |
points,
|
| 83 |
neg_points,
|
| 84 |
+
pos_queries, # (N, D) numpy array
|
| 85 |
+
neg_queries, # (M, D) numpy array
|
| 86 |
image_size,
|
| 87 |
pixel_threshold=5,
|
| 88 |
+
similarity_threshold=0.3,
|
|
|
|
| 89 |
):
|
| 90 |
"""
|
| 91 |
Discriminative Point Suppression (DPS):
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
Step 1: Find spatially closest negative point for each positive point
|
| 94 |
+
Step 2: If distance < pixel_threshold, check query similarity
|
| 95 |
+
Step 3: Suppress only if query similarity > similarity_threshold
|
| 96 |
+
|
| 97 |
+
This two-stage design ensures suppression only when predictions are
|
| 98 |
+
both spatially overlapping AND semantically conflicting.
|
| 99 |
|
| 100 |
Args:
|
| 101 |
points: List of [x, y] positive points (normalized, 0-1)
|
|
|
|
| 104 |
neg_queries: (M, D) query embeddings for negative predictions
|
| 105 |
image_size: (width, height) in pixels
|
| 106 |
pixel_threshold: spatial distance threshold in pixels
|
| 107 |
+
similarity_threshold: cosine similarity threshold for semantic conflict
|
|
|
|
| 108 |
|
| 109 |
Returns:
|
| 110 |
filtered_points: points after suppression
|
| 111 |
filtered_indices: indices of kept points
|
| 112 |
+
suppression_info: dict with detailed suppression decisions
|
| 113 |
"""
|
| 114 |
if not neg_points or not points:
|
| 115 |
return points, list(range(len(points))), {}
|
|
|
|
| 117 |
width, height = image_size
|
| 118 |
N, M = len(points), len(neg_points)
|
| 119 |
|
| 120 |
+
# === Step 1: Spatial Matching ===
|
| 121 |
points_arr = np.array(points) * np.array([width, height]) # (N, 2)
|
| 122 |
neg_points_arr = np.array(neg_points) * np.array([width, height]) # (M, 2)
|
| 123 |
|
| 124 |
+
# Compute pairwise distances
|
| 125 |
spatial_dist = np.linalg.norm(
|
| 126 |
points_arr[:, None, :] - neg_points_arr[None, :, :], axis=-1
|
| 127 |
) # (N, M)
|
| 128 |
|
| 129 |
+
# Find nearest negative for each positive
|
| 130 |
+
nearest_neg_idx = spatial_dist.argmin(axis=1) # (N,)
|
| 131 |
+
nearest_neg_dist = spatial_dist.min(axis=1) # (N,)
|
| 132 |
+
|
| 133 |
+
# Check spatial condition
|
| 134 |
+
spatially_close = nearest_neg_dist < pixel_threshold # (N,)
|
| 135 |
+
|
| 136 |
+
# === Step 2: Query Similarity Check (only for spatially close pairs) ===
|
| 137 |
# Normalize queries
|
| 138 |
pos_q = pos_queries / (np.linalg.norm(pos_queries, axis=-1, keepdims=True) + 1e-8)
|
| 139 |
neg_q = neg_queries / (np.linalg.norm(neg_queries, axis=-1, keepdims=True) + 1e-8)
|
| 140 |
|
| 141 |
+
# Compute similarity only for matched pairs
|
| 142 |
+
matched_neg_q = neg_q[nearest_neg_idx] # (N, D)
|
| 143 |
+
query_sim = (pos_q * matched_neg_q).sum(axis=-1) # (N,) cosine similarity
|
| 144 |
|
| 145 |
+
# Check semantic condition
|
| 146 |
+
semantically_similar = query_sim > similarity_threshold # (N,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
# === Step 3: Joint Decision ===
|
| 149 |
+
# Suppress only if BOTH conditions are met
|
| 150 |
+
should_suppress = spatially_close & semantically_similar # (N,)
|
| 151 |
|
| 152 |
# === Filter ===
|
| 153 |
keep_mask = ~should_suppress
|
| 154 |
filtered_points = np.array(points)[keep_mask].tolist()
|
| 155 |
filtered_indices = np.where(keep_mask)[0].tolist()
|
| 156 |
|
| 157 |
+
# === Suppression Info ===
|
| 158 |
suppression_info = {
|
| 159 |
+
"nearest_neg_idx": nearest_neg_idx.tolist(),
|
| 160 |
+
"nearest_neg_dist": nearest_neg_dist.tolist(),
|
| 161 |
+
"query_similarity": query_sim.tolist(),
|
| 162 |
+
"spatially_close": spatially_close.tolist(),
|
| 163 |
+
"semantically_similar": semantically_similar.tolist(),
|
| 164 |
"suppressed_indices": np.where(should_suppress)[0].tolist(),
|
|
|
|
| 165 |
}
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
return filtered_points, filtered_indices, suppression_info
|
| 168 |
|
| 169 |
def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
|
|
|
|
| 239 |
outputs["pred_logits"] = outputs["logits"]
|
| 240 |
|
| 241 |
threshold = box_threshold if box_threshold > 0 else model.box_threshold
|
| 242 |
+
pos_queries = outputs["pos_queries"].squeeze(0).float()
|
| 243 |
+
neg_queries = outputs["neg_queries"].squeeze(0).float()
|
| 244 |
+
pos_queries = pos_queries[-1].squeeze(0)
|
| 245 |
+
neg_queries = neg_queries[-1].squeeze(0)
|
| 246 |
+
pos_queries = pos_queries.cpu().numpy()
|
| 247 |
+
neg_queries = neg_queries.cpu().numpy()
|
| 248 |
+
results = post_process_grounded_object_detection_with_queries(outputs, pos_queries, box_threshold=threshold)[0]
|
| 249 |
|
| 250 |
boxes = results["boxes"]
|
| 251 |
boxes = [box.tolist() for box in boxes]
|
|
|
|
| 259 |
neg_outputs["pred_points"] = outputs["neg_pred_boxes"][:, :, :2]
|
| 260 |
neg_outputs["pred_logits"] = outputs["neg_logits"]
|
| 261 |
|
| 262 |
+
neg_results = post_process_grounded_object_detection_with_queries(neg_outputs, neg_queries, box_threshold=threshold)[0]
|
| 263 |
neg_boxes = neg_results["boxes"]
|
| 264 |
neg_boxes = [box.tolist() for box in neg_boxes]
|
| 265 |
neg_points = [[box[0], box[1]] for box in neg_boxes]
|
| 266 |
|
| 267 |
+
pos_queries = results["queries"]
|
| 268 |
+
neg_queries = neg_results["queries"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
img_size = image.size
|
| 271 |
# filtered_points, kept_indices = filter_points_by_negative(
|
|
|
|
| 281 |
neg_queries,
|
| 282 |
image_size=img_size,
|
| 283 |
pixel_threshold=5,
|
| 284 |
+
similarity_threshold=0.25,
|
|
|
|
| 285 |
)
|
| 286 |
|
| 287 |
filtered_boxes = [boxes[i] for i in kept_indices]
|
utils.py
CHANGED
|
@@ -45,6 +45,38 @@ def post_process_grounded_object_detection(
|
|
| 45 |
|
| 46 |
return results
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
class collator:
|
| 50 |
def __init__(self, processor=None, use_negative=True):
|
|
|
|
| 45 |
|
| 46 |
return results
|
| 47 |
|
| 48 |
+
def post_process_grounded_object_detection_with_queries(
|
| 49 |
+
outputs,
|
| 50 |
+
queries,
|
| 51 |
+
box_threshold: float = 0.4,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Post-process grounded object detection outputs.
|
| 55 |
+
Now also returns the query embeddings for each kept prediction.
|
| 56 |
+
"""
|
| 57 |
+
logits, boxes = outputs.logits, outputs.pred_boxes
|
| 58 |
+
assert len(logits) == queries.shape[0], "logits and queries must have the same batch size"
|
| 59 |
+
|
| 60 |
+
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
|
| 61 |
+
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
|
| 62 |
+
|
| 63 |
+
results = []
|
| 64 |
+
for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)):
|
| 65 |
+
mask = s > box_threshold
|
| 66 |
+
score = s[mask]
|
| 67 |
+
box = b[mask]
|
| 68 |
+
prob = p[mask]
|
| 69 |
+
|
| 70 |
+
result = {"scores": score, "boxes": box}
|
| 71 |
+
|
| 72 |
+
# 保存对应的 query embeddings
|
| 73 |
+
if queries is not None:
|
| 74 |
+
result["queries"] = queries[idx][mask] # (num_kept, D)
|
| 75 |
+
|
| 76 |
+
results.append(result)
|
| 77 |
+
assert len(results['scores']) == len(results['boxes']) == results['queries'].shape[0], "scores, boxes and queries must have the same length"
|
| 78 |
+
return results
|
| 79 |
+
|
| 80 |
|
| 81 |
class collator:
|
| 82 |
def __init__(self, processor=None, use_negative=True):
|