English
PAINE / predictor /training /train.py
joonghk's picture
first commit
03de09d
import argparse
import json
import random
from pathlib import Path
from typing import Dict, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from predictor.training.dataloader import prep_dataloaders, denormalize, AVAILABLE_TARGETS
from predictor.models import get_model, NOISE_ENCODERS, TEXT_ENCODERS
from predictor.configs.model_dims import MODEL_DIMS, get_dims
from predictor.training.losses import (
ndcg_at_k,
ndcg_at_k_per_prompt,
spearman_corrcoef,
pearson_corrcoef,
MAESRCCLoss,
MAELambdaRankLoss,
LambdaRankLoss,
)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def train_one_epoch(
model: nn.Module,
loader: torch.utils.data.DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: torch.device,
epoch: int = 0,
loss_type: str = 'mae+srcc',
use_grouped: bool = False,
) -> Dict[str, float]:
model.train()
running_display_loss = 0.0
running_total_loss = 0.0
targetlist = []
predictionlist = []
uses_lambdarank = isinstance(criterion, MAELambdaRankLoss)
for batch_idx, batch in enumerate(loader):
noise = batch['noise'].to(device)
prompt_embeds = batch['prompt_embeds'].to(device)
prompt_mask = batch['prompt_mask'].to(device)
optimizer.zero_grad()
preds = model(noise, prompt_embeds, prompt_mask)
targets = batch['y'].to(device).unsqueeze(1)
group_ids = batch['prompt_id'].to(device) if use_grouped else None
if uses_lambdarank:
loss = criterion(preds, targets, group_ids=group_ids)
criterion.backward(preds, targets, loss, group_ids=group_ids)
batch_display_loss = loss.item()
batch_total_loss = loss.item()
else:
if group_ids is not None:
loss = criterion(preds, targets, group_ids=group_ids)
else:
loss = criterion(preds, targets)
loss.backward()
batch_display_loss = loss.item()
batch_total_loss = loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_display_loss += batch_display_loss * noise.size(0)
running_total_loss += batch_total_loss * noise.size(0)
targetlist.extend(targets.squeeze(1).cpu().numpy())
predictionlist.extend(preds.squeeze(1).detach().cpu().numpy())
n_samples = len(loader.dataset)
result = {
'display_loss': running_display_loss / n_samples,
'total_loss': running_total_loss / n_samples,
'loss': running_display_loss / n_samples,
'target_mean': float(np.mean(targetlist)),
'target_std': float(np.std(targetlist)),
'pred_mean': float(np.mean(predictionlist)),
'pred_std': float(np.std(predictionlist)),
}
return result
@torch.no_grad()
def evaluate(
model: nn.Module,
loader: torch.utils.data.DataLoader,
device: torch.device,
ndcg_k: int = 5,
y_mean: float = 0.0,
y_std: float = 1.0,
gain_type: str = 'exp2',
) -> Dict[str, float]:
model.eval()
all_preds_raw = []
all_targets_raw = []
all_prompt_ids = []
for batch in loader:
noise = batch['noise'].to(device)
prompt_embeds = batch['prompt_embeds'].to(device)
prompt_mask = batch['prompt_mask'].to(device)
targets_raw = batch['raw_y'].to(device)
prompt_ids = batch['prompt_id'].to(device)
preds_norm = model(noise, prompt_embeds, prompt_mask).squeeze(1)
preds_raw = denormalize(preds_norm, y_mean, y_std)
all_preds_raw.append(preds_raw)
all_targets_raw.append(targets_raw)
all_prompt_ids.append(prompt_ids)
all_preds_raw = torch.cat(all_preds_raw, dim=0)
all_targets_raw = torch.cat(all_targets_raw, dim=0)
all_prompt_ids = torch.cat(all_prompt_ids, dim=0)
n_samples = len(all_preds_raw)
mae_raw = (all_preds_raw - all_targets_raw).abs().mean().item()
if n_samples > 1 and all_preds_raw.std() > 1e-9:
srcc = spearman_corrcoef(all_preds_raw, all_targets_raw).item()
pearson = pearson_corrcoef(all_preds_raw, all_targets_raw).item()
ndcg = ndcg_at_k_per_prompt(
all_preds_raw, all_targets_raw, all_prompt_ids,
k=ndcg_k, gain_type=gain_type,
)
else:
srcc = 0.0
pearson = 0.0
ndcg = 0.0
return {
'n_samples': n_samples,
'mae_raw': mae_raw,
'srcc': srcc,
'pearson': pearson,
f'ndcg_{ndcg_k}': ndcg,
'target_mean': all_targets_raw.mean().item(),
'target_std': all_targets_raw.std().item(),
'pred_mean': all_preds_raw.mean().item(),
'pred_std': all_preds_raw.std().item(),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, required=True,
choices=list(MODEL_DIMS.keys()))
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--noise_enc', type=str, default='residualconv', choices=NOISE_ENCODERS)
parser.add_argument('--text_enc', type=str, default='attnpool', choices=TEXT_ENCODERS)
parser.add_argument('--target', type=str, default='pick_score', choices=AVAILABLE_TARGETS)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=float, default=1e-8)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--loss', type=str, default='mae+srcc',
choices=['mae+srcc', 'mae+lambdarank'])
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--exp_name', type=str, default='baseline')
parser.add_argument('--output_dir', type=str, default='./experiments')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--max_prompts', type=int, default=-1)
parser.add_argument('--k_prompts', type=int, default=2)
parser.add_argument('--ndcg_k', type=int, default=3)
parser.add_argument('--primary_metric', type=str, default='srcc',
choices=['ndcg', 'srcc'])
args = parser.parse_args()
dims = get_dims(args.model_type)
spatial_size = dims['spatial_size']
in_channels = dims['latent_shape'][0]
embed_dim = dims['embed_dim']
seq_len = dims['seq_len']
set_seed(args.seed)
exp_dir = Path(args.output_dir) / f"{args.exp_name}"
exp_dir.mkdir(parents=True, exist_ok=True)
with open(exp_dir / "config.json", "w") as f:
json.dump(vars(args), f, indent=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
use_grouped = args.k_prompts > 0
train_loader, val_loader, test_loader, stats = prep_dataloaders(
data_dir=args.data_dir,
model_type=args.model_type,
target=args.target,
split_by='prompt',
batch_size=args.batch_size,
num_workers=args.num_workers,
seed=args.seed,
k_prompts_per_batch=args.k_prompts,
max_prompts=args.max_prompts,
)
y_mean, y_std = stats['y_mean'], stats['y_std']
model = get_model(
noise_enc=args.noise_enc,
text_enc=args.text_enc,
dropout=args.dropout,
num_heads=1,
spatial_size=spatial_size,
in_channels=in_channels,
embed_dim=embed_dim,
seq_len=seq_len,
pos_encoding='sinusoidal',
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.loss == 'mae+srcc':
criterion = MAESRCCLoss(srcc_weight=1.0, regularization_strength=1e-2)
elif args.loss == 'mae+lambdarank':
criterion = MAELambdaRankLoss(lambdarank_weight=1.0, sigma=1.0, gain_type='exp2')
else:
raise ValueError(f"Unknown loss: {args.loss}")
primary_higher_better = (args.primary_metric != 'mae')
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max' if primary_higher_better else 'min', factor=0.5, patience=5
)
best_primary_value = float('-inf') if primary_higher_better else float('inf')
ndcg_key = f'ndcg_{args.ndcg_k}'
for epoch in range(args.epochs):
train_one_epoch(
model, train_loader, criterion, optimizer, device,
epoch=epoch,
loss_type=args.loss,
use_grouped=use_grouped,
)
val_metrics = evaluate(
model, val_loader, device, args.ndcg_k,
y_mean=y_mean, y_std=y_std,
gain_type='exp2',
)
if args.primary_metric == 'ndcg':
current_primary = val_metrics[ndcg_key]
elif args.primary_metric == 'srcc':
current_primary = val_metrics['srcc']
print(f"Epoch {epoch+1}/{args.epochs} SRCC={val_metrics['srcc']:.4f} NDCG@{args.ndcg_k}={val_metrics[ndcg_key]:.4f} MAE={val_metrics['mae_raw']:.4f}")
scheduler.step(current_primary)
checkpoint = {
'model_state_dict': {k: v.half() for k, v in model.state_dict().items()},
'model_config': {
'noise_enc': args.noise_enc,
'text_enc': args.text_enc,
'dropout': args.dropout,
'num_heads': 1,
'model_type': args.model_type,
'spatial_size': spatial_size,
'in_channels': in_channels,
'embed_dim': embed_dim,
'seq_len': seq_len,
'pos_encoding': 'sinusoidal',
},
'normalization': {
'target': args.target,
'y_mean': y_mean,
'y_std': y_std,
},
}
improved = (primary_higher_better and current_primary > best_primary_value) or \
(not primary_higher_better and current_primary < best_primary_value)
if improved:
best_primary_value = current_primary
torch.save(checkpoint, exp_dir / "best_model.pth")
checkpoint = torch.load(exp_dir / "best_model.pth", weights_only=False)
state_dict = {k: v.float() for k, v in checkpoint['model_state_dict'].items()}
model.load_state_dict(state_dict)
evaluate(
model, test_loader, device, args.ndcg_k,
y_mean=y_mean, y_std=y_std,
gain_type='exp2',
)
if __name__ == "__main__":
main()