Zhen Ye commited on
Commit
58bb3a4
·
1 Parent(s): 7e7e04e

Add semantic COCO class matching via sentence-transformers and expand synonym table

Browse files
Files changed (2) hide show
  1. coco_classes.py +103 -3
  2. requirements.txt +1 -0
coco_classes.py CHANGED
@@ -1,8 +1,13 @@
1
  from __future__ import annotations
2
 
3
  import difflib
 
4
  import re
5
- from typing import Dict, Tuple
 
 
 
 
6
 
7
  COCO_CLASSES: Tuple[str, ...] = (
8
  "person",
@@ -105,6 +110,10 @@ _COCO_SYNONYMS: Dict[str, str] = {
105
  "woman": "person",
106
  "men": "person",
107
  "women": "person",
 
 
 
 
108
  "motorbike": "motorcycle",
109
  "motor bike": "motorcycle",
110
  "bike": "bicycle",
@@ -112,11 +121,28 @@ _COCO_SYNONYMS: Dict[str, str] = {
112
  "plane": "airplane",
113
  "jet": "airplane",
114
  "aeroplane": "airplane",
 
 
 
115
  "pickup": "truck",
116
  "pickup truck": "truck",
117
  "semi": "truck",
118
  "lorry": "truck",
119
  "tractor trailer": "truck",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  "coach": "bus",
121
  "television": "tv",
122
  "tv monitor": "tv",
@@ -130,8 +156,80 @@ _COCO_SYNONYMS: Dict[str, str] = {
130
  _ALIAS_LOOKUP: Dict[str, str] = {_normalize(alias): canonical for alias, canonical in _COCO_SYNONYMS.items()}
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def canonicalize_coco_name(value: str | None) -> str | None:
134
- """Map an arbitrary string to the closest COCO class name if possible."""
 
 
 
 
 
 
 
 
 
135
 
136
  if not value:
137
  return None
@@ -160,4 +258,6 @@ def canonicalize_coco_name(value: str | None) -> str | None:
160
  close = difflib.get_close_matches(normalized, list(_CANONICAL_LOOKUP.keys()), n=1, cutoff=0.82)
161
  if close:
162
  return _CANONICAL_LOOKUP[close[0]]
163
- return None
 
 
 
1
  from __future__ import annotations
2
 
3
  import difflib
4
+ import logging
5
  import re
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import numpy as np
9
+
10
+ logger = logging.getLogger(__name__)
11
 
12
  COCO_CLASSES: Tuple[str, ...] = (
13
  "person",
 
110
  "woman": "person",
111
  "men": "person",
112
  "women": "person",
113
+ "pedestrian": "person",
114
+ "soldier": "person",
115
+ "infantry": "person",
116
+ "civilian": "person",
117
  "motorbike": "motorcycle",
118
  "motor bike": "motorcycle",
119
  "bike": "bicycle",
 
121
  "plane": "airplane",
122
  "jet": "airplane",
123
  "aeroplane": "airplane",
124
+ "drone": "airplane",
125
+ "uav": "airplane",
126
+ "helicopter": "airplane",
127
  "pickup": "truck",
128
  "pickup truck": "truck",
129
  "semi": "truck",
130
  "lorry": "truck",
131
  "tractor trailer": "truck",
132
+ "vehicle": "car",
133
+ "sedan": "car",
134
+ "suv": "car",
135
+ "van": "car",
136
+ "vessel": "boat",
137
+ "ship": "boat",
138
+ "warship": "boat",
139
+ "speedboat": "boat",
140
+ "cargo ship": "boat",
141
+ "fishing boat": "boat",
142
+ "yacht": "boat",
143
+ "kayak": "boat",
144
+ "canoe": "boat",
145
+ "watercraft": "boat",
146
  "coach": "bus",
147
  "television": "tv",
148
  "tv monitor": "tv",
 
156
  _ALIAS_LOOKUP: Dict[str, str] = {_normalize(alias): canonical for alias, canonical in _COCO_SYNONYMS.items()}
157
 
158
 
159
+ # ---------------------------------------------------------------------------
160
+ # Semantic similarity fallback (lazy-loaded)
161
+ # ---------------------------------------------------------------------------
162
+
163
+ _SEMANTIC_MODEL = None
164
+ _COCO_EMBEDDINGS: Optional[np.ndarray] = None
165
+ _SEMANTIC_THRESHOLD = 0.65 # Minimum cosine similarity to accept a match
166
+
167
+
168
+ def _get_semantic_model():
169
+ """Lazy-load a lightweight sentence-transformer for semantic matching."""
170
+ global _SEMANTIC_MODEL, _COCO_EMBEDDINGS
171
+ if _SEMANTIC_MODEL is not None:
172
+ return _SEMANTIC_MODEL, _COCO_EMBEDDINGS
173
+
174
+ try:
175
+ from sentence_transformers import SentenceTransformer
176
+ _SEMANTIC_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
177
+ # Prefix with "a photo of a" to anchor embeddings in visual/object space
178
+ coco_phrases = [f"a photo of a {cls}" for cls in COCO_CLASSES]
179
+ _COCO_EMBEDDINGS = _SEMANTIC_MODEL.encode(
180
+ coco_phrases, normalize_embeddings=True
181
+ )
182
+ logger.info("Loaded semantic similarity model for COCO class mapping")
183
+ except Exception:
184
+ logger.warning("sentence-transformers unavailable; semantic COCO mapping disabled", exc_info=True)
185
+ _SEMANTIC_MODEL = False # Sentinel: tried and failed
186
+ _COCO_EMBEDDINGS = None
187
+
188
+ return _SEMANTIC_MODEL, _COCO_EMBEDDINGS
189
+
190
+
191
+ def _semantic_coco_match(value: str) -> Optional[str]:
192
+ """Find the closest COCO class by embedding cosine similarity.
193
+
194
+ Returns the COCO class name if similarity >= threshold, else None.
195
+ """
196
+ model, coco_embs = _get_semantic_model()
197
+ if model is False or coco_embs is None:
198
+ return None
199
+
200
+ query_emb = model.encode(
201
+ [f"a photo of a {value}"], normalize_embeddings=True
202
+ )
203
+ similarities = query_emb @ coco_embs.T # (1, 80)
204
+ best_idx = int(np.argmax(similarities))
205
+ best_score = float(similarities[0, best_idx])
206
+
207
+ if best_score >= _SEMANTIC_THRESHOLD:
208
+ matched = COCO_CLASSES[best_idx]
209
+ logger.info(
210
+ "Semantic COCO match: '%s' -> '%s' (score=%.3f)",
211
+ value, matched, best_score,
212
+ )
213
+ return matched
214
+
215
+ logger.debug(
216
+ "Semantic COCO match failed: '%s' best='%s' (score=%.3f < %.2f)",
217
+ value, COCO_CLASSES[best_idx], best_score, _SEMANTIC_THRESHOLD,
218
+ )
219
+ return None
220
+
221
+
222
  def canonicalize_coco_name(value: str | None) -> str | None:
223
+ """Map an arbitrary string to the closest COCO class name if possible.
224
+
225
+ Matching cascade:
226
+ 1. Exact normalized match
227
+ 2. Synonym lookup
228
+ 3. Substring match (alias then canonical)
229
+ 4. Token-level match
230
+ 5. Fuzzy string match (difflib)
231
+ 6. Semantic embedding similarity (sentence-transformers)
232
+ """
233
 
234
  if not value:
235
  return None
 
258
  close = difflib.get_close_matches(normalized, list(_CANONICAL_LOOKUP.keys()), n=1, cutoff=0.82)
259
  if close:
260
  return _CANONICAL_LOOKUP[close[0]]
261
+
262
+ # Last resort: semantic embedding similarity
263
+ return _semantic_coco_match(value)
requirements.txt CHANGED
@@ -9,3 +9,4 @@ huggingface-hub
9
  ultralytics
10
  python-dotenv
11
  einops
 
 
9
  ultralytics
10
  python-dotenv
11
  einops
12
+ sentence-transformers