| """ |
| Mask regression detection head on cofiber features. |
| |
| Replaces FCOS's 4-distance box regression with K×K soft-membership prediction |
| per FCOS-positive location. Box is decoded via differentiable trapezoid-moment |
| inversion: box_width = sqrt(12 * Var(membership_marginal) - stride^2). |
| |
| Loss: combined BCE (per-cell mask) + GIoU (decoded box vs GT). |
| |
| Predicted advantage: 4 distance outputs -> 81 mask outputs gives ~sqrt(81/4) = 4.5x |
| theoretical noise reduction in decoded box, empirically ~2-3x after nonlinear decoder. |
| Current split-tower at 20.7 mAP operates at high regression noise; mask regression |
| should hit the 28-40 mAP range at the same 4M parameter budget. |
| """ |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.cuda.amp import autocast, GradScaler |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") |
| COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") |
| VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") |
| RESOLUTION = 640 |
| NUM_CLASSES = 80 |
| K = 9 |
|
|
|
|
| |
| |
| |
| def cofiber_decompose(f, n_scales): |
| cofibers = []; residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega); residual = omega |
| cofibers.append(residual); return cofibers |
|
|
|
|
| class ConvGNBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
| self.act = nn.GELU() |
| def forward(self, x): |
| return self.act(self.norm(self.conv(x))) |
|
|
|
|
| class DWResBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.pw = nn.Conv2d(channels, channels, 1) |
| self.act = nn.GELU() |
| self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels) |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
| def forward(self, x): |
| return x + self.norm(self.dw(self.act(self.pw(x)))) |
|
|
|
|
| def make_tower(hidden, n_std, n_dw): |
| layers = [ConvGNBlock(hidden) for _ in range(n_std)] + \ |
| [DWResBlock(hidden) for _ in range(n_dw)] |
| return nn.Sequential(*layers) |
|
|
|
|
| class MaskRegressionHead(nn.Module): |
| """Split-tower head with K×K mask regression instead of 4-distance regression.""" |
|
|
| def __init__(self, feat_dim=768, hidden=192, n_std_layers=5, n_dw_layers=4, n_scales=3): |
| super().__init__() |
| self.n_scales = n_scales |
| self.scale_norms = nn.ModuleList([nn.GroupNorm(1, feat_dim) for _ in range(n_scales)]) |
| self.stem = nn.Conv2d(feat_dim, hidden, 1) |
| self.stem_act = nn.GELU() |
| self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2) |
| self.p3_norm = nn.GroupNorm(min(32, hidden), hidden) |
| self.lateral_convs = nn.ModuleList([nn.Conv2d(hidden, hidden, 1) for _ in range(n_scales - 1)]) |
| self.lateral_norms = nn.ModuleList([nn.GroupNorm(min(32, hidden), hidden) for _ in range(n_scales - 1)]) |
| self.cls_tower = make_tower(hidden, n_std_layers, n_dw_layers) |
| self.reg_tower = make_tower(hidden, n_std_layers, n_dw_layers) |
| self.cls_pred = nn.Conv2d(hidden, NUM_CLASSES, 1) |
| self.mask_pred = nn.Conv2d(hidden, K * K, 1) |
| |
| |
| nn.init.zeros_(self.mask_pred.bias) |
| nn.init.normal_(self.mask_pred.weight, std=0.01) |
| self.ctr_pred = nn.Conv2d(hidden, 1, 1) |
| nn.init.constant_(self.cls_pred.bias, -math.log(99)) |
|
|
| def forward(self, spatial): |
| cofibers = cofiber_decompose(spatial, self.n_scales) |
| cls_l, mask_l, ctr_l = [], [], [] |
| scale_features = [] |
| for i, cof in enumerate(cofibers): |
| x = self.stem_act(self.stem(self.scale_norms[i](cof))) |
| scale_features.append(x) |
| for i in range(len(scale_features) - 2, -1, -1): |
| coarse_up = F.interpolate(scale_features[i + 1], size=scale_features[i].shape[2:], |
| mode="bilinear", align_corners=False) |
| scale_features[i] = self.lateral_norms[i]( |
| scale_features[i] + self.lateral_convs[i](coarse_up)) |
| p3 = self.p3_norm(self.p3_upsample(scale_features[0])) |
| all_features = [p3] + scale_features |
| for x in all_features: |
| cls_feat = self.cls_tower(x) |
| reg_feat = self.reg_tower(x) |
| cls_l.append(self.cls_pred(cls_feat)) |
| mask_l.append(self.mask_pred(reg_feat)) |
| ctr_l.append(self.ctr_pred(reg_feat)) |
| return cls_l, mask_l, ctr_l |
|
|
|
|
| |
| |
| |
| def decode_mask_to_box(mask, stride, center_y, center_x): |
| """mask: (B, K, K) in [0, 1]. Returns (B, 4) = (y0, x0, y1, x1).""" |
| B, Kh, Kw = mask.shape |
| assert Kh == Kw == K |
| half = K / 2 |
| device = mask.device |
| eps = 1e-6 |
|
|
| offsets = (torch.arange(K, device=device, dtype=mask.dtype) - half + 0.5) * stride |
| ys = center_y[:, None] + offsets[None, :] |
| xs = center_x[:, None] + offsets[None, :] |
|
|
| col = mask.sum(dim=1) |
| row = mask.sum(dim=2) |
|
|
| col_sum = col.sum(dim=1, keepdim=True).clamp_min(eps) |
| row_sum = row.sum(dim=1, keepdim=True).clamp_min(eps) |
|
|
| mu_x = (col * xs).sum(dim=1, keepdim=True) / col_sum |
| mu_y = (row * ys).sum(dim=1, keepdim=True) / row_sum |
|
|
| var_x = (col * (xs - mu_x) ** 2).sum(dim=1, keepdim=True) / col_sum |
| var_y = (row * (ys - mu_y) ** 2).sum(dim=1, keepdim=True) / row_sum |
|
|
| W_box = torch.sqrt((12 * var_x - stride ** 2).clamp_min(0) + eps) |
| H_box = torch.sqrt((12 * var_y - stride ** 2).clamp_min(0) + eps) |
|
|
| return torch.cat([ |
| mu_y - H_box / 2, |
| mu_x - W_box / 2, |
| mu_y + H_box / 2, |
| mu_x + W_box / 2, |
| ], dim=1) |
|
|
|
|
| |
| |
| |
| def compute_gt_mask(boxes, center_y, center_x, stride): |
| """For each (center_y_i, center_x_i) and its assigned box_i in `boxes`, |
| compute the K×K soft membership mask. |
| |
| boxes: (N, 4) = (y0, x0, y1, x1) |
| center_y, center_x: (N,) — patch centers |
| Returns: (N, K, K) |
| """ |
| N = boxes.shape[0] |
| device = boxes.device |
| half = K / 2 |
|
|
| offsets = (torch.arange(K, device=device, dtype=torch.float32) - half + 0.5) * stride |
| |
| cys = center_y[:, None] + offsets[None, :] |
| cxs = center_x[:, None] + offsets[None, :] |
|
|
| y0, x0, y1, x1 = boxes.unbind(dim=1) |
|
|
| |
| cell_y_lo = cys - stride / 2 |
| cell_y_hi = cys + stride / 2 |
| cell_x_lo = cxs - stride / 2 |
| cell_x_hi = cxs + stride / 2 |
|
|
| |
| inter_y = (torch.minimum(y1[:, None], cell_y_hi) - torch.maximum(y0[:, None], cell_y_lo)).clamp_min(0) |
| inter_x = (torch.minimum(x1[:, None], cell_x_hi) - torch.maximum(x0[:, None], cell_x_lo)).clamp_min(0) |
|
|
| |
| fy = inter_y / stride |
| fx = inter_x / stride |
|
|
| mask = fy[:, :, None] * fx[:, None, :] |
| return mask |
|
|
|
|
| |
| |
| |
| def giou_loss(pred, gt): |
| """pred, gt: (N, 4) = (y0, x0, y1, x1). Returns per-sample (1 - GIoU).""" |
| y0p, x0p, y1p, x1p = pred.unbind(-1) |
| y0g, x0g, y1g, x1g = gt.unbind(-1) |
| |
| iy0 = torch.maximum(y0p, y0g); ix0 = torch.maximum(x0p, x0g) |
| iy1 = torch.minimum(y1p, y1g); ix1 = torch.minimum(x1p, x1g) |
| inter = (iy1 - iy0).clamp_min(0) * (ix1 - ix0).clamp_min(0) |
| |
| ap = (y1p - y0p).clamp_min(0) * (x1p - x0p).clamp_min(0) |
| ag = (y1g - y0g).clamp_min(0) * (x1g - x0g).clamp_min(0) |
| union = ap + ag - inter |
| iou_v = inter / union.clamp_min(1e-9) |
| |
| ey0 = torch.minimum(y0p, y0g); ex0 = torch.minimum(x0p, x0g) |
| ey1 = torch.maximum(y1p, y1g); ex1 = torch.maximum(x1p, x1g) |
| enc = (ey1 - ey0).clamp_min(0) * (ex1 - ex0).clamp_min(0) |
| giou = iou_v - (enc - union) / enc.clamp_min(1e-9) |
| return 1 - giou |
|
|
|
|
| |
| |
| |
| def compute_loss_mask(cls_per, mask_per, ctr_per, locs_per, boxes_list, labels_list, |
| bce_weight=1.0, giou_weight=2.0): |
| B = cls_per[0].shape[0] |
| device = cls_per[0].device |
| num_classes = cls_per[0].shape[1] |
|
|
| n_levels = len(cls_per) |
| if n_levels == 4: |
| strides = [8, 16, 32, 64] |
| size_ranges = [(-1, 64), (64, 128), (128, 256), (256, float("inf"))] |
| else: |
| raise ValueError(f"Expected 4 levels, got {n_levels}") |
|
|
| |
| flat_cls, flat_mask, flat_ctr = [], [], [] |
| for cl, mk, ct in zip(cls_per, mask_per, ctr_per): |
| b, c, h, w = cl.shape |
| flat_cls.append(cl.permute(0, 2, 3, 1).reshape(b, h * w, c)) |
| flat_mask.append(mk.permute(0, 2, 3, 1).reshape(b, h * w, K, K)) |
| flat_ctr.append(ct.permute(0, 2, 3, 1).reshape(b, h * w)) |
| pred_cls = torch.cat(flat_cls, 1) |
| pred_mask = torch.cat(flat_mask, 1) |
| pred_ctr = torch.cat(flat_ctr, 1) |
| all_locs = torch.cat(locs_per, 0) |
|
|
| |
| n_per_level = [loc.shape[0] for loc in locs_per] |
| strides_per_loc = torch.zeros(all_locs.shape[0], device=device) |
| cum = 0 |
| level_ranges = [] |
| for i, n in enumerate(n_per_level): |
| level_ranges.append((cum, cum + n, strides[i], size_ranges[i])) |
| strides_per_loc[cum:cum + n] = strides[i] |
| cum += n |
|
|
| total_cls_loss = 0.0 |
| total_bce_loss = 0.0 |
| total_giou_loss = 0.0 |
| total_ctr_loss = 0.0 |
| n_pos_total = 0 |
|
|
| for b in range(B): |
| boxes = boxes_list[b] |
| labels = labels_list[b] |
| if boxes.numel() == 0: |
| |
| cls_targets = torch.zeros_like(pred_cls[b]) |
| total_cls_loss = total_cls_loss + focal_loss(pred_cls[b], cls_targets) |
| continue |
|
|
| |
| cls_target = torch.zeros_like(pred_cls[b]) |
| pos_mask = torch.zeros(all_locs.shape[0], dtype=torch.bool, device=device) |
| pos_box = torch.zeros(all_locs.shape[0], 4, device=device) |
| pos_ctrness = torch.zeros(all_locs.shape[0], device=device) |
|
|
| for lo, hi, stride, (slo, shi) in level_ranges: |
| n = hi - lo |
| loc = all_locs[lo:hi] |
| l = loc[:, None, 0] - boxes[None, :, 0] |
| t = loc[:, None, 1] - boxes[None, :, 1] |
| r = boxes[None, :, 2] - loc[:, None, 0] |
| bot = boxes[None, :, 3] - loc[:, None, 1] |
| ltrb = torch.stack([l, t, r, bot], dim=-1) |
| in_box = ltrb.min(dim=-1).values > 0 |
| cx = (boxes[:, 0] + boxes[:, 2]) / 2 |
| cy = (boxes[:, 1] + boxes[:, 3]) / 2 |
| rad = stride * 1.5 |
| in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) & |
| (loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad)) |
| max_d = ltrb.max(dim=-1).values |
| in_level = (max_d >= slo) & (max_d <= shi) |
| pos = in_box & in_center & in_level |
| areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
| a = areas[None, :].expand_as(pos).clone() |
| a[~pos] = float("inf") |
| matched = a.argmin(dim=-1) |
| is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf") |
|
|
| pos_mask[lo:hi] = is_pos |
| if is_pos.any(): |
| pos_box[lo:hi][is_pos] = boxes[matched[is_pos]] |
| cls_target[lo:hi][is_pos, labels[matched[is_pos]]] = 1 |
| lp, tp, rp, bp = ltrb[torch.arange(n, device=device)[is_pos], matched[is_pos]].unbind(-1) |
| pos_ctrness[lo:hi][is_pos] = torch.sqrt( |
| (torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) * |
| (torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6))) |
|
|
| total_cls_loss = total_cls_loss + focal_loss(pred_cls[b], cls_target) |
|
|
| if pos_mask.any(): |
| pos_idx = pos_mask.nonzero(as_tuple=True)[0] |
| pos_locs = all_locs[pos_idx] |
| pos_strides = strides_per_loc[pos_idx] |
| pos_boxes = pos_box[pos_idx] |
| pos_masks_pred = pred_mask[b, pos_idx] |
| |
| pos_masks_prob = pos_masks_pred.clamp(0, 1) |
|
|
| |
| |
| box_yxyx = pos_boxes |
| |
| |
| |
| |
| |
| boxes_yxyx = torch.stack([pos_boxes[:, 1], pos_boxes[:, 0], |
| pos_boxes[:, 3], pos_boxes[:, 2]], dim=1) |
|
|
| cys = pos_locs[:, 1] |
| cxs = pos_locs[:, 0] |
|
|
| |
| |
| gt_mask_list = [] |
| decoded_boxes_list = [] |
| gt_boxes_list = [] |
| for lo, hi, stride, _ in level_ranges: |
| level_pos = pos_mask[lo:hi] |
| if not level_pos.any(): |
| continue |
| level_idx_in_pos = (pos_idx >= lo) & (pos_idx < hi) |
| if not level_idx_in_pos.any(): |
| continue |
| p_ids = level_idx_in_pos.nonzero(as_tuple=True)[0] |
| these_boxes = boxes_yxyx[p_ids] |
| these_cys = cys[p_ids] |
| these_cxs = cxs[p_ids] |
| these_masks = pos_masks_prob[p_ids] |
|
|
| gt_masks = compute_gt_mask(these_boxes, these_cys, these_cxs, stride) |
| decoded = decode_mask_to_box(these_masks, stride, these_cys, these_cxs) |
| gt_mask_list.append((these_masks, pos_masks_pred[p_ids])) |
| decoded_boxes_list.append(decoded) |
| gt_boxes_list.append(these_boxes) |
|
|
| |
| if gt_mask_list: |
| all_gt_masks = torch.cat([gm for gm, _ in gt_mask_list], dim=0) |
| all_pred_raw = torch.cat([pl for _, pl in gt_mask_list], dim=0) |
| |
| |
| is_boundary = (all_gt_masks > 0.05) & (all_gt_masks < 0.95) |
| weights = torch.where(is_boundary, |
| torch.full_like(all_gt_masks, 5.0), |
| torch.ones_like(all_gt_masks)) |
| |
| mse = ((all_pred_raw - all_gt_masks) ** 2 * weights).sum() |
| all_decoded = torch.cat(decoded_boxes_list, dim=0) |
| all_gt_boxes = torch.cat(gt_boxes_list, dim=0) |
| giou = giou_loss(all_decoded, all_gt_boxes).sum() |
| total_bce_loss = total_bce_loss + mse |
| total_giou_loss = total_giou_loss + giou |
|
|
| |
| ctr_loss = F.binary_cross_entropy_with_logits( |
| pred_ctr[b, pos_idx], pos_ctrness[pos_idx], reduction="sum") |
| total_ctr_loss = total_ctr_loss + ctr_loss |
|
|
| n_pos_total += int(pos_mask.sum()) |
|
|
| n_pos_total = max(1, n_pos_total) |
| loss = (total_cls_loss / n_pos_total + |
| bce_weight * total_bce_loss / (n_pos_total * K * K) + |
| giou_weight * total_giou_loss / n_pos_total + |
| total_ctr_loss / n_pos_total) |
| return loss |
|
|
|
|
| def focal_loss(logits, targets, alpha=0.25, gamma=2.0): |
| p = torch.sigmoid(logits) |
| ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") |
| pt = p * targets + (1 - p) * (1 - targets) |
| at = alpha * targets + (1 - alpha) * (1 - targets) |
| return (at * (1 - pt) ** gamma * ce).sum() |
|
|
|
|
| |
| |
| |
| def make_locations(feature_sizes, strides, device): |
| locs = [] |
| for (h, w), s in zip(feature_sizes, strides): |
| ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") |
| locs.append(torch.stack([gx.flatten(), gy.flatten()], -1)) |
| return locs |
|
|
|
|
| |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hidden", type=int, default=192) |
| parser.add_argument("--std-layers", type=int, default=5) |
| parser.add_argument("--dw-layers", type=int, default=4) |
| parser.add_argument("--epochs", type=int, default=8) |
| parser.add_argument("--batch-size", type=int, default=16) |
| parser.add_argument("--lr", type=float, default=5e-4) |
| parser.add_argument("--bce-weight", type=float, default=1.0) |
| parser.add_argument("--giou-weight", type=float, default=2.0) |
| parser.add_argument("--resume", type=str, default=None) |
| args = parser.parse_args() |
|
|
| head = MaskRegressionHead(hidden=args.hidden, n_std_layers=args.std_layers, |
| n_dw_layers=args.dw_layers).cuda() |
| n_params = sum(p.numel() for p in head.parameters()) |
| print("=" * 60) |
| print(f"Mask Regression Head: {args.hidden} hidden, {args.std_layers} std + {args.dw_layers} dw per tower") |
| print(f" K = {K} (mask grid), {K*K} output channels per location") |
| print(f" {n_params:,} params") |
| print(f" Loss: BCE (weight {args.bce_weight}) + GIoU (weight {args.giou_weight})") |
| print("=" * 60, flush=True) |
|
|
| start_step = 0 |
| if args.resume: |
| ckpt = torch.load(args.resume, map_location="cuda", weights_only=False) |
| head.load_state_dict(ckpt["head"]) |
| start_step = ckpt["step"] |
| print(f"Resumed from step {start_step}", flush=True) |
|
|
| manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) |
| n_shards = manifest["n_shards"] |
| n_images = manifest["n_images"] |
| steps_per_epoch = n_images // args.batch_size |
| total_steps = steps_per_epoch * args.epochs |
| warmup = int(total_steps * 0.03) |
|
|
| optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda s: |
| s / max(warmup, 1) if s < warmup else |
| 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(total_steps - warmup, 1)))) |
| scaler = GradScaler() |
|
|
| if start_step > 0: |
| for _ in range(start_step): |
| scheduler.step() |
| print(f" Scheduler advanced to step {start_step}", flush=True) |
|
|
| H = RESOLUTION // 16 |
| strides = [8, 16, 32, 64] |
| locs = make_locations([(H*2,H*2),(H,H),(H//2,H//2),(H//4,H//4)], strides, torch.device("cuda")) |
| shard_paths = [os.path.join(CACHE_DIR, f"shard_{i:04d}.pt") for i in range(n_shards)] |
|
|
| print(f" {n_images} images, batch {args.batch_size}, {total_steps} steps, {args.epochs} epochs") |
| print(f" fp16 mixed precision") |
| print(f" Training...\n", flush=True) |
|
|
| head.train() |
| global_step = start_step |
| t0 = time.time() |
|
|
| for epoch in range(args.epochs): |
| shard_order = torch.randperm(n_shards).tolist() |
| epoch_t0 = time.time() |
| for shard_idx in shard_order: |
| if global_step >= total_steps: break |
| shard = torch.load(shard_paths[shard_idx], map_location="cpu", weights_only=False) |
| within = torch.randperm(len(shard)).tolist() |
| for batch_start in range(0, len(shard), args.batch_size): |
| if global_step >= total_steps: break |
| batch_idx = within[batch_start:batch_start + args.batch_size] |
| if len(batch_idx) < 2: continue |
|
|
| spatial = torch.stack([shard[i]["spatial"] for i in batch_idx]).float().cuda() |
| boxes = [shard[i]["boxes"].cuda() for i in batch_idx] |
| labels = [shard[i]["labels"].cuda() for i in batch_idx] |
|
|
| try: |
| with autocast(): |
| cls_l, mask_l, ctr_l = head(spatial) |
| loss = compute_loss_mask(cls_l, mask_l, ctr_l, locs, boxes, labels, |
| bce_weight=args.bce_weight, |
| giou_weight=args.giou_weight) |
|
|
| optimizer.zero_grad() |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(head.parameters(), 5.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
| global_step += 1 |
|
|
| if global_step % 100 == 0: |
| lr = scheduler.get_last_lr()[0] |
| elapsed = time.time() - t0 |
| print(f" step {global_step}/{total_steps} (ep {epoch+1}) " |
| f"loss={loss.item():.4f} lr={lr:.2e} " |
| f"{global_step/elapsed:.1f} it/s", flush=True) |
|
|
| if global_step % 4000 == 0: |
| out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "mask_regression") |
| os.makedirs(out_dir, exist_ok=True) |
| ckpt = os.path.join(out_dir, f"checkpoint_step{global_step}.pth") |
| torch.save({"head": head.state_dict(), "step": global_step}, ckpt) |
|
|
| except RuntimeError as e: |
| if "out of memory" in str(e): |
| torch.cuda.empty_cache() |
| optimizer.zero_grad() |
| global_step += 1 |
| scheduler.step() |
| continue |
| raise |
| del shard |
| print(f" Epoch {epoch+1}/{args.epochs} complete ({time.time()-epoch_t0:.0f}s)\n", flush=True) |
|
|
| out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "mask_regression") |
| os.makedirs(out_dir, exist_ok=True) |
| out = os.path.join(out_dir, f"mask_reg_{args.hidden}h_{args.std_layers}std_{args.dw_layers}dw_{args.epochs}ep.pth") |
| torch.save({"head": head.state_dict(), "step": -1, "config": { |
| "hidden": args.hidden, "std_layers": args.std_layers, |
| "dw_layers": args.dw_layers, "K": K, |
| }}, out) |
| print(f"Saved: {out}") |
| print(f"{n_params:,} params, {(time.time()-t0)/60:.1f} minutes") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|