Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| import os, re, math, random, glob, time, subprocess, sys, zlib, gc, warnings, atexit | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Optional, Dict, List | |
| import numpy as np | |
| from PIL import Image, ImageFile | |
| from PIL.Image import DecompressionBombWarning | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.multiprocessing as mp | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader, Sampler | |
| from torchvision import transforms | |
| # tqdm (auto-install if missing) | |
| try: | |
| from tqdm.auto import tqdm | |
| except Exception: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "tqdm"]) | |
| from tqdm.auto import tqdm | |
| # ------------------------- Config ------------------------- | |
| class Cfg: | |
| data_root: str = "./" | |
| folders: dict = None | |
| stages: list = None | |
| P: int = 16 | |
| K: int = 2 | |
| embed_dim: int = 256 | |
| workers: int = 8 | |
| weight_decay: float = 0.01 | |
| alpha_proxy: float = 32.0 | |
| margin_proxy: float = 0.2 | |
| supcon_tau: float = 0.07 | |
| mv_tau: float = 0.10 | |
| mixstyle_p: float = 0.10 | |
| out_dir: str = "./checkpoints_style" | |
| seed: int = 1337 | |
| max_steps_per_epoch: Optional[int] = None # None이면 데이터 길이에 따라 자동 | |
| print_every: int = 50 | |
| use_compile: bool = False | |
| cfg = Cfg( | |
| folders=dict(whole="dataset", face="dataset_face", eyes="dataset_eyes"), | |
| stages=[ | |
| dict(sz_whole=224, sz_face=192, sz_eyes=128, epochs=12, lr=3e-4, P=64, K=2), | |
| dict(sz_whole=384, sz_face=320, sz_eyes=192, epochs=12, lr=1.5e-4, P=24, K=2), | |
| dict(sz_whole=512, sz_face=384, sz_eyes=224, epochs=24, lr=8e-5, P=12, K=2), | |
| ], | |
| ) | |
| # ------------------------- Device & determinism ------------------------- | |
| def seed_all(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| seed_all(cfg.seed) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| if hasattr(torch, "set_float32_matmul_precision"): | |
| torch.set_float32_matmul_precision("high") | |
| # Lazy amp_dtype detection to avoid CUDA init at import time (required for HF Spaces ZeroGPU) | |
| _amp_dtype_cache = None | |
| def _get_amp_dtype(): | |
| global _amp_dtype_cache | |
| if _amp_dtype_cache is None: | |
| try: | |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): | |
| _amp_dtype_cache = torch.bfloat16 | |
| else: | |
| _amp_dtype_cache = torch.float16 | |
| except Exception: | |
| _amp_dtype_cache = torch.float16 | |
| return _amp_dtype_cache | |
| # For backwards compatibility, amp_dtype is accessed via property-like usage | |
| # but we keep a module-level name that can be imported (defaults to float16, updated on first GPU use) | |
| amp_dtype = torch.float16 # safe default; actual dtype picked at runtime via _get_amp_dtype() | |
| # --- PIL safety/verbosity tweaks --- | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| Image.MAX_IMAGE_PIXELS = 300_000_000 | |
| warnings.filterwarnings("ignore", category=DecompressionBombWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning, module="PIL.TiffImagePlugin") | |
| # ------------------------- Robust multiprocessing for DataLoader ------------------------- | |
| def _init_mp_ctx(): | |
| method = mp.get_start_method(allow_none=True) | |
| if method is None: | |
| preferred = 'fork' if sys.platform.startswith('linux') else 'spawn' | |
| try: | |
| mp.set_start_method(preferred, force=True) | |
| except Exception: | |
| pass | |
| method = mp.get_start_method(allow_none=True) or preferred | |
| print(f"[mp] using '{method}'.") | |
| return mp.get_context(method) | |
| MP_CTX = _init_mp_ctx() | |
| _DL_TRACK = [] | |
| def _track_dl(dl): | |
| _DL_TRACK.append(dl); return dl | |
| def _close_dl(dl): | |
| try: | |
| it = getattr(dl, "_iterator", None) | |
| if it is not None: | |
| it._shutdown_workers() | |
| dl._iterator = None | |
| except Exception: | |
| pass | |
| def _cleanup_all_dls(): | |
| for dl in list(_DL_TRACK): | |
| _close_dl(dl) | |
| _DL_TRACK.clear() | |
| def _should_fallback_workers(err: Exception) -> bool: | |
| s = str(err) | |
| return ("Can't get attribute" in s or | |
| "PicklingError" in s or | |
| ("AttributeError" in s and "__main__" in s)) | |
| # ------------------------- Helpers ------------------------- | |
| def stable_int(s: str) -> int: | |
| return zlib.adler32(s.encode("utf-8")) & 0xffffffff | |
| def l2n(x, eps=1e-8): | |
| return F.normalize(x, dim=-1, eps=eps) | |
| # ------------------------- Dataset ------------------------- | |
| class TriViewDataset(Dataset): | |
| """ | |
| - whole / face / eyes 각각에 대해 9:1로 train/val split (경로 해시 기반). | |
| - __getitem__에서는 해당 작가의 view pool에서 랜덤으로 뽑아서 tri-view 구성. | |
| - 파일명 매칭 전혀 사용 X, 작가(label)만 동일하면 아무 이미지나 조합. | |
| - index는 whole 기반으로 만들고, label/gid/path 는 whole 기준. | |
| """ | |
| def __init__(self, root, folders, split="train", | |
| T_whole=None, T_face=None, T_eyes=None): | |
| assert split in ("train", "val") | |
| self.split = split | |
| self.root = Path(root) | |
| self.dirs = {k: self.root / v for k, v in folders.items()} | |
| self.T = dict(whole=T_whole, face=T_face, eyes=T_eyes) | |
| # artist 목록 | |
| whole_root = self.dirs["whole"] | |
| artists = sorted([d.name for d in whole_root.iterdir() if d.is_dir()]) | |
| self.artist2id = {a: i for i, a in enumerate(artists)} | |
| self.id2artist = {v: k for k, v in self.artist2id.items()} | |
| self.num_classes = len(self.artist2id) | |
| # artist별 view pool (split 별) | |
| self.whole_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()} | |
| self.face_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()} | |
| self.eyes_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()} | |
| def view_split(paths: List[Path], split: str) -> List[Path]: | |
| train_list, val_list = [], [] | |
| for p in paths: | |
| h = stable_int(str(p)) % 10 | |
| if split == "train": | |
| if h < 9: # 0~8 => train | |
| train_list.append(p) | |
| else: | |
| if h >= 9: # 9 => val | |
| val_list.append(p) | |
| return train_list if split == "train" else val_list | |
| # whole / face / eyes 각각에 대해 artist별 split | |
| for artist_name, aid in self.artist2id.items(): | |
| # whole | |
| w_dir = self.dirs["whole"] / artist_name | |
| if w_dir.is_dir(): | |
| w_all = sorted([p for p in w_dir.iterdir() if p.is_file()]) | |
| else: | |
| w_all = [] | |
| self.whole_paths_by_artist[aid] = view_split(w_all, split) | |
| # face | |
| f_dir = self.dirs["face"] / artist_name | |
| if f_dir.is_dir(): | |
| f_all = sorted([p for p in f_dir.iterdir() if p.is_file()]) | |
| else: | |
| f_all = [] | |
| self.face_paths_by_artist[aid] = view_split(f_all, split) | |
| # eyes | |
| e_dir = self.dirs["eyes"] / artist_name | |
| if e_dir.is_dir(): | |
| e_all = sorted([p for p in e_dir.iterdir() if p.is_file()]) | |
| else: | |
| e_all = [] | |
| self.eyes_paths_by_artist[aid] = view_split(e_all, split) | |
| # index: whole 기반 anchor | |
| self.index = [] | |
| for aid, w_list in self.whole_paths_by_artist.items(): | |
| for wp in w_list: | |
| rec = { | |
| "label": aid, | |
| "whole": str(wp), | |
| "gid": stable_int(str(wp)), | |
| "path": str(wp), | |
| } | |
| self.index.append(rec) | |
| def __len__(self): | |
| return len(self.index) | |
| def _load_one(self, path: Optional[Path], T): | |
| if path is None: | |
| return None | |
| try: | |
| im = Image.open(path).convert("RGB") | |
| except Exception: | |
| return None | |
| if T is not None: | |
| return T(im) | |
| else: | |
| return transforms.ToTensor()(im) | |
| def __getitem__(self, i): | |
| rec = self.index[i] | |
| aid = rec["label"] | |
| W_pool = self.whole_paths_by_artist.get(aid, []) | |
| F_pool = self.face_paths_by_artist.get(aid, []) | |
| E_pool = self.eyes_paths_by_artist.get(aid, []) | |
| pw = random.choice(W_pool) if W_pool else None | |
| pf = random.choice(F_pool) if F_pool else None | |
| pe = random.choice(E_pool) if E_pool else None | |
| xw = self._load_one(pw, self.T["whole"]) if pw is not None else None | |
| xf = self._load_one(pf, self.T["face"]) if pf is not None else None | |
| xe = self._load_one(pe, self.T["eyes"]) if pe is not None else None | |
| gid = torch.tensor([rec["gid"]], dtype=torch.long) | |
| return dict( | |
| whole=xw, | |
| face=xf, | |
| eyes=xe, | |
| label=torch.tensor(aid, dtype=torch.long), | |
| gid=gid, | |
| path=rec["path"], | |
| ) | |
| # ------------------------- PK batch sampler ------------------------- | |
| class PKBatchSampler(Sampler): | |
| """P개 클래스 × K개 이미지를 한 배치로 뽑는 샘플러.""" | |
| def __init__(self, dataset: TriViewDataset, P: int, K: int): | |
| self.P, self.K = int(P), int(K) | |
| from collections import defaultdict | |
| self.by_cls = defaultdict(list) | |
| for idx, rec in enumerate(dataset.index): | |
| self.by_cls[rec["label"]].append(idx) | |
| self.labels = list(self.by_cls.keys()) | |
| for lst in self.by_cls.values(): | |
| random.shuffle(lst) | |
| def __iter__(self): | |
| while True: | |
| P, K = self.P, self.K | |
| if len(self.labels) >= P: | |
| classes = random.sample(self.labels, P) | |
| else: | |
| classes = random.choices(self.labels, k=P) | |
| batch = [] | |
| for c in classes: | |
| pool = self.by_cls[c] | |
| if len(pool) >= K: | |
| picks = random.sample(pool, K) | |
| else: | |
| picks = [random.choice(pool) for _ in range(K)] | |
| batch.extend(picks) | |
| yield batch | |
| def __len__(self): # not used | |
| return 10**9 | |
| # ------------------------- Collate & transforms ------------------------- | |
| def collate_triview(batch): | |
| labels = torch.stack([b["label"] for b in batch]) | |
| gids = torch.stack([b["gid"] for b in batch]).squeeze(1) | |
| paths = [b["path"] for b in batch] | |
| views, masks = {}, {} | |
| for k in ("whole", "face", "eyes"): | |
| xs = [b[k] for b in batch] | |
| mask = torch.tensor([x is not None for x in xs], dtype=torch.bool) | |
| if any(mask): | |
| ex = next(x for x in xs if x is not None) | |
| zeros = torch.zeros_like(ex) | |
| xs = [x if x is not None else zeros for x in xs] | |
| views[k] = torch.stack(xs, dim=0) | |
| else: | |
| views[k] = None | |
| masks[k] = mask | |
| return dict(views=views, masks=masks, labels=labels, gids=gids, paths=paths) | |
| def make_transforms(sz_w, sz_f, sz_e): | |
| def aug(s): | |
| return transforms.Compose([ | |
| transforms.RandomResizedCrop(s, scale=(0.6, 1.0)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ColorJitter(brightness=0.1, contrast=0.1, | |
| saturation=0.05, hue=0.02), | |
| transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3), | |
| ]) | |
| return aug(sz_w), aug(sz_f), aug(sz_e) | |
| def make_val_transforms(sz_w, sz_f, sz_e): | |
| def val(s): | |
| return transforms.Compose([ | |
| transforms.Resize(int(s*1.15)), | |
| transforms.CenterCrop(s), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3), | |
| ]) | |
| return val(sz_w), val(sz_f), val(sz_e) | |
| # ------------------------- Model & heads ------------------------- | |
| class MixStyle(nn.Module): | |
| def __init__(self, p=0.3, alpha=0.1): | |
| super().__init__() | |
| self.p = p; self.alpha = alpha | |
| def forward(self, x): | |
| if not self.training or self.p <= 0.0: | |
| return x | |
| B,C,H,W = x.shape | |
| mu = x.mean([2,3], keepdim=True) | |
| var = x.var([2,3], unbiased=False, keepdim=True) | |
| sigma = (var+1e-5).sqrt() | |
| perm = torch.randperm(B, device=x.device) | |
| mu2, sigma2 = mu[perm], sigma[perm] | |
| lam = torch.distributions.Beta(self.alpha, self.alpha).sample((B,1,1,1)).to(x.device) | |
| mu_mix = mu*lam + mu2*(1-lam) | |
| sigma_mix = sigma*lam + sigma2*(1-lam) | |
| x_norm = (x - mu)/sigma | |
| apply = (torch.rand(B,1,1,1, device=x.device) < self.p).float() | |
| mixed = x_norm * sigma_mix + mu_mix | |
| return mixed*apply + x*(1-apply) | |
| class SqueezeExcite(nn.Module): | |
| def __init__(self, c, r=16): | |
| super().__init__() | |
| m = max(8, c//r) | |
| self.net = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(c, m, 1), nn.GELU(), | |
| nn.Conv2d(m, c, 1), nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return x * self.net(x) | |
| class ConvBlock(nn.Module): | |
| def __init__(self, ci, co, k=3, s=1, p=1): | |
| super().__init__() | |
| self.conv = nn.Conv2d(ci, co, k, s, p, bias=False) | |
| self.gn = nn.GroupNorm(16, co) | |
| self.act = nn.GELU() | |
| def forward(self, x): | |
| return self.act(self.gn(self.conv(x))) | |
| class ResBlock(nn.Module): | |
| def __init__(self, ci, co, down=False, mix=None): | |
| super().__init__() | |
| self.c1 = ConvBlock(ci, co, 3, 1, 1) | |
| self.c2 = ConvBlock(co, co, 3, 1, 1) | |
| self.se = SqueezeExcite(co) | |
| self.down = down | |
| self.pool = nn.AvgPool2d(2) if down else nn.Identity() | |
| self.proj = nn.Conv2d(ci, co, 1, 1, 0, bias=False) if ci != co else nn.Identity() | |
| self.mix = mix | |
| def forward(self, x): | |
| h = self.c1(x) | |
| if self.mix is not None: | |
| h = self.mix(h) | |
| h = self.c2(h) | |
| h = self.se(h) | |
| if self.down: | |
| h = self.pool(h); x = self.pool(x) | |
| return F.gelu(h + self.proj(x)) | |
| def matrix_sqrt_newton_schulz(A, iters=5): | |
| B,C,_ = A.shape | |
| normA = A.reshape(B, -1).norm(dim=1).view(B,1,1).clamp(min=1e-8) | |
| Y = A / normA | |
| I = torch.eye(C, device=A.device).expand(B, C, C) | |
| Z = I.clone() | |
| for _ in range(iters): | |
| T = 0.5 * (3.0*I - Z.bmm(Y)) | |
| Y = Y.bmm(T) | |
| Z = T.bmm(Z) | |
| return Y * (normA.sqrt()) | |
| class GramHead(nn.Module): | |
| def __init__(self, c_in, c_red=64, proj=128): | |
| super().__init__() | |
| self.red = nn.Conv2d(c_in, c_red, 1, bias=False) | |
| self.proj = nn.Linear(c_red*c_red, proj) | |
| def forward(self, x): | |
| f = self.red(x) | |
| B,C,H,W = f.shape | |
| Fm = f.flatten(2) | |
| G = torch.bmm(Fm, Fm.transpose(1,2)) / (H*W) | |
| return self.proj(G.reshape(B, C*C)) | |
| class CovISqrtHead(nn.Module): | |
| def __init__(self, c_in, c_red=64, proj=128): | |
| super().__init__() | |
| self.red = nn.Conv2d(c_in, c_red, 1, bias=False) | |
| self.proj = nn.Linear(c_red*c_red, proj) | |
| def forward(self, x): | |
| with torch.amp.autocast('cuda', enabled=False): | |
| f = self.red(x.float()) | |
| B,C,H,W = f.shape | |
| Fm = f.flatten(2) | |
| mu = Fm.mean(-1, keepdim=True) | |
| Xc = Fm - mu | |
| cov = torch.bmm(Xc, Xc.transpose(1,2)) / (H*W - 1 + 1e-5) | |
| cov = matrix_sqrt_newton_schulz(cov.float(), iters=5) | |
| return self.proj(cov.reshape(B, C*C)) | |
| def spectrum_hist(x, K=16, O=8): | |
| B,C,H,W = x.shape | |
| spec = torch.fft.rfft2(x, norm='ortho').abs().mean(1) | |
| H2, W2 = spec.shape[-2], spec.shape[-1] | |
| yy, xx = torch.meshgrid( | |
| torch.linspace(-1, 1, H2, device=x.device), | |
| torch.linspace(0, 1, W2, device=x.device), | |
| indexing="ij" | |
| ) | |
| rr = (yy**2 + xx**2).sqrt().clamp(0, 1 - 1e-8) | |
| th = (torch.atan2(yy, xx + 1e-9) + math.pi/2) | |
| rb = (rr * K).long().clamp(0, K-1) | |
| ob = (th / math.pi * O).long().clamp(0, O-1) | |
| mag = torch.log1p(spec) | |
| rad = torch.zeros(B, K, device=x.device) | |
| ang = torch.zeros(B, O, device=x.device) | |
| rbf = rb.reshape(-1); obf = ob.reshape(-1) | |
| for b in range(B): | |
| m = mag[b].reshape(-1) | |
| rad[b].scatter_add_(0, rbf, m) | |
| ang[b].scatter_add_(0, obf, m) | |
| rad = rad / (rad.sum(-1, keepdim=True)+1e-6) | |
| ang = ang / (ang.sum(-1, keepdim=True)+1e-6) | |
| return torch.cat([rad, ang], dim=1) | |
| class SpectrumHead(nn.Module): | |
| def __init__(self, c_in, proj=64, K=16, O=8): | |
| super().__init__() | |
| self.proj = nn.Linear(K+O, proj) | |
| def forward(self, x): | |
| with torch.amp.autocast('cuda', enabled=False): | |
| h = spectrum_hist(x.float()) | |
| return self.proj(h) | |
| class StatsHead(nn.Module): | |
| def __init__(self, c_in, proj=64): | |
| super().__init__() | |
| c = min(64, c_in) | |
| self.red = nn.Conv2d(c_in, c, 1, bias=False) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(c*2, 128), | |
| nn.GELU(), | |
| nn.Linear(128, proj), | |
| ) | |
| def forward(self, x): | |
| f = self.red(x) | |
| mu = f.mean([2,3]) | |
| lv = torch.log(f.var([2,3], unbiased=False)+1e-5) | |
| return self.mlp(torch.cat([mu, lv], dim=1)) | |
| class ViewEncoder(nn.Module): | |
| """ | |
| - Normalize([0.5],[0.5])된 RGB 입력 | |
| - RGB -> Lab 변환 | |
| - backbone + 스타일 헤드 4개 (Gram/Cov/Spectrum/Stats) | |
| - 브랜치 attention | |
| """ | |
| def __init__(self, mix_p=0.3, out_dim=256): | |
| super().__init__() | |
| self.mix = MixStyle(p=mix_p, alpha=0.1) | |
| ch = [32, 64, 128, 192, 256] | |
| self.stem = nn.Sequential( | |
| ConvBlock(3, ch[0], 3, 1, 1), | |
| ConvBlock(ch[0], ch[0], 3, 1, 1), | |
| ) | |
| self.b1 = ResBlock(ch[0], ch[1], down=True, mix=self.mix) | |
| self.b2 = ResBlock(ch[1], ch[2], down=True, mix=self.mix) | |
| self.b3 = ResBlock(ch[2], ch[3], down=True, mix=None) | |
| self.b4 = ResBlock(ch[3], ch[4], down=True, mix=None) | |
| self.h_gram3 = GramHead(ch[3]) | |
| self.h_cov3 = CovISqrtHead(ch[3]) | |
| self.h_sp3 = SpectrumHead(ch[3]) | |
| self.h_st3 = StatsHead(ch[3]) | |
| self.h_gram4 = GramHead(ch[4]) | |
| self.h_cov4 = CovISqrtHead(ch[4]) | |
| self.h_sp4 = SpectrumHead(ch[4]) | |
| self.h_st4 = StatsHead(ch[4]) | |
| fdim = (128+128+64+64)*2 # 768 | |
| self.fdim = fdim | |
| self.branch_gate = nn.Sequential( | |
| nn.LayerNorm(fdim), | |
| nn.Linear(fdim, 4, bias=True), | |
| ) | |
| self.fuse = nn.Sequential( | |
| nn.Linear(fdim, 512), | |
| nn.GELU(), | |
| nn.Linear(512, out_dim), | |
| ) | |
| def _rgb_to_lab(self, x: torch.Tensor) -> torch.Tensor: | |
| with torch.amp.autocast('cuda', enabled=False): | |
| x_f = x.float() | |
| rgb = (x_f * 0.5 + 0.5).clamp(0.0, 1.0) | |
| thresh = 0.04045 | |
| low = rgb / 12.92 | |
| high = ((rgb + 0.055) / 1.055).pow(2.4) | |
| rgb_lin = torch.where(rgb <= thresh, low, high) | |
| rgb_lin = rgb_lin.permute(0, 2, 3, 1) | |
| M = rgb_lin.new_tensor([ | |
| [0.4124564, 0.3575761, 0.1804375], | |
| [0.2126729, 0.7151522, 0.0721750], | |
| [0.0193339, 0.1191920, 0.9503041], | |
| ]) | |
| xyz = torch.matmul(rgb_lin, M.T) | |
| Xn, Yn, Zn = 0.95047, 1.00000, 1.08883 | |
| xyz = xyz / rgb_lin.new_tensor([Xn, Yn, Zn]) | |
| eps = 0.008856 | |
| kappa = 903.3 | |
| def f(t): | |
| t = t.clamp(min=1e-6) | |
| return torch.where( | |
| t > eps, | |
| t.pow(1.0 / 3.0), | |
| (kappa * t + 16.0) / 116.0, | |
| ) | |
| f_xyz = f(xyz) | |
| fx, fy, fz = f_xyz[..., 0], f_xyz[..., 1], f_xyz[..., 2] | |
| L = 116.0 * fy - 16.0 | |
| a = 500.0 * (fx - fy) | |
| b = 200.0 * (fy - fz) | |
| L_scaled = L / 100.0 | |
| a_scaled = (a + 128.0) / 255.0 | |
| b_scaled = (b + 128.0) / 255.0 | |
| lab = torch.stack([L_scaled, a_scaled, b_scaled], dim=-1) | |
| lab = lab.permute(0, 3, 1, 2) | |
| return lab.to(dtype=x.dtype) | |
| def forward(self, x): | |
| x_lab = self._rgb_to_lab(x) | |
| f0 = self.stem(x_lab) | |
| f1 = self.b1(f0) | |
| f2 = self.b2(f1) | |
| f3 = self.b3(f2) | |
| f4 = self.b4(f3) | |
| g3 = self.h_gram3(f3) | |
| c3 = self.h_cov3(f3) | |
| sp3 = self.h_sp3(f3) | |
| st3 = self.h_st3(f3) | |
| g4 = self.h_gram4(f4) | |
| c4 = self.h_cov4(f4) | |
| sp4 = self.h_sp4(f4) | |
| st4 = self.h_st4(f4) | |
| b_gram = torch.cat([g3, g4], dim=1) | |
| b_cov = torch.cat([c3, c4], dim=1) | |
| b_sp = torch.cat([sp3, sp4], dim=1) | |
| b_st = torch.cat([st3, st4], dim=1) | |
| flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1) # [B,768] | |
| gate_logits = self.branch_gate(flat) | |
| w = torch.softmax(gate_logits, dim=-1) | |
| w0, w1, w2, w3 = w[:,0:1], w[:,1:2], w[:,2:3], w[:,3:4] | |
| flat_weighted = torch.cat([ | |
| b_gram * w0, | |
| b_cov * w1, | |
| b_sp * w2, | |
| b_st * w3, | |
| ], dim=1) | |
| view_vec = self.fuse(flat_weighted) | |
| return view_vec | |
| class TriViewStyleNet(nn.Module): | |
| def __init__(self, out_dim=256, mix_p=0.3, share_backbone: bool = True): | |
| super().__init__() | |
| if share_backbone: | |
| shared = ViewEncoder(mix_p=mix_p, out_dim=out_dim) | |
| self.enc_whole = shared | |
| self.enc_face = shared | |
| self.enc_eyes = shared | |
| else: | |
| self.enc_whole = ViewEncoder(mix_p=mix_p, out_dim=out_dim) | |
| self.enc_face = ViewEncoder(mix_p=mix_p, out_dim=out_dim) | |
| self.enc_eyes = ViewEncoder(mix_p=mix_p, out_dim=out_dim) | |
| self.view_gate = nn.Sequential( | |
| nn.LayerNorm(out_dim), | |
| nn.Linear(out_dim, 1, bias=True), | |
| ) | |
| def forward(self, views, masks): | |
| outs, alphas = {}, [] | |
| for k, enc in (("whole", self.enc_whole), | |
| ("face", self.enc_face), | |
| ("eyes", self.enc_eyes)): | |
| if views[k] is None: | |
| outs[k] = None | |
| alphas.append(None) | |
| continue | |
| vk = enc(views[k].to(memory_format=torch.channels_last)) | |
| outs[k] = l2n(vk) | |
| score = self.view_gate(outs[k]).squeeze(1) | |
| score = torch.where( | |
| masks[k].to(score.device), | |
| score, | |
| torch.full_like(score, -1e4), | |
| ) | |
| alphas.append(score) | |
| scores = [a for a in alphas if a is not None] | |
| if len(scores) == 0: | |
| raise RuntimeError("All views are missing in this batch.") | |
| A = torch.stack(scores, dim=1) # [B, num_views] | |
| W = F.softmax(A, dim=1) | |
| present = [outs[k] for k in ("whole","face","eyes") if outs[k] is not None] | |
| Z = torch.stack(present, dim=1) # [B, num_views, dim] | |
| fused = l2n((W.unsqueeze(-1) * Z).sum(dim=1)) # [B, dim] | |
| return fused, outs, W | |
| # ------------------------- Losses ------------------------- | |
| class ProxyAnchorLoss(nn.Module): | |
| def __init__(self, num_classes, dim, alpha=16.0, margin=0.1, neg_weight=0.25): | |
| super().__init__() | |
| self.proxies = nn.Parameter(torch.randn(num_classes, dim)) | |
| nn.init.normal_(self.proxies, std=0.01) | |
| self.alpha = float(alpha) | |
| self.margin = float(margin) | |
| self.neg_weight = float(neg_weight) | |
| def forward(self, z, y): | |
| with torch.amp.autocast('cuda', enabled=False): | |
| z = F.normalize(z.float(), dim=-1) | |
| P = F.normalize(self.proxies.float(), dim=-1) | |
| sim = z @ P.t() | |
| C = sim.size(1) | |
| yOH = F.one_hot(y, num_classes=C).float() | |
| pos_e = torch.clamp(-self.alpha * (sim - self.margin), | |
| min=-60.0, max=60.0) | |
| neg_e = torch.clamp( self.alpha * (sim + self.margin), | |
| min=-60.0, max=60.0) | |
| pos_term = torch.exp(pos_e) * yOH | |
| neg_term = torch.exp(neg_e) * (1.0 - yOH) | |
| pos_sum = pos_term.sum(0) | |
| neg_sum = neg_term.sum(0) | |
| num_pos = (yOH.sum(0) > 0) | |
| L_pos = torch.log1p(pos_sum[num_pos]).sum() / (num_pos.sum().clamp_min(1.0)) | |
| L_neg = torch.log1p(neg_sum).sum() / C | |
| return L_pos + self.neg_weight * L_neg | |
| class SupConLoss(nn.Module): | |
| def __init__(self, tau=0.07): | |
| super().__init__() | |
| self.tau = tau | |
| def forward(self, feats, labels): | |
| feats = l2n(feats) | |
| sim = feats @ feats.t() / self.tau | |
| logits = sim - torch.eye(sim.size(0), device=sim.device) * 1e9 | |
| pos_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)) & \ | |
| (~torch.eye(len(labels), device=labels.device, dtype=torch.bool)) | |
| numer = (torch.exp(logits) * pos_mask).sum(1) | |
| denom = torch.exp(logits).sum(1).clamp_min(1e-8) | |
| valid = (pos_mask.sum(1) > 0) | |
| loss = -torch.log((numer+1e-12) / denom) | |
| return (loss[valid].mean() if valid.any() else torch.tensor(0.0, device=feats.device)) | |
| class MultiViewInfoNCE(nn.Module): | |
| def __init__(self, tau=0.1): | |
| super().__init__() | |
| self.tau = tau | |
| def forward(self, feats, gids): | |
| feats = l2n(feats) | |
| sim = feats @ feats.t() / self.tau | |
| logits = sim - torch.eye(sim.size(0), device=sim.device) * 1e9 | |
| pos_mask = (gids.unsqueeze(1) == gids.unsqueeze(0)) & \ | |
| (~torch.eye(len(gids), device=gids.device, dtype=torch.bool)) | |
| numer = (torch.exp(logits) * pos_mask).sum(1) | |
| denom = torch.exp(logits).sum(1).clamp_min(1e-8) | |
| valid = (pos_mask.sum(1) > 0) | |
| loss = -torch.log((numer+1e-12) / denom) | |
| return (loss[valid].mean() if valid.any() else torch.tensor(0.0, device=feats.device)) | |
| # --------------------- Logging / checkpoints / schedulers ----------------- | |
| os.makedirs(cfg.out_dir, exist_ok=True) | |
| LOG_TXT = os.path.join(cfg.out_dir, "train.log") | |
| METRICS_CSV = os.path.join(cfg.out_dir, "metrics_epoch.csv") | |
| if not os.path.exists(METRICS_CSV): | |
| with open(METRICS_CSV, "w", encoding="utf-8") as f: | |
| f.write("timestamp,stage,epoch,steps,P,K,train_loss,train_proxy,train_sup,train_mv," | |
| "val_proxy,proxy_top1,knn_r1,knn_r5,kmeans_acc,nmi,ari\n") | |
| def wlog_global(msg, also_print=False): | |
| ts_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| line = f"[{ts_str}] {msg}" | |
| with open(LOG_TXT, "a", encoding="utf-8", buffering=1) as _logf: | |
| _logf.write(line + "\n") | |
| if also_print: | |
| tqdm.write(line) | |
| def write_epoch_metrics(stage_i, epoch_i, steps, P, K, | |
| tr_mean, tr_p, tr_s, tr_m, | |
| val_proxy, proxy_top1, | |
| knn_r1, knn_r5, | |
| kmeans_acc, nmi, ari): | |
| ts_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| def fmt(x): | |
| if x is None: | |
| return "nan" | |
| try: | |
| if hasattr(x, "item"): | |
| x = float(x.item()) | |
| else: | |
| x = float(x) | |
| except Exception: | |
| return "nan" | |
| if np.isnan(x) or np.isinf(x): | |
| return "nan" | |
| return f"{x:.6f}" | |
| with open(METRICS_CSV, "a", encoding="utf-8") as fh: | |
| fh.write( | |
| f"{ts_str},{stage_i},{epoch_i},{steps},{P},{K}," | |
| f"{fmt(tr_mean)},{fmt(tr_p)},{fmt(tr_s)},{fmt(tr_m)}," | |
| f"{fmt(val_proxy)},{fmt(proxy_top1)}," | |
| f"{fmt(knn_r1)},{fmt(knn_r5)}," | |
| f"{fmt(kmeans_acc)},{fmt(nmi)},{fmt(ari)}\n" | |
| ) | |
| def save_ckpt(path, model, proxy_loss, optim, sched, meta, is_main: bool): | |
| if not is_main: | |
| return | |
| base_model = model.module if isinstance(model, nn.parallel.DistributedDataParallel) else model | |
| torch.save({ | |
| "model": base_model.state_dict(), | |
| "proxies": proxy_loss.state_dict(), | |
| "optim": optim.state_dict() if optim else None, | |
| "sched": sched.state_dict() if sched else None, | |
| "meta": meta, | |
| }, path) | |
| def find_latest_checkpoint(out_dir): | |
| paths = glob.glob(os.path.join(out_dir, "stage*_epoch*.pt")) | |
| best, best_stage, best_epoch = None, -1, -1 | |
| for p in paths: | |
| m = re.search(r"stage(\d+)_epoch(\d+)\.pt$", os.path.basename(p)) | |
| if not m: | |
| continue | |
| si, ep = int(m.group(1)), int(m.group(2)) | |
| if (si > best_stage) or (si == best_stage and ep > best_epoch): | |
| best, best_stage, best_epoch = p, si, ep | |
| return best, best_stage, best_epoch | |
| def _pick_from_schedule(sched, default_val, ep): | |
| if not sched: | |
| return int(default_val) | |
| if isinstance(sched, dict): | |
| items = sorted([(int(k), int(v)) for k,v in sched.items()], key=lambda x: x[0]) | |
| else: | |
| items = sorted([(int(k), int(v)) for k,v in sched], key=lambda x: x[0]) | |
| val = int(default_val) | |
| for k,v in items: | |
| if ep >= k: | |
| val = int(v) | |
| return int(val) | |
| def resolve_epoch_PK(stage: dict, ep: int): | |
| P = int(stage.get("P", cfg.P)) | |
| K = int(stage.get("K", cfg.K)) | |
| P = _pick_from_schedule(stage.get("P_schedule"), P, ep) | |
| K = _pick_from_schedule(stage.get("K_schedule"), K, ep) | |
| bs_sched = stage.get("bs_schedule") | |
| if bs_sched: | |
| bs = _pick_from_schedule(bs_sched, P*K, ep) | |
| if bs % K != 0: | |
| wlog_global(f"[batch] bs_schedule value {bs} not divisible by K={K}; rounding down to {bs//K*K}", also_print=True) | |
| bs = (bs // K) * K | |
| P = max(1, bs // K) | |
| return int(P), int(K) | |
| def estimate_steps_per_epoch(train_len: int, global_batch: int, max_steps: Optional[int]): | |
| if max_steps is not None: | |
| return int(max_steps) | |
| return max(1, math.ceil(train_len / max(1, global_batch))) | |
| def build_train_loader(ds: TriViewDataset, P: int, K: int): | |
| bs = PKBatchSampler(ds, P, K) | |
| dl = DataLoader( | |
| ds, | |
| batch_sampler=bs, | |
| num_workers=cfg.workers, | |
| pin_memory=True, | |
| collate_fn=collate_triview, | |
| persistent_workers=False, | |
| prefetch_factor=2 if cfg.workers > 0 else None, | |
| multiprocessing_context=MP_CTX, | |
| ) | |
| return _track_dl(dl) | |
| def make_cosine_with_warmup(optimizer, warmup_steps, total_steps): | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return float(step) / max(1, warmup_steps) | |
| rem = max(1, total_steps - warmup_steps) | |
| progress = (step - warmup_steps) / rem | |
| return 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | |
| # ------------------------------ DDP worker -------------------------------- | |
| def ddp_train_worker(rank: int, world_size: int): | |
| torch.cuda.set_device(rank) | |
| device = torch.device("cuda", rank) | |
| os.environ.setdefault("MASTER_ADDR", "127.0.0.1") | |
| os.environ.setdefault("MASTER_PORT", "29500") | |
| os.environ["RANK"] = str(rank) | |
| os.environ["WORLD_SIZE"] = str(world_size) | |
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) | |
| seed_all(cfg.seed + rank) | |
| is_main = (rank == 0) | |
| # class count | |
| artists_dir = os.path.join(cfg.data_root, cfg.folders['whole']) | |
| num_classes_total = len([ | |
| d for d in os.listdir(artists_dir) | |
| if os.path.isdir(os.path.join(artists_dir, d)) | |
| ]) | |
| if is_main: | |
| wlog_global(f"[DDP] world_size={world_size}, num_classes_total={num_classes_total}", also_print=True) | |
| # model & losses | |
| base_model = TriViewStyleNet( | |
| out_dim=cfg.embed_dim, | |
| mix_p=cfg.mixstyle_p, | |
| share_backbone=True, | |
| ).to(device) | |
| base_model = base_model.to(memory_format=torch.channels_last) | |
| if cfg.use_compile and hasattr(torch, "compile"): | |
| try: | |
| base_model = torch.compile(base_model, mode="reduce-overhead", fullgraph=False) | |
| except Exception: | |
| pass | |
| model = nn.parallel.DistributedDataParallel( | |
| base_model, | |
| device_ids=[rank], | |
| output_device=rank, | |
| find_unused_parameters=False, | |
| ) | |
| proxy_loss = ProxyAnchorLoss( | |
| num_classes=num_classes_total, | |
| dim=cfg.embed_dim, | |
| alpha=cfg.alpha_proxy, | |
| margin=cfg.margin_proxy, | |
| neg_weight=0.25, | |
| ).to(device) | |
| supcon = SupConLoss(tau=cfg.supcon_tau).to(device) | |
| mv_infonce = MultiViewInfoNCE(tau=cfg.mv_tau).to(device) | |
| # resume | |
| resume_info = None | |
| ckpt_path, ck_stage, ck_epoch = find_latest_checkpoint(cfg.out_dir) | |
| if ckpt_path is not None: | |
| ck = torch.load(ckpt_path, map_location="cpu") | |
| try: | |
| model.module.load_state_dict(ck["model"], strict=False) | |
| except Exception as e: | |
| if is_main: | |
| wlog_global(f"[resume] WARNING: model state load failed: {e}", also_print=True) | |
| try: | |
| proxy_loss.load_state_dict(ck["proxies"]) | |
| except Exception as e: | |
| if is_main: | |
| wlog_global(f"[resume] WARNING: proxy state load failed: {e}", also_print=True) | |
| meta = ck.get("meta", {}) | |
| last_stage = int(meta.get("stage", ck_stage or 1)) | |
| last_epoch = int(meta.get("epoch", ck_epoch or 0)) | |
| start_stage = last_stage | |
| start_epoch = last_epoch + 1 | |
| if start_stage <= len(cfg.stages) and start_epoch > cfg.stages[start_stage-1]["epochs"]: | |
| start_stage += 1 | |
| start_epoch = 1 | |
| resume_info = dict( | |
| ckpt=ck, | |
| path=ckpt_path, | |
| last_stage=last_stage, | |
| last_epoch=last_epoch, | |
| start_stage=start_stage, | |
| start_epoch=start_epoch, | |
| ) | |
| if is_main: | |
| wlog_global( | |
| f"[resume] Found {ckpt_path} (stage {last_stage}, epoch {last_epoch}). " | |
| f"Resuming at stage {start_stage}, epoch {start_epoch}.", | |
| also_print=True, | |
| ) | |
| else: | |
| if is_main: | |
| wlog_global("[resume] No checkpoint found; training from scratch.", also_print=True) | |
| scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available()) | |
| global_step = 0 | |
| proxy_lr_mult = 5.0 | |
| RAMP_EPOCHS = 3 | |
| WARMUP_EPOCHS = 1 | |
| VALIDATE_EVERY = 4 # N epoch마다 검증 | |
| from tqdm.auto import tqdm as tqdm_local | |
| # Stage loop | |
| for si, stage in enumerate(cfg.stages, 1): | |
| if resume_info and si < resume_info["start_stage"]: | |
| if is_main: | |
| wlog_global(f"[resume] Skipping stage {si}; already completed.", also_print=True) | |
| continue | |
| # datasets per stage | |
| T_w_tr, T_f_tr, T_e_tr = make_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"]) | |
| T_w_val, T_f_val, T_e_val = make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"]) | |
| train_ds = TriViewDataset(cfg.data_root, cfg.folders, split="train", | |
| T_whole=T_w_tr, T_face=T_f_tr, T_eyes=T_e_tr) | |
| val_ds = TriViewDataset(cfg.data_root, cfg.folders, split="val", | |
| T_whole=T_w_val, T_face=T_f_val, T_eyes=T_e_val) | |
| # steps_per_epoch schedule (global batch 기준) | |
| steps_list = [] | |
| for ep_tmp in range(1, stage["epochs"]+1): | |
| P_tmp, K_tmp = resolve_epoch_PK(stage, ep_tmp) | |
| global_batch = P_tmp * K_tmp * world_size | |
| steps = estimate_steps_per_epoch( | |
| len(train_ds), | |
| global_batch, | |
| cfg.max_steps_per_epoch, | |
| ) | |
| steps_list.append(steps) | |
| total_steps_stage = int(sum(steps_list)) | |
| warmup_steps = int(steps_list[0] * WARMUP_EPOCHS) | |
| params = [ | |
| {"params": model.parameters(), "lr": stage["lr"]}, | |
| {"params": proxy_loss.parameters(), "lr": stage["lr"] * proxy_lr_mult}, | |
| ] | |
| optim = torch.optim.AdamW(params, weight_decay=cfg.weight_decay) | |
| sched = make_cosine_with_warmup(optim, warmup_steps=warmup_steps, total_steps=total_steps_stage) | |
| start_ep = 1 | |
| if resume_info and si == resume_info["start_stage"]: | |
| start_ep = resume_info["start_epoch"] | |
| if resume_info["last_stage"] == si and start_ep > 1: | |
| try: | |
| if resume_info["ckpt"].get("optim") is not None: | |
| optim.load_state_dict(resume_info["ckpt"]["optim"]) | |
| if resume_info["ckpt"].get("sched") is not None: | |
| sched.load_state_dict(resume_info["ckpt"]["sched"]) | |
| if is_main: | |
| wlog_global(f"[resume] Loaded optimizer/scheduler from {resume_info['path']}.", also_print=True) | |
| except Exception as e: | |
| if is_main: | |
| wlog_global(f"[resume] WARNING: could not load optimizer/scheduler state: {e}", also_print=True) | |
| stage_msg = (f"\n=== [DDP] Stage {si}/{len(cfg.stages)} :: " | |
| f"wh/face/eyes={stage['sz_whole']}/{stage['sz_face']}/{stage['sz_eyes']} | " | |
| f"epochs={stage['epochs']} | lr={stage['lr']} | classes={num_classes_total} ===") | |
| if is_main: | |
| print(stage_msg) | |
| wlog_global(stage_msg) | |
| # epoch loop | |
| for ep in range(start_ep, stage["epochs"]+1): | |
| P_e, K_e = resolve_epoch_PK(stage, ep) | |
| B_e = P_e * K_e # local batch | |
| train_dl = build_train_loader(train_ds, P_e, K_e) | |
| steps_per_epoch = steps_list[ep-1] | |
| model.train() | |
| proxy_loss.train() | |
| running = {"proxy":0.0, "supcon":0.0, "mv":0.0, "tot":0.0} | |
| ep_sum_tot = ep_sum_p = ep_sum_s = ep_sum_m = 0.0 | |
| ramp = min(1.0, ep / RAMP_EPOCHS) | |
| if is_main: | |
| tbar = tqdm_local(range(1, steps_per_epoch+1), | |
| desc=f"[train-DDP] stage{si} ep{ep} (P={P_e},K={K_e},B={B_e},rank={rank})", | |
| leave=True) | |
| else: | |
| tbar = range(1, steps_per_epoch+1) | |
| train_iter = iter(train_dl) | |
| for it in tbar: | |
| try: | |
| batch = next(train_iter) | |
| except Exception as e: | |
| if _should_fallback_workers(e) and cfg.workers > 0: | |
| if is_main: | |
| print("[mp] Worker pickling error detected. Rebuilding loaders with num_workers=0.") | |
| cfg.workers = 0 | |
| train_dl = build_train_loader(train_ds, P_e, K_e) | |
| train_iter = iter(train_dl) | |
| batch = next(train_iter) | |
| else: | |
| raise | |
| labels = batch["labels"].to(device, non_blocking=True) | |
| gids = batch["gids"].to(device, non_blocking=True) | |
| views = { | |
| k: (v.to(device, non_blocking=True).to(memory_format=torch.channels_last) | |
| if v is not None else None) | |
| for k,v in batch["views"].items() | |
| } | |
| masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()} | |
| with torch.amp.autocast('cuda', dtype=_get_amp_dtype()): | |
| z_fused, z_views_dict, W = model(views, masks) | |
| Z_all, Y_all, G_all = [], [], [] | |
| for vk in ("whole","face","eyes"): | |
| zk = z_views_dict.get(vk) | |
| if zk is None: | |
| continue | |
| mk = masks[vk] | |
| if mk.any(): | |
| Z_all.append(zk[mk]) | |
| Y_all.append(labels[mk]) | |
| G_all.append(gids[mk]) | |
| if len(Z_all) == 0: | |
| Z_all, Y_all, G_all = [z_fused], [labels], [gids] | |
| Z_all = torch.cat(Z_all, dim=0) | |
| Y_all = torch.cat(Y_all, dim=0) | |
| G_all = torch.cat(G_all, dim=0) | |
| L_proxy = proxy_loss(z_fused, labels) | |
| L_sup = supcon(Z_all, Y_all) | |
| L_mv = mv_infonce(Z_all, G_all) | |
| L_total = L_proxy + (0.5 * ramp) * L_sup + (0.5 * ramp) * L_mv | |
| optim.zero_grad(set_to_none=True) | |
| scaler.scale(L_total).backward() | |
| scaler.step(optim) | |
| scaler.update() | |
| sched.step() | |
| global_step += 1 | |
| running["proxy"] += L_proxy.item() | |
| running["supcon"] += L_sup.item() | |
| running["mv"] += L_mv.item() | |
| running["tot"] += L_total.item() | |
| ep_sum_tot += L_total.item() | |
| ep_sum_p += L_proxy.item() | |
| ep_sum_s += L_sup.item() | |
| ep_sum_m += L_mv.item() | |
| if is_main and (it % cfg.print_every == 0 or it == steps_per_epoch): | |
| denom = min(cfg.print_every, it % cfg.print_every or cfg.print_every) | |
| tbar.set_postfix({ | |
| "L": f"{running['tot']/denom:.3f}", | |
| "proxy": f"{running['proxy']/denom:.3f}", | |
| "sup": f"{running['supcon']/denom:.3f}", | |
| "mv": f"{running['mv']/denom:.3f}", | |
| "lr": f"{optim.param_groups[0]['lr']:.2e}", | |
| }) | |
| msg = (f"stage{si} ep{ep:02d} it{it:05d}/{steps_per_epoch} | " | |
| f"P={P_e} K={K_e} B={B_e} | " | |
| f"L={running['tot']/denom:.3f} " | |
| f"(proxy={running['proxy']/denom:.3f}, " | |
| f"sup={running['supcon']/denom:.3f}, " | |
| f"mv={running['mv']/denom:.3f}) | " | |
| f"lr={optim.param_groups[0]['lr']:.2e}") | |
| wlog_global(msg) | |
| running = {k:0.0 for k in running} | |
| # ===== 검증 (proxy loss + proxy Top1만) ===== | |
| proxy_top1 = float("nan") | |
| kmeans_acc = float("nan") # 사용 안 하지만 CSV 포맷 때문에 남겨둠 | |
| nmi = float("nan") | |
| ari = float("nan") | |
| knn_r1 = float("nan") | |
| knn_r5 = float("nan") | |
| val_proxy_mean = float("nan") | |
| do_val = (VALIDATE_EVERY <= 0) or (ep % VALIDATE_EVERY == 0) or (ep == stage["epochs"]) | |
| if do_val: | |
| from torch.utils.data.distributed import DistributedSampler | |
| val_sampler = DistributedSampler( | |
| val_ds, | |
| num_replicas=world_size, | |
| rank=rank, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| val_sampler.set_epoch(ep) | |
| val_dl_ddp = DataLoader( | |
| val_ds, | |
| batch_size=B_e, | |
| sampler=val_sampler, | |
| num_workers=min(8, cfg.workers), | |
| pin_memory=True, | |
| collate_fn=collate_triview, | |
| persistent_workers=False, | |
| multiprocessing_context=MP_CTX, | |
| ) | |
| model.eval() | |
| proxy_loss.eval() | |
| local_loss_sum = 0.0 | |
| local_loss_cnt = 0.0 | |
| local_correct = 0.0 | |
| local_total = 0.0 | |
| with torch.no_grad(): | |
| Pn = F.normalize(proxy_loss.proxies.detach(), dim=1).to(device) | |
| with torch.no_grad(), torch.amp.autocast('cuda', dtype=amp_dtype): | |
| for batch in val_dl_ddp: | |
| labels = batch["labels"].to(device, non_blocking=True) | |
| views = { | |
| k: (v.to(device).to(memory_format=torch.channels_last) if v is not None else None) | |
| for k, v in batch["views"].items() | |
| } | |
| masks = {k: v.to(device, non_blocking=True) for k, v in batch["masks"].items()} | |
| z_fused, _, _ = model(views, masks) | |
| L = proxy_loss(z_fused, labels) | |
| z_norm = F.normalize(z_fused, dim=1) | |
| logits = z_norm @ Pn.t() | |
| pred = logits.argmax(dim=1) | |
| correct = (pred == labels).float().sum().item() | |
| bs = float(labels.size(0)) | |
| local_loss_sum += L.item() | |
| local_loss_cnt += 1.0 | |
| local_correct += correct | |
| local_total += bs | |
| t = torch.tensor( | |
| [local_loss_sum, local_loss_cnt, local_correct, local_total], | |
| device=device, | |
| ) | |
| dist.all_reduce(t, op=dist.ReduceOp.SUM) | |
| total_loss_sum = float(t[0].item()) | |
| total_loss_cnt = max(1.0, float(t[1].item())) | |
| total_correct = float(t[2].item()) | |
| total_total = max(1.0, float(t[3].item())) | |
| val_proxy_mean = total_loss_sum / total_loss_cnt | |
| proxy_top1 = total_correct / total_total | |
| if is_main: | |
| print(f"[val] ep{ep:02d} proxy-loss ~ {val_proxy_mean:.3f}, Top1={proxy_top1:.4f}") | |
| wlog_global(f"[val] ep{ep:02d} proxy-loss ~ {val_proxy_mean:.3f}, Top1={proxy_top1:.4f}") | |
| dist.barrier() | |
| _close_dl(val_dl_ddp) | |
| del val_dl_ddp | |
| gc.collect() | |
| time.sleep(0.05) | |
| # ----- Epoch metrics & checkpoint (rank0) ----- | |
| train_mean = ep_sum_tot / steps_per_epoch | |
| train_p = ep_sum_p / steps_per_epoch | |
| train_s = ep_sum_s / steps_per_epoch | |
| train_m = ep_sum_m / steps_per_epoch | |
| write_epoch_metrics( | |
| si, ep, steps_per_epoch, P_e, K_e, | |
| train_mean, train_p, train_s, train_m, | |
| val_proxy_mean, proxy_top1, | |
| knn_r1, knn_r5, | |
| kmeans_acc, nmi, ari, | |
| ) | |
| ck = os.path.join(cfg.out_dir, f"stage{si}_epoch{ep}.pt") | |
| save_ckpt( | |
| ck, model, proxy_loss, optim, sched, | |
| meta=dict( | |
| stage=si, epoch=ep, | |
| P=P_e, K=K_e, steps=steps_per_epoch, | |
| val_every=VALIDATE_EVERY, | |
| proxy_top1=proxy_top1, | |
| knn_r1=knn_r1, knn_r5=knn_r5, | |
| ), | |
| is_main=is_main, | |
| ) | |
| if is_main: | |
| print(f"Saved: {ck}") | |
| wlog_global(f"Saved: {ck}") | |
| _close_dl(train_dl) | |
| del train_dl | |
| gc.collect() | |
| time.sleep(0.1) | |
| dist.destroy_process_group() | |
| if is_main: | |
| print("\n[DDP] Training finished or paused. Checkpoints in:", cfg.out_dir) | |
| print("Logs:", LOG_TXT, " | CSV:", METRICS_CSV) | |
| print("Tip: Re-run this script to RESUME (DDP).") | |
| # ------------------------------ entry point -------------------------------- | |
| def run_ddp_training(): | |
| if not torch.cuda.is_available(): | |
| print("CUDA not available; DDP training requires GPU.") | |
| return | |
| world_size = torch.cuda.device_count() | |
| print(f"[DDP] Launching training on {world_size} GPUs...") | |
| mp.spawn( | |
| ddp_train_worker, | |
| args=(world_size,), | |
| nprocs=world_size, | |
| join=True, | |
| ) | |
| if __name__ == "__main__": | |
| run_ddp_training() | |