zzsyppt commited on
Commit
d314be2
·
verified ·
1 Parent(s): 4617bad

Upload Adacrop MobileNetV3 distilled version

Browse files
Files changed (4) hide show
  1. common.py +493 -0
  2. student_best.pth +3 -0
  3. student_last.pth +3 -0
  4. train_mobilenet_distill.py +532 -0
common.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import pathlib
4
+ import random
5
+ from pathlib import Path
6
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+ from PIL import Image
13
+ from torch.utils.data import Dataset
14
+ from torchvision import models
15
+
16
+
17
+ ACTIONS = ["left", "right", "up", "down", "zoom_in", "zoom_out", "stop"]
18
+
19
+
20
+ def find_adacrop_root() -> Path:
21
+ return Path(__file__).resolve().parents[1]
22
+
23
+
24
+ def _strip_adacrop_prefix(path_text: str) -> str:
25
+ path_text = path_text.replace("\\", "/")
26
+ if path_text.startswith("./"):
27
+ path_text = path_text[2:]
28
+ if path_text.startswith("Adacrop/"):
29
+ path_text = path_text[len("Adacrop/") :]
30
+ return path_text
31
+
32
+
33
+ def resolve_image_path(raw_path: str, adacrop_root: Path, source_file: Optional[Path] = None) -> Path:
34
+ """Resolve mixed project paths, including JSONL paths like ./outpainted/a.png."""
35
+ raw = str(raw_path).replace("\\", "/")
36
+ candidates: List[Path] = []
37
+
38
+ p = Path(raw)
39
+ if p.is_absolute():
40
+ candidates.append(p)
41
+
42
+ if source_file is not None:
43
+ candidates.append(source_file.parent / raw)
44
+ if raw.startswith("./"):
45
+ candidates.append(source_file.parent / raw[2:])
46
+
47
+ stripped = _strip_adacrop_prefix(raw)
48
+ candidates.append(adacrop_root / stripped)
49
+ candidates.append(adacrop_root.parent / raw)
50
+
51
+ # Old merged JSONs may contain Adacrop/data/outpainted/foo.png, while this
52
+ # workspace stores those files under data/outpainted_dataset/outpainted.
53
+ if stripped.startswith("data/outpainted/"):
54
+ suffix = stripped[len("data/outpainted/") :]
55
+ candidates.append(adacrop_root / "data" / "outpainted_dataset" / "outpainted" / suffix)
56
+
57
+ # The outpainted JSONL stores paths as ./outpainted/foo.png relative to the
58
+ # JSONL file: data/outpainted_dataset/training_pairs.jsonl.
59
+ if stripped.startswith("outpainted/"):
60
+ candidates.append(adacrop_root / "data" / "outpainted_dataset" / stripped)
61
+
62
+ for cand in candidates:
63
+ if cand.exists():
64
+ return cand.resolve()
65
+ return candidates[0].resolve()
66
+
67
+
68
+ def normalize_boxes(value) -> List[List[float]]:
69
+ if value is None:
70
+ return []
71
+ if isinstance(value, dict):
72
+ if all(k in value for k in ("x1", "y1", "x2", "y2")):
73
+ return [[float(value["x1"]), float(value["y1"]), float(value["x2"]), float(value["y2"])]]
74
+ if all(k in value for k in ("x", "y", "w", "h")):
75
+ x, y, w, h = float(value["x"]), float(value["y"]), float(value["w"]), float(value["h"])
76
+ return [[x, y, x + w, y + h]]
77
+ return []
78
+ if isinstance(value, (list, tuple)):
79
+ if len(value) == 4 and all(isinstance(v, (int, float)) for v in value):
80
+ return [[float(v) for v in value]]
81
+ boxes: List[List[float]] = []
82
+ for item in value:
83
+ boxes.extend(normalize_boxes(item))
84
+ return boxes
85
+ return []
86
+
87
+
88
+ def canonical_box_xyxy(box: Sequence[float], width: int, height: int, img_path: Optional[str] = None) -> List[float]:
89
+ """Return a pixel-space [x1,y1,x2,y2] box.
90
+
91
+ The outpainted JSONL is xyxy, while the CUHK split files in this workspace
92
+ use yxyx-like coordinates. Use the image path when it is unambiguous, then
93
+ fall back to bounds checks.
94
+ """
95
+ a, b, c, d = [float(v) for v in box]
96
+ path_text = (img_path or "").replace("\\", "/").lower()
97
+
98
+ if "cuhk_images" in path_text:
99
+ x1, y1, x2, y2 = b, a, d, c
100
+ elif "outpainted" in path_text or "gaic_dataset" in path_text:
101
+ x1, y1, x2, y2 = a, b, c, d
102
+ else:
103
+ xyxy_valid = 0 <= a < c <= width and 0 <= b < d <= height
104
+ yxyx_valid = 0 <= b < d <= width and 0 <= a < c <= height
105
+ if yxyx_valid and not xyxy_valid:
106
+ x1, y1, x2, y2 = b, a, d, c
107
+ else:
108
+ x1, y1, x2, y2 = a, b, c, d
109
+
110
+ x1, x2 = sorted([x1, x2])
111
+ y1, y2 = sorted([y1, y2])
112
+ x1 = min(max(0.0, x1), float(width))
113
+ x2 = min(max(0.0, x2), float(width))
114
+ y1 = min(max(0.0, y1), float(height))
115
+ y2 = min(max(0.0, y2), float(height))
116
+ if x2 <= x1:
117
+ x2 = min(float(width), x1 + 1.0)
118
+ if y2 <= y1:
119
+ y2 = min(float(height), y1 + 1.0)
120
+ return [x1, y1, x2, y2]
121
+
122
+
123
+ def load_records(path: Path, adacrop_root: Path, require_images: bool = True) -> List[Dict]:
124
+ path = Path(path)
125
+ rows: List[Dict] = []
126
+ if path.suffix.lower() == ".jsonl":
127
+ with path.open("r", encoding="utf-8") as f:
128
+ for line in f:
129
+ line = line.strip()
130
+ if line:
131
+ rows.append(json.loads(line))
132
+ else:
133
+ with path.open("r", encoding="utf-8") as f:
134
+ rows = json.load(f)
135
+
136
+ records: List[Dict] = []
137
+ for row in rows:
138
+ raw_img = row.get("img") or row.get("file")
139
+ if not raw_img:
140
+ continue
141
+ img_path = resolve_image_path(raw_img, adacrop_root, source_file=path)
142
+ if require_images and not img_path.exists():
143
+ continue
144
+ boxes = normalize_boxes(row.get("box") or row.get("boxes") or row.get("orig_bbox"))
145
+ records.append({"img": str(img_path), "boxes": boxes, "raw": row})
146
+ return records
147
+
148
+
149
+ def resnet50_no_weights():
150
+ try:
151
+ return models.resnet50(weights=None)
152
+ except TypeError:
153
+ return models.resnet50(pretrained=False)
154
+
155
+
156
+ def mobilenet_v3_no_weights(arch: str):
157
+ if arch == "mobilenet_v3_large":
158
+ try:
159
+ return models.mobilenet_v3_large(weights=None)
160
+ except TypeError:
161
+ return models.mobilenet_v3_large(pretrained=False)
162
+ if arch == "mobilenet_v3_small":
163
+ try:
164
+ return models.mobilenet_v3_small(weights=None)
165
+ except TypeError:
166
+ return models.mobilenet_v3_small(pretrained=False)
167
+ raise ValueError(f"Unsupported student arch: {arch}")
168
+
169
+
170
+ class TeacherActorCritic(nn.Module):
171
+ def __init__(self, n_actions: int = len(ACTIONS)):
172
+ super().__init__()
173
+ self.backbone = resnet50_no_weights()
174
+ self.backbone.fc = nn.Identity()
175
+ feat_dim = 2048
176
+ self.actor = nn.Sequential(
177
+ nn.Linear(feat_dim + 4, 1024),
178
+ nn.ReLU(),
179
+ nn.Dropout(0.3),
180
+ nn.Linear(1024, 512),
181
+ nn.ReLU(),
182
+ nn.Dropout(0.2),
183
+ nn.Linear(512, n_actions),
184
+ )
185
+ self.critic = nn.Sequential(
186
+ nn.Linear(feat_dim + 4, 1024),
187
+ nn.ReLU(),
188
+ nn.Dropout(0.3),
189
+ nn.Linear(1024, 512),
190
+ nn.ReLU(),
191
+ nn.Dropout(0.2),
192
+ nn.Linear(512, 1),
193
+ )
194
+ self.bbox_head = nn.Sequential(nn.Linear(feat_dim, 512), nn.ReLU(), nn.Linear(512, 4))
195
+
196
+ def forward(self, img_tensor: torch.Tensor, state: torch.Tensor):
197
+ feats = self.backbone(img_tensor)
198
+ x = torch.cat([feats, state], dim=1)
199
+ logits = self.actor(x)
200
+ return F.softmax(logits, dim=1), self.critic(x)
201
+
202
+ def backbone_forward(self, img_tensor: torch.Tensor):
203
+ feats = self.backbone(img_tensor)
204
+ return self.bbox_head(feats)
205
+
206
+
207
+ class MobileNetPolicy(nn.Module):
208
+ def __init__(self, arch: str = "mobilenet_v3_small", n_actions: int = len(ACTIONS)):
209
+ super().__init__()
210
+ base = mobilenet_v3_no_weights(arch)
211
+ self.arch = arch
212
+ self.features = base.features
213
+ self.avgpool = base.avgpool
214
+ feat_dim = base.classifier[0].in_features
215
+ self.actor = nn.Sequential(
216
+ nn.Linear(feat_dim + 4, 512),
217
+ nn.ReLU(),
218
+ nn.Dropout(0.2),
219
+ nn.Linear(512, 256),
220
+ nn.ReLU(),
221
+ nn.Dropout(0.1),
222
+ nn.Linear(256, n_actions),
223
+ )
224
+ self.bbox_head = nn.Sequential(
225
+ nn.Linear(feat_dim, 256),
226
+ nn.ReLU(),
227
+ nn.Dropout(0.1),
228
+ nn.Linear(256, 4),
229
+ )
230
+
231
+ def extract_feats(self, img_tensor: torch.Tensor):
232
+ feats = self.features(img_tensor)
233
+ feats = self.avgpool(feats)
234
+ return torch.flatten(feats, 1)
235
+
236
+ def forward(self, img_tensor: torch.Tensor, state: torch.Tensor):
237
+ feats = self.extract_feats(img_tensor)
238
+ logits = self.actor(torch.cat([feats, state], dim=1))
239
+ return F.softmax(logits, dim=1), logits
240
+
241
+ def backbone_forward(self, img_tensor: torch.Tensor):
242
+ feats = self.extract_feats(img_tensor)
243
+ return torch.sigmoid(self.bbox_head(feats))
244
+
245
+
246
+ def load_teacher(ckpt_path: Path, device: torch.device) -> TeacherActorCritic:
247
+ ckpt = torch_load_portable(ckpt_path)
248
+ state_dict = ckpt.get("model_state_dict", ckpt) if isinstance(ckpt, dict) else ckpt
249
+ model = TeacherActorCritic(n_actions=len(ACTIONS))
250
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
251
+ if unexpected:
252
+ print(f"[teacher] unexpected keys: {unexpected[:8]}")
253
+ missing_required = [k for k in missing if not k.startswith("critic.") and not k.startswith("bbox_head.")]
254
+ if missing_required:
255
+ raise RuntimeError(f"Teacher checkpoint missing required keys: {missing_required[:8]}")
256
+ return model.to(device).eval()
257
+
258
+
259
+ def load_student(ckpt_path: Path, device: torch.device, arch: Optional[str] = None) -> MobileNetPolicy:
260
+ ckpt = torch_load_portable(ckpt_path)
261
+ ckpt_arch = ckpt.get("arch", arch or "mobilenet_v3_small")
262
+ model = MobileNetPolicy(arch=ckpt_arch, n_actions=len(ACTIONS))
263
+ state_dict = ckpt.get("model_state_dict", ckpt)
264
+ model.load_state_dict(state_dict)
265
+ return model.to(device).eval()
266
+
267
+
268
+ def torch_load_portable(ckpt_path: Path):
269
+ try:
270
+ return torch.load(ckpt_path, map_location="cpu", weights_only=False)
271
+ except NotImplementedError as exc:
272
+ if "WindowsPath" not in str(exc):
273
+ raise
274
+ # Checkpoints saved on Windows may pickle pathlib.WindowsPath inside
275
+ # metadata such as args. On POSIX, remap it before loading.
276
+ pathlib.WindowsPath = pathlib.PosixPath
277
+ return torch.load(ckpt_path, map_location="cpu", weights_only=False)
278
+
279
+
280
+ def xyxy_to_xywh(box: Sequence[float]) -> List[float]:
281
+ x1, y1, x2, y2 = [float(v) for v in box]
282
+ x1, x2 = sorted([x1, x2])
283
+ y1, y2 = sorted([y1, y2])
284
+ return [x1, y1, max(1.0, x2 - x1), max(1.0, y2 - y1)]
285
+
286
+
287
+ def xywh_to_xyxy(box: Sequence[float]) -> List[float]:
288
+ x, y, w, h = [float(v) for v in box]
289
+ return [x, y, x + w, y + h]
290
+
291
+
292
+ def box_iou_xyxy(a: Sequence[float], b: Sequence[float]) -> float:
293
+ ax1, ay1, ax2, ay2 = [float(v) for v in a]
294
+ bx1, by1, bx2, by2 = [float(v) for v in b]
295
+ ix1, iy1 = max(ax1, bx1), max(ay1, by1)
296
+ ix2, iy2 = min(ax2, bx2), min(ay2, by2)
297
+ iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
298
+ inter = iw * ih
299
+ area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
300
+ area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
301
+ union = area_a + area_b - inter
302
+ return 0.0 if union <= 1e-8 else inter / union
303
+
304
+
305
+ def clamp_xywh(box: Sequence[float], width: int, height: int, delta: float = 0.05) -> List[float]:
306
+ x, y, w, h = [float(v) for v in box]
307
+ min_size = max(10.0, min(width, height) * 0.05)
308
+ w = max(min_size, min(w, float(width)))
309
+ h = max(min_size, min(h, float(height)))
310
+ x = min(max(0.0, x), float(width) - w)
311
+ y = min(max(0.0, y), float(height) - h)
312
+ w = max(min_size, min(float(width) - x, max(w, delta * width)))
313
+ h = max(min_size, min(float(height) - y, max(h, delta * height)))
314
+ return [x, y, w, h]
315
+
316
+
317
+ def random_box(width: int, height: int) -> List[float]:
318
+ ratio = width / max(1, height)
319
+ scale = random.uniform(0.3, 0.8)
320
+ if ratio >= 1:
321
+ w = max(10.0, width * scale)
322
+ h = max(10.0, w / ratio)
323
+ else:
324
+ h = max(10.0, height * scale)
325
+ w = max(10.0, h * ratio)
326
+ x = random.uniform(0.0, max(1.0, width - w))
327
+ y = random.uniform(0.0, max(1.0, height - h))
328
+ return clamp_xywh([x, y, w, h], width, height)
329
+
330
+
331
+ def jitter_box(box_xywh: Sequence[float], width: int, height: int, jitter: float = 0.12) -> List[float]:
332
+ x, y, w, h = [float(v) for v in box_xywh]
333
+ x += random.uniform(-jitter, jitter) * width
334
+ y += random.uniform(-jitter, jitter) * height
335
+ w *= random.uniform(1.0 - jitter, 1.0 + jitter)
336
+ h *= random.uniform(1.0 - jitter, 1.0 + jitter)
337
+ return clamp_xywh([x, y, w, h], width, height)
338
+
339
+
340
+ def box_state(box_xywh: Sequence[float], width: int, height: int) -> torch.Tensor:
341
+ x, y, w, h = [float(v) for v in box_xywh]
342
+ state = [
343
+ (x + 0.5 * w) / max(1.0, width),
344
+ (y + 0.5 * h) / max(1.0, height),
345
+ w / max(1.0, width),
346
+ h / max(1.0, height),
347
+ ]
348
+ if not all(math.isfinite(v) for v in state):
349
+ state = [0.5, 0.5, 0.6, 0.6]
350
+ return torch.tensor(state, dtype=torch.float32)
351
+
352
+
353
+ def render_crop(img: Image.Image, box_xywh: Sequence[float], img_size: int) -> torch.Tensor:
354
+ x, y, w, h = [float(v) for v in box_xywh]
355
+ crop = img.crop((x, y, x + w, y + h)).resize((img_size, img_size))
356
+ return T.ToTensor()(crop)
357
+
358
+
359
+ def render_full_image(img: Image.Image, img_size: int) -> torch.Tensor:
360
+ return T.ToTensor()(img.resize((img_size, img_size)))
361
+
362
+
363
+ def bbox_target_from_xyxy(box_xyxy: Sequence[float], width: int, height: int) -> torch.Tensor:
364
+ x1, y1, x2, y2 = [float(v) for v in box_xyxy]
365
+ x1, x2 = sorted([x1, x2])
366
+ y1, y2 = sorted([y1, y2])
367
+ target = [
368
+ ((x1 + x2) * 0.5) / max(1.0, width),
369
+ ((y1 + y2) * 0.5) / max(1.0, height),
370
+ max(1.0, x2 - x1) / max(1.0, width),
371
+ max(1.0, y2 - y1) / max(1.0, height),
372
+ ]
373
+ return torch.tensor([min(1.0, max(0.0, v)) for v in target], dtype=torch.float32)
374
+
375
+
376
+ def bbox_cxcywh_to_xyxy(box_cxcywh: Sequence[float], width: int, height: int) -> List[float]:
377
+ cx, cy, w, h = [float(v) for v in box_cxcywh]
378
+ bw = w * width
379
+ bh = h * height
380
+ x1 = cx * width - 0.5 * bw
381
+ y1 = cy * height - 0.5 * bh
382
+ x2 = x1 + bw
383
+ y2 = y1 + bh
384
+ return [
385
+ min(max(0.0, x1), float(width)),
386
+ min(max(0.0, y1), float(height)),
387
+ min(max(0.0, x2), float(width)),
388
+ min(max(0.0, y2), float(height)),
389
+ ]
390
+
391
+
392
+ def step_box(box_xywh: Sequence[float], action_idx: int, width: int, height: int, delta: float = 0.05) -> List[float]:
393
+ act = ACTIONS[int(action_idx)]
394
+ x, y, w, h = [float(v) for v in box_xywh]
395
+ dx, dy = delta * w, delta * h
396
+ cx, cy = x + 0.5 * w, y + 0.5 * h
397
+ if act == "left":
398
+ x = max(0.0, x - dx)
399
+ elif act == "right":
400
+ x = min(width - w, x + dx)
401
+ elif act == "up":
402
+ y = max(0.0, y - dy)
403
+ elif act == "down":
404
+ y = min(height - h, y + dy)
405
+ elif act == "zoom_in":
406
+ w *= 1.0 - delta
407
+ h *= 1.0 - delta
408
+ x = cx - 0.5 * w
409
+ y = cy - 0.5 * h
410
+ elif act == "zoom_out":
411
+ w *= 1.0 + delta
412
+ h *= 1.0 + delta
413
+ x = cx - 0.5 * w
414
+ y = cy - 0.5 * h
415
+ return clamp_xywh([x, y, w, h], width, height, delta=delta)
416
+
417
+
418
+ class PolicyStateDataset(Dataset):
419
+ def __init__(
420
+ self,
421
+ records: Sequence[Dict],
422
+ img_size: int = 224,
423
+ samples_per_image: int = 1,
424
+ random_box_prob: float = 0.65,
425
+ jitter: float = 0.12,
426
+ ):
427
+ self.records = list(records)
428
+ self.img_size = int(img_size)
429
+ self.samples_per_image = max(1, int(samples_per_image))
430
+ self.random_box_prob = float(random_box_prob)
431
+ self.jitter = float(jitter)
432
+
433
+ def __len__(self) -> int:
434
+ return len(self.records) * self.samples_per_image
435
+
436
+ def __getitem__(self, idx: int):
437
+ rec = self.records[idx % len(self.records)]
438
+ img = Image.open(rec["img"]).convert("RGB")
439
+ width, height = img.size
440
+ boxes = rec.get("boxes") or []
441
+
442
+ if boxes and random.random() > self.random_box_prob:
443
+ gt_box = canonical_box_xyxy(random.choice(boxes), width, height, img_path=rec["img"])
444
+ box = jitter_box(xyxy_to_xywh(gt_box), width, height, jitter=self.jitter)
445
+ else:
446
+ box = random_box(width, height)
447
+
448
+ return render_crop(img, box, self.img_size), box_state(box, width, height)
449
+
450
+
451
+ class BBoxDataset(Dataset):
452
+ def __init__(self, records: Sequence[Dict], img_size: int = 224, samples_per_image: int = 1):
453
+ self.records = [r for r in records if r.get("boxes")]
454
+ self.img_size = int(img_size)
455
+ self.samples_per_image = max(1, int(samples_per_image))
456
+
457
+ def __len__(self) -> int:
458
+ return len(self.records) * self.samples_per_image
459
+
460
+ def __getitem__(self, idx: int):
461
+ rec = self.records[idx % len(self.records)]
462
+ img = Image.open(rec["img"]).convert("RGB")
463
+ width, height = img.size
464
+ box = canonical_box_xyxy(random.choice(rec["boxes"]), width, height, img_path=rec["img"])
465
+ return render_full_image(img, self.img_size), bbox_target_from_xyxy(box, width, height)
466
+
467
+
468
+ class BBoxEvalDataset(Dataset):
469
+ def __init__(self, records: Sequence[Dict], img_size: int = 224):
470
+ self.records = [r for r in records if r.get("boxes")]
471
+ self.img_size = int(img_size)
472
+
473
+ def __len__(self) -> int:
474
+ return len(self.records)
475
+
476
+ def __getitem__(self, idx: int):
477
+ rec = self.records[idx]
478
+ img = Image.open(rec["img"]).convert("RGB")
479
+ width, height = img.size
480
+ targets = torch.stack(
481
+ [
482
+ bbox_target_from_xyxy(canonical_box_xyxy(box, width, height, img_path=rec["img"]), width, height)
483
+ for box in rec["boxes"]
484
+ ]
485
+ )
486
+ return render_full_image(img, self.img_size), targets
487
+
488
+
489
+ def soften_probs(probs: torch.Tensor, temperature: float) -> torch.Tensor:
490
+ if temperature <= 1.0:
491
+ return probs
492
+ softened = probs.clamp_min(1e-8).pow(1.0 / temperature)
493
+ return softened / softened.sum(dim=1, keepdim=True)
student_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8754f42dba8ec738701aaca6893803bd8ebb6ce212f75e42da8e6186c54ebb1
3
+ size 18336390
student_last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14da7c12373975c86deb5d99cecedb17a9e2c98a5868a38e5f78e53394203225
3
+ size 18336390
train_mobilenet_distill.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import time
4
+ from itertools import cycle
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import DataLoader
10
+
11
+ from common import (
12
+ ACTIONS,
13
+ BBoxDataset,
14
+ BBoxEvalDataset,
15
+ MobileNetPolicy,
16
+ PolicyStateDataset,
17
+ bbox_cxcywh_to_xyxy,
18
+ box_iou_xyxy,
19
+ find_adacrop_root,
20
+ load_records,
21
+ load_teacher,
22
+ soften_probs,
23
+ )
24
+
25
+
26
+ def parse_args():
27
+ root = find_adacrop_root()
28
+ parser = argparse.ArgumentParser(description="Two-stage distillation: BBox head + PPO actor policy.")
29
+ parser.add_argument("--teacher-ckpt", type=Path, default=root.parent / "ppo_best_val_final_score.pth")
30
+ parser.add_argument("--train-jsonl", type=Path, default=root / "data" / "outpainted_dataset" / "training_pairs.jsonl")
31
+ parser.add_argument("--val-json", type=Path, default=root / "data" / "splits" / "val_mixed.json")
32
+ parser.add_argument("--output-dir", type=Path, default=root / "distillation" / "runs")
33
+ parser.add_argument("--arch", choices=["mobilenet_v3_small", "mobilenet_v3_large"], default="mobilenet_v3_small")
34
+ parser.add_argument("--resume-student", type=Path, default=None, help="Load an existing student checkpoint before training.")
35
+ parser.add_argument("--skip-bbox-stage", action="store_true", help="Skip Stage 1 and go directly to Stage 2 policy distillation.")
36
+
37
+ parser.add_argument("--bbox-epochs", type=int, default=5, help="Stage 1 epochs for bbox head distillation/supervision.")
38
+ parser.add_argument("--epochs", type=int, default=10, help="Stage 2 epochs for actor policy distillation.")
39
+ parser.add_argument("--batch-size", type=int, default=64)
40
+ parser.add_argument("--bbox-batch-size", type=int, default=0, help="Stage 2 bbox regularization batch size; 0 uses --batch-size.")
41
+ parser.add_argument("--lr", type=float, default=1e-4)
42
+ parser.add_argument("--bbox-lr", type=float, default=1e-4)
43
+ parser.add_argument("--weight-decay", type=float, default=1e-4)
44
+ parser.add_argument("--num-workers", type=int, default=4)
45
+ parser.add_argument("--pin-memory", action="store_true", help="Enable DataLoader pinned memory. Off by default to reduce Windows CUDA OOM risk.")
46
+ parser.add_argument("--samples-per-image", type=int, default=1)
47
+ parser.add_argument("--max-train-images", type=int, default=0)
48
+ parser.add_argument("--max-val-images", type=int, default=512)
49
+ parser.add_argument("--img-size", type=int, default=224)
50
+
51
+ parser.add_argument("--random-box-prob", type=float, default=0.65)
52
+ parser.add_argument("--jitter", type=float, default=0.12)
53
+ parser.add_argument("--temperature", type=float, default=2.0)
54
+ parser.add_argument("--ce-weight", type=float, default=0.25)
55
+ parser.add_argument("--bbox-gt-weight", type=float, default=1.0)
56
+ parser.add_argument("--bbox-teacher-weight", type=float, default=0.25)
57
+ parser.add_argument("--stage2-bbox-weight", type=float, default=0.10)
58
+
59
+ parser.add_argument("--save-every", type=int, default=5)
60
+ parser.add_argument("--patience", type=int, default=8, help="Stage 2 early-stop patience in epochs; <=0 disables.")
61
+ parser.add_argument("--min-delta", type=float, default=1e-4)
62
+ parser.add_argument("--seed", type=int, default=42)
63
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
64
+ return parser.parse_args()
65
+
66
+
67
+ def make_loader(dataset, batch_size, shuffle, num_workers, pin_memory=False, drop_last=False):
68
+ return DataLoader(
69
+ dataset,
70
+ batch_size=batch_size,
71
+ shuffle=shuffle,
72
+ num_workers=num_workers,
73
+ pin_memory=bool(pin_memory),
74
+ drop_last=drop_last,
75
+ )
76
+
77
+
78
+ def iou_from_cxcywh_batch(preds, targets):
79
+ preds = preds.detach().cpu().clamp(0.0, 1.0)
80
+ targets = targets.detach().cpu().clamp(0.0, 1.0)
81
+ ious = []
82
+ for pred, target in zip(preds, targets):
83
+ ious.append(box_iou_xyxy(bbox_cxcywh_to_xyxy(pred.tolist(), 1, 1), bbox_cxcywh_to_xyxy(target.tolist(), 1, 1)))
84
+ return sum(ious) / max(1, len(ious))
85
+
86
+
87
+ def best_iou_against_targets(pred_box, target_boxes):
88
+ pred_xyxy = bbox_cxcywh_to_xyxy(pred_box.tolist(), 1, 1)
89
+ return max(box_iou_xyxy(pred_xyxy, bbox_cxcywh_to_xyxy(t.tolist(), 1, 1)) for t in target_boxes)
90
+
91
+
92
+ @torch.no_grad()
93
+ def validate_bbox(student, teacher, loader, device, bbox_gt_weight, bbox_teacher_weight):
94
+ student.eval()
95
+ teacher.eval()
96
+ total = 0
97
+ total_loss = 0.0
98
+ gt_loss_sum = 0.0
99
+ teacher_loss_sum = 0.0
100
+ gt_iou_sum = 0.0
101
+ teacher_iou_sum = 0.0
102
+
103
+ for imgs, targets in loader:
104
+ imgs = imgs.to(device, non_blocking=True)
105
+ targets = targets.to(device, non_blocking=True)
106
+ preds = student.backbone_forward(imgs)
107
+ teacher_preds = teacher.backbone_forward(imgs).clamp(0.0, 1.0)
108
+
109
+ if targets.ndim == 3:
110
+ # Evaluation records can have multiple acceptable GT boxes. Use the
111
+ # closest GT for loss, and best IoU for reporting.
112
+ per_box_l1 = torch.abs(preds.unsqueeze(1) - targets).mean(dim=2)
113
+ best_idx = per_box_l1.argmin(dim=1)
114
+ chosen_targets = targets[torch.arange(targets.size(0), device=targets.device), best_idx]
115
+ else:
116
+ chosen_targets = targets
117
+
118
+ gt_loss = F.smooth_l1_loss(preds, chosen_targets)
119
+ teacher_loss = F.smooth_l1_loss(preds, teacher_preds)
120
+ loss = bbox_gt_weight * gt_loss + bbox_teacher_weight * teacher_loss
121
+
122
+ bs = imgs.size(0)
123
+ total += bs
124
+ total_loss += loss.item() * bs
125
+ gt_loss_sum += gt_loss.item() * bs
126
+ teacher_loss_sum += teacher_loss.item() * bs
127
+ if targets.ndim == 3:
128
+ preds_cpu = preds.detach().cpu().clamp(0.0, 1.0)
129
+ teacher_cpu = teacher_preds.detach().cpu().clamp(0.0, 1.0)
130
+ targets_cpu = targets.detach().cpu().clamp(0.0, 1.0)
131
+ gt_iou_sum += sum(best_iou_against_targets(p, ts) for p, ts in zip(preds_cpu, targets_cpu))
132
+ teacher_iou_sum += sum(best_iou_against_targets(p, ts) for p, ts in zip(teacher_cpu, targets_cpu))
133
+ else:
134
+ gt_iou_sum += iou_from_cxcywh_batch(preds, chosen_targets) * bs
135
+ teacher_iou_sum += iou_from_cxcywh_batch(teacher_preds, chosen_targets) * bs
136
+
137
+ return {
138
+ "bbox_loss": total_loss / max(1, total),
139
+ "bbox_gt_loss": gt_loss_sum / max(1, total),
140
+ "bbox_teacher_loss": teacher_loss_sum / max(1, total),
141
+ "bbox_gt_iou": gt_iou_sum / max(1, total),
142
+ "bbox_teacher_iou": teacher_iou_sum / max(1, total),
143
+ "bbox_samples": total,
144
+ }
145
+
146
+
147
+ @torch.no_grad()
148
+ def validate_policy(student, teacher, loader, device, temperature):
149
+ student.eval()
150
+ teacher.eval()
151
+ total = 0
152
+ total_kl = 0.0
153
+ total_ce = 0.0
154
+ total_agree = 0.0
155
+
156
+ for imgs, states in loader:
157
+ imgs = imgs.to(device, non_blocking=True)
158
+ states = states.to(device, non_blocking=True)
159
+ teacher_probs, _ = teacher(imgs, states)
160
+ student_probs, student_logits = student(imgs, states)
161
+ target_probs = soften_probs(teacher_probs, temperature)
162
+ kl = F.kl_div(F.log_softmax(student_logits / temperature, dim=1), target_probs, reduction="batchmean")
163
+ kl = kl * (temperature * temperature)
164
+ ce = F.cross_entropy(student_logits, teacher_probs.argmax(dim=1))
165
+ agree = (student_probs.argmax(dim=1) == teacher_probs.argmax(dim=1)).float().mean()
166
+
167
+ bs = imgs.size(0)
168
+ total += bs
169
+ total_kl += kl.item() * bs
170
+ total_ce += ce.item() * bs
171
+ total_agree += agree.item() * bs
172
+
173
+ return {
174
+ "policy_kl": total_kl / max(1, total),
175
+ "policy_ce": total_ce / max(1, total),
176
+ "policy_top1_agreement": total_agree / max(1, total),
177
+ "policy_samples": total,
178
+ }
179
+
180
+
181
+ def save_ckpt(path, student, optimizer, args, epoch, stage, metrics):
182
+ torch.save(
183
+ {
184
+ "arch": args.arch,
185
+ "epoch": epoch,
186
+ "stage": stage,
187
+ "model_state_dict": student.state_dict(),
188
+ "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
189
+ "args": vars(args),
190
+ "metrics": metrics,
191
+ },
192
+ path,
193
+ )
194
+
195
+
196
+ def load_student_checkpoint(student, ckpt_path: Path, device: torch.device):
197
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
198
+ state_dict = ckpt.get("model_state_dict", ckpt)
199
+ missing, unexpected = student.load_state_dict(state_dict, strict=False)
200
+ if missing:
201
+ print(f"[resume] missing keys: {missing[:8]}")
202
+ if unexpected:
203
+ print(f"[resume] unexpected keys: {unexpected[:8]}")
204
+ print(
205
+ f"[resume] loaded student checkpoint: {ckpt_path} "
206
+ f"(stage={ckpt.get('stage', 'unknown')}, epoch={ckpt.get('epoch', 'unknown')})"
207
+ )
208
+ return student.to(device)
209
+
210
+
211
+ def train_bbox_stage(args, student, teacher, train_loader, val_loader, device, run_dir, writer, csv_file):
212
+ print(f"[stage1] bbox distillation/supervision for {args.bbox_epochs} epoch(s)")
213
+ optimizer = torch.optim.AdamW(student.parameters(), lr=args.bbox_lr, weight_decay=args.weight_decay)
214
+ scaler = torch.amp.GradScaler("cuda", enabled=device.type == "cuda")
215
+ best_iou = -1.0
216
+
217
+ for epoch in range(1, args.bbox_epochs + 1):
218
+ student.train()
219
+ total = 0
220
+ loss_sum = 0.0
221
+ gt_loss_sum = 0.0
222
+ teacher_loss_sum = 0.0
223
+
224
+ for imgs, targets in train_loader:
225
+ imgs = imgs.to(device, non_blocking=True)
226
+ targets = targets.to(device, non_blocking=True)
227
+ with torch.no_grad():
228
+ teacher_targets = teacher.backbone_forward(imgs).clamp(0.0, 1.0)
229
+
230
+ optimizer.zero_grad(set_to_none=True)
231
+ with torch.amp.autocast("cuda", enabled=device.type == "cuda"):
232
+ preds = student.backbone_forward(imgs)
233
+ gt_loss = F.smooth_l1_loss(preds, targets)
234
+ teacher_loss = F.smooth_l1_loss(preds, teacher_targets)
235
+ loss = args.bbox_gt_weight * gt_loss + args.bbox_teacher_weight * teacher_loss
236
+
237
+ scaler.scale(loss).backward()
238
+ scaler.step(optimizer)
239
+ scaler.update()
240
+
241
+ bs = imgs.size(0)
242
+ total += bs
243
+ loss_sum += loss.item() * bs
244
+ gt_loss_sum += gt_loss.item() * bs
245
+ teacher_loss_sum += teacher_loss.item() * bs
246
+
247
+ val_bbox = validate_bbox(student, teacher, val_loader, device, args.bbox_gt_weight, args.bbox_teacher_weight)
248
+ row = {
249
+ "stage": "bbox",
250
+ "epoch": epoch,
251
+ "train_loss": loss_sum / max(1, total),
252
+ "train_bbox_gt_loss": gt_loss_sum / max(1, total),
253
+ "train_bbox_teacher_loss": teacher_loss_sum / max(1, total),
254
+ "val_bbox_loss": val_bbox["bbox_loss"],
255
+ "val_bbox_gt_loss": val_bbox["bbox_gt_loss"],
256
+ "val_bbox_teacher_loss": val_bbox["bbox_teacher_loss"],
257
+ "val_bbox_gt_iou": val_bbox["bbox_gt_iou"],
258
+ "val_bbox_teacher_iou": val_bbox["bbox_teacher_iou"],
259
+ "val_bbox_samples": val_bbox["bbox_samples"],
260
+ }
261
+ writer.writerow(row)
262
+ csv_file.flush()
263
+
264
+ save_ckpt(run_dir / "student_bbox_stage1_last.pth", student, optimizer, args, epoch, "bbox", row)
265
+ if val_bbox["bbox_gt_iou"] > best_iou + args.min_delta:
266
+ best_iou = val_bbox["bbox_gt_iou"]
267
+ save_ckpt(run_dir / "student_bbox_stage1_best.pth", student, optimizer, args, epoch, "bbox", row)
268
+ print(f"[stage1][save] best bbox: {run_dir / 'student_bbox_stage1_best.pth'}")
269
+
270
+ print(
271
+ f"[stage1][epoch {epoch}] loss={row['train_loss']:.4f} "
272
+ f"val_bbox_iou={row['val_bbox_gt_iou']:.3f} "
273
+ f"val_teacher_iou={row['val_bbox_teacher_iou']:.3f}"
274
+ )
275
+ if device.type == "cuda":
276
+ torch.cuda.empty_cache()
277
+
278
+
279
+ def train_policy_stage(args, student, teacher, policy_loader, bbox_loader, val_policy_loader, val_bbox_loader, device, run_dir, writer, csv_file):
280
+ print(f"[stage2] actor policy distillation for {args.epochs} epoch(s)")
281
+ optimizer = torch.optim.AdamW(student.parameters(), lr=args.lr, weight_decay=args.weight_decay)
282
+ scaler = torch.amp.GradScaler("cuda", enabled=device.type == "cuda")
283
+ bbox_iter = cycle(bbox_loader) if args.stage2_bbox_weight > 0 and len(bbox_loader) > 0 else None
284
+
285
+ best_agreement = -1.0
286
+ epochs_without_improvement = 0
287
+
288
+ for epoch in range(1, args.epochs + 1):
289
+ student.train()
290
+ total = 0
291
+ loss_sum = 0.0
292
+ kl_sum = 0.0
293
+ ce_sum = 0.0
294
+ bbox_sum = 0.0
295
+ agree_sum = 0.0
296
+
297
+ for step, (imgs, states) in enumerate(policy_loader, start=1):
298
+ imgs = imgs.to(device, non_blocking=True)
299
+ states = states.to(device, non_blocking=True)
300
+
301
+ with torch.no_grad():
302
+ teacher_probs, _ = teacher(imgs, states)
303
+ target_probs = soften_probs(teacher_probs, args.temperature)
304
+ hard_targets = teacher_probs.argmax(dim=1)
305
+
306
+ bbox_loss = torch.zeros((), device=device)
307
+ bbox_bs = imgs.size(0)
308
+ if bbox_iter is not None:
309
+ bbox_imgs, bbox_targets = next(bbox_iter)
310
+ bbox_imgs = bbox_imgs.to(device, non_blocking=True)
311
+ bbox_targets = bbox_targets.to(device, non_blocking=True)
312
+ bbox_bs = bbox_imgs.size(0)
313
+
314
+ optimizer.zero_grad(set_to_none=True)
315
+ with torch.amp.autocast("cuda", enabled=device.type == "cuda"):
316
+ student_probs, student_logits = student(imgs, states)
317
+ kl = F.kl_div(F.log_softmax(student_logits / args.temperature, dim=1), target_probs, reduction="batchmean")
318
+ kl = kl * (args.temperature * args.temperature)
319
+ ce = F.cross_entropy(student_logits, hard_targets)
320
+ policy_loss = kl + args.ce_weight * ce
321
+
322
+ if bbox_iter is not None:
323
+ bbox_preds = student.backbone_forward(bbox_imgs)
324
+ bbox_loss = F.smooth_l1_loss(bbox_preds, bbox_targets)
325
+ loss = policy_loss + args.stage2_bbox_weight * bbox_loss
326
+
327
+ scaler.scale(loss).backward()
328
+ scaler.step(optimizer)
329
+ scaler.update()
330
+
331
+ bs = imgs.size(0)
332
+ total += bs
333
+ loss_sum += loss.item() * bs
334
+ kl_sum += kl.item() * bs
335
+ ce_sum += ce.item() * bs
336
+ bbox_sum += bbox_loss.item() * bbox_bs
337
+ agree_sum += (student_probs.argmax(dim=1) == hard_targets).float().mean().item() * bs
338
+
339
+ if step % 50 == 0:
340
+ print(
341
+ f"[stage2][epoch {epoch}] step {step}/{len(policy_loader)} "
342
+ f"loss={loss_sum / total:.4f} kl={kl_sum / total:.4f} "
343
+ f"agree={agree_sum / total:.3f}"
344
+ )
345
+
346
+ val_policy = validate_policy(student, teacher, val_policy_loader, device, args.temperature)
347
+ val_bbox = validate_bbox(student, teacher, val_bbox_loader, device, args.bbox_gt_weight, args.bbox_teacher_weight)
348
+ row = {
349
+ "stage": "policy",
350
+ "epoch": epoch,
351
+ "train_loss": loss_sum / max(1, total),
352
+ "train_policy_kl": kl_sum / max(1, total),
353
+ "train_policy_ce": ce_sum / max(1, total),
354
+ "train_policy_top1_agreement": agree_sum / max(1, total),
355
+ "train_stage2_bbox_loss": bbox_sum / max(1, total),
356
+ "val_policy_kl": val_policy["policy_kl"],
357
+ "val_policy_ce": val_policy["policy_ce"],
358
+ "val_policy_top1_agreement": val_policy["policy_top1_agreement"],
359
+ "val_policy_samples": val_policy["policy_samples"],
360
+ "val_bbox_loss": val_bbox["bbox_loss"],
361
+ "val_bbox_gt_iou": val_bbox["bbox_gt_iou"],
362
+ "val_bbox_teacher_iou": val_bbox["bbox_teacher_iou"],
363
+ }
364
+
365
+ improved = row["val_policy_top1_agreement"] > best_agreement + args.min_delta
366
+ if improved:
367
+ best_agreement = row["val_policy_top1_agreement"]
368
+ epochs_without_improvement = 0
369
+ else:
370
+ epochs_without_improvement += 1
371
+ should_stop = args.patience > 0 and epochs_without_improvement >= args.patience
372
+
373
+ row["best_val_policy_top1_agreement"] = best_agreement
374
+ row["epochs_without_improvement"] = epochs_without_improvement
375
+ row["early_stop"] = bool(should_stop)
376
+
377
+ save_ckpt(run_dir / "student_last.pth", student, optimizer, args, epoch, "policy", row)
378
+ if improved:
379
+ save_ckpt(run_dir / "student_best.pth", student, optimizer, args, epoch, "policy", row)
380
+ print(f"[stage2][save] best policy: {run_dir / 'student_best.pth'}")
381
+ if args.save_every > 0 and epoch % args.save_every == 0:
382
+ path = run_dir / f"student_epoch_{epoch:03d}.pth"
383
+ save_ckpt(path, student, optimizer, args, epoch, "policy", row)
384
+ print(f"[stage2][save] periodic checkpoint: {path}")
385
+
386
+ writer.writerow(row)
387
+ csv_file.flush()
388
+
389
+ print(
390
+ f"[stage2][epoch {epoch}] loss={row['train_loss']:.4f} "
391
+ f"val_agree={row['val_policy_top1_agreement']:.3f} "
392
+ f"val_bbox_iou={row['val_bbox_gt_iou']:.3f} "
393
+ f"best={best_agreement:.3f} stale={epochs_without_improvement}/{args.patience if args.patience > 0 else 'off'}"
394
+ )
395
+
396
+ if should_stop:
397
+ print(f"[early-stop] no policy agreement improvement for {args.patience} epoch(s).")
398
+ break
399
+ if device.type == "cuda":
400
+ torch.cuda.empty_cache()
401
+
402
+
403
+ def main():
404
+ args = parse_args()
405
+ torch.manual_seed(args.seed)
406
+ device = torch.device(args.device)
407
+ root = find_adacrop_root()
408
+
409
+ run_dir = args.output_dir / f"{args.arch}_twostage_{time.strftime('%Y%m%d_%H%M%S')}"
410
+ run_dir.mkdir(parents=True, exist_ok=True)
411
+
412
+ train_records = load_records(args.train_jsonl, root, require_images=True)
413
+ val_records = load_records(args.val_json, root, require_images=True) if args.val_json.exists() else []
414
+ if args.max_train_images > 0:
415
+ train_records = train_records[: args.max_train_images]
416
+ if args.max_val_images > 0:
417
+ val_records = val_records[: args.max_val_images]
418
+ if not train_records:
419
+ raise RuntimeError("No training images were resolved. Check --train-jsonl and path handling.")
420
+
421
+ print(f"[data] train images: {len(train_records)}")
422
+ print(f"[data] val images: {len(val_records)}")
423
+ print(f"[data] first train image: {train_records[0]['img']}")
424
+
425
+ bbox_train_ds = BBoxDataset(train_records, img_size=args.img_size, samples_per_image=args.samples_per_image)
426
+ bbox_val_ds = BBoxEvalDataset(val_records or train_records[: min(256, len(train_records))], img_size=args.img_size)
427
+ policy_train_ds = PolicyStateDataset(
428
+ train_records,
429
+ img_size=args.img_size,
430
+ samples_per_image=args.samples_per_image,
431
+ random_box_prob=args.random_box_prob,
432
+ jitter=args.jitter,
433
+ )
434
+ policy_val_ds = PolicyStateDataset(
435
+ val_records or train_records[: min(256, len(train_records))],
436
+ img_size=args.img_size,
437
+ samples_per_image=1,
438
+ random_box_prob=args.random_box_prob,
439
+ jitter=args.jitter,
440
+ )
441
+ if len(bbox_train_ds) == 0:
442
+ raise RuntimeError("No bbox labels found for Stage 1. Check box/orig_bbox fields.")
443
+
444
+ bbox_batch_size = args.bbox_batch_size if args.bbox_batch_size > 0 else args.batch_size
445
+ bbox_train_loader = make_loader(
446
+ bbox_train_ds,
447
+ bbox_batch_size,
448
+ True,
449
+ args.num_workers,
450
+ pin_memory=args.pin_memory,
451
+ drop_last=True,
452
+ )
453
+ bbox_val_loader = make_loader(
454
+ bbox_val_ds,
455
+ bbox_batch_size,
456
+ False,
457
+ max(0, min(args.num_workers, 4)),
458
+ pin_memory=args.pin_memory,
459
+ )
460
+ policy_train_loader = make_loader(
461
+ policy_train_ds,
462
+ args.batch_size,
463
+ True,
464
+ args.num_workers,
465
+ pin_memory=args.pin_memory,
466
+ drop_last=True,
467
+ )
468
+ policy_val_loader = make_loader(
469
+ policy_val_ds,
470
+ args.batch_size,
471
+ False,
472
+ max(0, min(args.num_workers, 4)),
473
+ pin_memory=args.pin_memory,
474
+ )
475
+
476
+ teacher = load_teacher(args.teacher_ckpt, device)
477
+ student = MobileNetPolicy(arch=args.arch, n_actions=len(ACTIONS)).to(device)
478
+ if args.resume_student is not None:
479
+ student = load_student_checkpoint(student, args.resume_student, device)
480
+
481
+ metrics_path = run_dir / "metrics.csv"
482
+ fieldnames = [
483
+ "stage",
484
+ "epoch",
485
+ "train_loss",
486
+ "train_bbox_gt_loss",
487
+ "train_bbox_teacher_loss",
488
+ "train_policy_kl",
489
+ "train_policy_ce",
490
+ "train_policy_top1_agreement",
491
+ "train_stage2_bbox_loss",
492
+ "val_bbox_loss",
493
+ "val_bbox_gt_loss",
494
+ "val_bbox_teacher_loss",
495
+ "val_bbox_gt_iou",
496
+ "val_bbox_teacher_iou",
497
+ "val_bbox_samples",
498
+ "val_policy_kl",
499
+ "val_policy_ce",
500
+ "val_policy_top1_agreement",
501
+ "val_policy_samples",
502
+ "best_val_policy_top1_agreement",
503
+ "epochs_without_improvement",
504
+ "early_stop",
505
+ ]
506
+ with metrics_path.open("w", newline="", encoding="utf-8") as f:
507
+ writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
508
+ writer.writeheader()
509
+ if args.skip_bbox_stage:
510
+ print("[stage1] skipped by --skip-bbox-stage")
511
+ elif args.bbox_epochs > 0:
512
+ train_bbox_stage(args, student, teacher, bbox_train_loader, bbox_val_loader, device, run_dir, writer, f)
513
+ if args.epochs > 0:
514
+ train_policy_stage(
515
+ args,
516
+ student,
517
+ teacher,
518
+ policy_train_loader,
519
+ bbox_train_loader,
520
+ policy_val_loader,
521
+ bbox_val_loader,
522
+ device,
523
+ run_dir,
524
+ writer,
525
+ f,
526
+ )
527
+
528
+ print(f"[done] run dir: {run_dir}")
529
+
530
+
531
+ if __name__ == "__main__":
532
+ main()