ArtistEmbeddingClassifier / scripts /train_style_ddp.py
iljung1106
Disabled loading CUDA on main process
c61411c
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, re, math, random, glob, time, subprocess, sys, zlib, gc, warnings, atexit
from dataclasses import dataclass
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List
import numpy as np
from PIL import Image, ImageFile
from PIL.Image import DecompressionBombWarning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
# tqdm (auto-install if missing)
try:
from tqdm.auto import tqdm
except Exception:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "tqdm"])
from tqdm.auto import tqdm
# ------------------------- Config -------------------------
@dataclass
class Cfg:
data_root: str = "./"
folders: dict = None
stages: list = None
P: int = 16
K: int = 2
embed_dim: int = 256
workers: int = 8
weight_decay: float = 0.01
alpha_proxy: float = 32.0
margin_proxy: float = 0.2
supcon_tau: float = 0.07
mv_tau: float = 0.10
mixstyle_p: float = 0.10
out_dir: str = "./checkpoints_style"
seed: int = 1337
max_steps_per_epoch: Optional[int] = None # None이면 데이터 길이에 따라 자동
print_every: int = 50
use_compile: bool = False
cfg = Cfg(
folders=dict(whole="dataset", face="dataset_face", eyes="dataset_eyes"),
stages=[
dict(sz_whole=224, sz_face=192, sz_eyes=128, epochs=12, lr=3e-4, P=64, K=2),
dict(sz_whole=384, sz_face=320, sz_eyes=192, epochs=12, lr=1.5e-4, P=24, K=2),
dict(sz_whole=512, sz_face=384, sz_eyes=224, epochs=24, lr=8e-5, P=12, K=2),
],
)
# ------------------------- Device & determinism -------------------------
def seed_all(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
seed_all(cfg.seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
# Lazy amp_dtype detection to avoid CUDA init at import time (required for HF Spaces ZeroGPU)
_amp_dtype_cache = None
def _get_amp_dtype():
global _amp_dtype_cache
if _amp_dtype_cache is None:
try:
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
_amp_dtype_cache = torch.bfloat16
else:
_amp_dtype_cache = torch.float16
except Exception:
_amp_dtype_cache = torch.float16
return _amp_dtype_cache
# For backwards compatibility, amp_dtype is accessed via property-like usage
# but we keep a module-level name that can be imported (defaults to float16, updated on first GPU use)
amp_dtype = torch.float16 # safe default; actual dtype picked at runtime via _get_amp_dtype()
# --- PIL safety/verbosity tweaks ---
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = 300_000_000
warnings.filterwarnings("ignore", category=DecompressionBombWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="PIL.TiffImagePlugin")
# ------------------------- Robust multiprocessing for DataLoader -------------------------
def _init_mp_ctx():
method = mp.get_start_method(allow_none=True)
if method is None:
preferred = 'fork' if sys.platform.startswith('linux') else 'spawn'
try:
mp.set_start_method(preferred, force=True)
except Exception:
pass
method = mp.get_start_method(allow_none=True) or preferred
print(f"[mp] using '{method}'.")
return mp.get_context(method)
MP_CTX = _init_mp_ctx()
_DL_TRACK = []
def _track_dl(dl):
_DL_TRACK.append(dl); return dl
def _close_dl(dl):
try:
it = getattr(dl, "_iterator", None)
if it is not None:
it._shutdown_workers()
dl._iterator = None
except Exception:
pass
@atexit.register
def _cleanup_all_dls():
for dl in list(_DL_TRACK):
_close_dl(dl)
_DL_TRACK.clear()
def _should_fallback_workers(err: Exception) -> bool:
s = str(err)
return ("Can't get attribute" in s or
"PicklingError" in s or
("AttributeError" in s and "__main__" in s))
# ------------------------- Helpers -------------------------
def stable_int(s: str) -> int:
return zlib.adler32(s.encode("utf-8")) & 0xffffffff
def l2n(x, eps=1e-8):
return F.normalize(x, dim=-1, eps=eps)
# ------------------------- Dataset -------------------------
class TriViewDataset(Dataset):
"""
- whole / face / eyes 각각에 대해 9:1로 train/val split (경로 해시 기반).
- __getitem__에서는 해당 작가의 view pool에서 랜덤으로 뽑아서 tri-view 구성.
- 파일명 매칭 전혀 사용 X, 작가(label)만 동일하면 아무 이미지나 조합.
- index는 whole 기반으로 만들고, label/gid/path 는 whole 기준.
"""
def __init__(self, root, folders, split="train",
T_whole=None, T_face=None, T_eyes=None):
assert split in ("train", "val")
self.split = split
self.root = Path(root)
self.dirs = {k: self.root / v for k, v in folders.items()}
self.T = dict(whole=T_whole, face=T_face, eyes=T_eyes)
# artist 목록
whole_root = self.dirs["whole"]
artists = sorted([d.name for d in whole_root.iterdir() if d.is_dir()])
self.artist2id = {a: i for i, a in enumerate(artists)}
self.id2artist = {v: k for k, v in self.artist2id.items()}
self.num_classes = len(self.artist2id)
# artist별 view pool (split 별)
self.whole_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()}
self.face_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()}
self.eyes_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()}
def view_split(paths: List[Path], split: str) -> List[Path]:
train_list, val_list = [], []
for p in paths:
h = stable_int(str(p)) % 10
if split == "train":
if h < 9: # 0~8 => train
train_list.append(p)
else:
if h >= 9: # 9 => val
val_list.append(p)
return train_list if split == "train" else val_list
# whole / face / eyes 각각에 대해 artist별 split
for artist_name, aid in self.artist2id.items():
# whole
w_dir = self.dirs["whole"] / artist_name
if w_dir.is_dir():
w_all = sorted([p for p in w_dir.iterdir() if p.is_file()])
else:
w_all = []
self.whole_paths_by_artist[aid] = view_split(w_all, split)
# face
f_dir = self.dirs["face"] / artist_name
if f_dir.is_dir():
f_all = sorted([p for p in f_dir.iterdir() if p.is_file()])
else:
f_all = []
self.face_paths_by_artist[aid] = view_split(f_all, split)
# eyes
e_dir = self.dirs["eyes"] / artist_name
if e_dir.is_dir():
e_all = sorted([p for p in e_dir.iterdir() if p.is_file()])
else:
e_all = []
self.eyes_paths_by_artist[aid] = view_split(e_all, split)
# index: whole 기반 anchor
self.index = []
for aid, w_list in self.whole_paths_by_artist.items():
for wp in w_list:
rec = {
"label": aid,
"whole": str(wp),
"gid": stable_int(str(wp)),
"path": str(wp),
}
self.index.append(rec)
def __len__(self):
return len(self.index)
def _load_one(self, path: Optional[Path], T):
if path is None:
return None
try:
im = Image.open(path).convert("RGB")
except Exception:
return None
if T is not None:
return T(im)
else:
return transforms.ToTensor()(im)
def __getitem__(self, i):
rec = self.index[i]
aid = rec["label"]
W_pool = self.whole_paths_by_artist.get(aid, [])
F_pool = self.face_paths_by_artist.get(aid, [])
E_pool = self.eyes_paths_by_artist.get(aid, [])
pw = random.choice(W_pool) if W_pool else None
pf = random.choice(F_pool) if F_pool else None
pe = random.choice(E_pool) if E_pool else None
xw = self._load_one(pw, self.T["whole"]) if pw is not None else None
xf = self._load_one(pf, self.T["face"]) if pf is not None else None
xe = self._load_one(pe, self.T["eyes"]) if pe is not None else None
gid = torch.tensor([rec["gid"]], dtype=torch.long)
return dict(
whole=xw,
face=xf,
eyes=xe,
label=torch.tensor(aid, dtype=torch.long),
gid=gid,
path=rec["path"],
)
# ------------------------- PK batch sampler -------------------------
class PKBatchSampler(Sampler):
"""P개 클래스 × K개 이미지를 한 배치로 뽑는 샘플러."""
def __init__(self, dataset: TriViewDataset, P: int, K: int):
self.P, self.K = int(P), int(K)
from collections import defaultdict
self.by_cls = defaultdict(list)
for idx, rec in enumerate(dataset.index):
self.by_cls[rec["label"]].append(idx)
self.labels = list(self.by_cls.keys())
for lst in self.by_cls.values():
random.shuffle(lst)
def __iter__(self):
while True:
P, K = self.P, self.K
if len(self.labels) >= P:
classes = random.sample(self.labels, P)
else:
classes = random.choices(self.labels, k=P)
batch = []
for c in classes:
pool = self.by_cls[c]
if len(pool) >= K:
picks = random.sample(pool, K)
else:
picks = [random.choice(pool) for _ in range(K)]
batch.extend(picks)
yield batch
def __len__(self): # not used
return 10**9
# ------------------------- Collate & transforms -------------------------
def collate_triview(batch):
labels = torch.stack([b["label"] for b in batch])
gids = torch.stack([b["gid"] for b in batch]).squeeze(1)
paths = [b["path"] for b in batch]
views, masks = {}, {}
for k in ("whole", "face", "eyes"):
xs = [b[k] for b in batch]
mask = torch.tensor([x is not None for x in xs], dtype=torch.bool)
if any(mask):
ex = next(x for x in xs if x is not None)
zeros = torch.zeros_like(ex)
xs = [x if x is not None else zeros for x in xs]
views[k] = torch.stack(xs, dim=0)
else:
views[k] = None
masks[k] = mask
return dict(views=views, masks=masks, labels=labels, gids=gids, paths=paths)
def make_transforms(sz_w, sz_f, sz_e):
def aug(s):
return transforms.Compose([
transforms.RandomResizedCrop(s, scale=(0.6, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.1, contrast=0.1,
saturation=0.05, hue=0.02),
transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3),
])
return aug(sz_w), aug(sz_f), aug(sz_e)
def make_val_transforms(sz_w, sz_f, sz_e):
def val(s):
return transforms.Compose([
transforms.Resize(int(s*1.15)),
transforms.CenterCrop(s),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3),
])
return val(sz_w), val(sz_f), val(sz_e)
# ------------------------- Model & heads -------------------------
class MixStyle(nn.Module):
def __init__(self, p=0.3, alpha=0.1):
super().__init__()
self.p = p; self.alpha = alpha
def forward(self, x):
if not self.training or self.p <= 0.0:
return x
B,C,H,W = x.shape
mu = x.mean([2,3], keepdim=True)
var = x.var([2,3], unbiased=False, keepdim=True)
sigma = (var+1e-5).sqrt()
perm = torch.randperm(B, device=x.device)
mu2, sigma2 = mu[perm], sigma[perm]
lam = torch.distributions.Beta(self.alpha, self.alpha).sample((B,1,1,1)).to(x.device)
mu_mix = mu*lam + mu2*(1-lam)
sigma_mix = sigma*lam + sigma2*(1-lam)
x_norm = (x - mu)/sigma
apply = (torch.rand(B,1,1,1, device=x.device) < self.p).float()
mixed = x_norm * sigma_mix + mu_mix
return mixed*apply + x*(1-apply)
class SqueezeExcite(nn.Module):
def __init__(self, c, r=16):
super().__init__()
m = max(8, c//r)
self.net = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(c, m, 1), nn.GELU(),
nn.Conv2d(m, c, 1), nn.Sigmoid()
)
def forward(self, x):
return x * self.net(x)
class ConvBlock(nn.Module):
def __init__(self, ci, co, k=3, s=1, p=1):
super().__init__()
self.conv = nn.Conv2d(ci, co, k, s, p, bias=False)
self.gn = nn.GroupNorm(16, co)
self.act = nn.GELU()
def forward(self, x):
return self.act(self.gn(self.conv(x)))
class ResBlock(nn.Module):
def __init__(self, ci, co, down=False, mix=None):
super().__init__()
self.c1 = ConvBlock(ci, co, 3, 1, 1)
self.c2 = ConvBlock(co, co, 3, 1, 1)
self.se = SqueezeExcite(co)
self.down = down
self.pool = nn.AvgPool2d(2) if down else nn.Identity()
self.proj = nn.Conv2d(ci, co, 1, 1, 0, bias=False) if ci != co else nn.Identity()
self.mix = mix
def forward(self, x):
h = self.c1(x)
if self.mix is not None:
h = self.mix(h)
h = self.c2(h)
h = self.se(h)
if self.down:
h = self.pool(h); x = self.pool(x)
return F.gelu(h + self.proj(x))
def matrix_sqrt_newton_schulz(A, iters=5):
B,C,_ = A.shape
normA = A.reshape(B, -1).norm(dim=1).view(B,1,1).clamp(min=1e-8)
Y = A / normA
I = torch.eye(C, device=A.device).expand(B, C, C)
Z = I.clone()
for _ in range(iters):
T = 0.5 * (3.0*I - Z.bmm(Y))
Y = Y.bmm(T)
Z = T.bmm(Z)
return Y * (normA.sqrt())
class GramHead(nn.Module):
def __init__(self, c_in, c_red=64, proj=128):
super().__init__()
self.red = nn.Conv2d(c_in, c_red, 1, bias=False)
self.proj = nn.Linear(c_red*c_red, proj)
def forward(self, x):
f = self.red(x)
B,C,H,W = f.shape
Fm = f.flatten(2)
G = torch.bmm(Fm, Fm.transpose(1,2)) / (H*W)
return self.proj(G.reshape(B, C*C))
class CovISqrtHead(nn.Module):
def __init__(self, c_in, c_red=64, proj=128):
super().__init__()
self.red = nn.Conv2d(c_in, c_red, 1, bias=False)
self.proj = nn.Linear(c_red*c_red, proj)
def forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
f = self.red(x.float())
B,C,H,W = f.shape
Fm = f.flatten(2)
mu = Fm.mean(-1, keepdim=True)
Xc = Fm - mu
cov = torch.bmm(Xc, Xc.transpose(1,2)) / (H*W - 1 + 1e-5)
cov = matrix_sqrt_newton_schulz(cov.float(), iters=5)
return self.proj(cov.reshape(B, C*C))
def spectrum_hist(x, K=16, O=8):
B,C,H,W = x.shape
spec = torch.fft.rfft2(x, norm='ortho').abs().mean(1)
H2, W2 = spec.shape[-2], spec.shape[-1]
yy, xx = torch.meshgrid(
torch.linspace(-1, 1, H2, device=x.device),
torch.linspace(0, 1, W2, device=x.device),
indexing="ij"
)
rr = (yy**2 + xx**2).sqrt().clamp(0, 1 - 1e-8)
th = (torch.atan2(yy, xx + 1e-9) + math.pi/2)
rb = (rr * K).long().clamp(0, K-1)
ob = (th / math.pi * O).long().clamp(0, O-1)
mag = torch.log1p(spec)
rad = torch.zeros(B, K, device=x.device)
ang = torch.zeros(B, O, device=x.device)
rbf = rb.reshape(-1); obf = ob.reshape(-1)
for b in range(B):
m = mag[b].reshape(-1)
rad[b].scatter_add_(0, rbf, m)
ang[b].scatter_add_(0, obf, m)
rad = rad / (rad.sum(-1, keepdim=True)+1e-6)
ang = ang / (ang.sum(-1, keepdim=True)+1e-6)
return torch.cat([rad, ang], dim=1)
class SpectrumHead(nn.Module):
def __init__(self, c_in, proj=64, K=16, O=8):
super().__init__()
self.proj = nn.Linear(K+O, proj)
def forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
h = spectrum_hist(x.float())
return self.proj(h)
class StatsHead(nn.Module):
def __init__(self, c_in, proj=64):
super().__init__()
c = min(64, c_in)
self.red = nn.Conv2d(c_in, c, 1, bias=False)
self.mlp = nn.Sequential(
nn.Linear(c*2, 128),
nn.GELU(),
nn.Linear(128, proj),
)
def forward(self, x):
f = self.red(x)
mu = f.mean([2,3])
lv = torch.log(f.var([2,3], unbiased=False)+1e-5)
return self.mlp(torch.cat([mu, lv], dim=1))
class ViewEncoder(nn.Module):
"""
- Normalize([0.5],[0.5])된 RGB 입력
- RGB -> Lab 변환
- backbone + 스타일 헤드 4개 (Gram/Cov/Spectrum/Stats)
- 브랜치 attention
"""
def __init__(self, mix_p=0.3, out_dim=256):
super().__init__()
self.mix = MixStyle(p=mix_p, alpha=0.1)
ch = [32, 64, 128, 192, 256]
self.stem = nn.Sequential(
ConvBlock(3, ch[0], 3, 1, 1),
ConvBlock(ch[0], ch[0], 3, 1, 1),
)
self.b1 = ResBlock(ch[0], ch[1], down=True, mix=self.mix)
self.b2 = ResBlock(ch[1], ch[2], down=True, mix=self.mix)
self.b3 = ResBlock(ch[2], ch[3], down=True, mix=None)
self.b4 = ResBlock(ch[3], ch[4], down=True, mix=None)
self.h_gram3 = GramHead(ch[3])
self.h_cov3 = CovISqrtHead(ch[3])
self.h_sp3 = SpectrumHead(ch[3])
self.h_st3 = StatsHead(ch[3])
self.h_gram4 = GramHead(ch[4])
self.h_cov4 = CovISqrtHead(ch[4])
self.h_sp4 = SpectrumHead(ch[4])
self.h_st4 = StatsHead(ch[4])
fdim = (128+128+64+64)*2 # 768
self.fdim = fdim
self.branch_gate = nn.Sequential(
nn.LayerNorm(fdim),
nn.Linear(fdim, 4, bias=True),
)
self.fuse = nn.Sequential(
nn.Linear(fdim, 512),
nn.GELU(),
nn.Linear(512, out_dim),
)
def _rgb_to_lab(self, x: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast('cuda', enabled=False):
x_f = x.float()
rgb = (x_f * 0.5 + 0.5).clamp(0.0, 1.0)
thresh = 0.04045
low = rgb / 12.92
high = ((rgb + 0.055) / 1.055).pow(2.4)
rgb_lin = torch.where(rgb <= thresh, low, high)
rgb_lin = rgb_lin.permute(0, 2, 3, 1)
M = rgb_lin.new_tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041],
])
xyz = torch.matmul(rgb_lin, M.T)
Xn, Yn, Zn = 0.95047, 1.00000, 1.08883
xyz = xyz / rgb_lin.new_tensor([Xn, Yn, Zn])
eps = 0.008856
kappa = 903.3
def f(t):
t = t.clamp(min=1e-6)
return torch.where(
t > eps,
t.pow(1.0 / 3.0),
(kappa * t + 16.0) / 116.0,
)
f_xyz = f(xyz)
fx, fy, fz = f_xyz[..., 0], f_xyz[..., 1], f_xyz[..., 2]
L = 116.0 * fy - 16.0
a = 500.0 * (fx - fy)
b = 200.0 * (fy - fz)
L_scaled = L / 100.0
a_scaled = (a + 128.0) / 255.0
b_scaled = (b + 128.0) / 255.0
lab = torch.stack([L_scaled, a_scaled, b_scaled], dim=-1)
lab = lab.permute(0, 3, 1, 2)
return lab.to(dtype=x.dtype)
def forward(self, x):
x_lab = self._rgb_to_lab(x)
f0 = self.stem(x_lab)
f1 = self.b1(f0)
f2 = self.b2(f1)
f3 = self.b3(f2)
f4 = self.b4(f3)
g3 = self.h_gram3(f3)
c3 = self.h_cov3(f3)
sp3 = self.h_sp3(f3)
st3 = self.h_st3(f3)
g4 = self.h_gram4(f4)
c4 = self.h_cov4(f4)
sp4 = self.h_sp4(f4)
st4 = self.h_st4(f4)
b_gram = torch.cat([g3, g4], dim=1)
b_cov = torch.cat([c3, c4], dim=1)
b_sp = torch.cat([sp3, sp4], dim=1)
b_st = torch.cat([st3, st4], dim=1)
flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1) # [B,768]
gate_logits = self.branch_gate(flat)
w = torch.softmax(gate_logits, dim=-1)
w0, w1, w2, w3 = w[:,0:1], w[:,1:2], w[:,2:3], w[:,3:4]
flat_weighted = torch.cat([
b_gram * w0,
b_cov * w1,
b_sp * w2,
b_st * w3,
], dim=1)
view_vec = self.fuse(flat_weighted)
return view_vec
class TriViewStyleNet(nn.Module):
def __init__(self, out_dim=256, mix_p=0.3, share_backbone: bool = True):
super().__init__()
if share_backbone:
shared = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
self.enc_whole = shared
self.enc_face = shared
self.enc_eyes = shared
else:
self.enc_whole = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
self.enc_face = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
self.enc_eyes = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
self.view_gate = nn.Sequential(
nn.LayerNorm(out_dim),
nn.Linear(out_dim, 1, bias=True),
)
def forward(self, views, masks):
outs, alphas = {}, []
for k, enc in (("whole", self.enc_whole),
("face", self.enc_face),
("eyes", self.enc_eyes)):
if views[k] is None:
outs[k] = None
alphas.append(None)
continue
vk = enc(views[k].to(memory_format=torch.channels_last))
outs[k] = l2n(vk)
score = self.view_gate(outs[k]).squeeze(1)
score = torch.where(
masks[k].to(score.device),
score,
torch.full_like(score, -1e4),
)
alphas.append(score)
scores = [a for a in alphas if a is not None]
if len(scores) == 0:
raise RuntimeError("All views are missing in this batch.")
A = torch.stack(scores, dim=1) # [B, num_views]
W = F.softmax(A, dim=1)
present = [outs[k] for k in ("whole","face","eyes") if outs[k] is not None]
Z = torch.stack(present, dim=1) # [B, num_views, dim]
fused = l2n((W.unsqueeze(-1) * Z).sum(dim=1)) # [B, dim]
return fused, outs, W
# ------------------------- Losses -------------------------
class ProxyAnchorLoss(nn.Module):
def __init__(self, num_classes, dim, alpha=16.0, margin=0.1, neg_weight=0.25):
super().__init__()
self.proxies = nn.Parameter(torch.randn(num_classes, dim))
nn.init.normal_(self.proxies, std=0.01)
self.alpha = float(alpha)
self.margin = float(margin)
self.neg_weight = float(neg_weight)
def forward(self, z, y):
with torch.amp.autocast('cuda', enabled=False):
z = F.normalize(z.float(), dim=-1)
P = F.normalize(self.proxies.float(), dim=-1)
sim = z @ P.t()
C = sim.size(1)
yOH = F.one_hot(y, num_classes=C).float()
pos_e = torch.clamp(-self.alpha * (sim - self.margin),
min=-60.0, max=60.0)
neg_e = torch.clamp( self.alpha * (sim + self.margin),
min=-60.0, max=60.0)
pos_term = torch.exp(pos_e) * yOH
neg_term = torch.exp(neg_e) * (1.0 - yOH)
pos_sum = pos_term.sum(0)
neg_sum = neg_term.sum(0)
num_pos = (yOH.sum(0) > 0)
L_pos = torch.log1p(pos_sum[num_pos]).sum() / (num_pos.sum().clamp_min(1.0))
L_neg = torch.log1p(neg_sum).sum() / C
return L_pos + self.neg_weight * L_neg
class SupConLoss(nn.Module):
def __init__(self, tau=0.07):
super().__init__()
self.tau = tau
def forward(self, feats, labels):
feats = l2n(feats)
sim = feats @ feats.t() / self.tau
logits = sim - torch.eye(sim.size(0), device=sim.device) * 1e9
pos_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)) & \
(~torch.eye(len(labels), device=labels.device, dtype=torch.bool))
numer = (torch.exp(logits) * pos_mask).sum(1)
denom = torch.exp(logits).sum(1).clamp_min(1e-8)
valid = (pos_mask.sum(1) > 0)
loss = -torch.log((numer+1e-12) / denom)
return (loss[valid].mean() if valid.any() else torch.tensor(0.0, device=feats.device))
class MultiViewInfoNCE(nn.Module):
def __init__(self, tau=0.1):
super().__init__()
self.tau = tau
def forward(self, feats, gids):
feats = l2n(feats)
sim = feats @ feats.t() / self.tau
logits = sim - torch.eye(sim.size(0), device=sim.device) * 1e9
pos_mask = (gids.unsqueeze(1) == gids.unsqueeze(0)) & \
(~torch.eye(len(gids), device=gids.device, dtype=torch.bool))
numer = (torch.exp(logits) * pos_mask).sum(1)
denom = torch.exp(logits).sum(1).clamp_min(1e-8)
valid = (pos_mask.sum(1) > 0)
loss = -torch.log((numer+1e-12) / denom)
return (loss[valid].mean() if valid.any() else torch.tensor(0.0, device=feats.device))
# --------------------- Logging / checkpoints / schedulers -----------------
os.makedirs(cfg.out_dir, exist_ok=True)
LOG_TXT = os.path.join(cfg.out_dir, "train.log")
METRICS_CSV = os.path.join(cfg.out_dir, "metrics_epoch.csv")
if not os.path.exists(METRICS_CSV):
with open(METRICS_CSV, "w", encoding="utf-8") as f:
f.write("timestamp,stage,epoch,steps,P,K,train_loss,train_proxy,train_sup,train_mv,"
"val_proxy,proxy_top1,knn_r1,knn_r5,kmeans_acc,nmi,ari\n")
def wlog_global(msg, also_print=False):
ts_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
line = f"[{ts_str}] {msg}"
with open(LOG_TXT, "a", encoding="utf-8", buffering=1) as _logf:
_logf.write(line + "\n")
if also_print:
tqdm.write(line)
def write_epoch_metrics(stage_i, epoch_i, steps, P, K,
tr_mean, tr_p, tr_s, tr_m,
val_proxy, proxy_top1,
knn_r1, knn_r5,
kmeans_acc, nmi, ari):
ts_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def fmt(x):
if x is None:
return "nan"
try:
if hasattr(x, "item"):
x = float(x.item())
else:
x = float(x)
except Exception:
return "nan"
if np.isnan(x) or np.isinf(x):
return "nan"
return f"{x:.6f}"
with open(METRICS_CSV, "a", encoding="utf-8") as fh:
fh.write(
f"{ts_str},{stage_i},{epoch_i},{steps},{P},{K},"
f"{fmt(tr_mean)},{fmt(tr_p)},{fmt(tr_s)},{fmt(tr_m)},"
f"{fmt(val_proxy)},{fmt(proxy_top1)},"
f"{fmt(knn_r1)},{fmt(knn_r5)},"
f"{fmt(kmeans_acc)},{fmt(nmi)},{fmt(ari)}\n"
)
def save_ckpt(path, model, proxy_loss, optim, sched, meta, is_main: bool):
if not is_main:
return
base_model = model.module if isinstance(model, nn.parallel.DistributedDataParallel) else model
torch.save({
"model": base_model.state_dict(),
"proxies": proxy_loss.state_dict(),
"optim": optim.state_dict() if optim else None,
"sched": sched.state_dict() if sched else None,
"meta": meta,
}, path)
def find_latest_checkpoint(out_dir):
paths = glob.glob(os.path.join(out_dir, "stage*_epoch*.pt"))
best, best_stage, best_epoch = None, -1, -1
for p in paths:
m = re.search(r"stage(\d+)_epoch(\d+)\.pt$", os.path.basename(p))
if not m:
continue
si, ep = int(m.group(1)), int(m.group(2))
if (si > best_stage) or (si == best_stage and ep > best_epoch):
best, best_stage, best_epoch = p, si, ep
return best, best_stage, best_epoch
def _pick_from_schedule(sched, default_val, ep):
if not sched:
return int(default_val)
if isinstance(sched, dict):
items = sorted([(int(k), int(v)) for k,v in sched.items()], key=lambda x: x[0])
else:
items = sorted([(int(k), int(v)) for k,v in sched], key=lambda x: x[0])
val = int(default_val)
for k,v in items:
if ep >= k:
val = int(v)
return int(val)
def resolve_epoch_PK(stage: dict, ep: int):
P = int(stage.get("P", cfg.P))
K = int(stage.get("K", cfg.K))
P = _pick_from_schedule(stage.get("P_schedule"), P, ep)
K = _pick_from_schedule(stage.get("K_schedule"), K, ep)
bs_sched = stage.get("bs_schedule")
if bs_sched:
bs = _pick_from_schedule(bs_sched, P*K, ep)
if bs % K != 0:
wlog_global(f"[batch] bs_schedule value {bs} not divisible by K={K}; rounding down to {bs//K*K}", also_print=True)
bs = (bs // K) * K
P = max(1, bs // K)
return int(P), int(K)
def estimate_steps_per_epoch(train_len: int, global_batch: int, max_steps: Optional[int]):
if max_steps is not None:
return int(max_steps)
return max(1, math.ceil(train_len / max(1, global_batch)))
def build_train_loader(ds: TriViewDataset, P: int, K: int):
bs = PKBatchSampler(ds, P, K)
dl = DataLoader(
ds,
batch_sampler=bs,
num_workers=cfg.workers,
pin_memory=True,
collate_fn=collate_triview,
persistent_workers=False,
prefetch_factor=2 if cfg.workers > 0 else None,
multiprocessing_context=MP_CTX,
)
return _track_dl(dl)
def make_cosine_with_warmup(optimizer, warmup_steps, total_steps):
def lr_lambda(step):
if step < warmup_steps:
return float(step) / max(1, warmup_steps)
rem = max(1, total_steps - warmup_steps)
progress = (step - warmup_steps) / rem
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
# ------------------------------ DDP worker --------------------------------
def ddp_train_worker(rank: int, world_size: int):
torch.cuda.set_device(rank)
device = torch.device("cuda", rank)
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
seed_all(cfg.seed + rank)
is_main = (rank == 0)
# class count
artists_dir = os.path.join(cfg.data_root, cfg.folders['whole'])
num_classes_total = len([
d for d in os.listdir(artists_dir)
if os.path.isdir(os.path.join(artists_dir, d))
])
if is_main:
wlog_global(f"[DDP] world_size={world_size}, num_classes_total={num_classes_total}", also_print=True)
# model & losses
base_model = TriViewStyleNet(
out_dim=cfg.embed_dim,
mix_p=cfg.mixstyle_p,
share_backbone=True,
).to(device)
base_model = base_model.to(memory_format=torch.channels_last)
if cfg.use_compile and hasattr(torch, "compile"):
try:
base_model = torch.compile(base_model, mode="reduce-overhead", fullgraph=False)
except Exception:
pass
model = nn.parallel.DistributedDataParallel(
base_model,
device_ids=[rank],
output_device=rank,
find_unused_parameters=False,
)
proxy_loss = ProxyAnchorLoss(
num_classes=num_classes_total,
dim=cfg.embed_dim,
alpha=cfg.alpha_proxy,
margin=cfg.margin_proxy,
neg_weight=0.25,
).to(device)
supcon = SupConLoss(tau=cfg.supcon_tau).to(device)
mv_infonce = MultiViewInfoNCE(tau=cfg.mv_tau).to(device)
# resume
resume_info = None
ckpt_path, ck_stage, ck_epoch = find_latest_checkpoint(cfg.out_dir)
if ckpt_path is not None:
ck = torch.load(ckpt_path, map_location="cpu")
try:
model.module.load_state_dict(ck["model"], strict=False)
except Exception as e:
if is_main:
wlog_global(f"[resume] WARNING: model state load failed: {e}", also_print=True)
try:
proxy_loss.load_state_dict(ck["proxies"])
except Exception as e:
if is_main:
wlog_global(f"[resume] WARNING: proxy state load failed: {e}", also_print=True)
meta = ck.get("meta", {})
last_stage = int(meta.get("stage", ck_stage or 1))
last_epoch = int(meta.get("epoch", ck_epoch or 0))
start_stage = last_stage
start_epoch = last_epoch + 1
if start_stage <= len(cfg.stages) and start_epoch > cfg.stages[start_stage-1]["epochs"]:
start_stage += 1
start_epoch = 1
resume_info = dict(
ckpt=ck,
path=ckpt_path,
last_stage=last_stage,
last_epoch=last_epoch,
start_stage=start_stage,
start_epoch=start_epoch,
)
if is_main:
wlog_global(
f"[resume] Found {ckpt_path} (stage {last_stage}, epoch {last_epoch}). "
f"Resuming at stage {start_stage}, epoch {start_epoch}.",
also_print=True,
)
else:
if is_main:
wlog_global("[resume] No checkpoint found; training from scratch.", also_print=True)
scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())
global_step = 0
proxy_lr_mult = 5.0
RAMP_EPOCHS = 3
WARMUP_EPOCHS = 1
VALIDATE_EVERY = 4 # N epoch마다 검증
from tqdm.auto import tqdm as tqdm_local
# Stage loop
for si, stage in enumerate(cfg.stages, 1):
if resume_info and si < resume_info["start_stage"]:
if is_main:
wlog_global(f"[resume] Skipping stage {si}; already completed.", also_print=True)
continue
# datasets per stage
T_w_tr, T_f_tr, T_e_tr = make_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
T_w_val, T_f_val, T_e_val = make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
train_ds = TriViewDataset(cfg.data_root, cfg.folders, split="train",
T_whole=T_w_tr, T_face=T_f_tr, T_eyes=T_e_tr)
val_ds = TriViewDataset(cfg.data_root, cfg.folders, split="val",
T_whole=T_w_val, T_face=T_f_val, T_eyes=T_e_val)
# steps_per_epoch schedule (global batch 기준)
steps_list = []
for ep_tmp in range(1, stage["epochs"]+1):
P_tmp, K_tmp = resolve_epoch_PK(stage, ep_tmp)
global_batch = P_tmp * K_tmp * world_size
steps = estimate_steps_per_epoch(
len(train_ds),
global_batch,
cfg.max_steps_per_epoch,
)
steps_list.append(steps)
total_steps_stage = int(sum(steps_list))
warmup_steps = int(steps_list[0] * WARMUP_EPOCHS)
params = [
{"params": model.parameters(), "lr": stage["lr"]},
{"params": proxy_loss.parameters(), "lr": stage["lr"] * proxy_lr_mult},
]
optim = torch.optim.AdamW(params, weight_decay=cfg.weight_decay)
sched = make_cosine_with_warmup(optim, warmup_steps=warmup_steps, total_steps=total_steps_stage)
start_ep = 1
if resume_info and si == resume_info["start_stage"]:
start_ep = resume_info["start_epoch"]
if resume_info["last_stage"] == si and start_ep > 1:
try:
if resume_info["ckpt"].get("optim") is not None:
optim.load_state_dict(resume_info["ckpt"]["optim"])
if resume_info["ckpt"].get("sched") is not None:
sched.load_state_dict(resume_info["ckpt"]["sched"])
if is_main:
wlog_global(f"[resume] Loaded optimizer/scheduler from {resume_info['path']}.", also_print=True)
except Exception as e:
if is_main:
wlog_global(f"[resume] WARNING: could not load optimizer/scheduler state: {e}", also_print=True)
stage_msg = (f"\n=== [DDP] Stage {si}/{len(cfg.stages)} :: "
f"wh/face/eyes={stage['sz_whole']}/{stage['sz_face']}/{stage['sz_eyes']} | "
f"epochs={stage['epochs']} | lr={stage['lr']} | classes={num_classes_total} ===")
if is_main:
print(stage_msg)
wlog_global(stage_msg)
# epoch loop
for ep in range(start_ep, stage["epochs"]+1):
P_e, K_e = resolve_epoch_PK(stage, ep)
B_e = P_e * K_e # local batch
train_dl = build_train_loader(train_ds, P_e, K_e)
steps_per_epoch = steps_list[ep-1]
model.train()
proxy_loss.train()
running = {"proxy":0.0, "supcon":0.0, "mv":0.0, "tot":0.0}
ep_sum_tot = ep_sum_p = ep_sum_s = ep_sum_m = 0.0
ramp = min(1.0, ep / RAMP_EPOCHS)
if is_main:
tbar = tqdm_local(range(1, steps_per_epoch+1),
desc=f"[train-DDP] stage{si} ep{ep} (P={P_e},K={K_e},B={B_e},rank={rank})",
leave=True)
else:
tbar = range(1, steps_per_epoch+1)
train_iter = iter(train_dl)
for it in tbar:
try:
batch = next(train_iter)
except Exception as e:
if _should_fallback_workers(e) and cfg.workers > 0:
if is_main:
print("[mp] Worker pickling error detected. Rebuilding loaders with num_workers=0.")
cfg.workers = 0
train_dl = build_train_loader(train_ds, P_e, K_e)
train_iter = iter(train_dl)
batch = next(train_iter)
else:
raise
labels = batch["labels"].to(device, non_blocking=True)
gids = batch["gids"].to(device, non_blocking=True)
views = {
k: (v.to(device, non_blocking=True).to(memory_format=torch.channels_last)
if v is not None else None)
for k,v in batch["views"].items()
}
masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
with torch.amp.autocast('cuda', dtype=_get_amp_dtype()):
z_fused, z_views_dict, W = model(views, masks)
Z_all, Y_all, G_all = [], [], []
for vk in ("whole","face","eyes"):
zk = z_views_dict.get(vk)
if zk is None:
continue
mk = masks[vk]
if mk.any():
Z_all.append(zk[mk])
Y_all.append(labels[mk])
G_all.append(gids[mk])
if len(Z_all) == 0:
Z_all, Y_all, G_all = [z_fused], [labels], [gids]
Z_all = torch.cat(Z_all, dim=0)
Y_all = torch.cat(Y_all, dim=0)
G_all = torch.cat(G_all, dim=0)
L_proxy = proxy_loss(z_fused, labels)
L_sup = supcon(Z_all, Y_all)
L_mv = mv_infonce(Z_all, G_all)
L_total = L_proxy + (0.5 * ramp) * L_sup + (0.5 * ramp) * L_mv
optim.zero_grad(set_to_none=True)
scaler.scale(L_total).backward()
scaler.step(optim)
scaler.update()
sched.step()
global_step += 1
running["proxy"] += L_proxy.item()
running["supcon"] += L_sup.item()
running["mv"] += L_mv.item()
running["tot"] += L_total.item()
ep_sum_tot += L_total.item()
ep_sum_p += L_proxy.item()
ep_sum_s += L_sup.item()
ep_sum_m += L_mv.item()
if is_main and (it % cfg.print_every == 0 or it == steps_per_epoch):
denom = min(cfg.print_every, it % cfg.print_every or cfg.print_every)
tbar.set_postfix({
"L": f"{running['tot']/denom:.3f}",
"proxy": f"{running['proxy']/denom:.3f}",
"sup": f"{running['supcon']/denom:.3f}",
"mv": f"{running['mv']/denom:.3f}",
"lr": f"{optim.param_groups[0]['lr']:.2e}",
})
msg = (f"stage{si} ep{ep:02d} it{it:05d}/{steps_per_epoch} | "
f"P={P_e} K={K_e} B={B_e} | "
f"L={running['tot']/denom:.3f} "
f"(proxy={running['proxy']/denom:.3f}, "
f"sup={running['supcon']/denom:.3f}, "
f"mv={running['mv']/denom:.3f}) | "
f"lr={optim.param_groups[0]['lr']:.2e}")
wlog_global(msg)
running = {k:0.0 for k in running}
# ===== 검증 (proxy loss + proxy Top1만) =====
proxy_top1 = float("nan")
kmeans_acc = float("nan") # 사용 안 하지만 CSV 포맷 때문에 남겨둠
nmi = float("nan")
ari = float("nan")
knn_r1 = float("nan")
knn_r5 = float("nan")
val_proxy_mean = float("nan")
do_val = (VALIDATE_EVERY <= 0) or (ep % VALIDATE_EVERY == 0) or (ep == stage["epochs"])
if do_val:
from torch.utils.data.distributed import DistributedSampler
val_sampler = DistributedSampler(
val_ds,
num_replicas=world_size,
rank=rank,
shuffle=False,
drop_last=False,
)
val_sampler.set_epoch(ep)
val_dl_ddp = DataLoader(
val_ds,
batch_size=B_e,
sampler=val_sampler,
num_workers=min(8, cfg.workers),
pin_memory=True,
collate_fn=collate_triview,
persistent_workers=False,
multiprocessing_context=MP_CTX,
)
model.eval()
proxy_loss.eval()
local_loss_sum = 0.0
local_loss_cnt = 0.0
local_correct = 0.0
local_total = 0.0
with torch.no_grad():
Pn = F.normalize(proxy_loss.proxies.detach(), dim=1).to(device)
with torch.no_grad(), torch.amp.autocast('cuda', dtype=amp_dtype):
for batch in val_dl_ddp:
labels = batch["labels"].to(device, non_blocking=True)
views = {
k: (v.to(device).to(memory_format=torch.channels_last) if v is not None else None)
for k, v in batch["views"].items()
}
masks = {k: v.to(device, non_blocking=True) for k, v in batch["masks"].items()}
z_fused, _, _ = model(views, masks)
L = proxy_loss(z_fused, labels)
z_norm = F.normalize(z_fused, dim=1)
logits = z_norm @ Pn.t()
pred = logits.argmax(dim=1)
correct = (pred == labels).float().sum().item()
bs = float(labels.size(0))
local_loss_sum += L.item()
local_loss_cnt += 1.0
local_correct += correct
local_total += bs
t = torch.tensor(
[local_loss_sum, local_loss_cnt, local_correct, local_total],
device=device,
)
dist.all_reduce(t, op=dist.ReduceOp.SUM)
total_loss_sum = float(t[0].item())
total_loss_cnt = max(1.0, float(t[1].item()))
total_correct = float(t[2].item())
total_total = max(1.0, float(t[3].item()))
val_proxy_mean = total_loss_sum / total_loss_cnt
proxy_top1 = total_correct / total_total
if is_main:
print(f"[val] ep{ep:02d} proxy-loss ~ {val_proxy_mean:.3f}, Top1={proxy_top1:.4f}")
wlog_global(f"[val] ep{ep:02d} proxy-loss ~ {val_proxy_mean:.3f}, Top1={proxy_top1:.4f}")
dist.barrier()
_close_dl(val_dl_ddp)
del val_dl_ddp
gc.collect()
time.sleep(0.05)
# ----- Epoch metrics & checkpoint (rank0) -----
train_mean = ep_sum_tot / steps_per_epoch
train_p = ep_sum_p / steps_per_epoch
train_s = ep_sum_s / steps_per_epoch
train_m = ep_sum_m / steps_per_epoch
write_epoch_metrics(
si, ep, steps_per_epoch, P_e, K_e,
train_mean, train_p, train_s, train_m,
val_proxy_mean, proxy_top1,
knn_r1, knn_r5,
kmeans_acc, nmi, ari,
)
ck = os.path.join(cfg.out_dir, f"stage{si}_epoch{ep}.pt")
save_ckpt(
ck, model, proxy_loss, optim, sched,
meta=dict(
stage=si, epoch=ep,
P=P_e, K=K_e, steps=steps_per_epoch,
val_every=VALIDATE_EVERY,
proxy_top1=proxy_top1,
knn_r1=knn_r1, knn_r5=knn_r5,
),
is_main=is_main,
)
if is_main:
print(f"Saved: {ck}")
wlog_global(f"Saved: {ck}")
_close_dl(train_dl)
del train_dl
gc.collect()
time.sleep(0.1)
dist.destroy_process_group()
if is_main:
print("\n[DDP] Training finished or paused. Checkpoints in:", cfg.out_dir)
print("Logs:", LOG_TXT, " | CSV:", METRICS_CSV)
print("Tip: Re-run this script to RESUME (DDP).")
# ------------------------------ entry point --------------------------------
def run_ddp_training():
if not torch.cuda.is_available():
print("CUDA not available; DDP training requires GPU.")
return
world_size = torch.cuda.device_count()
print(f"[DDP] Launching training on {world_size} GPUs...")
mp.spawn(
ddp_train_worker,
args=(world_size,),
nprocs=world_size,
join=True,
)
if __name__ == "__main__":
run_ddp_training()