e6-visual-ratings / train.py
taigasan's picture
Pin dataset revision
9091fa0
Raw
History Blame Contribute Delete
17.5 kB
#!/usr/bin/env python3
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import glob
import os
import random
import sys
from huggingface_hub import snapshot_download
import pandas as pd
from PIL import Image
from safetensors.torch import load_file, save_file
import schedulefree
import torch
import wandb
from torch import nn
from torch.utils.data import Dataset
from tqdm.auto import tqdm
import torchvision.transforms.v2 as v2
from model import DINOv3ViTH, TaggerAestheticModel, _split_and_clean_state_dict
PATCH_SIZE = 16
N_REGISTERS = 4
MAX_SIZE = 512
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMAGE_PREPROCESS_WORKERS = 16
random.seed(3407)
torch.set_grad_enabled(True)
torch.set_float32_matmul_precision("high")
def load_votes_and_split():
repo_id = "taigasan/e6-visual-ratings"
local_repo_path = snapshot_download(repo_id=repo_id, repo_type="dataset", revision="daaa857ffab11075c2fc6912e7f23879d324dcc9")
print("Downloaded repo to:", local_repo_path)
rating_log_dir = os.path.join(local_repo_path, "ratings_log")
parquet_files = sorted(glob.glob(os.path.join(rating_log_dir, "*.parquet")))
assert len(parquet_files) > 0
df_list = [pd.read_parquet(path) for path in parquet_files]
combined_df = pd.concat(df_list, ignore_index=True)
print("Total votes before pool filter:", len(combined_df))
pool_path = os.path.join(local_repo_path, "pool.parquet")
pool_df = pd.read_parquet(pool_path)
assert "md5" in pool_df.columns
pool_md5_list = sorted(pool_df["md5"].astype(str).tolist())
valid_md5 = set(pool_md5_list)
combined_df = combined_df[
combined_df["md5a"].isin(valid_md5) & combined_df["md5b"].isin(valid_md5)
].reset_index(drop=True)
print("Total votes after pool filter:", len(combined_df))
print("Pool rows:", len(pool_df))
df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)
df_first = df.iloc[:2000].reset_index(drop=True)
df_second = df.iloc[2000:].reset_index(drop=True)
print("Val rows:", len(df_first))
print("Train rows:", len(df_second))
return df_first, df_second, pool_md5_list
def _snap(x: int, m: int) -> int:
return max(m, (x // m) * m)
def preprocess_pil(img: Image.Image, max_size: int = MAX_SIZE) -> torch.Tensor:
img = img.convert("RGB")
w, h = img.size
long_edge = max(w, h)
target_long = _snap(min(long_edge, max_size), PATCH_SIZE)
scale = target_long / long_edge
new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
return v2.Compose(
[
v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
)(img).unsqueeze(0)
def find_local_image_path(md5: str, image_dir) -> Path:
image_dir = Path(image_dir)
for ext in ("webp", "jpg", "jpeg", "png"):
p = image_dir / f"{md5}.{ext}"
if p.exists():
return p
raise AssertionError(f"Missing local image for md5={md5}; checked dir={image_dir}")
def load_and_preprocess_image(item):
md5, image_dir = item
p = find_local_image_path(md5, image_dir)
with Image.open(p) as img:
pixel_values = preprocess_pil(img)[0]
h = int(pixel_values.shape[1])
w = int(pixel_values.shape[2])
return md5, (h, w), pixel_values
def pool_features(outputs) -> torch.Tensor:
hidden = outputs if isinstance(outputs, torch.Tensor) else outputs.last_hidden_state
assert hidden.ndim == 3 and hidden.shape[1] >= 1 + N_REGISTERS
cls = hidden[:, 0, :]
regs = hidden[:, 1 : 1 + N_REGISTERS, :].flatten(1) # [b, reg, d] -> [b, reg*d]
return torch.cat([cls, regs], dim=-1).to(torch.float32)
class PreferenceDataset(Dataset):
def __init__(self, pair_df, embed_cache):
self.embed_cache = embed_cache
pairs = []
has_a_won = False
has_b_won = False
for row in pair_df.itertuples(index=False):
md5a = row.md5a
md5b = row.md5b
winner_md5 = row.winner_md5
if not winner_md5:
continue
if winner_md5 == md5a:
outcome = 1
has_a_won = True
else:
outcome = 0
has_b_won = True
assert md5a in self.embed_cache
assert md5b in self.embed_cache
pairs.append(
{
"md5a": md5a,
"md5b": md5b,
"outcome": outcome,
}
)
assert has_a_won
assert has_b_won
self.pairs = pairs
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
sample = self.pairs[idx]
md5a = sample["md5a"]
md5b = sample["md5b"]
outcome = sample["outcome"]
embed_a = self.embed_cache[md5a]
embed_b = self.embed_cache[md5b]
outcome = torch.tensor([outcome], dtype=torch.float32)
return {
"embed_a": embed_a,
"embed_b": embed_b,
"outcome": outcome,
}
def build_feature_table(md5_list, image_dir, backbone, batch_size=64):
out = {}
total = len(md5_list)
assert total > 0
progress_step = max(1, total // 100)
next_progress = progress_step
feature_dim = None
backbone.eval()
with torch.no_grad():
for start in tqdm(range(0, total, batch_size), total=(total + batch_size - 1) // batch_size, leave=True):
batch_md5 = md5_list[start : start + batch_size]
size_to_md5 = defaultdict(list)
size_to_tensors = defaultdict(list)
worker_count = min(IMAGE_PREPROCESS_WORKERS, len(batch_md5))
assert worker_count > 0
with ThreadPoolExecutor(max_workers=worker_count) as executor:
preprocessed = list(executor.map(load_and_preprocess_image, [(m, image_dir) for m in batch_md5]))
for md5, size, pixel_values in preprocessed:
size_to_md5[size].append(md5)
size_to_tensors[size].append(pixel_values)
for size in size_to_md5:
md5_group = size_to_md5[size]
tensor_group = size_to_tensors[size]
for gstart in range(0, len(md5_group), batch_size):
gend = min(gstart + batch_size, len(md5_group))
pixel_values = torch.stack(tensor_group[gstart:gend], dim=0).to("cuda")
outputs = backbone(pixel_values=pixel_values)
features = pool_features(outputs).cpu()
assert features.ndim == 2 and features.shape[0] == (gend - gstart)
assert torch.isfinite(features).all(), f"Non-finite features in size group {size}"
if feature_dim is None:
feature_dim = int(features.shape[1])
assert int(features.shape[1]) == feature_dim
for md5, feat in zip(md5_group[gstart:gend], features):
out[md5] = feat
done = len(out)
if done >= next_progress or done == total:
print(f"Feature progress: {done}/{total}")
while next_progress <= done:
next_progress += progress_step
assert len(out) == total
return out
def load_or_build_embed_cache(md5_list, image_dir, backbone, cache_path, batch_size=64):
assert len(md5_list) > 0
cache = {}
cache_path = Path(cache_path)
if cache_path.exists():
payload = torch.load(cache_path, map_location="cpu")
cached_md5 = payload["md5"]
cached_features = payload["features"]
assert len(cached_md5) == len(cached_features)
assert cached_features.ndim == 2 and torch.isfinite(cached_features).all()
cache = {m: cached_features[i] for i, m in enumerate(cached_md5)}
print(f"Loaded embed cache: {len(cache)} items from {cache_path}")
missing = [m for m in md5_list if m not in cache]
if missing:
print(f"Building missing embeddings: {len(missing)}")
built = build_feature_table(missing, image_dir=image_dir, backbone=backbone, batch_size=batch_size)
cache.update(built)
ordered = sorted(cache.keys())
features = torch.stack([cache[m] for m in ordered], dim=0)
assert features.ndim == 2 and torch.isfinite(features).all()
cache_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({"md5": ordered, "features": features}, cache_path)
print(f"Saved embed cache: {cache_path}")
return {m: cache[m] for m in md5_list}
def load_frozen_backbone(backbone_path):
backbone_path = Path(backbone_path).resolve()
assert backbone_path.exists(), f"Missing backbone checkpoint: {backbone_path}"
sd = load_file(str(backbone_path), device="cpu")
backbone_sd, _ = _split_and_clean_state_dict(sd)
backbone = DINOv3ViTH()
backbone.load_state_dict(backbone_sd, strict=True)
backbone = backbone.to("cuda").eval().to(torch.float32)
for p in backbone.parameters():
p.requires_grad_(False)
return backbone
def test_code(df_second, pool_md5_list):
batch_size = 4
backbone_path = Path("./tagger_proto.safetensors")
backbone = load_frozen_backbone(backbone_path)
embed_cache = load_or_build_embed_cache(
md5_list=pool_md5_list,
image_dir=Path("./samples").resolve(),
backbone=backbone,
cache_path="tagger_scorer/frozen_embed_cache_512.pt",
batch_size=64,
)
train_set = PreferenceDataset(
pair_df=df_second,
embed_cache=embed_cache,
)
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=batch_size,
shuffle=True,
num_workers=0,
)
print(len(train_set))
for sample in train_loader:
print(sample["embed_a"].shape, sample["embed_b"].shape, sample["outcome"].shape)
break
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class AestheticPredictor(nn.Module):
def __init__(self, feature_dim: int):
super().__init__()
self.scoring_head = TaggerAestheticModel(feature_dim).scoring_head
def forward(self, features):
return self.scoring_head(features)
class dotdict(dict):
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def eval_model(model, optimizer, val_loader, wandb_run, step):
model.eval()
optimizer.eval()
acc_loss = 0.0
num_steps = 0
with torch.no_grad():
for data in tqdm(val_loader, total=len(val_loader), leave=True):
for k in data:
data[k] = data[k].to("cuda")
output_one = model(data["embed_a"])
output_two = model(data["embed_b"])
pred_outcome = output_one - output_two
criterion = nn.BCEWithLogitsLoss()
loss = criterion(pred_outcome, data["outcome"])
acc_loss += loss.item()
num_steps += 1
acc_loss /= num_steps
wandb_run.log({"val_loss": acc_loss}, step=step)
model.train()
optimizer.train()
def save_model(model, step, model_name):
out_dir = Path("./checkpoints")
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{model_name}_{step}.safetensors"
save_file({k: v.detach().to(torch.float32).contiguous() for k, v in model.state_dict().items()}, str(out_path))
def train_main():
print(
"WARNING: Using pinned dataset revision daaa857ffab11075c2fc6912e7f23879d324dcc9. "
"This is an old snapshot — embed_cache.pt covers this pool so you do not need to "
"if you want to use the latest version have sample images locally in /samples."
)
df_first, df_second, pool_md5_list = load_votes_and_split()
settings = {
"model_name": "classifer_head",
"remarks": "Training head only with cached DINOv3 embeddings at 512.",
"optimizer": "AdamW schedulefree",
"batch_size": 4,
"val_batch_size": 16,
"image_dir": "./samples",
"do_eval": True,
"backbone_path": "./tagger_proto.safetensors",
"embed_cache_path": "./embed_cache.pt",
"lr": 1e-3,
"warmup_steps": 100,
"weight_decay": 1e-2,
"betas": (0.9, 0.999),
"save_every": 5000,
"eval_every": 2000,
"num_accumulation_steps": 8,
"train_target_steps": 6000 * 8,
"embed_cache_batch_size": 64,
}
s = dotdict(settings)
backbone = load_frozen_backbone(s.backbone_path)
image_dir = Path(s.image_dir).resolve()
embed_cache = load_or_build_embed_cache(
md5_list=pool_md5_list,
image_dir=image_dir,
backbone=backbone,
cache_path=s.embed_cache_path,
batch_size=s.embed_cache_batch_size,
)
print("Cached embeds:", len(embed_cache))
feature_dim = int(next(iter(embed_cache.values())).shape[0])
model = AestheticPredictor(feature_dim).to(torch.float32).to("cuda")
print(count_parameters(model))
train_set = PreferenceDataset(
pair_df=df_second,
embed_cache=embed_cache,
)
print("Train set size:", len(train_set))
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=s.batch_size,
shuffle=True,
num_workers=12,
drop_last=True,
)
if s.do_eval:
val_set = PreferenceDataset(
pair_df=df_first,
embed_cache=embed_cache,
)
val_loader = torch.utils.data.DataLoader(
val_set,
batch_size=s.val_batch_size,
shuffle=True,
num_workers=12,
drop_last=True,
)
wandb_api_key = os.getenv("WANDB_API_KEY")
assert wandb_api_key is not None and wandb_api_key != "", "WANDB_API_KEY env var is required"
wandb.login(key=wandb_api_key)
wandb_run = wandb.init(
project=os.getenv("WANDB_PROJECT", "aesthetic_bradley_terry"),
name=settings["model_name"],
config=settings,
)
assert wandb_run is not None
decay_head = []
no_decay_head = []
for name, param in model.scoring_head.named_parameters():
param.requires_grad = True
if "bias" in name or "embed" in name or "norm" in name:
no_decay_head.append(param)
else:
decay_head.append(param)
optimizer = schedulefree.AdamWScheduleFree(
[
{"params": decay_head, "lr": s.lr},
{"params": no_decay_head, "weight_decay": 0.0, "lr": s.lr},
],
lr=s.lr,
warmup_steps=s.warmup_steps,
weight_decay=s.weight_decay,
betas=s.betas,
)
model.train()
optimizer.train()
step_counter = 0
optimizer.zero_grad()
try:
for _ in range(200000):
acc_loss = 0.0
for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
for k in data:
data[k] = data[k].to("cuda")
output_one = model(data["embed_a"])
output_two = model(data["embed_b"])
pred_outcome = output_one - output_two
criterion = nn.BCEWithLogitsLoss()
loss = criterion(pred_outcome, data["outcome"])
pre_acc_loss = loss.item()
wandb_run.log({"pre_acc_train_loss": pre_acc_loss}, step=step_counter)
acc_loss += pre_acc_loss / s.num_accumulation_steps
loss = loss / s.num_accumulation_steps
loss.backward()
if ((i + 1) % s.num_accumulation_steps == 0) or (i + 1 == len(train_loader)):
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
wandb_run.log(
{
"train_loss": acc_loss,
"lr": s.lr,
"grad_norm": float(grad_norm),
},
step=step_counter,
)
acc_loss = 0.0
if step_counter + 1 >= s.train_target_steps:
save_model(model, step_counter, s.model_name)
if s.do_eval:
eval_model(model, optimizer, val_loader, wandb_run, step_counter)
return
if step_counter % s.save_every == 0:
save_model(model, step_counter, s.model_name)
if s.do_eval and step_counter % s.eval_every == 0:
eval_model(model, optimizer, val_loader, wandb_run, step_counter)
step_counter += 1
except KeyboardInterrupt:
pass
finally:
save_model(model, step_counter, s.model_name)
if s.do_eval:
eval_model(model, optimizer, val_loader, wandb_run, step_counter)
wandb.finish()
print("Finished Training")
if __name__ == "__main__":
train_main()