File size: 8,126 Bytes
faf011c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""Stage 4C: direct classifier-score supervision.

Same 3.27M student architecture as Stage 4. Same 40-D output. But the loss
is on the *classifier score* rather than the per-dim values:

    student_score = student_out[pos_dims].sum() - student_out[neg_dims].sum()
    teacher_score = teacher_target[pos_dims].sum() - teacher_target[neg_dims].sum()
    loss = (student_score - teacher_score) ** 2

The student is optimized to produce the same binary decision as the teacher
at the classifier threshold, not to reproduce the teacher's feature geometry
dim-by-dim. If the Stage 4B plateau at F1 0.723 was caused by even
small per-dim errors accumulating into scalar miscalibration, this should
close the gap.
"""
import os, sys, time, json, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from pycocotools.coco import COCO
from safetensors.torch import save_file

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, '/mnt/d/_tmp/1pc_repo/stage_4')
from student import SpecialistStudent

COCO_ROOT = '/home/zootest/datasets/coco'
TARGETS = f'{COCO_ROOT}/stage4_teacher_targets/targets.pt'
CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_4c'
DEVICE = 'cuda'
RES = 768
BATCH = 16
LR = 3e-4
WD = 1e-4
EPOCHS = 15
WARMUP_FRAC = 0.03


class CocoImgDataset(torch.utils.data.Dataset):
    def __init__(self, coco_root, pack):
        self.root = f'{coco_root}/train2017'
        coco = COCO(f'{coco_root}/annotations/instances_train2017.json')
        self.img_ids = pack['img_ids']
        self.targets = pack['targets']         # (N, 40)
        self.id_to_file = {i['id']: i['file_name'] for i in coco.loadImgs(coco.getImgIds())}

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, i):
        img_id = self.img_ids[i]
        target = self.targets[i].float()         # (40,)
        fname = self.id_to_file.get(img_id)
        if fname is None:
            return None
        try:
            img = Image.open(f'{self.root}/{fname}').convert('RGB').resize((RES, RES), Image.BILINEAR)
        except Exception:
            return None
        arr = np.asarray(img, dtype=np.uint8).copy()
        x = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return (x - mean) / std, target


def collate(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    xs, ts = zip(*batch)
    return torch.stack(xs), torch.stack(ts)


def eval_f1(student, pos_idx, neg_idx):
    coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
    img_ids = sorted(coco.getImgIds())[:500]
    id_to_file = {i['id']: i['file_name'] for i in coco.loadImgs(coco.getImgIds())}
    MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(DEVICE)
    STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(DEVICE)
    scores, labels = [], []
    student.eval()
    with torch.inference_mode():
        for img_id in img_ids:
            fname = id_to_file.get(img_id)
            if not fname:
                continue
            img = Image.open(f'{COCO_ROOT}/val2017/{fname}').convert('RGB').resize((RES, RES), Image.BILINEAR)
            arr = np.asarray(img, dtype=np.uint8).copy()
            x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(DEVICE).float() / 255.0
            x = (x - MEAN) / STD
            with torch.autocast('cuda', dtype=torch.bfloat16):
                out = student(x).float()[0]
            scores.append((out[pos_idx].sum() - out[neg_idx].sum()).item())
            labels.append(any(a['category_id'] == 1
                              for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False))))
    scores = torch.tensor(scores); labels = torch.tensor(labels, dtype=torch.bool)
    uniq = torch.unique(scores).sort().values
    best = (0, 0, 0, 0)
    for t in uniq.tolist()[::max(1, len(uniq) // 500)]:
        pred = scores > t
        tp = (pred & labels).sum().float()
        fp = (pred & ~labels).sum().float()
        fn = (~pred & labels).sum().float()
        prec = tp / (tp + fp).clamp(min=1)
        rec = tp / (tp + fn).clamp(min=1)
        f1 = (2 * prec * rec / (prec + rec).clamp(min=1e-9)).item()
        if f1 > best[0]:
            best = (f1, t, prec.item(), rec.item())
    return best


def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    pack = torch.load(TARGETS, map_location='cpu', weights_only=False)
    print(f'[init] {pack["targets"].shape[0]} targets  shape {tuple(pack["targets"].shape)}',
          flush=True)

    # In the 40-D target vector, [0..19] are pos dims, [20..39] are neg dims (built that way by prepare_targets)
    pos_idx = torch.arange(0, 20, device=DEVICE)
    neg_idx = torch.arange(20, 40, device=DEVICE)

    # Pre-compute teacher scalar scores: (N,)
    teacher_scalar = pack['targets'].float()[:, :20].sum(1) - pack['targets'].float()[:, 20:].sum(1)
    pack['teacher_scalar'] = teacher_scalar
    print(f'[init] teacher scalar stats: mean={teacher_scalar.mean():.3f}  '
          f'std={teacher_scalar.std():.3f}', flush=True)

    ds = CocoImgDataset(COCO_ROOT, pack)
    loader = torch.utils.data.DataLoader(
        ds, batch_size=BATCH, shuffle=True, num_workers=4,
        pin_memory=True, collate_fn=collate, drop_last=True)

    student = SpecialistStudent().to(DEVICE)
    nparams = sum(p.numel() for p in student.parameters())
    print(f'[student] {nparams:,} params = {nparams/1e6:.2f}M', flush=True)

    total_steps = EPOCHS * len(loader)
    warmup = int(total_steps * WARMUP_FRAC)
    opt = torch.optim.AdamW(student.parameters(), lr=LR, weight_decay=WD)
    sched = torch.optim.lr_scheduler.LambdaLR(
        opt, lambda s: s / max(1, warmup) if s < warmup
        else 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(1, total_steps - warmup))))

    log = {'student_params': nparams, 'loss': 'MSE_on_classifier_scalar', 'epochs': []}
    step = 0; t0 = time.time()
    for ep in range(EPOCHS):
        student.train()
        ep_loss, n_batches = 0.0, 0
        for batch in loader:
            if batch is None:
                continue
            x, y = batch
            x = x.to(DEVICE, non_blocking=True); y = y.to(DEVICE, non_blocking=True)
            with torch.autocast('cuda', dtype=torch.bfloat16):
                pred = student(x)                              # (B, 40)
            pred = pred.float()
            student_scalar = pred[:, :20].sum(1) - pred[:, 20:].sum(1)    # (B,)
            teacher_scalar_b = y[:, :20].sum(1) - y[:, 20:].sum(1)
            loss = F.mse_loss(student_scalar, teacher_scalar_b)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
            opt.step(); sched.step()
            ep_loss += loss.item(); n_batches += 1; step += 1
            if step % 500 == 0:
                print(f'  ep {ep+1}/{EPOCHS}  step {step}/{total_steps}  '
                      f'loss={loss.item():.4f}  lr={opt.param_groups[0]["lr"]:.2e}  '
                      f'{(time.time()-t0)/60:.1f} min', flush=True)
        avg = ep_loss / max(1, n_batches)
        f1, thr, p, r = eval_f1(student, pos_idx, neg_idx)
        print(f'[ep {ep+1}] loss={avg:.4f}  F1={f1:.4f}  P={p:.4f}  R={r:.4f}  '
              f'θ={thr:.3f}  {(time.time()-t0)/60:.1f} min', flush=True)
        log['epochs'].append({'epoch': ep + 1, 'loss': avg,
                              'F1': f1, 'precision': p, 'recall': r, 'threshold': thr})
        if (ep + 1) % 5 == 0 or ep == EPOCHS - 1:
            save_file(student.state_dict(), f'{OUT_DIR}/student_ep{ep+1}.safetensors')
        with open(f'{OUT_DIR}/training_log.json', 'w') as f:
            json.dump(log, f, indent=2)

    save_file(student.state_dict(), f'{OUT_DIR}/student_final.safetensors')
    print(f'[done] total {(time.time()-t0)/60:.1f} min', flush=True)


if __name__ == '__main__':
    main()