Update README with pruning curve and dim20, fix person detector ext4 path
Browse files
README.md
CHANGED
|
@@ -161,13 +161,16 @@ The original design leads on precision. Additional scales, adaptive boundaries,
|
|
| 161 |
|
| 162 |
Two Cofiber Threshold variants trained on full COCO 2017 train (117,266 images), 8 epochs, batch 64, AdamW lr 1e-3, cosine schedule with 3% warmup. Frozen EUPE-ViT-B backbone. Evaluated with pycocotools on the standard 5000-image val set.
|
| 163 |
|
| 164 |
-
| Variant | Box regression | Params | Nonzero | mAP@[0.5:0.95] | mAP@0.50 | mAP@0.75 |
|
| 165 |
-
|---------|---------------|--------|---------|----------------|----------|----------|
|
| 166 |
-
| linear_70k | 768β4 | 69,976 | 69,976 | 4.0 | 15.8 | 0.8 |
|
| 167 |
-
| box32_92k | 768β32β4 | 91,640 | 91,640 | 5.7 | 20.6 | 1.3 |
|
| 168 |
-
|
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
## Repository Structure
|
| 173 |
|
|
|
|
| 161 |
|
| 162 |
Two Cofiber Threshold variants trained on full COCO 2017 train (117,266 images), 8 epochs, batch 64, AdamW lr 1e-3, cosine schedule with 3% warmup. Frozen EUPE-ViT-B backbone. Evaluated with pycocotools on the standard 5000-image val set.
|
| 163 |
|
| 164 |
+
| Variant | Box regression | Params | Nonzero | mAP@[0.5:0.95] | mAP@0.50 | mAP@0.75 |
|
| 165 |
+
|---------|---------------|--------|---------|----------------|----------|----------|
|
| 166 |
+
| linear_70k | 768β4 | 69,976 | 69,976 | 4.0 | 15.8 | 0.8 |
|
| 167 |
+
| box32_92k | 768β32β4 | 91,640 | 91,640 | 5.7 | 20.6 | 1.3 |
|
| 168 |
+
| box32 pruned R1 | 768β32β4 | 91,640 | 76,640 | 5.7 | 20.7 | 1.3 |
|
| 169 |
+
| box32 pruned R2 | 768β32β4 | 91,640 | ~62,000 | **5.9** | 20.4 | **1.5** |
|
| 170 |
+
| box32 pruned R3 | 768β32β4 | 91,640 | ~47,000 | 5.1 | 17.1 | 1.4 |
|
| 171 |
+
| dim20 (training) | 768β20β16β4 | 22,076 | 22,076 | pending | β | β |
|
| 172 |
+
|
| 173 |
+
Pruning improved mAP from 5.7 to 5.9 at R2 (~62K nonzero) by removing noisy prototype weights. R3 pushed past the degradation threshold. SVD analysis of the R2 prototypes showed effective rank ~20 for 72% energy retention, motivating the dim20 variant: a 768β20 bottleneck projection followed by 20β80 classification, initialized from the SVD vectors of the pruned prototypes. All variants are the smallest detection heads to produce standard COCO mAP numbers.
|
| 174 |
|
| 175 |
## Repository Structure
|
| 176 |
|
heads/cofiber_threshold/dim20_20k/checkpoint.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:acd116819d27c1b4c8cfeece6195d6d22abe31b3382394c2af44ec509b7bf7ef
|
| 3 |
+
size 94325
|
heads/cofiber_threshold/dim20_20k/head.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cofiber Threshold with dimension selection: 768β20β80 classification.
|
| 2 |
+
|
| 3 |
+
The bottleneck dimension K=20 was selected from SVD analysis of the pruned
|
| 4 |
+
prototype matrix, where rank 20 captures 72% of the energy. This is the
|
| 5 |
+
information bottleneck variant applied to detection: how few feature dimensions
|
| 6 |
+
does the backbone need to expose for 80-class detection?
|
| 7 |
+
|
| 8 |
+
~20K total params.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def cofiber_decompose(f, n_scales):
|
| 18 |
+
cofibers = []
|
| 19 |
+
residual = f
|
| 20 |
+
for _ in range(n_scales - 1):
|
| 21 |
+
omega = F.avg_pool2d(residual, 2)
|
| 22 |
+
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
|
| 23 |
+
cofibers.append(residual - sigma_omega)
|
| 24 |
+
residual = omega
|
| 25 |
+
cofibers.append(residual)
|
| 26 |
+
return cofibers
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CofiberThresholdDim20(nn.Module):
|
| 30 |
+
"""Cofiber decomposition + 768β20 projection + 20β80 classification. ~20K params."""
|
| 31 |
+
name = "cofiber_threshold_dim20"
|
| 32 |
+
needs_intermediates = False
|
| 33 |
+
|
| 34 |
+
def __init__(self, feat_dim=768, bottleneck_dim=20, num_classes=80, n_scales=3, reg_hidden=16):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.n_scales = n_scales
|
| 37 |
+
self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
|
| 38 |
+
# Bottleneck projection
|
| 39 |
+
self.project = nn.Linear(feat_dim, bottleneck_dim, bias=False)
|
| 40 |
+
# Classification from bottleneck
|
| 41 |
+
self.cls_weight = nn.Parameter(torch.randn(num_classes, bottleneck_dim) * 0.01)
|
| 42 |
+
self.cls_bias = nn.Parameter(torch.zeros(num_classes))
|
| 43 |
+
# Box regression from bottleneck (small hidden layer)
|
| 44 |
+
self.reg_hidden = nn.Linear(bottleneck_dim, reg_hidden)
|
| 45 |
+
self.reg_act = nn.GELU()
|
| 46 |
+
self.reg_out = nn.Linear(reg_hidden, 4)
|
| 47 |
+
# Centerness from bottleneck
|
| 48 |
+
self.ctr_weight = nn.Parameter(torch.randn(1, bottleneck_dim) * 0.01)
|
| 49 |
+
self.ctr_bias = nn.Parameter(torch.zeros(1))
|
| 50 |
+
self.scale_params = nn.Parameter(torch.ones(n_scales))
|
| 51 |
+
|
| 52 |
+
def forward(self, spatial, inter=None):
|
| 53 |
+
cofibers = cofiber_decompose(spatial, self.n_scales)
|
| 54 |
+
cls_l, reg_l, ctr_l = [], [], []
|
| 55 |
+
for i, cof in enumerate(cofibers):
|
| 56 |
+
B, C, H, W = cof.shape
|
| 57 |
+
f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
|
| 58 |
+
z = self.project(f) # (N, 20)
|
| 59 |
+
cls = (z @ self.cls_weight.T + self.cls_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
|
| 60 |
+
reg_raw = (self.reg_out(self.reg_act(self.reg_hidden(z))) * self.scale_params[i]).clamp(-10, 10)
|
| 61 |
+
reg = torch.exp(reg_raw).reshape(B, H, W, 4).permute(0, 3, 1, 2)
|
| 62 |
+
ctr = (z @ self.ctr_weight.T + self.ctr_bias).reshape(B, H, W, 1).permute(0, 3, 1, 2)
|
| 63 |
+
cls_l.append(cls)
|
| 64 |
+
reg_l.append(reg)
|
| 65 |
+
ctr_l.append(ctr)
|
| 66 |
+
return cls_l, reg_l, ctr_l
|
heads/cofiber_threshold/dim20_20k/svd_init.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ec7e85396c35b2e0678220d2471286a7df583d4635b2553717fd107dd80a4b5
|
| 3 |
+
size 69741
|
heads/cofiber_threshold/dim20_20k/train.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Cofiber Threshold Dim20 (22K params) on full COCO 2017 train.
|
| 3 |
+
|
| 4 |
+
Initialized from SVD of the pruned prototype matrix β the projection starts
|
| 5 |
+
from the top-20 directions the pruned prototypes identified as important.
|
| 6 |
+
|
| 7 |
+
Same hyperparameters as box32: batch 64, lr 1e-3, cosine + 3% warmup, 8 epochs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torch.utils.data import DataLoader, Dataset
|
| 21 |
+
from torchvision.transforms import v2
|
| 22 |
+
from torchvision.ops import generalized_box_iou
|
| 23 |
+
|
| 24 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
sys.path.insert(0, os.path.join(SCRIPT_DIR, '..', '..', '..'))
|
| 26 |
+
|
| 27 |
+
EUPE_REPO = os.environ.get("ARENA_BACKBONE_REPO", "/home/zootest/EUPE")
|
| 28 |
+
EUPE_WEIGHTS = os.environ.get("ARENA_BACKBONE_WEIGHTS", "/home/zootest/weights/eupe_vitb/EUPE-ViT-B.pt")
|
| 29 |
+
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "/home/zootest/datasets/coco")
|
| 30 |
+
OUTPUT_DIR = SCRIPT_DIR
|
| 31 |
+
|
| 32 |
+
if EUPE_REPO not in sys.path:
|
| 33 |
+
sys.path.insert(0, EUPE_REPO)
|
| 34 |
+
|
| 35 |
+
RESOLUTION = 640
|
| 36 |
+
NUM_CLASSES = 80
|
| 37 |
+
BATCH_SIZE = 64
|
| 38 |
+
LR = 1e-3
|
| 39 |
+
WEIGHT_DECAY = 1e-4
|
| 40 |
+
EPOCHS = 8
|
| 41 |
+
GRAD_CLIP = 5.0
|
| 42 |
+
WARMUP_FRACTION = 0.03
|
| 43 |
+
|
| 44 |
+
COCO_CONTIG_TO_CAT = [
|
| 45 |
+
1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,
|
| 46 |
+
33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,
|
| 47 |
+
59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90,
|
| 48 |
+
]
|
| 49 |
+
COCO_CAT_TO_CONTIG = {cat: i for i, cat in enumerate(COCO_CONTIG_TO_CAT)}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def letterbox(image, res):
|
| 53 |
+
W0, H0 = image.size
|
| 54 |
+
scale = res / max(H0, W0)
|
| 55 |
+
new_w, new_h = int(round(W0 * scale)), int(round(H0 * scale))
|
| 56 |
+
resized = image.resize((new_w, new_h), Image.BILINEAR)
|
| 57 |
+
canvas = Image.new("RGB", (res, res), (0, 0, 0))
|
| 58 |
+
canvas.paste(resized, (0, 0))
|
| 59 |
+
return canvas, scale
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class COCODetection(Dataset):
|
| 63 |
+
def __init__(self, root, split="train"):
|
| 64 |
+
img_dir = os.path.join(root, f"{split}2017")
|
| 65 |
+
ann_file = os.path.join(root, "annotations", f"instances_{split}2017.json")
|
| 66 |
+
with open(ann_file) as f:
|
| 67 |
+
coco = json.load(f)
|
| 68 |
+
self.img_dir = img_dir
|
| 69 |
+
self.normalize = v2.Compose([
|
| 70 |
+
v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
|
| 71 |
+
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 72 |
+
])
|
| 73 |
+
id_to_anns = {}
|
| 74 |
+
for a in coco["annotations"]:
|
| 75 |
+
if a["iscrowd"]:
|
| 76 |
+
continue
|
| 77 |
+
cat = a["category_id"]
|
| 78 |
+
if cat not in COCO_CAT_TO_CONTIG:
|
| 79 |
+
continue
|
| 80 |
+
id_to_anns.setdefault(a["image_id"], []).append(a)
|
| 81 |
+
self.items = []
|
| 82 |
+
id_to_info = {img["id"]: img for img in coco["images"]}
|
| 83 |
+
for iid, anns in id_to_anns.items():
|
| 84 |
+
info = id_to_info[iid]
|
| 85 |
+
boxes, labels = [], []
|
| 86 |
+
for a in anns:
|
| 87 |
+
x, y, w, h = a["bbox"]
|
| 88 |
+
if w < 1 or h < 1:
|
| 89 |
+
continue
|
| 90 |
+
boxes.append([x, y, x + w, y + h])
|
| 91 |
+
labels.append(COCO_CAT_TO_CONTIG[a["category_id"]])
|
| 92 |
+
if boxes:
|
| 93 |
+
self.items.append({"file": info["file_name"], "boxes": boxes, "labels": labels})
|
| 94 |
+
print(f" COCO {split}: {len(self.items)} images", flush=True)
|
| 95 |
+
|
| 96 |
+
def __len__(self):
|
| 97 |
+
return len(self.items)
|
| 98 |
+
|
| 99 |
+
def __getitem__(self, idx):
|
| 100 |
+
item = self.items[idx]
|
| 101 |
+
img = Image.open(os.path.join(self.img_dir, item["file"])).convert("RGB")
|
| 102 |
+
canvas, scale = letterbox(img, RESOLUTION)
|
| 103 |
+
x = self.normalize(canvas)
|
| 104 |
+
boxes = torch.tensor(item["boxes"], dtype=torch.float32) * scale
|
| 105 |
+
labels = torch.tensor(item["labels"], dtype=torch.long)
|
| 106 |
+
return x, boxes, labels
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def collate_fn(batch):
|
| 110 |
+
return torch.stack([b[0] for b in batch]), [b[1] for b in batch], [b[2] for b in batch]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# Inline head
|
| 114 |
+
def cofiber_decompose(f, n_scales):
|
| 115 |
+
cofibers = []
|
| 116 |
+
residual = f
|
| 117 |
+
for _ in range(n_scales - 1):
|
| 118 |
+
omega = F.avg_pool2d(residual, 2)
|
| 119 |
+
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
|
| 120 |
+
cofibers.append(residual - sigma_omega)
|
| 121 |
+
residual = omega
|
| 122 |
+
cofibers.append(residual)
|
| 123 |
+
return cofibers
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class CofiberThresholdDim20(nn.Module):
|
| 127 |
+
def __init__(self, feat_dim=768, bottleneck_dim=20, num_classes=80, n_scales=3, reg_hidden=16):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.n_scales = n_scales
|
| 130 |
+
self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
|
| 131 |
+
self.project = nn.Linear(feat_dim, bottleneck_dim, bias=False)
|
| 132 |
+
self.cls_weight = nn.Parameter(torch.randn(num_classes, bottleneck_dim) * 0.01)
|
| 133 |
+
self.cls_bias = nn.Parameter(torch.zeros(num_classes))
|
| 134 |
+
self.reg_hidden_layer = nn.Linear(bottleneck_dim, reg_hidden)
|
| 135 |
+
self.reg_act = nn.GELU()
|
| 136 |
+
self.reg_out = nn.Linear(reg_hidden, 4)
|
| 137 |
+
self.ctr_weight = nn.Parameter(torch.randn(1, bottleneck_dim) * 0.01)
|
| 138 |
+
self.ctr_bias = nn.Parameter(torch.zeros(1))
|
| 139 |
+
self.scale_params = nn.Parameter(torch.ones(n_scales))
|
| 140 |
+
|
| 141 |
+
def forward(self, spatial):
|
| 142 |
+
cofibers = cofiber_decompose(spatial, self.n_scales)
|
| 143 |
+
cls_l, reg_l, ctr_l = [], [], []
|
| 144 |
+
for i, cof in enumerate(cofibers):
|
| 145 |
+
B, C, H, W = cof.shape
|
| 146 |
+
f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
|
| 147 |
+
z = self.project(f)
|
| 148 |
+
cls = (z @ self.cls_weight.T + self.cls_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
|
| 149 |
+
reg_raw = (self.reg_out(self.reg_act(self.reg_hidden_layer(z))) * self.scale_params[i]).clamp(-10, 10)
|
| 150 |
+
reg = torch.exp(reg_raw).reshape(B, H, W, 4).permute(0, 3, 1, 2)
|
| 151 |
+
ctr = (z @ self.ctr_weight.T + self.ctr_bias).reshape(B, H, W, 1).permute(0, 3, 1, 2)
|
| 152 |
+
cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr)
|
| 153 |
+
return cls_l, reg_l, ctr_l
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# Inline loss (same as other scripts)
|
| 157 |
+
def make_locations(feature_sizes, strides, device):
|
| 158 |
+
locs = []
|
| 159 |
+
for (h, w), s in zip(feature_sizes, strides):
|
| 160 |
+
ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
|
| 161 |
+
xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
|
| 162 |
+
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
|
| 163 |
+
locs.append(torch.stack([gx.flatten(), gy.flatten()], -1))
|
| 164 |
+
return locs
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def assign_targets(locations, boxes, labels, strides, size_ranges):
|
| 168 |
+
cls_t, reg_t, ctr_t = [], [], []
|
| 169 |
+
if boxes.numel() == 0:
|
| 170 |
+
for loc in locations:
|
| 171 |
+
n = loc.shape[0]
|
| 172 |
+
cls_t.append(torch.full((n,), -1, dtype=torch.long, device=loc.device))
|
| 173 |
+
reg_t.append(torch.zeros(n, 4, device=loc.device))
|
| 174 |
+
ctr_t.append(torch.zeros(n, device=loc.device))
|
| 175 |
+
return cls_t, reg_t, ctr_t
|
| 176 |
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 177 |
+
for loc, stride, sr in zip(locations, strides, size_ranges):
|
| 178 |
+
n = loc.shape[0]
|
| 179 |
+
l = loc[:, None, 0] - boxes[None, :, 0]
|
| 180 |
+
t = loc[:, None, 1] - boxes[None, :, 1]
|
| 181 |
+
r = boxes[None, :, 2] - loc[:, None, 0]
|
| 182 |
+
b = boxes[None, :, 3] - loc[:, None, 1]
|
| 183 |
+
ltrb = torch.stack([l, t, r, b], dim=-1)
|
| 184 |
+
in_box = ltrb.min(dim=-1).values > 0
|
| 185 |
+
cx = (boxes[:, 0] + boxes[:, 2]) / 2
|
| 186 |
+
cy = (boxes[:, 1] + boxes[:, 3]) / 2
|
| 187 |
+
rad = stride * 1.5
|
| 188 |
+
in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) &
|
| 189 |
+
(loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad))
|
| 190 |
+
max_d = ltrb.max(dim=-1).values
|
| 191 |
+
in_level = (max_d >= sr[0]) & (max_d <= sr[1])
|
| 192 |
+
pos = in_box & in_center & in_level
|
| 193 |
+
a = areas[None, :].expand_as(pos).clone()
|
| 194 |
+
a[~pos] = float("inf")
|
| 195 |
+
matched = a.argmin(dim=-1)
|
| 196 |
+
is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf")
|
| 197 |
+
ct = torch.full((n,), -1, dtype=torch.long, device=loc.device)
|
| 198 |
+
ct[is_pos] = labels[matched[is_pos]]
|
| 199 |
+
rt = torch.zeros(n, 4, device=loc.device)
|
| 200 |
+
if is_pos.any():
|
| 201 |
+
rt[is_pos] = ltrb[torch.arange(n, device=loc.device)[is_pos], matched[is_pos]]
|
| 202 |
+
ctrt = torch.zeros(n, device=loc.device)
|
| 203 |
+
if is_pos.any():
|
| 204 |
+
lp, tp, rp, bp = rt[is_pos].unbind(-1)
|
| 205 |
+
ctrt[is_pos] = torch.sqrt(
|
| 206 |
+
(torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) *
|
| 207 |
+
(torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6)))
|
| 208 |
+
cls_t.append(ct); reg_t.append(rt); ctr_t.append(ctrt)
|
| 209 |
+
return cls_t, reg_t, ctr_t
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
|
| 213 |
+
p = torch.sigmoid(logits)
|
| 214 |
+
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
| 215 |
+
pt = p * targets + (1 - p) * (1 - targets)
|
| 216 |
+
at = alpha * targets + (1 - alpha) * (1 - targets)
|
| 217 |
+
return (at * (1 - pt) ** gamma * ce).sum()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def compute_loss(cls_per, reg_per, ctr_per, locs_per, boxes_batch, labels_batch):
|
| 221 |
+
B = cls_per[0].shape[0]
|
| 222 |
+
device = cls_per[0].device
|
| 223 |
+
num_classes = cls_per[0].shape[1]
|
| 224 |
+
strides = [16, 32, 64]
|
| 225 |
+
size_ranges = [(-1, 128), (128, 256), (256, float("inf"))]
|
| 226 |
+
flat_cls, flat_reg, flat_ctr = [], [], []
|
| 227 |
+
for cl, rg, ct in zip(cls_per, reg_per, ctr_per):
|
| 228 |
+
b, c, h, w = cl.shape
|
| 229 |
+
flat_cls.append(cl.permute(0, 2, 3, 1).reshape(b, h * w, c))
|
| 230 |
+
flat_reg.append(rg.permute(0, 2, 3, 1).reshape(b, h * w, 4))
|
| 231 |
+
flat_ctr.append(ct.permute(0, 2, 3, 1).reshape(b, h * w))
|
| 232 |
+
pred_cls = torch.cat(flat_cls, 1)
|
| 233 |
+
pred_reg = torch.cat(flat_reg, 1)
|
| 234 |
+
pred_ctr = torch.cat(flat_ctr, 1)
|
| 235 |
+
all_locs = torch.cat(locs_per, 0)
|
| 236 |
+
all_ct, all_rt, all_ctt = [], [], []
|
| 237 |
+
for i in range(B):
|
| 238 |
+
ct, rt, ctt = assign_targets(locs_per, boxes_batch[i], labels_batch[i], strides, size_ranges)
|
| 239 |
+
all_ct.append(torch.cat(ct)); all_rt.append(torch.cat(rt)); all_ctt.append(torch.cat(ctt))
|
| 240 |
+
tgt_cls = torch.stack(all_ct)
|
| 241 |
+
tgt_reg = torch.stack(all_rt)
|
| 242 |
+
tgt_ctr = torch.stack(all_ctt)
|
| 243 |
+
pos = tgt_cls >= 0
|
| 244 |
+
npos = max(pos.sum().item(), 1)
|
| 245 |
+
oh = torch.zeros_like(pred_cls)
|
| 246 |
+
pi = pos.nonzero(as_tuple=True)
|
| 247 |
+
oh[pi[0], pi[1], tgt_cls[pos]] = 1.0
|
| 248 |
+
loss_cls = focal_loss(pred_cls.reshape(-1, num_classes), oh.reshape(-1, num_classes)) / npos
|
| 249 |
+
if pos.any():
|
| 250 |
+
pp = pred_reg[pos]; tp = tgt_reg[pos]; pl = all_locs[None].expand(B, -1, -1)[pos]
|
| 251 |
+
pb = torch.stack([pl[:,0]-pp[:,0], pl[:,1]-pp[:,1], pl[:,0]+pp[:,2], pl[:,1]+pp[:,3]], -1)
|
| 252 |
+
tb = torch.stack([pl[:,0]-tp[:,0], pl[:,1]-tp[:,1], pl[:,0]+tp[:,2], pl[:,1]+tp[:,3]], -1)
|
| 253 |
+
giou = generalized_box_iou(pb, tb)
|
| 254 |
+
loss_reg = (1 - giou.diagonal()).sum() / npos
|
| 255 |
+
loss_ctr = F.binary_cross_entropy_with_logits(pred_ctr[pos], tgt_ctr[pos], reduction="sum") / npos
|
| 256 |
+
else:
|
| 257 |
+
loss_reg = loss_ctr = torch.tensor(0.0, device=device)
|
| 258 |
+
return loss_cls + loss_reg + loss_ctr
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def train():
|
| 262 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 263 |
+
print("=" * 60)
|
| 264 |
+
print("Cofiber Threshold Dim20: 22K params, SVD-initialized, 8 epochs")
|
| 265 |
+
print("=" * 60, flush=True)
|
| 266 |
+
|
| 267 |
+
print("\n[1/4] Loading backbone...", flush=True)
|
| 268 |
+
backbone = torch.hub.load(EUPE_REPO, "eupe_vitb16", source="local", weights=EUPE_WEIGHTS)
|
| 269 |
+
backbone = backbone.cuda().eval()
|
| 270 |
+
for p in backbone.parameters():
|
| 271 |
+
p.requires_grad = False
|
| 272 |
+
|
| 273 |
+
print("\n[2/4] Building head with SVD initialization...", flush=True)
|
| 274 |
+
head = CofiberThresholdDim20().cuda()
|
| 275 |
+
|
| 276 |
+
# Initialize from SVD of pruned prototypes
|
| 277 |
+
svd_init_path = os.path.join(SCRIPT_DIR, "svd_init.pt")
|
| 278 |
+
if os.path.isfile(svd_init_path):
|
| 279 |
+
svd_init = torch.load(svd_init_path, map_location="cuda")
|
| 280 |
+
head.project.weight.data = svd_init["project"]
|
| 281 |
+
head.cls_weight.data = svd_init["cls_weight"]
|
| 282 |
+
print(" SVD initialization loaded", flush=True)
|
| 283 |
+
else:
|
| 284 |
+
print(" No SVD init found, using random", flush=True)
|
| 285 |
+
|
| 286 |
+
n_params = sum(p.numel() for p in head.parameters())
|
| 287 |
+
print(f" {n_params:,} params", flush=True)
|
| 288 |
+
|
| 289 |
+
print("\n[3/4] Loading COCO...", flush=True)
|
| 290 |
+
train_ds = COCODetection(COCO_ROOT, "train")
|
| 291 |
+
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
|
| 292 |
+
num_workers=4, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
| 293 |
+
steps_per_epoch = len(train_loader)
|
| 294 |
+
total_steps = steps_per_epoch * EPOCHS
|
| 295 |
+
warmup_steps = int(total_steps * WARMUP_FRACTION)
|
| 296 |
+
print(f" {len(train_ds)} images, {steps_per_epoch} steps/epoch, {total_steps} total", flush=True)
|
| 297 |
+
|
| 298 |
+
optimizer = torch.optim.AdamW(head.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
|
| 299 |
+
def lr_lambda(step):
|
| 300 |
+
if step < warmup_steps:
|
| 301 |
+
return step / max(warmup_steps, 1)
|
| 302 |
+
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
| 303 |
+
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 304 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 305 |
+
|
| 306 |
+
strides = [16, 32, 64]
|
| 307 |
+
H = RESOLUTION // 16
|
| 308 |
+
locs = make_locations([(H, H), (H//2, H//2), (H//4, H//4)], strides, torch.device("cuda"))
|
| 309 |
+
|
| 310 |
+
print(f"\n[4/4] Training...", flush=True)
|
| 311 |
+
log_file = open(os.path.join(OUTPUT_DIR, "train.log"), "a")
|
| 312 |
+
head.train()
|
| 313 |
+
global_step = 0
|
| 314 |
+
running_loss = 0.0
|
| 315 |
+
running_count = 0
|
| 316 |
+
t0 = time.time()
|
| 317 |
+
|
| 318 |
+
for epoch in range(EPOCHS):
|
| 319 |
+
print(f"\n Epoch {epoch+1}/{EPOCHS} starting (step {global_step})", flush=True)
|
| 320 |
+
for images, boxes_b, labels_b in train_loader:
|
| 321 |
+
if global_step >= total_steps:
|
| 322 |
+
break
|
| 323 |
+
images = images.cuda(non_blocking=True)
|
| 324 |
+
boxes_b = [b.cuda(non_blocking=True) for b in boxes_b]
|
| 325 |
+
labels_b = [l.cuda(non_blocking=True) for l in labels_b]
|
| 326 |
+
try:
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 329 |
+
out = backbone.forward_features(images)
|
| 330 |
+
patches = out["x_norm_patchtokens"].float()
|
| 331 |
+
B, N, D = patches.shape
|
| 332 |
+
h = w = int(N ** 0.5)
|
| 333 |
+
spatial = patches.permute(0, 2, 1).reshape(B, D, h, w)
|
| 334 |
+
cls_l, reg_l, ctr_l = head(spatial)
|
| 335 |
+
loss = compute_loss(cls_l, reg_l, ctr_l, locs, boxes_b, labels_b)
|
| 336 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 337 |
+
print(f" WARNING: NaN/Inf loss at step {global_step}", flush=True)
|
| 338 |
+
optimizer.zero_grad(); global_step += 1; scheduler.step(); continue
|
| 339 |
+
optimizer.zero_grad()
|
| 340 |
+
loss.backward()
|
| 341 |
+
torch.nn.utils.clip_grad_norm_(head.parameters(), GRAD_CLIP)
|
| 342 |
+
optimizer.step()
|
| 343 |
+
scheduler.step()
|
| 344 |
+
global_step += 1
|
| 345 |
+
running_loss += loss.item()
|
| 346 |
+
running_count += 1
|
| 347 |
+
if global_step % 100 == 0:
|
| 348 |
+
elapsed = time.time() - t0
|
| 349 |
+
avg = running_loss / max(running_count, 1)
|
| 350 |
+
lr_now = scheduler.get_last_lr()[0]
|
| 351 |
+
msg = f"step {global_step}/{total_steps} (epoch {epoch+1}) loss={loss.item():.4f} avg={avg:.4f} lr={lr_now:.2e} {running_count/elapsed:.1f} it/s"
|
| 352 |
+
print(msg, flush=True)
|
| 353 |
+
log_file.write(msg + "\n"); log_file.flush()
|
| 354 |
+
if global_step % 1000 == 0:
|
| 355 |
+
torch.save({"head": head.state_dict(), "global_step": global_step},
|
| 356 |
+
os.path.join(OUTPUT_DIR, "checkpoint.pth"))
|
| 357 |
+
print(f" Checkpoint saved at step {global_step}", flush=True)
|
| 358 |
+
except Exception as e:
|
| 359 |
+
import traceback
|
| 360 |
+
print(f"\n ERROR at step {global_step}: {e}", flush=True)
|
| 361 |
+
traceback.print_exc()
|
| 362 |
+
if "out of memory" in str(e):
|
| 363 |
+
torch.cuda.empty_cache(); optimizer.zero_grad(); global_step += 1; scheduler.step(); continue
|
| 364 |
+
raise
|
| 365 |
+
print(f" Epoch {epoch+1}/{EPOCHS} complete (step {global_step})", flush=True)
|
| 366 |
+
|
| 367 |
+
final_path = os.path.join(OUTPUT_DIR, "cofiber_threshold_dim20_coco_8ep_22k.pth")
|
| 368 |
+
torch.save(head.state_dict(), final_path)
|
| 369 |
+
print(f"\nSaved: {final_path}")
|
| 370 |
+
print(f"Training complete: {total_steps} steps, {(time.time()-t0)/3600:.1f} hours", flush=True)
|
| 371 |
+
log_file.close()
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
train()
|
heads/cofiber_threshold_person/linear_9k/train.py
CHANGED
|
@@ -22,7 +22,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
|
|
| 22 |
|
| 23 |
EUPE_REPO = os.environ.get("ARENA_BACKBONE_REPO", "/home/zootest/EUPE")
|
| 24 |
EUPE_WEIGHTS = os.environ.get("ARENA_BACKBONE_WEIGHTS", "/home/zootest/weights/eupe_vitb/EUPE-ViT-B.pt")
|
| 25 |
-
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "/
|
| 26 |
OUTPUT_DIR = os.path.join(os.path.dirname(__file__))
|
| 27 |
|
| 28 |
if EUPE_REPO not in sys.path:
|
|
|
|
| 22 |
|
| 23 |
EUPE_REPO = os.environ.get("ARENA_BACKBONE_REPO", "/home/zootest/EUPE")
|
| 24 |
EUPE_WEIGHTS = os.environ.get("ARENA_BACKBONE_WEIGHTS", "/home/zootest/weights/eupe_vitb/EUPE-ViT-B.pt")
|
| 25 |
+
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "/home/zootest/datasets/coco")
|
| 26 |
OUTPUT_DIR = os.path.join(os.path.dirname(__file__))
|
| 27 |
|
| 28 |
if EUPE_REPO not in sys.path:
|