phanerozoic commited on
Commit
3a6f6a2
Β·
verified Β·
1 Parent(s): a8ed813

Update README with pruning curve and dim20, fix person detector ext4 path

Browse files
README.md CHANGED
@@ -161,13 +161,16 @@ The original design leads on precision. Additional scales, adaptive boundaries,
161
 
162
  Two Cofiber Threshold variants trained on full COCO 2017 train (117,266 images), 8 epochs, batch 64, AdamW lr 1e-3, cosine schedule with 3% warmup. Frozen EUPE-ViT-B backbone. Evaluated with pycocotools on the standard 5000-image val set.
163
 
164
- | Variant | Box regression | Params | Nonzero | mAP@[0.5:0.95] | mAP@0.50 | mAP@0.75 | mAP small | mAP medium | mAP large |
165
- |---------|---------------|--------|---------|----------------|----------|----------|-----------|------------|-----------|
166
- | linear_70k | 768β†’4 | 69,976 | 69,976 | 4.0 | 15.8 | 0.8 | 1.3 | 4.1 | 6.3 |
167
- | box32_92k | 768β†’32β†’4 | 91,640 | 91,640 | 5.7 | 20.6 | 1.3 | 2.5 | 6.4 | 7.9 |
168
- | box32_92k pruned | 768β†’32β†’4 | 91,640 | 76,640 | 5.7 | 20.7 | 1.3 | 2.5 | 6.5 | 8.0 |
169
-
170
- The pruned variant zeros 15,000 prototype weights with no mAP degradation. All three are the smallest detection heads to produce standard COCO mAP numbers. Both unpruned variants remain smaller than the NanoDet-m-0.5x detection head (94K parameters). Pruning is ongoing.
 
 
 
171
 
172
  ## Repository Structure
173
 
 
161
 
162
  Two Cofiber Threshold variants trained on full COCO 2017 train (117,266 images), 8 epochs, batch 64, AdamW lr 1e-3, cosine schedule with 3% warmup. Frozen EUPE-ViT-B backbone. Evaluated with pycocotools on the standard 5000-image val set.
163
 
164
+ | Variant | Box regression | Params | Nonzero | mAP@[0.5:0.95] | mAP@0.50 | mAP@0.75 |
165
+ |---------|---------------|--------|---------|----------------|----------|----------|
166
+ | linear_70k | 768β†’4 | 69,976 | 69,976 | 4.0 | 15.8 | 0.8 |
167
+ | box32_92k | 768β†’32β†’4 | 91,640 | 91,640 | 5.7 | 20.6 | 1.3 |
168
+ | box32 pruned R1 | 768β†’32β†’4 | 91,640 | 76,640 | 5.7 | 20.7 | 1.3 |
169
+ | box32 pruned R2 | 768β†’32β†’4 | 91,640 | ~62,000 | **5.9** | 20.4 | **1.5** |
170
+ | box32 pruned R3 | 768β†’32β†’4 | 91,640 | ~47,000 | 5.1 | 17.1 | 1.4 |
171
+ | dim20 (training) | 768β†’20β†’16β†’4 | 22,076 | 22,076 | pending | β€” | β€” |
172
+
173
+ Pruning improved mAP from 5.7 to 5.9 at R2 (~62K nonzero) by removing noisy prototype weights. R3 pushed past the degradation threshold. SVD analysis of the R2 prototypes showed effective rank ~20 for 72% energy retention, motivating the dim20 variant: a 768β†’20 bottleneck projection followed by 20β†’80 classification, initialized from the SVD vectors of the pruned prototypes. All variants are the smallest detection heads to produce standard COCO mAP numbers.
174
 
175
  ## Repository Structure
176
 
