parrvv commited on
Commit
0243bd1
Β·
verified Β·
1 Parent(s): 08099a9

Upload 2 files

Browse files
Files changed (2) hide show
  1. predictor.py +363 -0
  2. validator_local.py +671 -0
predictor.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ predictor.py β€” Student inference file for hidden evaluation.
3
+
4
+ ╔══════════════════════════════════════════════════════════════════╗
5
+ β•‘ DO NOT RENAME ANY FUNCTION. β•‘
6
+ β•‘ DO NOT CHANGE FUNCTION SIGNATURES. β•‘
7
+ β•‘ DO NOT REMOVE ANY FUNCTION. β•‘
8
+ β•‘ DO NOT RENAME CLS_CLASS_MAPPING or SEG_CLASS_MAPPING. β•‘
9
+ β•‘ You may add helper functions / imports as needed. β•‘
10
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
11
+
12
+ Tasks
13
+ -----
14
+ Task 3.1 β€” Multi-label image-level classification (5 classes).
15
+ Task 3.2 β€” Object detection + instance segmentation (5 classes).
16
+
17
+ You must implement ALL FOUR functions below.
18
+
19
+ Class Mappings
20
+ --------------
21
+ Fill in the two dictionaries below (CLS_CLASS_MAPPING, SEG_CLASS_MAPPING)
22
+ to map your model's output indices to the canonical category names.
23
+
24
+ The canonical 5 categories (from the DeepFashion2 subset) are:
25
+ short sleeve top, long sleeve top, trousers, shorts, skirt
26
+
27
+ Your indices can be in any order, but the category name strings
28
+ must match exactly (case-insensitive). Background class is optional
29
+ but recommended for detection/segmentation models β€” the evaluator
30
+ will automatically ignore it.
31
+
32
+ Important: Masks must be at the ORIGINAL image resolution.
33
+ If your model internally resizes images, resize the masks back
34
+ to the input image dimensions before returning them.
35
+
36
+ Model Weights
37
+ -------------
38
+ Place your trained weights inside model_files/ as:
39
+ model_files/cls.pt (or cls.pth) β€” classification model
40
+ model_files/seg.pt (or seg.pth) β€” detection + segmentation model
41
+
42
+ Evaluation Metrics
43
+ ------------------
44
+ Classification : Macro F1-score + Per-label macro accuracy
45
+ Detection : mAP @ [0.5 : 0.05 : 0.95]
46
+ Segmentation : Per-class mIoU (macro-averaged)
47
+ """
48
+
49
+ from __future__ import annotations
50
+
51
+ import json
52
+ from pathlib import Path
53
+ from typing import Any, Dict, List
54
+
55
+ import numpy as np
56
+ import torch
57
+ import torch.nn as nn
58
+ import torchvision.models as models
59
+ import torchvision.transforms as T
60
+ from PIL import Image
61
+ from ultralytics import YOLO
62
+ import cv2
63
+
64
+
65
+ # ═══════════════════════════════════════════════════════════════════
66
+ # CLASS MAPPINGS β€” FILL THESE IN
67
+ # ═══════════════════════════════════════════════════════════════════
68
+
69
+ # Classification: maps your model's output index β†’ canonical class name.
70
+ # Must have exactly 5 entries (one per clothing class, NO background).
71
+ # Example:
72
+ # CLS_CLASS_MAPPING = {
73
+ # 0: "short sleeve top",
74
+ # 1: "long sleeve top",
75
+ # 2: "trousers",
76
+ # 3: "shorts",
77
+ # 4: "skirt",
78
+ # }
79
+ CLS_CLASS_MAPPING: Dict[int, str] = {
80
+ 0: "short sleeve top",
81
+ 1: "long sleeve top",
82
+ 2: "shorts",
83
+ 3: "trousers",
84
+ 4: "skirt",
85
+ }
86
+
87
+ # Detection + Segmentation: maps your model's output index β†’ class name.
88
+ # Include background if your model outputs it (evaluator will ignore it).
89
+ # Example:
90
+ # SEG_CLASS_MAPPING = {
91
+ # 0: "background",
92
+ # 1: "short sleeve top",
93
+ # 2: "long sleeve top",
94
+ # 3: "trousers",
95
+ # 4: "shorts",
96
+ # 5: "skirt",
97
+ # }
98
+ SEG_CLASS_MAPPING: Dict[int, str] = {
99
+ 0: "short sleeve top",
100
+ 1: "long sleeve top",
101
+ 2: "shorts",
102
+ 3: "trousers",
103
+ 4: "skirt",
104
+ }
105
+
106
+
107
+ # ═══════════════════════════════════════════════════════════════════
108
+ # Helper utilities (you may modify or add more)
109
+ # ═══════════════════════════════════════════════════════════════════
110
+
111
+ def _find_weights(folder: Path, stem: str) -> Path:
112
+ """Return the first existing weights file matching stem.pt or stem.pth."""
113
+ for ext in (".pt", ".pth"):
114
+ candidate = folder / "model_files" / (stem + ext)
115
+ if candidate.exists():
116
+ return candidate
117
+ raise FileNotFoundError(
118
+ f"No weights file found for '{stem}' in {folder / 'model_files'}"
119
+ )
120
+
121
+
122
+ def _load_json(path: Path) -> Dict[str, Any]:
123
+ with open(path, "r", encoding="utf-8") as f:
124
+ return json.load(f)
125
+
126
+
127
+ # ═════════════════════════════════════��═════════════════════════════
128
+ # TASK 3.1 β€” CLASSIFICATION
129
+ # ═══════════════════════════════════════════════════════════════════
130
+
131
+ def load_classification_model(folder: str, device: str) -> Any:
132
+ """
133
+ Load your trained classification model.
134
+
135
+ Parameters
136
+ ----------
137
+ folder : str
138
+ Absolute path to your submission folder (the one containing
139
+ this predictor.py, model_files/, class_mapping_cls.json, etc.).
140
+ device : str
141
+ PyTorch device string, e.g. "cuda", "mps", or "cpu".
142
+
143
+ Returns
144
+ -------
145
+ model : Any
146
+ Whatever object your predict_classification function needs.
147
+ This is passed directly as the first argument to
148
+ predict_classification().
149
+
150
+ Notes
151
+ -----
152
+ - Load weights from <folder>/model_files/cls.pt (or .pth).
153
+ - Use CLS_CLASS_MAPPING defined above to map output indices.
154
+ - The returned object can be a dict, a nn.Module, or anything
155
+ your prediction function expects.
156
+ """
157
+ model_path = _find_weights(Path(folder), "cls")
158
+
159
+ # Initialize EfficientNet B0 model
160
+ model = models.efficientnet_b0(weights=None)
161
+ in_features = model.classifier[1].in_features
162
+ # We have 5 classes
163
+ model.classifier[1] = nn.Linear(in_features, 5)
164
+
165
+ # Load weights
166
+ state_dict = torch.load(model_path, map_location=device)
167
+ model.load_state_dict(state_dict)
168
+
169
+ model.to(device)
170
+ model.eval()
171
+
172
+ return model
173
+
174
+
175
+ def predict_classification(model: Any, images: List[Image.Image]) -> List[Dict]:
176
+ """
177
+ Run multi-label classification on a list of images.
178
+
179
+ Parameters
180
+ ----------
181
+ model : Any
182
+ The object returned by load_classification_model().
183
+ images : list of PIL.Image.Image
184
+ A list of RGB PIL images.
185
+
186
+ Returns
187
+ -------
188
+ results : list of dict
189
+ One dict per image, with the key "labels":
190
+
191
+ [
192
+ {"labels": [int, int, int, int, int]},
193
+ {"labels": [int, int, int, int, int]},
194
+ ...
195
+ ]
196
+
197
+ Each "labels" list has exactly 5 elements (one per class,
198
+ in the order defined by your CLS_CLASS_MAPPING dictionary).
199
+ Each element is 0 or 1.
200
+
201
+ Example
202
+ -------
203
+ >>> results = predict_classification(model, [img1, img2])
204
+ >>> results[0]
205
+ {"labels": [1, 0, 0, 1, 0]}
206
+ """
207
+ # Equivalent to the val_transform in albumentations used during training
208
+ transform = T.Compose([
209
+ T.Resize((256, 256)),
210
+ T.CenterCrop((224, 224)),
211
+ T.ToTensor(),
212
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
213
+ ])
214
+
215
+ device = next(model.parameters()).device
216
+ results = []
217
+
218
+ with torch.no_grad():
219
+ for img in images:
220
+ # Ensure image is in RGB
221
+ if img.mode != "RGB":
222
+ img = img.convert("RGB")
223
+
224
+ img_tensor = transform(img).unsqueeze(0).to(device)
225
+ out = model(img_tensor)
226
+
227
+ prob = torch.sigmoid(out).squeeze(0)
228
+ # Threshold matches your compute_f1 logic
229
+ pred = (prob > 0.4).int().tolist()
230
+
231
+ results.append({"labels": pred})
232
+
233
+ return results
234
+
235
+
236
+ # ═══════════════════════════════════════════════════════════════════
237
+ # TASK 3.2 β€” DETECTION + INSTANCE SEGMENTATION
238
+ # ═══════════════════════════════════════════════════════════════════
239
+
240
+ def load_detection_model(folder: str, device: str) -> Any:
241
+ """
242
+ Load your trained detection + segmentation model.
243
+
244
+ Parameters
245
+ ----------
246
+ folder : str
247
+ Absolute path to your submission folder.
248
+ device : str
249
+ PyTorch device string, e.g. "cuda", "mps", or "cpu".
250
+
251
+ Returns
252
+ -------
253
+ model : Any
254
+ Whatever object your predict_detection_segmentation function
255
+ needs. Passed directly as the first argument.
256
+
257
+ Notes
258
+ -----
259
+ - Load weights from <folder>/model_files/seg.pt (or .pth).
260
+ - Use SEG_CLASS_MAPPING defined above to map output indices.
261
+ """
262
+ model_path = _find_weights(Path(folder), "seg")
263
+ model = YOLO(model_path)
264
+ model.to(device)
265
+ return model
266
+
267
+
268
+ def predict_detection_segmentation(
269
+ model: Any,
270
+ images: List[Image.Image],
271
+ ) -> List[Dict]:
272
+ """
273
+ Run detection + instance segmentation on a list of images.
274
+
275
+ Parameters
276
+ ----------
277
+ model : Any
278
+ The object returned by load_detection_model().
279
+ images : list of PIL.Image.Image
280
+ A list of RGB PIL images.
281
+
282
+ Returns
283
+ -------
284
+ results : list of dict
285
+ One dict per image with keys "boxes", "scores", "labels", "masks":
286
+
287
+ [
288
+ {
289
+ "boxes": [[x1, y1, x2, y2], ...], # list of float coords
290
+ "scores": [float, ...], # confidence in [0, 1]
291
+ "labels": [int, ...], # class indices (see mapping)
292
+ "masks": [np.ndarray, ...] # binary masks, HΓ—W, uint8
293
+ },
294
+ ...
295
+ ]
296
+
297
+ Output contract
298
+ ---------------
299
+ - boxes / scores / labels / masks must all have the same length
300
+ (= number of detected instances in that image).
301
+ - Each box is [x1, y1, x2, y2] with x1 < x2, y1 < y2.
302
+ - Coordinates must be within image bounds (0 ≀ x ≀ width, 0 ≀ y ≀ height).
303
+ - Each score is a float in [0, 1].
304
+ - Each label is an int index matching your SEG_CLASS_MAPPING.
305
+ - Each mask is a 2-D numpy array of shape (image_height, image_width)
306
+ with dtype uint8, containing only 0 and 1.
307
+ - If no objects are detected, return empty lists for all keys.
308
+
309
+ Example
310
+ -------
311
+ >>> results = predict_detection_segmentation(model, [img])
312
+ >>> results[0]["boxes"]
313
+ [[100.0, 40.0, 300.0, 420.0], [50.0, 200.0, 250.0, 600.0]]
314
+ >>> results[0]["masks"][0].shape
315
+ (height, width)
316
+ """
317
+ results = []
318
+
319
+ for img in images:
320
+ if img.mode != "RGB":
321
+ img = img.convert("RGB")
322
+
323
+ w, h = img.size
324
+
325
+ # YOLO prediction on PIL image directly
326
+ # We use retina_masks=True for higher resolution masks and correct sizing
327
+ preds = model.predict(source=img, imgsz=640, conf=0.25, verbose=False, retina_masks=True)
328
+ pred = preds[0]
329
+
330
+ boxes = []
331
+ scores = []
332
+ labels = []
333
+ masks_list = []
334
+
335
+ if pred.boxes is not None and len(pred.boxes) > 0:
336
+ boxes = pred.boxes.xyxy.cpu().numpy().tolist()
337
+ scores = pred.boxes.conf.cpu().numpy().tolist()
338
+ labels = pred.boxes.cls.cpu().numpy().astype(int).tolist()
339
+
340
+ if pred.masks is not None and len(pred.masks) > 0:
341
+ masks_data = pred.masks.data.cpu().numpy() # Extract masks (N, H, W)
342
+
343
+ for m in masks_data:
344
+ # Explicitly ensure it matches original image shape (h, w)
345
+ if m.shape != (h, w):
346
+ m = cv2.resize(m, (w, h), interpolation=cv2.INTER_NEAREST)
347
+
348
+ # Convert to strict binary uint8 (0 or 1)
349
+ m_binary = (m > 0.5).astype(np.uint8)
350
+ masks_list.append(m_binary)
351
+
352
+ # Fallback if masks were missing but boxes were detected
353
+ if len(masks_list) != len(boxes):
354
+ masks_list = [np.zeros((h, w), dtype=np.uint8) for _ in boxes]
355
+
356
+ results.append({
357
+ "boxes": boxes,
358
+ "scores": scores,
359
+ "labels": labels,
360
+ "masks": masks_list
361
+ })
362
+
363
+ return results
validator_local.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ validator_local.py β€” Robust student self-check script.
3
+
4
+ Place this inside your VRMP1_<roll_number>/ folder and run:
5
+
6
+ python validator_local.py
7
+
8
+ This validates with 100% coverage:
9
+ βœ“ All required files and weights exist
10
+ βœ“ predictor.py imports without errors
11
+ βœ“ CLS_CLASS_MAPPING and SEG_CLASS_MAPPING are correctly filled
12
+ βœ“ All 4 functions are implemented (not NotImplementedError)
13
+ βœ“ Models load successfully
14
+ βœ“ Classification output format is correct on a REAL image
15
+ βœ“ Detection + segmentation output format is correct on a REAL image
16
+ βœ“ Mask dimensions match the original image
17
+ βœ“ All value ranges and types are correct
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import importlib.util
23
+ import json
24
+ import sys
25
+ import traceback
26
+ from pathlib import Path
27
+ from typing import Any, Dict, List
28
+
29
+ import numpy as np
30
+ from PIL import Image, ImageDraw
31
+ from sklearn.metrics import f1_score
32
+
33
+ # DeepFashion2 category_id β†’ name (dataset constant)
34
+ DEEPFASHION_CATID_TO_NAME: Dict[int, str] = {
35
+ 1: "short sleeve top",
36
+ 2: "long sleeve top",
37
+ 3: "short sleeve outwear",
38
+ 4: "long sleeve outwear",
39
+ 5: "vest",
40
+ 6: "sling",
41
+ 7: "shorts",
42
+ 8: "trousers",
43
+ 9: "skirt",
44
+ 10: "short sleeve dress",
45
+ 11: "long sleeve dress",
46
+ 12: "vest dress",
47
+ 13: "sling dress",
48
+ }
49
+
50
+ # Populated at runtime from the student's CLS_CLASS_MAPPING
51
+ CANONICAL_CLASSES: set = set()
52
+ CANONICAL_CLASSES_LIST: list = []
53
+ NUM_CLASSES: int = 0
54
+ CANONICAL_NAME_TO_IDX: Dict[str, int] = {}
55
+ CATEGORY_ID_TO_CANONICAL: Dict[int, int] = {}
56
+
57
+ # ─── Counters ─────────────────────────────────────────────────────
58
+ _pass_count = 0
59
+ _fail_count = 0
60
+ _warn_count = 0
61
+
62
+
63
+ def _pass(msg: str):
64
+ global _pass_count
65
+ _pass_count += 1
66
+ print(f" [PASS] {msg}")
67
+
68
+
69
+ def _fail(msg: str):
70
+ global _fail_count
71
+ _fail_count += 1
72
+ print(f" [FAIL] {msg}")
73
+
74
+
75
+ def _warn(msg: str):
76
+ global _warn_count
77
+ _warn_count += 1
78
+ print(f" [WARN] {msg}")
79
+
80
+
81
+ def _check(condition: bool, pass_msg: str, fail_msg: str) -> bool:
82
+ if condition:
83
+ _pass(pass_msg)
84
+ return True
85
+ else:
86
+ _fail(fail_msg)
87
+ return False
88
+
89
+
90
+ # ─── Locate the real test image ──────────────────────────────────
91
+
92
+ def _find_test_image(folder: Path) -> Path | None:
93
+ """Walk up from the student folder to find hidden_dataset/images/000001.jpg."""
94
+ search = folder.parent # workspace root (one level up from VRMP1_*)
95
+ candidate = search / "hidden_dataset" / "images" / "000001.jpg"
96
+ if candidate.exists():
97
+ return candidate
98
+ # Try any image in hidden_dataset
99
+ hd = search / "hidden_dataset" / "images"
100
+ if hd.is_dir():
101
+ imgs = sorted(hd.glob("*.jpg"))
102
+ if imgs:
103
+ return imgs[0]
104
+ return None
105
+
106
+
107
+ def _find_test_annotation(img_path: Path) -> Path | None:
108
+ """Find the annotation JSON matching the test image."""
109
+ annos_dir = img_path.parent.parent / "annos"
110
+ anno_path = annos_dir / (img_path.stem + ".json")
111
+ return anno_path if anno_path.exists() else None
112
+
113
+
114
+ # ─── GT loading & metric helpers ─────────────────────────────────
115
+
116
+ def load_annotation(anno_path: Path) -> List[Dict[str, Any]]:
117
+ """Parse annotation JSON β†’ list of GT items (only canonical classes)."""
118
+ with open(anno_path, "r", encoding="utf-8") as f:
119
+ data = json.load(f)
120
+ items = []
121
+ for val in data.values():
122
+ if not isinstance(val, dict) or "bounding_box" not in val:
123
+ continue
124
+ cat_id = val["category_id"]
125
+ if cat_id not in CATEGORY_ID_TO_CANONICAL:
126
+ continue
127
+ items.append({
128
+ "box": val["bounding_box"],
129
+ "segmentation": val["segmentation"],
130
+ "category_id": cat_id,
131
+ "category_name": val.get("category_name", ""),
132
+ "canonical_idx": CATEGORY_ID_TO_CANONICAL[cat_id],
133
+ })
134
+ return items
135
+
136
+
137
+ def rasterize_polygons(segmentation: list, width: int, height: int) -> np.ndarray:
138
+ """Render polygon coordinate lists into a binary (H, W) mask."""
139
+ canvas = Image.new("L", (width, height), 0)
140
+ draw = ImageDraw.Draw(canvas)
141
+ for poly in segmentation:
142
+ coords = [(poly[i], poly[i + 1]) for i in range(0, len(poly) - 1, 2)]
143
+ if len(coords) >= 3:
144
+ draw.polygon(coords, fill=1)
145
+ return np.array(canvas, dtype=np.uint8)
146
+
147
+
148
+ def build_remap(student_mapping: dict) -> Dict[int, int]:
149
+ """Map student class index β†’ canonical class index by name matching."""
150
+ remap: Dict[int, int] = {}
151
+ for s_idx, s_name in student_mapping.items():
152
+ name = str(s_name).strip().lower()
153
+ if name in CANONICAL_NAME_TO_IDX:
154
+ remap[int(s_idx)] = CANONICAL_NAME_TO_IDX[name]
155
+ return remap
156
+
157
+
158
+ # ─── Class mapping validation ────────────────────────────────────
159
+
160
+ def validate_class_mapping(mapping, label: str, allow_background: bool) -> bool:
161
+ if not _check(isinstance(mapping, dict),
162
+ f"{label} is a dict.",
163
+ f"{label} must be a dict, got {type(mapping).__name__}."):
164
+ return False
165
+
166
+ if not _check(len(mapping) > 0,
167
+ f"{label} is non-empty ({len(mapping)} entries).",
168
+ f"{label} is empty β€” you must fill in your class mapping!"):
169
+ return False
170
+
171
+ # Check keys are ints
172
+ all_int_keys = all(isinstance(k, int) for k in mapping.keys())
173
+ _check(all_int_keys,
174
+ f"{label} keys are all integers.",
175
+ f"{label} keys must be integers. Got: {[type(k).__name__ for k in mapping.keys()]}")
176
+
177
+ # Check values are strings
178
+ all_str_vals = all(isinstance(v, str) for v in mapping.values())
179
+ _check(all_str_vals,
180
+ f"{label} values are all strings.",
181
+ f"{label} values must be strings.")
182
+
183
+ # Check canonical class coverage
184
+ clothing_names = set()
185
+ for k, v in mapping.items():
186
+ name = str(v).strip().lower()
187
+ if name == "background":
188
+ if not allow_background:
189
+ _warn(f"{label}: index {k} is 'background' β€” not expected in CLS_CLASS_MAPPING.")
190
+ continue
191
+ clothing_names.add(name)
192
+
193
+ missing = CANONICAL_CLASSES - clothing_names
194
+ extra = clothing_names - CANONICAL_CLASSES
195
+ if extra:
196
+ _warn(f"{label}: unrecognized classes (will be ignored by evaluator): {extra}")
197
+
198
+ if not _check(len(missing) == 0,
199
+ f"{label} covers all 5 canonical classes.",
200
+ f"{label} missing canonical classes: {missing}"):
201
+ return False
202
+
203
+ if not allow_background:
204
+ expected = 5
205
+ _check(len(mapping) == expected,
206
+ f"{label} has exactly {expected} entries (no background).",
207
+ f"{label} should have {expected} entries for classification, got {len(mapping)}.")
208
+ return True
209
+
210
+
211
+ # ─── Classification output validation ────────────────────────────
212
+
213
+ def validate_cls_output(outputs: list, num_images: int, num_classes: int) -> bool:
214
+ ok = True
215
+ if not _check(isinstance(outputs, list),
216
+ "Classification returns a list.",
217
+ f"Classification must return a list, got {type(outputs).__name__}."):
218
+ return False
219
+
220
+ if not _check(len(outputs) == num_images,
221
+ f"Classification returned {num_images} result(s) for {num_images} image(s).",
222
+ f"Expected {num_images} results, got {len(outputs)}."):
223
+ return False
224
+
225
+ for idx, out in enumerate(outputs):
226
+ prefix = f"cls_output[{idx}]"
227
+ if not _check(isinstance(out, dict),
228
+ f"{prefix} is a dict.",
229
+ f"{prefix} must be a dict, got {type(out).__name__}."):
230
+ ok = False
231
+ continue
232
+
233
+ if not _check("labels" in out,
234
+ f"{prefix} has 'labels' key.",
235
+ f"{prefix} missing 'labels' key. Keys found: {list(out.keys())}"):
236
+ ok = False
237
+ continue
238
+
239
+ labels = out["labels"]
240
+ if not _check(isinstance(labels, list),
241
+ f"{prefix}['labels'] is a list.",
242
+ f"{prefix}['labels'] must be a list, got {type(labels).__name__}."):
243
+ ok = False
244
+ continue
245
+
246
+ if not _check(len(labels) == num_classes,
247
+ f"{prefix}['labels'] has length {num_classes}.",
248
+ f"{prefix}['labels'] must have length {num_classes}, got {len(labels)}."):
249
+ ok = False
250
+ continue
251
+
252
+ all_valid = True
253
+ for i, l in enumerate(labels):
254
+ if not isinstance(l, int):
255
+ _fail(f"{prefix}['labels'][{i}] must be int, got {type(l).__name__}.")
256
+ ok = False
257
+ all_valid = False
258
+ elif l not in (0, 1):
259
+ _fail(f"{prefix}['labels'][{i}] must be 0 or 1, got {l}.")
260
+ ok = False
261
+ all_valid = False
262
+
263
+ if all_valid:
264
+ _pass(f"{prefix}: all label values are valid (binary 0/1). Output: {labels}")
265
+ return ok
266
+
267
+
268
+ # ─── Detection output validation ────────────────────────────────
269
+
270
+ def validate_det_output(outputs: list, num_images: int, img_sizes: list, max_label: int) -> bool:
271
+ ok = True
272
+ if not _check(isinstance(outputs, list),
273
+ "Detection returns a list.",
274
+ f"Detection must return a list, got {type(outputs).__name__}."):
275
+ return False
276
+
277
+ if not _check(len(outputs) == num_images,
278
+ f"Detection returned {num_images} result(s) for {num_images} image(s).",
279
+ f"Expected {num_images} results, got {len(outputs)}."):
280
+ return False
281
+
282
+ for idx, out in enumerate(outputs):
283
+ w, h = img_sizes[idx]
284
+ prefix = f"det_output[{idx}]"
285
+
286
+ if not _check(isinstance(out, dict),
287
+ f"{prefix} is a dict.",
288
+ f"{prefix} must be a dict."):
289
+ ok = False
290
+ continue
291
+
292
+ required_keys = {"boxes", "scores", "labels", "masks"}
293
+ present_keys = set(out.keys())
294
+ missing_keys = required_keys - present_keys
295
+ if not _check(len(missing_keys) == 0,
296
+ f"{prefix} has all required keys (boxes, scores, labels, masks).",
297
+ f"{prefix} missing keys: {missing_keys}"):
298
+ ok = False
299
+ continue
300
+
301
+ n = len(out["boxes"])
302
+ lengths_ok = (len(out["scores"]) == n and len(out["labels"]) == n
303
+ and len(out["masks"]) == n)
304
+ if not _check(lengths_ok,
305
+ f"{prefix}: all arrays have same length ({n} detections).",
306
+ f"{prefix}: length mismatch β€” boxes={n}, scores={len(out['scores'])}, "
307
+ f"labels={len(out['labels'])}, masks={len(out['masks'])}."):
308
+ ok = False
309
+ continue
310
+
311
+ if n == 0:
312
+ _warn(f"{prefix}: zero detections β€” model may be undertrained or image has no objects.")
313
+ continue
314
+
315
+ # Boxes
316
+ boxes_valid = True
317
+ for i, box in enumerate(out["boxes"]):
318
+ if not (isinstance(box, (list, tuple)) and len(box) == 4):
319
+ _fail(f"{prefix}/boxes[{i}] must be [x1,y1,x2,y2].")
320
+ ok = False
321
+ boxes_valid = False
322
+ else:
323
+ x1, y1, x2, y2 = [float(c) for c in box]
324
+ if not (x1 < x2 and y1 < y2):
325
+ _fail(f"{prefix}/boxes[{i}]: need x1<x2 and y1<y2, got [{x1:.1f},{y1:.1f},{x2:.1f},{y2:.1f}].")
326
+ ok = False
327
+ boxes_valid = False
328
+ if boxes_valid:
329
+ _pass(f"{prefix}: all {n} boxes have valid [x1,y1,x2,y2] format.")
330
+
331
+ # Scores
332
+ scores_valid = True
333
+ for i, s in enumerate(out["scores"]):
334
+ if not isinstance(s, (int, float)):
335
+ _fail(f"{prefix}/scores[{i}] must be numeric, got {type(s).__name__}.")
336
+ ok = False
337
+ scores_valid = False
338
+ elif not (0.0 <= float(s) <= 1.0):
339
+ _fail(f"{prefix}/scores[{i}] must be in [0,1], got {s}.")
340
+ ok = False
341
+ scores_valid = False
342
+ if scores_valid:
343
+ _pass(f"{prefix}: all {n} scores in [0, 1].")
344
+
345
+ # Labels
346
+ labels_valid = True
347
+ for i, l in enumerate(out["labels"]):
348
+ if not isinstance(l, int):
349
+ _fail(f"{prefix}/labels[{i}] must be int, got {type(l).__name__}.")
350
+ ok = False
351
+ labels_valid = False
352
+ elif not (0 <= l <= max_label):
353
+ _fail(f"{prefix}/labels[{i}] must be in [0, {max_label}], got {l}.")
354
+ ok = False
355
+ labels_valid = False
356
+ if labels_valid:
357
+ _pass(f"{prefix}: all {n} labels are valid integers in [0, {max_label}].")
358
+
359
+ # Masks
360
+ masks_valid = True
361
+ for i, mask in enumerate(out["masks"]):
362
+ arr = np.asarray(mask)
363
+ if arr.ndim != 2:
364
+ _fail(f"{prefix}/masks[{i}] must be 2D, got {arr.ndim}D shape={arr.shape}.")
365
+ ok = False
366
+ masks_valid = False
367
+ continue
368
+ if arr.shape != (h, w):
369
+ _fail(f"{prefix}/masks[{i}] shape {arr.shape} != image size ({h}, {w}). "
370
+ "You must resize masks back to the original image resolution!")
371
+ ok = False
372
+ masks_valid = False
373
+ uniq = set(np.unique(arr).tolist())
374
+ if not uniq.issubset({0, 1}):
375
+ _fail(f"{prefix}/masks[{i}] must be binary (0/1), got values {uniq}.")
376
+ ok = False
377
+ masks_valid = False
378
+ if masks_valid and n > 0:
379
+ _pass(f"{prefix}: all {n} masks are binary and match image size ({h}x{w}).")
380
+
381
+ return ok
382
+
383
+
384
+ # ═══════════════════════════════════════════════════════════════════
385
+ # Main
386
+ # ═══════════════════════════════════════════════════════════════════
387
+
388
+ def main():
389
+ folder = Path(__file__).resolve().parent
390
+ print("=" * 60)
391
+ print(f" VALIDATOR β€” {folder.name}")
392
+ print("=" * 60)
393
+
394
+ # ─── 1. Required files ────────────────────────────────────────
395
+ print("\n[1/6] Checking required files ...")
396
+ abort = False
397
+ if not _check((folder / "predictor.py").exists(),
398
+ "predictor.py found.",
399
+ "predictor.py NOT found!"):
400
+ abort = True
401
+
402
+ has_cls_weights = (folder / "model_files" / "cls.pt").exists() or \
403
+ (folder / "model_files" / "cls.pth").exists()
404
+ has_seg_weights = (folder / "model_files" / "seg.pt").exists() or \
405
+ (folder / "model_files" / "seg.pth").exists()
406
+
407
+ if has_cls_weights:
408
+ _pass("model_files/cls.pt(h) found.")
409
+ else:
410
+ _warn("model_files/cls.pt(h) not found β€” OK if classification reuses the seg model.")
411
+
412
+ if not _check(has_seg_weights,
413
+ "model_files/seg.pt(h) found.",
414
+ "model_files/seg.pt(h) NOT found!"):
415
+ abort = True
416
+
417
+ if abort:
418
+ print("\n[ABORT] Fix missing files before continuing.")
419
+ sys.exit(1)
420
+
421
+ # ─── 2. Find test image + annotation ──────────────────────────
422
+ print("\n[2/6] Locating test image ...")
423
+ test_img_path = _find_test_image(folder)
424
+ if test_img_path is None:
425
+ _fail("Cannot find hidden_dataset/images/000001.jpg β€” "
426
+ "make sure hidden_dataset/ is in the parent directory.")
427
+ sys.exit(1)
428
+ else:
429
+ test_img = Image.open(test_img_path).convert("RGB")
430
+ img_w, img_h = test_img.size
431
+ _pass(f"Using real test image: {test_img_path.name} ({img_w}x{img_h})")
432
+
433
+ anno_path = _find_test_annotation(test_img_path)
434
+
435
+ # ─── 3. Import predictor ─────────────────────────────────────
436
+ print("\n[3/6] Importing predictor.py ...")
437
+ try:
438
+ spec = importlib.util.spec_from_file_location("predictor", folder / "predictor.py")
439
+ predictor = importlib.util.module_from_spec(spec)
440
+ spec.loader.exec_module(predictor)
441
+ _pass("predictor.py imported successfully.")
442
+ except Exception as e:
443
+ _fail(f"predictor.py import error: {e}")
444
+ traceback.print_exc()
445
+ sys.exit(1)
446
+
447
+ # ─── 4. Validate class mappings + function existence ─────────
448
+ print("\n[4/6] Validating class mappings and function signatures ...")
449
+
450
+ has_cls_map = hasattr(predictor, "CLS_CLASS_MAPPING")
451
+ has_seg_map = hasattr(predictor, "SEG_CLASS_MAPPING")
452
+
453
+ if not _check(has_cls_map,
454
+ "CLS_CLASS_MAPPING attribute exists.",
455
+ "CLS_CLASS_MAPPING not found in predictor.py!"):
456
+ sys.exit(1)
457
+ if not _check(has_seg_map,
458
+ "SEG_CLASS_MAPPING attribute exists.",
459
+ "SEG_CLASS_MAPPING not found in predictor.py!"):
460
+ sys.exit(1)
461
+
462
+ # Build canonical class structures from CLS_CLASS_MAPPING
463
+ global CANONICAL_CLASSES, CANONICAL_CLASSES_LIST, NUM_CLASSES
464
+ global CANONICAL_NAME_TO_IDX, CATEGORY_ID_TO_CANONICAL
465
+
466
+ cls_names = []
467
+ for idx in sorted(predictor.CLS_CLASS_MAPPING.keys()):
468
+ name = str(predictor.CLS_CLASS_MAPPING[idx]).strip().lower()
469
+ if name != "background":
470
+ cls_names.append(name)
471
+ CANONICAL_CLASSES_LIST = cls_names
472
+ CANONICAL_CLASSES = set(cls_names)
473
+ NUM_CLASSES = len(cls_names)
474
+ CANONICAL_NAME_TO_IDX = {name: i for i, name in enumerate(cls_names)}
475
+
476
+ CATEGORY_ID_TO_CANONICAL = {}
477
+ for cat_id, cat_name in DEEPFASHION_CATID_TO_NAME.items():
478
+ if cat_name in CANONICAL_NAME_TO_IDX:
479
+ CATEGORY_ID_TO_CANONICAL[cat_id] = CANONICAL_NAME_TO_IDX[cat_name]
480
+
481
+ _pass(f"Derived {NUM_CLASSES} canonical classes from CLS_CLASS_MAPPING: {cls_names}")
482
+
483
+ validate_class_mapping(predictor.CLS_CLASS_MAPPING, "CLS_CLASS_MAPPING", allow_background=False)
484
+ validate_class_mapping(predictor.SEG_CLASS_MAPPING, "SEG_CLASS_MAPPING", allow_background=True)
485
+
486
+ # Load GT annotation now that canonical mapping is ready
487
+ gt_items: List[Dict[str, Any]] = []
488
+ if anno_path is not None:
489
+ gt_items = load_annotation(anno_path)
490
+ _pass(f"Loaded GT annotation: {anno_path.name} ({len(gt_items)} objects)")
491
+ else:
492
+ _warn("No annotation found β€” metrics (F1, mIoU) will be skipped.")
493
+
494
+ max_label = max(int(k) for k in predictor.SEG_CLASS_MAPPING.keys()) if predictor.SEG_CLASS_MAPPING else 5
495
+ num_cls_classes = len(predictor.CLS_CLASS_MAPPING)
496
+
497
+ # Check all 4 required functions exist and are callable
498
+ required_fns = [
499
+ "load_classification_model",
500
+ "predict_classification",
501
+ "load_detection_model",
502
+ "predict_detection_segmentation",
503
+ ]
504
+ for fn_name in required_fns:
505
+ if not _check(hasattr(predictor, fn_name) and callable(getattr(predictor, fn_name)),
506
+ f"{fn_name}() exists and is callable.",
507
+ f"{fn_name}() NOT found or not callable!"):
508
+ sys.exit(1)
509
+
510
+ # ─── 5. Test classification pipeline ──────────���──────────────
511
+ print(f"\n[5/6] Testing classification on real image ({img_w}x{img_h}) ...")
512
+ device = "cpu"
513
+
514
+ # 5a. load_classification_model β€” must NOT raise NotImplementedError
515
+ cls_model = None
516
+ try:
517
+ cls_model = predictor.load_classification_model(str(folder), device)
518
+ _pass("load_classification_model() returned successfully.")
519
+ except NotImplementedError:
520
+ _fail("load_classification_model() raises NotImplementedError β€” "
521
+ "you MUST implement this function!")
522
+ except Exception as e:
523
+ _fail(f"load_classification_model() raised: {e}")
524
+ traceback.print_exc()
525
+
526
+ # 5b. predict_classification β€” must NOT raise NotImplementedError
527
+ cls_out = None
528
+ if cls_model is not None:
529
+ try:
530
+ cls_out = predictor.predict_classification(cls_model, [test_img])
531
+ _pass("predict_classification() returned successfully.")
532
+ validate_cls_output(cls_out, num_images=1, num_classes=num_cls_classes)
533
+ except NotImplementedError:
534
+ _fail("predict_classification() raises NotImplementedError β€” "
535
+ "you MUST implement this function!")
536
+ cls_out = None
537
+ except Exception as e:
538
+ _fail(f"predict_classification() raised: {e}")
539
+ traceback.print_exc()
540
+ cls_out = None
541
+
542
+ # 5c. Compute macro F1 if GT annotation is available
543
+ if cls_out is not None and anno_path is not None:
544
+ try:
545
+ remap_cls = build_remap(predictor.CLS_CLASS_MAPPING)
546
+ gt_vec = np.zeros(NUM_CLASSES, dtype=np.int32)
547
+ for item in gt_items:
548
+ gt_vec[item["canonical_idx"]] = 1
549
+ pred_vec = np.zeros(NUM_CLASSES, dtype=np.int32)
550
+ student_labels = cls_out[0]["labels"]
551
+ for s_idx, val in enumerate(student_labels):
552
+ canonical = remap_cls.get(s_idx)
553
+ if canonical is not None:
554
+ pred_vec[canonical] = val
555
+ macro_f1 = float(f1_score(
556
+ gt_vec.reshape(1, -1), pred_vec.reshape(1, -1),
557
+ average="macro", zero_division=0.0,
558
+ ))
559
+ print(f"\n ** Classification Macro F1: {macro_f1:.4f} **")
560
+ except Exception as e:
561
+ _warn(f"Could not compute macro F1: {e}")
562
+
563
+ # ─── 6. Test detection + segmentation pipeline ───────────────
564
+ print(f"\n[6/6] Testing detection + segmentation on real image ({img_w}x{img_h}) ...")
565
+
566
+ # 6a. load_detection_model β€” must NOT raise NotImplementedError
567
+ det_model = None
568
+ try:
569
+ det_model = predictor.load_detection_model(str(folder), device)
570
+ _pass("load_detection_model() returned successfully.")
571
+ except NotImplementedError:
572
+ _fail("load_detection_model() raises NotImplementedError β€” "
573
+ "you MUST implement this function!")
574
+ except Exception as e:
575
+ _fail(f"load_detection_model() raised: {e}")
576
+ traceback.print_exc()
577
+
578
+ # 6b. predict_detection_segmentation β€” must NOT raise NotImplementedError
579
+ det_out = None
580
+ if det_model is not None:
581
+ try:
582
+ det_out = predictor.predict_detection_segmentation(det_model, [test_img])
583
+ _pass("predict_detection_segmentation() returned successfully.")
584
+ validate_det_output(
585
+ det_out,
586
+ num_images=1,
587
+ img_sizes=[(img_w, img_h)],
588
+ max_label=max_label,
589
+ )
590
+ except NotImplementedError:
591
+ _fail("predict_detection_segmentation() raises NotImplementedError β€” "
592
+ "you MUST implement this function!")
593
+ det_out = None
594
+ except Exception as e:
595
+ _fail(f"predict_detection_segmentation() raised: {e}")
596
+ traceback.print_exc()
597
+ det_out = None
598
+
599
+ # 6c. Compute mIoU if GT annotation is available
600
+ if det_out is not None and anno_path is not None and len(det_out) > 0:
601
+ try:
602
+ remap_seg = build_remap(predictor.SEG_CLASS_MAPPING)
603
+ pred = det_out[0]
604
+ IGNORE_LABEL = 255
605
+
606
+ # Build predicted semantic map (highest-confidence per pixel)
607
+ pred_sem = np.full((img_h, img_w), IGNORE_LABEL, dtype=np.uint8)
608
+ pred_conf = np.full((img_h, img_w), -1.0, dtype=np.float32)
609
+ for mask, score, label in zip(
610
+ pred["masks"], pred["scores"], pred["labels"]
611
+ ):
612
+ canonical = remap_seg.get(label)
613
+ if canonical is None:
614
+ continue
615
+ binary = np.asarray(mask, dtype=np.uint8)
616
+ if binary.shape != (img_h, img_w):
617
+ mask_pil = Image.fromarray(binary * 255)
618
+ mask_pil = mask_pil.resize((img_w, img_h), Image.NEAREST)
619
+ binary = (np.array(mask_pil) > 127).astype(np.uint8)
620
+ higher = (binary == 1) & (score > pred_conf)
621
+ pred_sem[higher] = canonical
622
+ pred_conf[higher] = score
623
+
624
+ # Build GT semantic map from polygon annotations
625
+ gt_sem = np.full((img_h, img_w), IGNORE_LABEL, dtype=np.uint8)
626
+ for item in gt_items:
627
+ gt_mask = rasterize_polygons(item["segmentation"], img_w, img_h)
628
+ gt_sem[gt_mask == 1] = item["canonical_idx"]
629
+
630
+ # Per-class IoU
631
+ intersection = np.zeros(NUM_CLASSES, dtype=np.float64)
632
+ union = np.zeros(NUM_CLASSES, dtype=np.float64)
633
+ for c in range(NUM_CLASSES):
634
+ pred_c = (pred_sem == c)
635
+ gt_c = (gt_sem == c)
636
+ intersection[c] = np.logical_and(pred_c, gt_c).sum()
637
+ union[c] = np.logical_or(pred_c, gt_c).sum()
638
+
639
+ per_class_iou = []
640
+ for c in range(NUM_CLASSES):
641
+ if union[c] > 0:
642
+ per_class_iou.append(float(intersection[c] / union[c]))
643
+ else:
644
+ per_class_iou.append(float("nan"))
645
+
646
+ valid_ious = [v for v in per_class_iou if not np.isnan(v)]
647
+ miou = float(np.mean(valid_ious)) if valid_ious else 0.0
648
+
649
+ print(f"\n ** Segmentation mIoU: {miou:.4f} **")
650
+ for c in range(NUM_CLASSES):
651
+ iou_str = f"{per_class_iou[c]:.4f}" if not np.isnan(per_class_iou[c]) else "N/A"
652
+ print(f" {CANONICAL_CLASSES_LIST[c]:20s}: {iou_str}")
653
+ except Exception as e:
654
+ _warn(f"Could not compute mIoU: {e}")
655
+
656
+ # ─── Summary ─────────────────────────────────────────────────
657
+ print("\n" + "=" * 60)
658
+ print(f" RESULTS: {_pass_count} passed, {_fail_count} failed, "
659
+ f"{_warn_count} warnings")
660
+ print("=" * 60)
661
+ if _fail_count > 0:
662
+ print("\n VALIDATION FAILED β€” fix the [FAIL] items above before submitting.\n")
663
+ sys.exit(1)
664
+ elif _warn_count > 0:
665
+ print("\n VALIDATION PASSED WITH WARNINGS β€” review [WARN] items above.\n")
666
+ else:
667
+ print("\n ALL CHECKS PASSED β€” your submission looks good!\n")
668
+
669
+
670
+ if __name__ == "__main__":
671
+ main()