| """ |
| Training script for MARS: Multi-scale Adaptive Recurrence with State compression |
| |
| Trains both MARS and SASRec baseline for comparison. |
| Uses MovieLens-1M dataset (avg 164 interactions/user — ideal for long-sequence testing). |
| |
| Usage: |
| python train.py --model mars --max_seq_len 512 --epochs 50 |
| python train.py --model sasrec --max_seq_len 200 --epochs 50 |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import argparse |
| import random |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
| from model import MARS, SASRecBaseline |
| from data import ( |
| load_movielens_1m, |
| generate_synthetic_data, |
| ReindexedData, |
| create_dataloaders, |
| save_data_config, |
| ) |
| from evaluate import evaluate_model, compute_metrics_full |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| def train_epoch(model, train_loader, optimizer, device, epoch, log_interval=50): |
| model.train() |
| total_loss = 0 |
| num_batches = 0 |
| start_time = time.time() |
| |
| for batch_idx, batch in enumerate(train_loader): |
| batch = {k: v.to(device) for k, v in batch.items()} |
| |
| optimizer.zero_grad() |
| loss = model(batch) |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| |
| optimizer.step() |
| |
| total_loss += loss.item() |
| num_batches += 1 |
| |
| if (batch_idx + 1) % log_interval == 0: |
| avg_loss = total_loss / num_batches |
| elapsed = time.time() - start_time |
| print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | " |
| f"Loss: {avg_loss:.4f} | Time: {elapsed:.1f}s") |
| |
| avg_loss = total_loss / num_batches |
| epoch_time = time.time() - start_time |
| return avg_loss, epoch_time |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='MARS Training') |
| parser.add_argument('--model', type=str, default='mars', choices=['mars', 'sasrec']) |
| parser.add_argument('--dataset', type=str, default='ml-1m', |
| choices=['ml-1m', 'synthetic', 'amazon']) |
| parser.add_argument('--amazon_category', type=str, default='Movies_and_TV') |
| parser.add_argument('--embed_dim', type=int, default=64) |
| parser.add_argument('--max_seq_len', type=int, default=512) |
| parser.add_argument('--short_term_len', type=int, default=50) |
| parser.add_argument('--num_memory_tokens', type=int, default=8) |
| parser.add_argument('--num_tadn_layers', type=int, default=3) |
| parser.add_argument('--num_attn_layers', type=int, default=2) |
| parser.add_argument('--num_heads', type=int, default=2) |
| parser.add_argument('--state_dim', type=int, default=64) |
| parser.add_argument('--dropout', type=float, default=0.1) |
| parser.add_argument('--batch_size', type=int, default=128) |
| parser.add_argument('--lr', type=float, default=1e-3) |
| parser.add_argument('--weight_decay', type=float, default=0.01) |
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--num_negatives', type=int, default=4) |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--eval_interval', type=int, default=5) |
| parser.add_argument('--save_dir', type=str, default='./checkpoints') |
| parser.add_argument('--device', type=str, default='auto') |
| parser.add_argument('--push_to_hub', action='store_true') |
| parser.add_argument('--hub_model_id', type=str, default='') |
| args = parser.parse_args() |
| |
| set_seed(args.seed) |
| |
| |
| if args.device == 'auto': |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| device = torch.device(args.device) |
| print(f"Using device: {device}") |
| |
| |
| try: |
| import trackio |
| run_name = f"MARS-{args.model}-{args.dataset}-{args.max_seq_len}" |
| trackio.init( |
| name=run_name, |
| project="mars-seqrec", |
| ) |
| use_trackio = True |
| print(f"Trackio initialized: {run_name}") |
| except Exception as e: |
| print(f"Trackio not available: {e}") |
| use_trackio = False |
| |
| |
| print(f"\n{'='*60}") |
| print(f"Loading dataset: {args.dataset}") |
| print(f"{'='*60}") |
| |
| if args.dataset == 'ml-1m': |
| sequences = load_movielens_1m(min_interactions=5) |
| elif args.dataset == 'synthetic': |
| sequences = generate_synthetic_data( |
| num_users=10000, num_items=5000, |
| min_seq_len=50, max_seq_len=1000 |
| ) |
| elif args.dataset == 'amazon': |
| from data import load_amazon_reviews |
| sequences = load_amazon_reviews( |
| category=args.amazon_category, |
| min_interactions=20, |
| max_users=50000 |
| ) |
| |
| if not sequences: |
| print("No data loaded! Using synthetic data as fallback.") |
| sequences = generate_synthetic_data() |
| |
| |
| data = ReindexedData(sequences, max_seq_len=args.max_seq_len) |
| train_loader, val_loader, test_loader = create_dataloaders( |
| data, max_seq_len=args.max_seq_len, |
| batch_size=args.batch_size, |
| num_negatives=args.num_negatives, |
| ) |
| |
| |
| os.makedirs(args.save_dir, exist_ok=True) |
| data_config = save_data_config(data, os.path.join(args.save_dir, 'data_config.json')) |
| |
| |
| print(f"\n{'='*60}") |
| print(f"Creating model: {args.model.upper()}") |
| print(f"{'='*60}") |
| |
| if args.model == 'mars': |
| model = MARS( |
| num_items=data.num_items, |
| embed_dim=args.embed_dim, |
| max_seq_len=args.max_seq_len, |
| short_term_len=args.short_term_len, |
| num_memory_tokens=args.num_memory_tokens, |
| num_tadn_layers=args.num_tadn_layers, |
| num_attn_layers=args.num_attn_layers, |
| num_heads=args.num_heads, |
| state_dim=args.state_dim, |
| dropout=args.dropout, |
| ) |
| else: |
| model = SASRecBaseline( |
| num_items=data.num_items, |
| embed_dim=args.embed_dim, |
| max_seq_len=min(args.max_seq_len, 200), |
| num_heads=args.num_heads, |
| num_layers=args.num_attn_layers, |
| dropout=args.dropout, |
| ) |
| |
| model = model.to(device) |
| num_params = count_parameters(model) |
| print(f"Model parameters: {num_params:,}") |
| print(f"Max sequence length: {args.max_seq_len}") |
| |
| |
| optimizer = AdamW( |
| model.parameters(), |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| ) |
| scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.01) |
| |
| |
| config = { |
| 'model': args.model, |
| 'dataset': args.dataset, |
| 'num_items': data.num_items, |
| 'embed_dim': args.embed_dim, |
| 'max_seq_len': args.max_seq_len, |
| 'short_term_len': args.short_term_len, |
| 'num_memory_tokens': args.num_memory_tokens, |
| 'num_tadn_layers': args.num_tadn_layers, |
| 'num_attn_layers': args.num_attn_layers, |
| 'num_heads': args.num_heads, |
| 'state_dim': args.state_dim, |
| 'dropout': args.dropout, |
| 'batch_size': args.batch_size, |
| 'lr': args.lr, |
| 'weight_decay': args.weight_decay, |
| 'epochs': args.epochs, |
| 'num_negatives': args.num_negatives, |
| 'num_params': num_params, |
| } |
| |
| with open(os.path.join(args.save_dir, 'config.json'), 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| if use_trackio: |
| trackio.log(config) |
| |
| |
| print(f"\n{'='*60}") |
| print(f"Starting training for {args.epochs} epochs") |
| print(f"{'='*60}") |
| |
| best_val_hr10 = 0 |
| best_epoch = 0 |
| results_history = [] |
| |
| for epoch in range(1, args.epochs + 1): |
| |
| train_loss, epoch_time = train_epoch( |
| model, train_loader, optimizer, device, epoch |
| ) |
| scheduler.step() |
| |
| current_lr = scheduler.get_last_lr()[0] |
| |
| print(f"\nEpoch {epoch}/{args.epochs} | Loss: {train_loss:.4f} | " |
| f"LR: {current_lr:.6f} | Time: {epoch_time:.1f}s") |
| |
| if use_trackio: |
| trackio.log({ |
| "train/loss": train_loss, |
| "train/lr": current_lr, |
| "train/epoch_time": epoch_time, |
| "epoch": epoch, |
| }) |
| |
| |
| if epoch % args.eval_interval == 0 or epoch == args.epochs: |
| print(f"\nEvaluating at epoch {epoch}...") |
| metrics = evaluate_model( |
| model, val_loader, data.num_items, device, |
| ks=[5, 10, 20, 50] |
| ) |
| |
| print(f" Val Results:") |
| for k, v in metrics.items(): |
| print(f" {k}: {v:.4f}") |
| |
| if use_trackio: |
| trackio.log({f"val/{k}": v for k, v in metrics.items()}) |
| trackio.log({"epoch": epoch}) |
| |
| |
| hr10 = metrics.get('HR@10', 0) |
| if hr10 > best_val_hr10: |
| best_val_hr10 = hr10 |
| best_epoch = epoch |
| |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'config': config, |
| 'metrics': metrics, |
| } |
| torch.save(checkpoint, os.path.join(args.save_dir, 'best_model.pt')) |
| print(f" ✓ New best model! HR@10={hr10:.4f}") |
| |
| results_history.append({ |
| 'epoch': epoch, |
| 'train_loss': train_loss, |
| **metrics |
| }) |
| |
| |
| print(f"\n{'='*60}") |
| print(f"Final Test Evaluation (best epoch: {best_epoch})") |
| print(f"{'='*60}") |
| |
| checkpoint = torch.load(os.path.join(args.save_dir, 'best_model.pt'), weights_only=False) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| test_metrics = evaluate_model( |
| model, test_loader, data.num_items, device, |
| ks=[5, 10, 20, 50] |
| ) |
| |
| print(f"\nTest Results:") |
| for k, v in test_metrics.items(): |
| print(f" {k}: {v:.4f}") |
| |
| if use_trackio: |
| trackio.log({f"test/{k}": v for k, v in test_metrics.items()}) |
| |
| |
| final_results = { |
| 'model': args.model, |
| 'dataset': args.dataset, |
| 'best_epoch': best_epoch, |
| 'best_val_hr10': best_val_hr10, |
| 'test_metrics': test_metrics, |
| 'config': config, |
| 'history': results_history, |
| } |
| |
| with open(os.path.join(args.save_dir, 'results.json'), 'w') as f: |
| json.dump(final_results, f, indent=2) |
| |
| |
| if args.push_to_hub and args.hub_model_id: |
| print(f"\nPushing to HF Hub: {args.hub_model_id}") |
| try: |
| from huggingface_hub import HfApi, upload_folder |
| api = HfApi() |
| api.create_repo(args.hub_model_id, exist_ok=True) |
| upload_folder( |
| folder_path=args.save_dir, |
| repo_id=args.hub_model_id, |
| commit_message=f"MARS training - {args.model} on {args.dataset}" |
| ) |
| print(f"✓ Pushed to https://huggingface.co/{args.hub_model_id}") |
| except Exception as e: |
| print(f"Failed to push: {e}") |
| |
| print(f"\n{'='*60}") |
| print(f"Training complete!") |
| print(f"Best Val HR@10: {best_val_hr10:.4f} (epoch {best_epoch})") |
| print(f"Test HR@10: {test_metrics.get('HR@10', 0):.4f}") |
| print(f"{'='*60}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|