File size: 7,624 Bytes
864ba61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Stage 4 training loop.

Train the compact specialist student to reproduce the 100 target dims that
EUPE-ViT-B produces for each COCO train image, using the per-image raw image
(resized to 768) as input. Target tensor is pre-computed by prepare_targets.py.

Loss: MSE on the 100-D output.
Optimizer: AdamW.
Schedule: cosine with 3% warmup.

Saves:
  student_final.safetensors  — best student weights
  training_log.json          — per-epoch loss + held-out F1 via Stage 0 classifier
"""
import os, sys, time, json, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from pycocotools.coco import COCO
from safetensors.torch import save_file

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)
from student import SpecialistStudent

COCO_ROOT = '/home/zootest/datasets/coco'
TARGETS = f'{COCO_ROOT}/stage4_teacher_targets/targets.pt'
CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_4'
DEVICE = 'cuda'
RES = 768
BATCH = 16
LR = 3e-4
WD = 1e-4
EPOCHS = 10
WARMUP_FRAC = 0.03


class CocoImgDataset(torch.utils.data.Dataset):
    def __init__(self, coco_root, targets_pack):
        self.root = f'{coco_root}/train2017'
        self.coco = COCO(f'{coco_root}/annotations/instances_train2017.json')
        self.img_ids = targets_pack['img_ids']
        self.targets = targets_pack['targets']
        # Build filename lookup
        self.id_to_file = {info['id']: info['file_name']
                           for info in self.coco.loadImgs(self.coco.getImgIds())}

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, i):
        img_id = self.img_ids[i]
        target = self.targets[i].float()
        fname = self.id_to_file.get(img_id, None)
        if fname is None:
            return None
        path = f'{self.root}/{fname}'
        try:
            img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
        except Exception:
            return None
        arr = np.asarray(img, dtype=np.uint8).copy()
        x = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        x = (x - mean) / std
        return x, target


def collate(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    xs, ts = zip(*batch)
    return torch.stack(xs), torch.stack(ts)


def eval_f1(student, classifier_json):
    """Eval on COCO val 2017 image-level person classification."""
    with open(classifier_json) as f:
        c = json.load(f)
    pos = c['pos_dims']
    neg = c['neg_dims']
    # Targets for student output are dims = pos + neg → 100-D. Inside that 100,
    # pos is [0..19], neg is [20..39].
    pos_idx = list(range(len(pos)))
    neg_idx = list(range(len(pos), len(pos) + len(neg)))

    coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
    img_ids = sorted(coco.getImgIds())[:1000]
    id_to_file = {info['id']: info['file_name']
                  for info in coco.loadImgs(coco.getImgIds())}
    MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(DEVICE)
    STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(DEVICE)
    scores = []
    labels = []
    with torch.inference_mode():
        for img_id in img_ids:
            fname = id_to_file.get(img_id)
            if fname is None:
                continue
            img = Image.open(f'{COCO_ROOT}/val2017/{fname}').convert('RGB').resize((RES, RES), Image.BILINEAR)
            arr = np.asarray(img, dtype=np.uint8).copy()
            x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(DEVICE).float() / 255.0
            x = (x - MEAN) / STD
            out = student(x).squeeze(0)
            s = out[pos_idx].sum() - out[neg_idx].sum()
            scores.append(s.item())
            labels.append(any(a['category_id'] == 1
                              for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False))))
    scores = torch.tensor(scores)
    labels = torch.tensor(labels, dtype=torch.bool)
    # Sweep threshold
    uniq = torch.unique(scores).sort().values
    best = (0, 0, 0, 0)
    for t in uniq.tolist()[::max(1, len(uniq) // 500)]:
        pred = scores > t
        tp = (pred & labels).sum().float()
        fp = (pred & ~labels).sum().float()
        fn = (~pred & labels).sum().float()
        prec = tp / (tp + fp).clamp(min=1)
        rec = tp / (tp + fn).clamp(min=1)
        f1 = (2 * prec * rec / (prec + rec).clamp(min=1e-9)).item()
        if f1 > best[0]:
            best = (f1, t, prec.item(), rec.item())
    return best


def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    print('[init] loading targets', flush=True)
    pack = torch.load(TARGETS, map_location='cpu', weights_only=False)
    print(f'  {pack["targets"].shape[0]} teacher targets', flush=True)

    dataset = CocoImgDataset(COCO_ROOT, pack)
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=BATCH, shuffle=True, num_workers=4,
        pin_memory=True, collate_fn=collate, drop_last=True,
    )

    student = SpecialistStudent().to(DEVICE)
    print(f'[student] {sum(p.numel() for p in student.parameters())/1e6:.2f}M params', flush=True)

    total_steps = EPOCHS * len(loader)
    warmup = int(total_steps * WARMUP_FRAC)
    opt = torch.optim.AdamW(student.parameters(), lr=LR, weight_decay=WD)
    sched = torch.optim.lr_scheduler.LambdaLR(
        opt, lambda s: s / max(1, warmup) if s < warmup
        else 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(1, total_steps - warmup))))

    log = {'epochs': [], 'student_params': int(sum(p.numel() for p in student.parameters()))}
    step = 0
    t0 = time.time()
    for ep in range(EPOCHS):
        student.train()
        ep_loss, n_batches = 0.0, 0
        for batch in loader:
            if batch is None:
                continue
            x, y = batch
            x = x.to(DEVICE, non_blocking=True)
            y = y.to(DEVICE, non_blocking=True)
            with torch.autocast('cuda', dtype=torch.bfloat16):
                pred = student(x)
            loss = F.mse_loss(pred.float(), y)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
            opt.step()
            sched.step()
            ep_loss += loss.item()
            n_batches += 1
            step += 1
            if step % 200 == 0:
                print(f'  ep {ep+1}/{EPOCHS}  step {step}/{total_steps}  '
                      f'loss={loss.item():.4f}  lr={opt.param_groups[0]["lr"]:.2e}  '
                      f'{(time.time()-t0)/60:.1f} min', flush=True)
        avg = ep_loss / max(1, n_batches)
        student.eval()
        f1, thr, p, r = eval_f1(student, CLASSIFIER)
        print(f'[ep {ep+1}] loss={avg:.4f}  F1={f1:.4f}  P={p:.4f}  R={r:.4f}  θ={thr:.3f}',
              flush=True)
        log['epochs'].append({'epoch': ep + 1, 'loss': avg,
                              'F1': f1, 'precision': p, 'recall': r, 'threshold': thr})
        # Save after each epoch
        save_file(student.state_dict(), f'{OUT_DIR}/student_ep{ep+1}.safetensors')
        with open(f'{OUT_DIR}/training_log.json', 'w') as f:
            json.dump(log, f, indent=2)

    # Final save
    save_file(student.state_dict(), f'{OUT_DIR}/student_final.safetensors')
    print(f'[done] total time {(time.time()-t0)/60:.1f} min', flush=True)


if __name__ == '__main__':
    main()