File size: 6,835 Bytes
b30e7a3
 
 
374a0ef
58bb3a4
b30e7a3
58bb3a4
 
 
 
 
b30e7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58bb3a4
 
 
 
b30e7a3
 
 
 
 
 
 
58bb3a4
 
 
b30e7a3
 
 
 
 
58bb3a4
 
 
 
 
 
 
 
 
 
 
 
 
 
b30e7a3
 
 
 
 
 
 
 
 
 
 
 
 
58bb3a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374a0ef
b30e7a3
58bb3a4
 
 
 
 
 
 
 
 
 
b30e7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58bb3a4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from __future__ import annotations

import difflib
import functools
import logging
import re
from typing import Dict, Optional, Tuple

import numpy as np

logger = logging.getLogger(__name__)

COCO_CLASSES: Tuple[str, ...] = (
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "backpack",
    "umbrella",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "dining table",
    "toilet",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
)


def coco_class_catalog() -> str:
    """Return the COCO classes in a comma-separated catalog for prompts."""

    return ", ".join(COCO_CLASSES)


def _normalize(label: str) -> str:
    return re.sub(r"[^a-z0-9]+", " ", label.lower()).strip()


_CANONICAL_LOOKUP: Dict[str, str] = {_normalize(name): name for name in COCO_CLASSES}
_COCO_SYNONYMS: Dict[str, str] = {
    "people": "person",
    "man": "person",
    "woman": "person",
    "men": "person",
    "women": "person",
    "pedestrian": "person",
    "soldier": "person",
    "infantry": "person",
    "civilian": "person",
    "motorbike": "motorcycle",
    "motor bike": "motorcycle",
    "bike": "bicycle",
    "aircraft": "airplane",
    "plane": "airplane",
    "jet": "airplane",
    "aeroplane": "airplane",
    "drone": "airplane",
    "uav": "airplane",
    "helicopter": "airplane",
    "pickup": "truck",
    "pickup truck": "truck",
    "semi": "truck",
    "lorry": "truck",
    "tractor trailer": "truck",
    "vehicle": "car",
    "sedan": "car",
    "suv": "car",
    "van": "car",
    "vessel": "boat",
    "ship": "boat",
    "warship": "boat",
    "speedboat": "boat",
    "cargo ship": "boat",
    "fishing boat": "boat",
    "yacht": "boat",
    "kayak": "boat",
    "canoe": "boat",
    "watercraft": "boat",
    "coach": "bus",
    "television": "tv",
    "tv monitor": "tv",
    "mobile phone": "cell phone",
    "smartphone": "cell phone",
    "cellphone": "cell phone",
    "dinner table": "dining table",
    "sofa": "couch",
    "cooker": "oven",
}
_ALIAS_LOOKUP: Dict[str, str] = {_normalize(alias): canonical for alias, canonical in _COCO_SYNONYMS.items()}


# ---------------------------------------------------------------------------
# Semantic similarity fallback (lazy-loaded)
# ---------------------------------------------------------------------------

_SEMANTIC_MODEL = None
_COCO_EMBEDDINGS: Optional[np.ndarray] = None
_SEMANTIC_THRESHOLD = 0.65  # Minimum cosine similarity to accept a match


def _get_semantic_model():
    """Lazy-load a lightweight sentence-transformer for semantic matching."""
    global _SEMANTIC_MODEL, _COCO_EMBEDDINGS
    if _SEMANTIC_MODEL is not None:
        return _SEMANTIC_MODEL, _COCO_EMBEDDINGS

    try:
        from sentence_transformers import SentenceTransformer
        _SEMANTIC_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
        # Prefix with "a photo of a" to anchor embeddings in visual/object space
        coco_phrases = [f"a photo of a {cls}" for cls in COCO_CLASSES]
        _COCO_EMBEDDINGS = _SEMANTIC_MODEL.encode(
            coco_phrases, normalize_embeddings=True
        )
        logger.info("Loaded semantic similarity model for COCO class mapping")
    except Exception:
        logger.warning("sentence-transformers unavailable; semantic COCO mapping disabled", exc_info=True)
        _SEMANTIC_MODEL = False  # Sentinel: tried and failed
        _COCO_EMBEDDINGS = None

    return _SEMANTIC_MODEL, _COCO_EMBEDDINGS


def _semantic_coco_match(value: str) -> Optional[str]:
    """Find the closest COCO class by embedding cosine similarity.

    Returns the COCO class name if similarity >= threshold, else None.
    """
    model, coco_embs = _get_semantic_model()
    if model is False or coco_embs is None:
        return None

    query_emb = model.encode(
        [f"a photo of a {value}"], normalize_embeddings=True
    )
    similarities = query_emb @ coco_embs.T  # (1, 80)
    best_idx = int(np.argmax(similarities))
    best_score = float(similarities[0, best_idx])

    if best_score >= _SEMANTIC_THRESHOLD:
        matched = COCO_CLASSES[best_idx]
        logger.info(
            "Semantic COCO match: '%s' -> '%s' (score=%.3f)",
            value, matched, best_score,
        )
        return matched

    logger.debug(
        "Semantic COCO match failed: '%s' best='%s' (score=%.3f < %.2f)",
        value, COCO_CLASSES[best_idx], best_score, _SEMANTIC_THRESHOLD,
    )
    return None


@functools.lru_cache(maxsize=512)
def canonicalize_coco_name(value: str | None) -> str | None:
    """Map an arbitrary string to the closest COCO class name if possible.

    Matching cascade:
    1. Exact normalized match
    2. Synonym lookup
    3. Substring match (alias then canonical)
    4. Token-level match
    5. Fuzzy string match (difflib)
    6. Semantic embedding similarity (sentence-transformers)
    """

    if not value:
        return None
    normalized = _normalize(value)
    if not normalized:
        return None
    if normalized in _CANONICAL_LOOKUP:
        return _CANONICAL_LOOKUP[normalized]
    if normalized in _ALIAS_LOOKUP:
        return _ALIAS_LOOKUP[normalized]

    for alias_norm, canonical in _ALIAS_LOOKUP.items():
        if alias_norm and alias_norm in normalized:
            return canonical
    for canonical_norm, canonical in _CANONICAL_LOOKUP.items():
        if canonical_norm and canonical_norm in normalized:
            return canonical

    tokens = normalized.split()
    for token in tokens:
        if token in _CANONICAL_LOOKUP:
            return _CANONICAL_LOOKUP[token]
        if token in _ALIAS_LOOKUP:
            return _ALIAS_LOOKUP[token]

    close = difflib.get_close_matches(normalized, list(_CANONICAL_LOOKUP.keys()), n=1, cutoff=0.82)
    if close:
        return _CANONICAL_LOOKUP[close[0]]

    # Last resort: semantic embedding similarity
    return _semantic_coco_match(value)