heads/cofiber_threshold/dim20_20k/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acd116819d27c1b4c8cfeece6195d6d22abe31b3382394c2af44ec509b7bf7ef
3
+ size 94325
heads/cofiber_threshold/dim20_20k/head.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cofiber Threshold with dimension selection: 768β†’20β†’80 classification.
2
+
3
+ The bottleneck dimension K=20 was selected from SVD analysis of the pruned
4
+ prototype matrix, where rank 20 captures 72% of the energy. This is the
5
+ information bottleneck variant applied to detection: how few feature dimensions
6
+ does the backbone need to expose for 80-class detection?
7
+
8
+ ~20K total params.
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def cofiber_decompose(f, n_scales):
18
+ cofibers = []
19
+ residual = f
20
+ for _ in range(n_scales - 1):
21
+ omega = F.avg_pool2d(residual, 2)
22
+ sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
23
+ cofibers.append(residual - sigma_omega)
24
+ residual = omega
25
+ cofibers.append(residual)
26
+ return cofibers
27
+
28
+
29
+ class CofiberThresholdDim20(nn.Module):
30
+ """Cofiber decomposition + 768β†’20 projection + 20β†’80 classification. ~20K params."""
31
+ name = "cofiber_threshold_dim20"
32
+ needs_intermediates = False
33
+
34
+ def __init__(self, feat_dim=768, bottleneck_dim=20, num_classes=80, n_scales=3, reg_hidden=16):
35
+ super().__init__()
36
+ self.n_scales = n_scales
37
+ self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
38
+ # Bottleneck projection
39
+ self.project = nn.Linear(feat_dim, bottleneck_dim, bias=False)
40
+ # Classification from bottleneck
41
+ self.cls_weight = nn.Parameter(torch.randn(num_classes, bottleneck_dim) * 0.01)
42
+ self.cls_bias = nn.Parameter(torch.zeros(num_classes))
43
+ # Box regression from bottleneck (small hidden layer)
44
+ self.reg_hidden = nn.Linear(bottleneck_dim, reg_hidden)
45
+ self.reg_act = nn.GELU()
46
+ self.reg_out = nn.Linear(reg_hidden, 4)
47
+ # Centerness from bottleneck
48
+ self.ctr_weight = nn.Parameter(torch.randn(1, bottleneck_dim) * 0.01)
49
+ self.ctr_bias = nn.Parameter(torch.zeros(1))
50
+ self.scale_params = nn.Parameter(torch.ones(n_scales))
51
+
52
+ def forward(self, spatial, inter=None):
53
+ cofibers = cofiber_decompose(spatial, self.n_scales)
54
+ cls_l, reg_l, ctr_l = [], [], []
55
+ for i, cof in enumerate(cofibers):
56
+ B, C, H, W = cof.shape
57
+ f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
58
+ z = self.project(f) # (N, 20)
59
+ cls = (z @ self.cls_weight.T + self.cls_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
60
+ reg_raw = (self.reg_out(self.reg_act(self.reg_hidden(z))) * self.scale_params[i]).clamp(-10, 10)
61
+ reg = torch.exp(reg_raw).reshape(B, H, W, 4).permute(0, 3, 1, 2)
62
+ ctr = (z @ self.ctr_weight.T + self.ctr_bias).reshape(B, H, W, 1).permute(0, 3, 1, 2)
63
+ cls_l.append(cls)
64
+ reg_l.append(reg)
65
+ ctr_l.append(ctr)
66
+ return cls_l, reg_l, ctr_l
heads/cofiber_threshold/dim20_20k/svd_init.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ec7e85396c35b2e0678220d2471286a7df583d4635b2553717fd107dd80a4b5
3
+ size 69741
heads/cofiber_threshold/dim20_20k/train.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Cofiber Threshold Dim20 (22K params) on full COCO 2017 train.
3
+
4
+ Initialized from SVD of the pruned prototype matrix β€” the projection starts
5
+ from the top-20 directions the pruned prototypes identified as important.
6
+
7
+ Same hyperparameters as box32: batch 64, lr 1e-3, cosine + 3% warmup, 8 epochs.
8
+ """
9
+
10
+ import math
11
+ import os
12
+ import sys
13
+ import time
14
+ import json
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from PIL import Image
20
+ from torch.utils.data import DataLoader, Dataset
21
+ from torchvision.transforms import v2
22
+ from torchvision.ops import generalized_box_iou
23
+
24
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.insert(0, os.path.join(SCRIPT_DIR, '..', '..', '..'))
26
+
27
+ EUPE_REPO = os.environ.get("ARENA_BACKBONE_REPO", "/home/zootest/EUPE")
28
+ EUPE_WEIGHTS = os.environ.get("ARENA_BACKBONE_WEIGHTS", "/home/zootest/weights/eupe_vitb/EUPE-ViT-B.pt")
29
+ COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "/home/zootest/datasets/coco")
30
+ OUTPUT_DIR = SCRIPT_DIR
31
+
32
+ if EUPE_REPO not in sys.path:
33
+ sys.path.insert(0, EUPE_REPO)
34
+
35
+ RESOLUTION = 640
36
+ NUM_CLASSES = 80
37
+ BATCH_SIZE = 64
38
+ LR = 1e-3
39
+ WEIGHT_DECAY = 1e-4
40
+ EPOCHS = 8
41
+ GRAD_CLIP = 5.0
42
+ WARMUP_FRACTION = 0.03
43
+
44
+ COCO_CONTIG_TO_CAT = [
45
+ 1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,
46
+ 33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,
47
+ 59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90,
48
+ ]
49
+ COCO_CAT_TO_CONTIG = {cat: i for i, cat in enumerate(COCO_CONTIG_TO_CAT)}
50
+
51
+
52
+ def letterbox(image, res):
53
+ W0, H0 = image.size
54
+ scale = res / max(H0, W0)
55
+ new_w, new_h = int(round(W0 * scale)), int(round(H0 * scale))
56
+ resized = image.resize((new_w, new_h), Image.BILINEAR)
57
+ canvas = Image.new("RGB", (res, res), (0, 0, 0))
58
+ canvas.paste(resized, (0, 0))
59
+ return canvas, scale
60
+
61
+
62
+ class COCODetection(Dataset):
63
+ def __init__(self, root, split="train"):
64
+ img_dir = os.path.join(root, f"{split}2017")
65
+ ann_file = os.path.join(root, "annotations", f"instances_{split}2017.json")
66
+ with open(ann_file) as f:
67
+ coco = json.load(f)
68
+ self.img_dir = img_dir
69
+ self.normalize = v2.Compose([
70
+ v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
71
+ v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
72
+ ])
73
+ id_to_anns = {}
74
+ for a in coco["annotations"]:
75
+ if a["iscrowd"]:
76
+ continue
77
+ cat = a["category_id"]
78
+ if cat not in COCO_CAT_TO_CONTIG:
79
+ continue
80
+ id_to_anns.setdefault(a["image_id"], []).append(a)
81
+ self.items = []
82
+ id_to_info = {img["id"]: img for img in coco["images"]}
83
+ for iid, anns in id_to_anns.items():
84
+ info = id_to_info[iid]
85
+ boxes, labels = [], []
86
+ for a in anns:
87
+ x, y, w, h = a["bbox"]
88
+ if w < 1 or h < 1:
89
+ continue
90
+ boxes.append([x, y, x + w, y + h])
91
+ labels.append(COCO_CAT_TO_CONTIG[a["category_id"]])
92
+ if boxes:
93
+ self.items.append({"file": info["file_name"], "boxes": boxes, "labels": labels})
94
+ print(f" COCO {split}: {len(self.items)} images", flush=True)
95
+
96
+ def __len__(self):
97
+ return len(self.items)
98
+
99
+ def __getitem__(self, idx):
100
+ item = self.items[idx]
101
+ img = Image.open(os.path.join(self.img_dir, item["file"])).convert("RGB")
102
+ canvas, scale = letterbox(img, RESOLUTION)
103
+ x = self.normalize(canvas)
104
+ boxes = torch.tensor(item["boxes"], dtype=torch.float32) * scale
105
+ labels = torch.tensor(item["labels"], dtype=torch.long)
106
+ return x, boxes, labels
107
+
108
+
109
+ def collate_fn(batch):
110
+ return torch.stack([b[0] for b in batch]), [b[1] for b in batch], [b[2] for b in batch]
111
+
112
+
113
+ # Inline head
114
+ def cofiber_decompose(f, n_scales):
115
+ cofibers = []
116
+ residual = f
117
+ for _ in range(n_scales - 1):
118
+ omega = F.avg_pool2d(residual, 2)
119
+ sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
120
+ cofibers.append(residual - sigma_omega)
121
+ residual = omega
122
+ cofibers.append(residual)
123
+ return cofibers
124
+
125
+
126
+ class CofiberThresholdDim20(nn.Module):
127
+ def __init__(self, feat_dim=768, bottleneck_dim=20, num_classes=80, n_scales=3, reg_hidden=16):
128
+ super().__init__()
129
+ self.n_scales = n_scales
130
+ self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
131
+ self.project = nn.Linear(feat_dim, bottleneck_dim, bias=False)
132
+ self.cls_weight = nn.Parameter(torch.randn(num_classes, bottleneck_dim) * 0.01)
133
+ self.cls_bias = nn.Parameter(torch.zeros(num_classes))
134
+ self.reg_hidden_layer = nn.Linear(bottleneck_dim, reg_hidden)
135
+ self.reg_act = nn.GELU()
136
+ self.reg_out = nn.Linear(reg_hidden, 4)
137
+ self.ctr_weight = nn.Parameter(torch.randn(1, bottleneck_dim) * 0.01)
138
+ self.ctr_bias = nn.Parameter(torch.zeros(1))
139
+ self.scale_params = nn.Parameter(torch.ones(n_scales))
140
+
141
+ def forward(self, spatial):
142
+ cofibers = cofiber_decompose(spatial, self.n_scales)
143
+ cls_l, reg_l, ctr_l = [], [], []
144
+ for i, cof in enumerate(cofibers):
145
+ B, C, H, W = cof.shape
146
+ f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
147
+ z = self.project(f)
148
+ cls = (z @ self.cls_weight.T + self.cls_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
149
+ reg_raw = (self.reg_out(self.reg_act(self.reg_hidden_layer(z))) * self.scale_params[i]).clamp(-10, 10)
150
+ reg = torch.exp(reg_raw).reshape(B, H, W, 4).permute(0, 3, 1, 2)
151
+ ctr = (z @ self.ctr_weight.T + self.ctr_bias).reshape(B, H, W, 1).permute(0, 3, 1, 2)
152
+ cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr)
153
+ return cls_l, reg_l, ctr_l
154
+
155
+
156
+ # Inline loss (same as other scripts)
157
+ def make_locations(feature_sizes, strides, device):
158
+ locs = []
159
+ for (h, w), s in zip(feature_sizes, strides):
160
+ ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
161
+ xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
162
+ gy, gx = torch.meshgrid(ys, xs, indexing="ij")
163
+ locs.append(torch.stack([gx.flatten(), gy.flatten()], -1))
164
+ return locs
165
+
166
+
167
+ def assign_targets(locations, boxes, labels, strides, size_ranges):
168
+ cls_t, reg_t, ctr_t = [], [], []
169
+ if boxes.numel() == 0:
170
+ for loc in locations:
171
+ n = loc.shape[0]
172
+ cls_t.append(torch.full((n,), -1, dtype=torch.long, device=loc.device))
173
+ reg_t.append(torch.zeros(n, 4, device=loc.device))
174
+ ctr_t.append(torch.zeros(n, device=loc.device))
175
+ return cls_t, reg_t, ctr_t
176
+ areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
177
+ for loc, stride, sr in zip(locations, strides, size_ranges):
178
+ n = loc.shape[0]
179
+ l = loc[:, None, 0] - boxes[None, :, 0]
180
+ t = loc[:, None, 1] - boxes[None, :, 1]
181
+ r = boxes[None, :, 2] - loc[:, None, 0]
182
+ b = boxes[None, :, 3] - loc[:, None, 1]
183
+ ltrb = torch.stack([l, t, r, b], dim=-1)
184
+ in_box = ltrb.min(dim=-1).values > 0
185
+ cx = (boxes[:, 0] + boxes[:, 2]) / 2
186
+ cy = (boxes[:, 1] + boxes[:, 3]) / 2
187
+ rad = stride * 1.5
188
+ in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) &
189
+ (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad))
190
+ max_d = ltrb.max(dim=-1).values
191
+ in_level = (max_d >= sr[0]) & (max_d <= sr[1])
192
+ pos = in_box & in_center & in_level
193
+ a = areas[None, :].expand_as(pos).clone()
194
+ a[~pos] = float("inf")
195
+ matched = a.argmin(dim=-1)
196
+ is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf")
197
+ ct = torch.full((n,), -1, dtype=torch.long, device=loc.device)
198
+ ct[is_pos] = labels[matched[is_pos]]
199
+ rt = torch.zeros(n, 4, device=loc.device)
200
+ if is_pos.any():
201
+ rt[is_pos] = ltrb[torch.arange(n, device=loc.device)[is_pos], matched[is_pos]]
202
+ ctrt = torch.zeros(n, device=loc.device)
203
+ if is_pos.any():
204
+ lp, tp, rp, bp = rt[is_pos].unbind(-1)
205
+ ctrt[is_pos] = torch.sqrt(
206
+ (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) *
207
+ (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6)))
208
+ cls_t.append(ct); reg_t.append(rt); ctr_t.append(ctrt)
209
+ return cls_t, reg_t, ctr_t
210
+
211
+
212
+ def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
213
+ p = torch.sigmoid(logits)
214
+ ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
215
+ pt = p * targets + (1 - p) * (1 - targets)
216
+ at = alpha * targets + (1 - alpha) * (1 - targets)
217
+ return (at * (1 - pt) ** gamma * ce).sum()
218
+
219
+
220
+ def compute_loss(cls_per, reg_per, ctr_per, locs_per, boxes_batch, labels_batch):
221
+ B = cls_per[0].shape[0]
222
+ device = cls_per[0].device
223
+ num_classes = cls_per[0].shape[1]
224
+ strides = [16, 32, 64]
225
+ size_ranges = [(-1, 128), (128, 256), (256, float("inf"))]
226
+ flat_cls, flat_reg, flat_ctr = [], [], []
227
+ for cl, rg, ct in zip(cls_per, reg_per, ctr_per):
228
+ b, c, h, w = cl.shape
229
+ flat_cls.append(cl.permute(0, 2, 3, 1).reshape(b, h * w, c))
230
+ flat_reg.append(rg.permute(0, 2, 3, 1).reshape(b, h * w, 4))
231
+ flat_ctr.append(ct.permute(0, 2, 3, 1).reshape(b, h * w))
232
+ pred_cls = torch.cat(flat_cls, 1)
233
+ pred_reg = torch.cat(flat_reg, 1)
234
+ pred_ctr = torch.cat(flat_ctr, 1)
235
+ all_locs = torch.cat(locs_per, 0)
236
+ all_ct, all_rt, all_ctt = [], [], []
237
+ for i in range(B):
238
+ ct, rt, ctt = assign_targets(locs_per, boxes_batch[i], labels_batch[i], strides, size_ranges)
239
+ all_ct.append(torch.cat(ct)); all_rt.append(torch.cat(rt)); all_ctt.append(torch.cat(ctt))
240
+ tgt_cls = torch.stack(all_ct)
241
+ tgt_reg = torch.stack(all_rt)
242
+ tgt_ctr = torch.stack(all_ctt)
243
+ pos = tgt_cls >= 0
244
+ npos = max(pos.sum().item(), 1)
245
+ oh = torch.zeros_like(pred_cls)
246
+ pi = pos.nonzero(as_tuple=True)
247
+ oh[pi[0], pi[1], tgt_cls[pos]] = 1.0
248
+ loss_cls = focal_loss(pred_cls.reshape(-1, num_classes), oh.reshape(-1, num_classes)) / npos
249
+ if pos.any():
250
+ pp = pred_reg[pos]; tp = tgt_reg[pos]; pl = all_locs[None].expand(B, -1, -1)[pos]
251
+ pb = torch.stack([pl[:,0]-pp[:,0], pl[:,1]-pp[:,1], pl[:,0]+pp[:,2], pl[:,1]+pp[:,3]], -1)
252
+ tb = torch.stack([pl[:,0]-tp[:,0], pl[:,1]-tp[:,1], pl[:,0]+tp[:,2], pl[:,1]+tp[:,3]], -1)
253
+ giou = generalized_box_iou(pb, tb)
254
+ loss_reg = (1 - giou.diagonal()).sum() / npos
255
+ loss_ctr = F.binary_cross_entropy_with_logits(pred_ctr[pos], tgt_ctr[pos], reduction="sum") / npos
256
+ else:
257
+ loss_reg = loss_ctr = torch.tensor(0.0, device=device)
258
+ return loss_cls + loss_reg + loss_ctr
259
+
260
+
261
+ def train():
262
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
263
+ print("=" * 60)
264
+ print("Cofiber Threshold Dim20: 22K params, SVD-initialized, 8 epochs")
265
+ print("=" * 60, flush=True)
266
+
267
+ print("\n[1/4] Loading backbone...", flush=True)
268
+ backbone = torch.hub.load(EUPE_REPO, "eupe_vitb16", source="local", weights=EUPE_WEIGHTS)
269
+ backbone = backbone.cuda().eval()
270
+ for p in backbone.parameters():
271
+ p.requires_grad = False
272
+
273
+ print("\n[2/4] Building head with SVD initialization...", flush=True)
274
+ head = CofiberThresholdDim20().cuda()
275
+
276
+ # Initialize from SVD of pruned prototypes
277
+ svd_init_path = os.path.join(SCRIPT_DIR, "svd_init.pt")
278
+ if os.path.isfile(svd_init_path):
279
+ svd_init = torch.load(svd_init_path, map_location="cuda")
280
+ head.project.weight.data = svd_init["project"]
281
+ head.cls_weight.data = svd_init["cls_weight"]
282
+ print(" SVD initialization loaded", flush=True)
283
+ else:
284
+ print(" No SVD init found, using random", flush=True)
285
+
286
+ n_params = sum(p.numel() for p in head.parameters())
287
+ print(f" {n_params:,} params", flush=True)
288
+
289
+ print("\n[3/4] Loading COCO...", flush=True)
290
+ train_ds = COCODetection(COCO_ROOT, "train")
291
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
292
+ num_workers=4, pin_memory=True, drop_last=True, collate_fn=collate_fn)
293
+ steps_per_epoch = len(train_loader)
294
+ total_steps = steps_per_epoch * EPOCHS
295
+ warmup_steps = int(total_steps * WARMUP_FRACTION)
296
+ print(f" {len(train_ds)} images, {steps_per_epoch} steps/epoch, {total_steps} total", flush=True)
297
+
298
+ optimizer = torch.optim.AdamW(head.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
299
+ def lr_lambda(step):
300
+ if step < warmup_steps:
301
+ return step / max(warmup_steps, 1)
302
+ progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
303
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
304
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
305
+
306
+ strides = [16, 32, 64]
307
+ H = RESOLUTION // 16
308
+ locs = make_locations([(H, H), (H//2, H//2), (H//4, H//4)], strides, torch.device("cuda"))
309
+
310
+ print(f"\n[4/4] Training...", flush=True)
311
+ log_file = open(os.path.join(OUTPUT_DIR, "train.log"), "a")
312
+ head.train()
313
+ global_step = 0
314
+ running_loss = 0.0
315
+ running_count = 0
316
+ t0 = time.time()
317
+
318
+ for epoch in range(EPOCHS):
319
+ print(f"\n Epoch {epoch+1}/{EPOCHS} starting (step {global_step})", flush=True)
320
+ for images, boxes_b, labels_b in train_loader:
321
+ if global_step >= total_steps:
322
+ break
323
+ images = images.cuda(non_blocking=True)
324
+ boxes_b = [b.cuda(non_blocking=True) for b in boxes_b]
325
+ labels_b = [l.cuda(non_blocking=True) for l in labels_b]
326
+ try:
327
+ with torch.no_grad():
328
+ with torch.autocast("cuda", dtype=torch.bfloat16):
329
+ out = backbone.forward_features(images)
330
+ patches = out["x_norm_patchtokens"].float()
331
+ B, N, D = patches.shape
332
+ h = w = int(N ** 0.5)
333
+ spatial = patches.permute(0, 2, 1).reshape(B, D, h, w)
334
+ cls_l, reg_l, ctr_l = head(spatial)
335
+ loss = compute_loss(cls_l, reg_l, ctr_l, locs, boxes_b, labels_b)
336
+ if torch.isnan(loss) or torch.isinf(loss):
337
+ print(f" WARNING: NaN/Inf loss at step {global_step}", flush=True)
338
+ optimizer.zero_grad(); global_step += 1; scheduler.step(); continue
339
+ optimizer.zero_grad()
340
+ loss.backward()
341
+ torch.nn.utils.clip_grad_norm_(head.parameters(), GRAD_CLIP)
342
+ optimizer.step()
343
+ scheduler.step()
344
+ global_step += 1
345
+ running_loss += loss.item()
346
+ running_count += 1
347
+ if global_step % 100 == 0:
348
+ elapsed = time.time() - t0
349
+ avg = running_loss / max(running_count, 1)
350
+ lr_now = scheduler.get_last_lr()[0]
351
+ msg = f"step {global_step}/{total_steps} (epoch {epoch+1}) loss={loss.item():.4f} avg={avg:.4f} lr={lr_now:.2e} {running_count/elapsed:.1f} it/s"
352
+ print(msg, flush=True)
353
+ log_file.write(msg + "\n"); log_file.flush()
354
+ if global_step % 1000 == 0:
355
+ torch.save({"head": head.state_dict(), "global_step": global_step},
356
+ os.path.join(OUTPUT_DIR, "checkpoint.pth"))
357
+ print(f" Checkpoint saved at step {global_step}", flush=True)
358
+ except Exception as e:
359
+ import traceback
360
+ print(f"\n ERROR at step {global_step}: {e}", flush=True)
361
+ traceback.print_exc()
362
+ if "out of memory" in str(e):
363
+ torch.cuda.empty_cache(); optimizer.zero_grad(); global_step += 1; scheduler.step(); continue
364
+ raise
365
+ print(f" Epoch {epoch+1}/{EPOCHS} complete (step {global_step})", flush=True)
366
+
367
+ final_path = os.path.join(OUTPUT_DIR, "cofiber_threshold_dim20_coco_8ep_22k.pth")
368
+ torch.save(head.state_dict(), final_path)
369
+ print(f"\nSaved: {final_path}")
370
+ print(f"Training complete: {total_steps} steps, {(time.time()-t0)/3600:.1f} hours", flush=True)
371
+ log_file.close()
372
+
373
+
374
+ if __name__ == "__main__":
375
+ train()
heads/cofiber_threshold_person/linear_9k/train.py CHANGED
@@ -22,7 +22,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
22
 
23
  EUPE_REPO = os.environ.get("ARENA_BACKBONE_REPO", "/home/zootest/EUPE")
24
  EUPE_WEIGHTS = os.environ.get("ARENA_BACKBONE_WEIGHTS", "/home/zootest/weights/eupe_vitb/EUPE-ViT-B.pt")
25
- COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "/mnt/d/JacobProject/datasets/llava_instruct/coco")
26
  OUTPUT_DIR = os.path.join(os.path.dirname(__file__))
27
 
28
  if EUPE_REPO not in sys.path:
 
22
 
23
  EUPE_REPO = os.environ.get("ARENA_BACKBONE_REPO", "/home/zootest/EUPE")
24
  EUPE_WEIGHTS = os.environ.get("ARENA_BACKBONE_WEIGHTS", "/home/zootest/weights/eupe_vitb/EUPE-ViT-B.pt")
25
+ COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "/home/zootest/datasets/coco")
26
  OUTPUT_DIR = os.path.join(os.path.dirname(__file__))
27
 
28
  if EUPE_REPO not in sys.path: