| |
|
|
| from collections import defaultdict |
| from concurrent.futures import ThreadPoolExecutor |
| from pathlib import Path |
| import glob |
| import os |
| import random |
| import sys |
|
|
| from huggingface_hub import snapshot_download |
| import pandas as pd |
| from PIL import Image |
| from safetensors.torch import load_file, save_file |
| import schedulefree |
| import torch |
| import wandb |
| from torch import nn |
| from torch.utils.data import Dataset |
| from tqdm.auto import tqdm |
| import torchvision.transforms.v2 as v2 |
| from model import DINOv3ViTH, TaggerAestheticModel, _split_and_clean_state_dict |
|
|
| PATCH_SIZE = 16 |
| N_REGISTERS = 4 |
| MAX_SIZE = 512 |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
| IMAGE_PREPROCESS_WORKERS = 16 |
|
|
| random.seed(3407) |
| torch.set_grad_enabled(True) |
| torch.set_float32_matmul_precision("high") |
|
|
|
|
| def load_votes_and_split(): |
| repo_id = "taigasan/e6-visual-ratings" |
| local_repo_path = snapshot_download(repo_id=repo_id, repo_type="dataset", revision="daaa857ffab11075c2fc6912e7f23879d324dcc9") |
| print("Downloaded repo to:", local_repo_path) |
|
|
| rating_log_dir = os.path.join(local_repo_path, "ratings_log") |
| parquet_files = sorted(glob.glob(os.path.join(rating_log_dir, "*.parquet"))) |
| assert len(parquet_files) > 0 |
| df_list = [pd.read_parquet(path) for path in parquet_files] |
| combined_df = pd.concat(df_list, ignore_index=True) |
| print("Total votes before pool filter:", len(combined_df)) |
|
|
| pool_path = os.path.join(local_repo_path, "pool.parquet") |
| pool_df = pd.read_parquet(pool_path) |
| assert "md5" in pool_df.columns |
| pool_md5_list = sorted(pool_df["md5"].astype(str).tolist()) |
| valid_md5 = set(pool_md5_list) |
| combined_df = combined_df[ |
| combined_df["md5a"].isin(valid_md5) & combined_df["md5b"].isin(valid_md5) |
| ].reset_index(drop=True) |
| print("Total votes after pool filter:", len(combined_df)) |
| print("Pool rows:", len(pool_df)) |
|
|
| df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True) |
| df_first = df.iloc[:2000].reset_index(drop=True) |
| df_second = df.iloc[2000:].reset_index(drop=True) |
| print("Val rows:", len(df_first)) |
| print("Train rows:", len(df_second)) |
| return df_first, df_second, pool_md5_list |
|
|
|
|
| def _snap(x: int, m: int) -> int: |
| return max(m, (x // m) * m) |
|
|
|
|
| def preprocess_pil(img: Image.Image, max_size: int = MAX_SIZE) -> torch.Tensor: |
| img = img.convert("RGB") |
| w, h = img.size |
| long_edge = max(w, h) |
| target_long = _snap(min(long_edge, max_size), PATCH_SIZE) |
| scale = target_long / long_edge |
| new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE) |
| new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE) |
| return v2.Compose( |
| [ |
| v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS), |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ] |
| )(img).unsqueeze(0) |
|
|
|
|
| def find_local_image_path(md5: str, image_dir) -> Path: |
| image_dir = Path(image_dir) |
| for ext in ("webp", "jpg", "jpeg", "png"): |
| p = image_dir / f"{md5}.{ext}" |
| if p.exists(): |
| return p |
| raise AssertionError(f"Missing local image for md5={md5}; checked dir={image_dir}") |
|
|
|
|
| def load_and_preprocess_image(item): |
| md5, image_dir = item |
| p = find_local_image_path(md5, image_dir) |
| with Image.open(p) as img: |
| pixel_values = preprocess_pil(img)[0] |
| h = int(pixel_values.shape[1]) |
| w = int(pixel_values.shape[2]) |
| return md5, (h, w), pixel_values |
|
|
|
|
| def pool_features(outputs) -> torch.Tensor: |
| hidden = outputs if isinstance(outputs, torch.Tensor) else outputs.last_hidden_state |
| assert hidden.ndim == 3 and hidden.shape[1] >= 1 + N_REGISTERS |
| cls = hidden[:, 0, :] |
| regs = hidden[:, 1 : 1 + N_REGISTERS, :].flatten(1) |
| return torch.cat([cls, regs], dim=-1).to(torch.float32) |
|
|
|
|
| class PreferenceDataset(Dataset): |
| def __init__(self, pair_df, embed_cache): |
| self.embed_cache = embed_cache |
|
|
| pairs = [] |
| has_a_won = False |
| has_b_won = False |
| for row in pair_df.itertuples(index=False): |
| md5a = row.md5a |
| md5b = row.md5b |
| winner_md5 = row.winner_md5 |
|
|
| if not winner_md5: |
| continue |
| if winner_md5 == md5a: |
| outcome = 1 |
| has_a_won = True |
| else: |
| outcome = 0 |
| has_b_won = True |
|
|
| assert md5a in self.embed_cache |
| assert md5b in self.embed_cache |
| pairs.append( |
| { |
| "md5a": md5a, |
| "md5b": md5b, |
| "outcome": outcome, |
| } |
| ) |
|
|
| assert has_a_won |
| assert has_b_won |
| self.pairs = pairs |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx): |
| sample = self.pairs[idx] |
| md5a = sample["md5a"] |
| md5b = sample["md5b"] |
| outcome = sample["outcome"] |
| embed_a = self.embed_cache[md5a] |
| embed_b = self.embed_cache[md5b] |
| outcome = torch.tensor([outcome], dtype=torch.float32) |
| return { |
| "embed_a": embed_a, |
| "embed_b": embed_b, |
| "outcome": outcome, |
| } |
|
|
|
|
| def build_feature_table(md5_list, image_dir, backbone, batch_size=64): |
| out = {} |
| total = len(md5_list) |
| assert total > 0 |
| progress_step = max(1, total // 100) |
| next_progress = progress_step |
| feature_dim = None |
|
|
| backbone.eval() |
| with torch.no_grad(): |
| for start in tqdm(range(0, total, batch_size), total=(total + batch_size - 1) // batch_size, leave=True): |
| batch_md5 = md5_list[start : start + batch_size] |
| size_to_md5 = defaultdict(list) |
| size_to_tensors = defaultdict(list) |
|
|
| worker_count = min(IMAGE_PREPROCESS_WORKERS, len(batch_md5)) |
| assert worker_count > 0 |
| with ThreadPoolExecutor(max_workers=worker_count) as executor: |
| preprocessed = list(executor.map(load_and_preprocess_image, [(m, image_dir) for m in batch_md5])) |
|
|
| for md5, size, pixel_values in preprocessed: |
| size_to_md5[size].append(md5) |
| size_to_tensors[size].append(pixel_values) |
|
|
| for size in size_to_md5: |
| md5_group = size_to_md5[size] |
| tensor_group = size_to_tensors[size] |
| for gstart in range(0, len(md5_group), batch_size): |
| gend = min(gstart + batch_size, len(md5_group)) |
| pixel_values = torch.stack(tensor_group[gstart:gend], dim=0).to("cuda") |
| outputs = backbone(pixel_values=pixel_values) |
| features = pool_features(outputs).cpu() |
| assert features.ndim == 2 and features.shape[0] == (gend - gstart) |
| assert torch.isfinite(features).all(), f"Non-finite features in size group {size}" |
| if feature_dim is None: |
| feature_dim = int(features.shape[1]) |
| assert int(features.shape[1]) == feature_dim |
| for md5, feat in zip(md5_group[gstart:gend], features): |
| out[md5] = feat |
|
|
| done = len(out) |
| if done >= next_progress or done == total: |
| print(f"Feature progress: {done}/{total}") |
| while next_progress <= done: |
| next_progress += progress_step |
|
|
| assert len(out) == total |
| return out |
|
|
|
|
| def load_or_build_embed_cache(md5_list, image_dir, backbone, cache_path, batch_size=64): |
| assert len(md5_list) > 0 |
|
|
| cache = {} |
| cache_path = Path(cache_path) |
| if cache_path.exists(): |
| payload = torch.load(cache_path, map_location="cpu") |
| cached_md5 = payload["md5"] |
| cached_features = payload["features"] |
| assert len(cached_md5) == len(cached_features) |
| assert cached_features.ndim == 2 and torch.isfinite(cached_features).all() |
| cache = {m: cached_features[i] for i, m in enumerate(cached_md5)} |
| print(f"Loaded embed cache: {len(cache)} items from {cache_path}") |
|
|
| missing = [m for m in md5_list if m not in cache] |
| if missing: |
| print(f"Building missing embeddings: {len(missing)}") |
| built = build_feature_table(missing, image_dir=image_dir, backbone=backbone, batch_size=batch_size) |
| cache.update(built) |
|
|
| ordered = sorted(cache.keys()) |
| features = torch.stack([cache[m] for m in ordered], dim=0) |
| assert features.ndim == 2 and torch.isfinite(features).all() |
| cache_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save({"md5": ordered, "features": features}, cache_path) |
| print(f"Saved embed cache: {cache_path}") |
|
|
| return {m: cache[m] for m in md5_list} |
|
|
|
|
| def load_frozen_backbone(backbone_path): |
| backbone_path = Path(backbone_path).resolve() |
| assert backbone_path.exists(), f"Missing backbone checkpoint: {backbone_path}" |
| sd = load_file(str(backbone_path), device="cpu") |
| backbone_sd, _ = _split_and_clean_state_dict(sd) |
| backbone = DINOv3ViTH() |
| backbone.load_state_dict(backbone_sd, strict=True) |
| backbone = backbone.to("cuda").eval().to(torch.float32) |
| for p in backbone.parameters(): |
| p.requires_grad_(False) |
| return backbone |
|
|
|
|
| def test_code(df_second, pool_md5_list): |
| batch_size = 4 |
| backbone_path = Path("./tagger_proto.safetensors") |
| backbone = load_frozen_backbone(backbone_path) |
| embed_cache = load_or_build_embed_cache( |
| md5_list=pool_md5_list, |
| image_dir=Path("./samples").resolve(), |
| backbone=backbone, |
| cache_path="tagger_scorer/frozen_embed_cache_512.pt", |
| batch_size=64, |
| ) |
|
|
| train_set = PreferenceDataset( |
| pair_df=df_second, |
| embed_cache=embed_cache, |
| ) |
| train_loader = torch.utils.data.DataLoader( |
| train_set, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=0, |
| ) |
| print(len(train_set)) |
| for sample in train_loader: |
| print(sample["embed_a"].shape, sample["embed_b"].shape, sample["outcome"].shape) |
| break |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| class AestheticPredictor(nn.Module): |
| def __init__(self, feature_dim: int): |
| super().__init__() |
| self.scoring_head = TaggerAestheticModel(feature_dim).scoring_head |
|
|
| def forward(self, features): |
| return self.scoring_head(features) |
|
|
|
|
| class dotdict(dict): |
| __getattr__ = dict.get |
| __setattr__ = dict.__setitem__ |
| __delattr__ = dict.__delitem__ |
|
|
|
|
| def eval_model(model, optimizer, val_loader, wandb_run, step): |
| model.eval() |
| optimizer.eval() |
| acc_loss = 0.0 |
| num_steps = 0 |
| with torch.no_grad(): |
| for data in tqdm(val_loader, total=len(val_loader), leave=True): |
| for k in data: |
| data[k] = data[k].to("cuda") |
| output_one = model(data["embed_a"]) |
| output_two = model(data["embed_b"]) |
| pred_outcome = output_one - output_two |
| criterion = nn.BCEWithLogitsLoss() |
| loss = criterion(pred_outcome, data["outcome"]) |
| acc_loss += loss.item() |
| num_steps += 1 |
|
|
| acc_loss /= num_steps |
| wandb_run.log({"val_loss": acc_loss}, step=step) |
|
|
| model.train() |
| optimizer.train() |
|
|
|
|
| def save_model(model, step, model_name): |
| out_dir = Path("./checkpoints") |
| out_dir.mkdir(parents=True, exist_ok=True) |
| out_path = out_dir / f"{model_name}_{step}.safetensors" |
| save_file({k: v.detach().to(torch.float32).contiguous() for k, v in model.state_dict().items()}, str(out_path)) |
|
|
|
|
| def train_main(): |
| print( |
| "WARNING: Using pinned dataset revision daaa857ffab11075c2fc6912e7f23879d324dcc9. " |
| "This is an old snapshot — embed_cache.pt covers this pool so you do not need to " |
| "if you want to use the latest version have sample images locally in /samples." |
| ) |
| df_first, df_second, pool_md5_list = load_votes_and_split() |
|
|
| settings = { |
| "model_name": "classifer_head", |
| "remarks": "Training head only with cached DINOv3 embeddings at 512.", |
| "optimizer": "AdamW schedulefree", |
| "batch_size": 4, |
| "val_batch_size": 16, |
| "image_dir": "./samples", |
| "do_eval": True, |
| "backbone_path": "./tagger_proto.safetensors", |
| "embed_cache_path": "./embed_cache.pt", |
| "lr": 1e-3, |
| "warmup_steps": 100, |
| "weight_decay": 1e-2, |
| "betas": (0.9, 0.999), |
| "save_every": 5000, |
| "eval_every": 2000, |
| "num_accumulation_steps": 8, |
| "train_target_steps": 6000 * 8, |
| "embed_cache_batch_size": 64, |
| } |
| s = dotdict(settings) |
|
|
| backbone = load_frozen_backbone(s.backbone_path) |
|
|
| image_dir = Path(s.image_dir).resolve() |
|
|
| embed_cache = load_or_build_embed_cache( |
| md5_list=pool_md5_list, |
| image_dir=image_dir, |
| backbone=backbone, |
| cache_path=s.embed_cache_path, |
| batch_size=s.embed_cache_batch_size, |
| ) |
| print("Cached embeds:", len(embed_cache)) |
|
|
| feature_dim = int(next(iter(embed_cache.values())).shape[0]) |
| model = AestheticPredictor(feature_dim).to(torch.float32).to("cuda") |
| print(count_parameters(model)) |
|
|
| train_set = PreferenceDataset( |
| pair_df=df_second, |
| embed_cache=embed_cache, |
| ) |
| print("Train set size:", len(train_set)) |
| train_loader = torch.utils.data.DataLoader( |
| train_set, |
| batch_size=s.batch_size, |
| shuffle=True, |
| num_workers=12, |
| drop_last=True, |
| ) |
|
|
| if s.do_eval: |
| val_set = PreferenceDataset( |
| pair_df=df_first, |
| embed_cache=embed_cache, |
| ) |
| val_loader = torch.utils.data.DataLoader( |
| val_set, |
| batch_size=s.val_batch_size, |
| shuffle=True, |
| num_workers=12, |
| drop_last=True, |
| ) |
|
|
| wandb_api_key = os.getenv("WANDB_API_KEY") |
| assert wandb_api_key is not None and wandb_api_key != "", "WANDB_API_KEY env var is required" |
| wandb.login(key=wandb_api_key) |
| wandb_run = wandb.init( |
| project=os.getenv("WANDB_PROJECT", "aesthetic_bradley_terry"), |
| name=settings["model_name"], |
| config=settings, |
| ) |
| assert wandb_run is not None |
|
|
| decay_head = [] |
| no_decay_head = [] |
| for name, param in model.scoring_head.named_parameters(): |
| param.requires_grad = True |
| if "bias" in name or "embed" in name or "norm" in name: |
| no_decay_head.append(param) |
| else: |
| decay_head.append(param) |
|
|
| optimizer = schedulefree.AdamWScheduleFree( |
| [ |
| {"params": decay_head, "lr": s.lr}, |
| {"params": no_decay_head, "weight_decay": 0.0, "lr": s.lr}, |
| ], |
| lr=s.lr, |
| warmup_steps=s.warmup_steps, |
| weight_decay=s.weight_decay, |
| betas=s.betas, |
| ) |
| model.train() |
| optimizer.train() |
| step_counter = 0 |
|
|
| optimizer.zero_grad() |
| try: |
| for _ in range(200000): |
| acc_loss = 0.0 |
| for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)): |
| for k in data: |
| data[k] = data[k].to("cuda") |
| output_one = model(data["embed_a"]) |
| output_two = model(data["embed_b"]) |
| pred_outcome = output_one - output_two |
| criterion = nn.BCEWithLogitsLoss() |
| loss = criterion(pred_outcome, data["outcome"]) |
| pre_acc_loss = loss.item() |
| wandb_run.log({"pre_acc_train_loss": pre_acc_loss}, step=step_counter) |
|
|
| acc_loss += pre_acc_loss / s.num_accumulation_steps |
| loss = loss / s.num_accumulation_steps |
| loss.backward() |
| if ((i + 1) % s.num_accumulation_steps == 0) or (i + 1 == len(train_loader)): |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| optimizer.zero_grad() |
| wandb_run.log( |
| { |
| "train_loss": acc_loss, |
| "lr": s.lr, |
| "grad_norm": float(grad_norm), |
| }, |
| step=step_counter, |
| ) |
| acc_loss = 0.0 |
| if step_counter + 1 >= s.train_target_steps: |
| save_model(model, step_counter, s.model_name) |
| if s.do_eval: |
| eval_model(model, optimizer, val_loader, wandb_run, step_counter) |
| return |
|
|
| if step_counter % s.save_every == 0: |
| save_model(model, step_counter, s.model_name) |
| if s.do_eval and step_counter % s.eval_every == 0: |
| eval_model(model, optimizer, val_loader, wandb_run, step_counter) |
| step_counter += 1 |
| except KeyboardInterrupt: |
| pass |
| finally: |
| save_model(model, step_counter, s.model_name) |
| if s.do_eval: |
| eval_model(model, optimizer, val_loader, wandb_run, step_counter) |
| wandb.finish() |
|
|
| print("Finished Training") |
|
|
|
|
| if __name__ == "__main__": |
| train_main() |