#!/usr/bin/env python3 from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from pathlib import Path import glob import os import random import sys from huggingface_hub import snapshot_download import pandas as pd from PIL import Image from safetensors.torch import load_file, save_file import schedulefree import torch import wandb from torch import nn from torch.utils.data import Dataset from tqdm.auto import tqdm import torchvision.transforms.v2 as v2 from model import DINOv3ViTH, TaggerAestheticModel, _split_and_clean_state_dict PATCH_SIZE = 16 N_REGISTERS = 4 MAX_SIZE = 512 IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] IMAGE_PREPROCESS_WORKERS = 16 random.seed(3407) torch.set_grad_enabled(True) torch.set_float32_matmul_precision("high") def load_votes_and_split(): repo_id = "taigasan/e6-visual-ratings" local_repo_path = snapshot_download(repo_id=repo_id, repo_type="dataset", revision="daaa857ffab11075c2fc6912e7f23879d324dcc9") print("Downloaded repo to:", local_repo_path) rating_log_dir = os.path.join(local_repo_path, "ratings_log") parquet_files = sorted(glob.glob(os.path.join(rating_log_dir, "*.parquet"))) assert len(parquet_files) > 0 df_list = [pd.read_parquet(path) for path in parquet_files] combined_df = pd.concat(df_list, ignore_index=True) print("Total votes before pool filter:", len(combined_df)) pool_path = os.path.join(local_repo_path, "pool.parquet") pool_df = pd.read_parquet(pool_path) assert "md5" in pool_df.columns pool_md5_list = sorted(pool_df["md5"].astype(str).tolist()) valid_md5 = set(pool_md5_list) combined_df = combined_df[ combined_df["md5a"].isin(valid_md5) & combined_df["md5b"].isin(valid_md5) ].reset_index(drop=True) print("Total votes after pool filter:", len(combined_df)) print("Pool rows:", len(pool_df)) df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True) df_first = df.iloc[:2000].reset_index(drop=True) df_second = df.iloc[2000:].reset_index(drop=True) print("Val rows:", len(df_first)) print("Train rows:", len(df_second)) return df_first, df_second, pool_md5_list def _snap(x: int, m: int) -> int: return max(m, (x // m) * m) def preprocess_pil(img: Image.Image, max_size: int = MAX_SIZE) -> torch.Tensor: img = img.convert("RGB") w, h = img.size long_edge = max(w, h) target_long = _snap(min(long_edge, max_size), PATCH_SIZE) scale = target_long / long_edge new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE) new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE) return v2.Compose( [ v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] )(img).unsqueeze(0) def find_local_image_path(md5: str, image_dir) -> Path: image_dir = Path(image_dir) for ext in ("webp", "jpg", "jpeg", "png"): p = image_dir / f"{md5}.{ext}" if p.exists(): return p raise AssertionError(f"Missing local image for md5={md5}; checked dir={image_dir}") def load_and_preprocess_image(item): md5, image_dir = item p = find_local_image_path(md5, image_dir) with Image.open(p) as img: pixel_values = preprocess_pil(img)[0] h = int(pixel_values.shape[1]) w = int(pixel_values.shape[2]) return md5, (h, w), pixel_values def pool_features(outputs) -> torch.Tensor: hidden = outputs if isinstance(outputs, torch.Tensor) else outputs.last_hidden_state assert hidden.ndim == 3 and hidden.shape[1] >= 1 + N_REGISTERS cls = hidden[:, 0, :] regs = hidden[:, 1 : 1 + N_REGISTERS, :].flatten(1) # [b, reg, d] -> [b, reg*d] return torch.cat([cls, regs], dim=-1).to(torch.float32) class PreferenceDataset(Dataset): def __init__(self, pair_df, embed_cache): self.embed_cache = embed_cache pairs = [] has_a_won = False has_b_won = False for row in pair_df.itertuples(index=False): md5a = row.md5a md5b = row.md5b winner_md5 = row.winner_md5 if not winner_md5: continue if winner_md5 == md5a: outcome = 1 has_a_won = True else: outcome = 0 has_b_won = True assert md5a in self.embed_cache assert md5b in self.embed_cache pairs.append( { "md5a": md5a, "md5b": md5b, "outcome": outcome, } ) assert has_a_won assert has_b_won self.pairs = pairs def __len__(self): return len(self.pairs) def __getitem__(self, idx): sample = self.pairs[idx] md5a = sample["md5a"] md5b = sample["md5b"] outcome = sample["outcome"] embed_a = self.embed_cache[md5a] embed_b = self.embed_cache[md5b] outcome = torch.tensor([outcome], dtype=torch.float32) return { "embed_a": embed_a, "embed_b": embed_b, "outcome": outcome, } def build_feature_table(md5_list, image_dir, backbone, batch_size=64): out = {} total = len(md5_list) assert total > 0 progress_step = max(1, total // 100) next_progress = progress_step feature_dim = None backbone.eval() with torch.no_grad(): for start in tqdm(range(0, total, batch_size), total=(total + batch_size - 1) // batch_size, leave=True): batch_md5 = md5_list[start : start + batch_size] size_to_md5 = defaultdict(list) size_to_tensors = defaultdict(list) worker_count = min(IMAGE_PREPROCESS_WORKERS, len(batch_md5)) assert worker_count > 0 with ThreadPoolExecutor(max_workers=worker_count) as executor: preprocessed = list(executor.map(load_and_preprocess_image, [(m, image_dir) for m in batch_md5])) for md5, size, pixel_values in preprocessed: size_to_md5[size].append(md5) size_to_tensors[size].append(pixel_values) for size in size_to_md5: md5_group = size_to_md5[size] tensor_group = size_to_tensors[size] for gstart in range(0, len(md5_group), batch_size): gend = min(gstart + batch_size, len(md5_group)) pixel_values = torch.stack(tensor_group[gstart:gend], dim=0).to("cuda") outputs = backbone(pixel_values=pixel_values) features = pool_features(outputs).cpu() assert features.ndim == 2 and features.shape[0] == (gend - gstart) assert torch.isfinite(features).all(), f"Non-finite features in size group {size}" if feature_dim is None: feature_dim = int(features.shape[1]) assert int(features.shape[1]) == feature_dim for md5, feat in zip(md5_group[gstart:gend], features): out[md5] = feat done = len(out) if done >= next_progress or done == total: print(f"Feature progress: {done}/{total}") while next_progress <= done: next_progress += progress_step assert len(out) == total return out def load_or_build_embed_cache(md5_list, image_dir, backbone, cache_path, batch_size=64): assert len(md5_list) > 0 cache = {} cache_path = Path(cache_path) if cache_path.exists(): payload = torch.load(cache_path, map_location="cpu") cached_md5 = payload["md5"] cached_features = payload["features"] assert len(cached_md5) == len(cached_features) assert cached_features.ndim == 2 and torch.isfinite(cached_features).all() cache = {m: cached_features[i] for i, m in enumerate(cached_md5)} print(f"Loaded embed cache: {len(cache)} items from {cache_path}") missing = [m for m in md5_list if m not in cache] if missing: print(f"Building missing embeddings: {len(missing)}") built = build_feature_table(missing, image_dir=image_dir, backbone=backbone, batch_size=batch_size) cache.update(built) ordered = sorted(cache.keys()) features = torch.stack([cache[m] for m in ordered], dim=0) assert features.ndim == 2 and torch.isfinite(features).all() cache_path.parent.mkdir(parents=True, exist_ok=True) torch.save({"md5": ordered, "features": features}, cache_path) print(f"Saved embed cache: {cache_path}") return {m: cache[m] for m in md5_list} def load_frozen_backbone(backbone_path): backbone_path = Path(backbone_path).resolve() assert backbone_path.exists(), f"Missing backbone checkpoint: {backbone_path}" sd = load_file(str(backbone_path), device="cpu") backbone_sd, _ = _split_and_clean_state_dict(sd) backbone = DINOv3ViTH() backbone.load_state_dict(backbone_sd, strict=True) backbone = backbone.to("cuda").eval().to(torch.float32) for p in backbone.parameters(): p.requires_grad_(False) return backbone def test_code(df_second, pool_md5_list): batch_size = 4 backbone_path = Path("./tagger_proto.safetensors") backbone = load_frozen_backbone(backbone_path) embed_cache = load_or_build_embed_cache( md5_list=pool_md5_list, image_dir=Path("./samples").resolve(), backbone=backbone, cache_path="tagger_scorer/frozen_embed_cache_512.pt", batch_size=64, ) train_set = PreferenceDataset( pair_df=df_second, embed_cache=embed_cache, ) train_loader = torch.utils.data.DataLoader( train_set, batch_size=batch_size, shuffle=True, num_workers=0, ) print(len(train_set)) for sample in train_loader: print(sample["embed_a"].shape, sample["embed_b"].shape, sample["outcome"].shape) break def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) class AestheticPredictor(nn.Module): def __init__(self, feature_dim: int): super().__init__() self.scoring_head = TaggerAestheticModel(feature_dim).scoring_head def forward(self, features): return self.scoring_head(features) class dotdict(dict): __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def eval_model(model, optimizer, val_loader, wandb_run, step): model.eval() optimizer.eval() acc_loss = 0.0 num_steps = 0 with torch.no_grad(): for data in tqdm(val_loader, total=len(val_loader), leave=True): for k in data: data[k] = data[k].to("cuda") output_one = model(data["embed_a"]) output_two = model(data["embed_b"]) pred_outcome = output_one - output_two criterion = nn.BCEWithLogitsLoss() loss = criterion(pred_outcome, data["outcome"]) acc_loss += loss.item() num_steps += 1 acc_loss /= num_steps wandb_run.log({"val_loss": acc_loss}, step=step) model.train() optimizer.train() def save_model(model, step, model_name): out_dir = Path("./checkpoints") out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / f"{model_name}_{step}.safetensors" save_file({k: v.detach().to(torch.float32).contiguous() for k, v in model.state_dict().items()}, str(out_path)) def train_main(): print( "WARNING: Using pinned dataset revision daaa857ffab11075c2fc6912e7f23879d324dcc9. " "This is an old snapshot — embed_cache.pt covers this pool so you do not need to " "if you want to use the latest version have sample images locally in /samples." ) df_first, df_second, pool_md5_list = load_votes_and_split() settings = { "model_name": "classifer_head", "remarks": "Training head only with cached DINOv3 embeddings at 512.", "optimizer": "AdamW schedulefree", "batch_size": 4, "val_batch_size": 16, "image_dir": "./samples", "do_eval": True, "backbone_path": "./tagger_proto.safetensors", "embed_cache_path": "./embed_cache.pt", "lr": 1e-3, "warmup_steps": 100, "weight_decay": 1e-2, "betas": (0.9, 0.999), "save_every": 5000, "eval_every": 2000, "num_accumulation_steps": 8, "train_target_steps": 6000 * 8, "embed_cache_batch_size": 64, } s = dotdict(settings) backbone = load_frozen_backbone(s.backbone_path) image_dir = Path(s.image_dir).resolve() embed_cache = load_or_build_embed_cache( md5_list=pool_md5_list, image_dir=image_dir, backbone=backbone, cache_path=s.embed_cache_path, batch_size=s.embed_cache_batch_size, ) print("Cached embeds:", len(embed_cache)) feature_dim = int(next(iter(embed_cache.values())).shape[0]) model = AestheticPredictor(feature_dim).to(torch.float32).to("cuda") print(count_parameters(model)) train_set = PreferenceDataset( pair_df=df_second, embed_cache=embed_cache, ) print("Train set size:", len(train_set)) train_loader = torch.utils.data.DataLoader( train_set, batch_size=s.batch_size, shuffle=True, num_workers=12, drop_last=True, ) if s.do_eval: val_set = PreferenceDataset( pair_df=df_first, embed_cache=embed_cache, ) val_loader = torch.utils.data.DataLoader( val_set, batch_size=s.val_batch_size, shuffle=True, num_workers=12, drop_last=True, ) wandb_api_key = os.getenv("WANDB_API_KEY") assert wandb_api_key is not None and wandb_api_key != "", "WANDB_API_KEY env var is required" wandb.login(key=wandb_api_key) wandb_run = wandb.init( project=os.getenv("WANDB_PROJECT", "aesthetic_bradley_terry"), name=settings["model_name"], config=settings, ) assert wandb_run is not None decay_head = [] no_decay_head = [] for name, param in model.scoring_head.named_parameters(): param.requires_grad = True if "bias" in name or "embed" in name or "norm" in name: no_decay_head.append(param) else: decay_head.append(param) optimizer = schedulefree.AdamWScheduleFree( [ {"params": decay_head, "lr": s.lr}, {"params": no_decay_head, "weight_decay": 0.0, "lr": s.lr}, ], lr=s.lr, warmup_steps=s.warmup_steps, weight_decay=s.weight_decay, betas=s.betas, ) model.train() optimizer.train() step_counter = 0 optimizer.zero_grad() try: for _ in range(200000): acc_loss = 0.0 for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)): for k in data: data[k] = data[k].to("cuda") output_one = model(data["embed_a"]) output_two = model(data["embed_b"]) pred_outcome = output_one - output_two criterion = nn.BCEWithLogitsLoss() loss = criterion(pred_outcome, data["outcome"]) pre_acc_loss = loss.item() wandb_run.log({"pre_acc_train_loss": pre_acc_loss}, step=step_counter) acc_loss += pre_acc_loss / s.num_accumulation_steps loss = loss / s.num_accumulation_steps loss.backward() if ((i + 1) % s.num_accumulation_steps == 0) or (i + 1 == len(train_loader)): grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() wandb_run.log( { "train_loss": acc_loss, "lr": s.lr, "grad_norm": float(grad_norm), }, step=step_counter, ) acc_loss = 0.0 if step_counter + 1 >= s.train_target_steps: save_model(model, step_counter, s.model_name) if s.do_eval: eval_model(model, optimizer, val_loader, wandb_run, step_counter) return if step_counter % s.save_every == 0: save_model(model, step_counter, s.model_name) if s.do_eval and step_counter % s.eval_every == 0: eval_model(model, optimizer, val_loader, wandb_run, step_counter) step_counter += 1 except KeyboardInterrupt: pass finally: save_model(model, step_counter, s.model_name) if s.do_eval: eval_model(model, optimizer, val_loader, wandb_run, step_counter) wandb.finish() print("Finished Training") if __name__ == "__main__": train_main()