| import argparse |
| import json |
| import random |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
|
|
| from predictor.training.dataloader import prep_dataloaders, denormalize, AVAILABLE_TARGETS |
| from predictor.models import get_model, NOISE_ENCODERS, TEXT_ENCODERS |
| from predictor.configs.model_dims import MODEL_DIMS, get_dims |
|
|
| from predictor.training.losses import ( |
| ndcg_at_k, |
| ndcg_at_k_per_prompt, |
| spearman_corrcoef, |
| pearson_corrcoef, |
| MAESRCCLoss, |
| MAELambdaRankLoss, |
| LambdaRankLoss, |
| ) |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def train_one_epoch( |
| model: nn.Module, |
| loader: torch.utils.data.DataLoader, |
| criterion: nn.Module, |
| optimizer: optim.Optimizer, |
| device: torch.device, |
| epoch: int = 0, |
| loss_type: str = 'mae+srcc', |
| use_grouped: bool = False, |
| ) -> Dict[str, float]: |
| model.train() |
| running_display_loss = 0.0 |
| running_total_loss = 0.0 |
|
|
| targetlist = [] |
| predictionlist = [] |
|
|
| uses_lambdarank = isinstance(criterion, MAELambdaRankLoss) |
|
|
| for batch_idx, batch in enumerate(loader): |
| noise = batch['noise'].to(device) |
| prompt_embeds = batch['prompt_embeds'].to(device) |
| prompt_mask = batch['prompt_mask'].to(device) |
|
|
| optimizer.zero_grad() |
| preds = model(noise, prompt_embeds, prompt_mask) |
|
|
| targets = batch['y'].to(device).unsqueeze(1) |
|
|
| group_ids = batch['prompt_id'].to(device) if use_grouped else None |
|
|
| if uses_lambdarank: |
| loss = criterion(preds, targets, group_ids=group_ids) |
| criterion.backward(preds, targets, loss, group_ids=group_ids) |
| batch_display_loss = loss.item() |
| batch_total_loss = loss.item() |
| else: |
| if group_ids is not None: |
| loss = criterion(preds, targets, group_ids=group_ids) |
| else: |
| loss = criterion(preds, targets) |
| loss.backward() |
| batch_display_loss = loss.item() |
| batch_total_loss = loss.item() |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| running_display_loss += batch_display_loss * noise.size(0) |
| running_total_loss += batch_total_loss * noise.size(0) |
| targetlist.extend(targets.squeeze(1).cpu().numpy()) |
| predictionlist.extend(preds.squeeze(1).detach().cpu().numpy()) |
|
|
| n_samples = len(loader.dataset) |
| result = { |
| 'display_loss': running_display_loss / n_samples, |
| 'total_loss': running_total_loss / n_samples, |
| 'loss': running_display_loss / n_samples, |
| 'target_mean': float(np.mean(targetlist)), |
| 'target_std': float(np.std(targetlist)), |
| 'pred_mean': float(np.mean(predictionlist)), |
| 'pred_std': float(np.std(predictionlist)), |
| } |
|
|
| return result |
|
|
|
|
| @torch.no_grad() |
| def evaluate( |
| model: nn.Module, |
| loader: torch.utils.data.DataLoader, |
| device: torch.device, |
| ndcg_k: int = 5, |
| y_mean: float = 0.0, |
| y_std: float = 1.0, |
| gain_type: str = 'exp2', |
| ) -> Dict[str, float]: |
| model.eval() |
|
|
| all_preds_raw = [] |
| all_targets_raw = [] |
| all_prompt_ids = [] |
|
|
| for batch in loader: |
| noise = batch['noise'].to(device) |
| prompt_embeds = batch['prompt_embeds'].to(device) |
| prompt_mask = batch['prompt_mask'].to(device) |
| targets_raw = batch['raw_y'].to(device) |
| prompt_ids = batch['prompt_id'].to(device) |
|
|
| preds_norm = model(noise, prompt_embeds, prompt_mask).squeeze(1) |
| preds_raw = denormalize(preds_norm, y_mean, y_std) |
|
|
| all_preds_raw.append(preds_raw) |
| all_targets_raw.append(targets_raw) |
| all_prompt_ids.append(prompt_ids) |
|
|
| all_preds_raw = torch.cat(all_preds_raw, dim=0) |
| all_targets_raw = torch.cat(all_targets_raw, dim=0) |
| all_prompt_ids = torch.cat(all_prompt_ids, dim=0) |
|
|
| n_samples = len(all_preds_raw) |
| mae_raw = (all_preds_raw - all_targets_raw).abs().mean().item() |
|
|
| if n_samples > 1 and all_preds_raw.std() > 1e-9: |
| srcc = spearman_corrcoef(all_preds_raw, all_targets_raw).item() |
| pearson = pearson_corrcoef(all_preds_raw, all_targets_raw).item() |
| ndcg = ndcg_at_k_per_prompt( |
| all_preds_raw, all_targets_raw, all_prompt_ids, |
| k=ndcg_k, gain_type=gain_type, |
| ) |
| else: |
| srcc = 0.0 |
| pearson = 0.0 |
| ndcg = 0.0 |
|
|
| return { |
| 'n_samples': n_samples, |
| 'mae_raw': mae_raw, |
| 'srcc': srcc, |
| 'pearson': pearson, |
| f'ndcg_{ndcg_k}': ndcg, |
| 'target_mean': all_targets_raw.mean().item(), |
| 'target_std': all_targets_raw.std().item(), |
| 'pred_mean': all_preds_raw.mean().item(), |
| 'pred_std': all_preds_raw.std().item(), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument('--model_type', type=str, required=True, |
| choices=list(MODEL_DIMS.keys())) |
| parser.add_argument('--data_dir', type=str, required=True) |
|
|
| parser.add_argument('--noise_enc', type=str, default='residualconv', choices=NOISE_ENCODERS) |
| parser.add_argument('--text_enc', type=str, default='attnpool', choices=TEXT_ENCODERS) |
|
|
| parser.add_argument('--target', type=str, default='pick_score', choices=AVAILABLE_TARGETS) |
| parser.add_argument('--lr', type=float, default=1e-4) |
| parser.add_argument('--weight_decay', type=float, default=1e-8) |
| parser.add_argument('--batch_size', type=int, default=256) |
| parser.add_argument('--epochs', type=int, default=30) |
| parser.add_argument('--loss', type=str, default='mae+srcc', |
| choices=['mae+srcc', 'mae+lambdarank']) |
| parser.add_argument('--dropout', type=float, default=0.3) |
| parser.add_argument('--exp_name', type=str, default='baseline') |
| parser.add_argument('--output_dir', type=str, default='./experiments') |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--num_workers', type=int, default=2) |
| parser.add_argument('--max_prompts', type=int, default=-1) |
|
|
| parser.add_argument('--k_prompts', type=int, default=2) |
|
|
| parser.add_argument('--ndcg_k', type=int, default=3) |
|
|
| parser.add_argument('--primary_metric', type=str, default='srcc', |
| choices=['ndcg', 'srcc']) |
|
|
| args = parser.parse_args() |
|
|
| dims = get_dims(args.model_type) |
| spatial_size = dims['spatial_size'] |
| in_channels = dims['latent_shape'][0] |
| embed_dim = dims['embed_dim'] |
| seq_len = dims['seq_len'] |
|
|
| set_seed(args.seed) |
| exp_dir = Path(args.output_dir) / f"{args.exp_name}" |
| exp_dir.mkdir(parents=True, exist_ok=True) |
|
|
| with open(exp_dir / "config.json", "w") as f: |
| json.dump(vars(args), f, indent=4) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| if torch.cuda.is_available(): |
| torch.set_float32_matmul_precision('high') |
|
|
| use_grouped = args.k_prompts > 0 |
|
|
| train_loader, val_loader, test_loader, stats = prep_dataloaders( |
| data_dir=args.data_dir, |
| model_type=args.model_type, |
| target=args.target, |
| split_by='prompt', |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| seed=args.seed, |
| k_prompts_per_batch=args.k_prompts, |
| max_prompts=args.max_prompts, |
| ) |
|
|
| y_mean, y_std = stats['y_mean'], stats['y_std'] |
|
|
| model = get_model( |
| noise_enc=args.noise_enc, |
| text_enc=args.text_enc, |
| dropout=args.dropout, |
| num_heads=1, |
| spatial_size=spatial_size, |
| in_channels=in_channels, |
| embed_dim=embed_dim, |
| seq_len=seq_len, |
| pos_encoding='sinusoidal', |
| ).to(device) |
|
|
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
|
|
| if args.loss == 'mae+srcc': |
| criterion = MAESRCCLoss(srcc_weight=1.0, regularization_strength=1e-2) |
| elif args.loss == 'mae+lambdarank': |
| criterion = MAELambdaRankLoss(lambdarank_weight=1.0, sigma=1.0, gain_type='exp2') |
| else: |
| raise ValueError(f"Unknown loss: {args.loss}") |
|
|
| primary_higher_better = (args.primary_metric != 'mae') |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='max' if primary_higher_better else 'min', factor=0.5, patience=5 |
| ) |
|
|
| best_primary_value = float('-inf') if primary_higher_better else float('inf') |
|
|
| ndcg_key = f'ndcg_{args.ndcg_k}' |
|
|
| for epoch in range(args.epochs): |
| train_one_epoch( |
| model, train_loader, criterion, optimizer, device, |
| epoch=epoch, |
| loss_type=args.loss, |
| use_grouped=use_grouped, |
| ) |
| val_metrics = evaluate( |
| model, val_loader, device, args.ndcg_k, |
| y_mean=y_mean, y_std=y_std, |
| gain_type='exp2', |
| ) |
|
|
| if args.primary_metric == 'ndcg': |
| current_primary = val_metrics[ndcg_key] |
| elif args.primary_metric == 'srcc': |
| current_primary = val_metrics['srcc'] |
|
|
| print(f"Epoch {epoch+1}/{args.epochs} SRCC={val_metrics['srcc']:.4f} NDCG@{args.ndcg_k}={val_metrics[ndcg_key]:.4f} MAE={val_metrics['mae_raw']:.4f}") |
|
|
| scheduler.step(current_primary) |
|
|
| checkpoint = { |
| 'model_state_dict': {k: v.half() for k, v in model.state_dict().items()}, |
| 'model_config': { |
| 'noise_enc': args.noise_enc, |
| 'text_enc': args.text_enc, |
| 'dropout': args.dropout, |
| 'num_heads': 1, |
| 'model_type': args.model_type, |
| 'spatial_size': spatial_size, |
| 'in_channels': in_channels, |
| 'embed_dim': embed_dim, |
| 'seq_len': seq_len, |
| 'pos_encoding': 'sinusoidal', |
| }, |
| 'normalization': { |
| 'target': args.target, |
| 'y_mean': y_mean, |
| 'y_std': y_std, |
| }, |
| } |
|
|
| improved = (primary_higher_better and current_primary > best_primary_value) or \ |
| (not primary_higher_better and current_primary < best_primary_value) |
|
|
| if improved: |
| best_primary_value = current_primary |
| torch.save(checkpoint, exp_dir / "best_model.pth") |
|
|
| checkpoint = torch.load(exp_dir / "best_model.pth", weights_only=False) |
| state_dict = {k: v.float() for k, v in checkpoint['model_state_dict'].items()} |
| model.load_state_dict(state_dict) |
|
|
| evaluate( |
| model, test_loader, device, args.ndcg_k, |
| y_mean=y_mean, y_std=y_std, |
| gain_type='exp2', |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|