""" Training script for MARS: Multi-scale Adaptive Recurrence with State compression Trains both MARS and SASRec baseline for comparison. Uses MovieLens-1M dataset (avg 164 interactions/user — ideal for long-sequence testing). Usage: python train.py --model mars --max_seq_len 512 --epochs 50 python train.py --model sasrec --max_seq_len 200 --epochs 50 """ import os import sys import time import json import argparse import random import numpy as np import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from model import MARS, SASRecBaseline from data import ( load_movielens_1m, generate_synthetic_data, ReindexedData, create_dataloaders, save_data_config, ) from evaluate import evaluate_model, compute_metrics_full def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def train_epoch(model, train_loader, optimizer, device, epoch, log_interval=50): model.train() total_loss = 0 num_batches = 0 start_time = time.time() for batch_idx, batch in enumerate(train_loader): batch = {k: v.to(device) for k, v in batch.items()} optimizer.zero_grad() loss = model(batch) loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() num_batches += 1 if (batch_idx + 1) % log_interval == 0: avg_loss = total_loss / num_batches elapsed = time.time() - start_time print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | " f"Loss: {avg_loss:.4f} | Time: {elapsed:.1f}s") avg_loss = total_loss / num_batches epoch_time = time.time() - start_time return avg_loss, epoch_time def main(): parser = argparse.ArgumentParser(description='MARS Training') parser.add_argument('--model', type=str, default='mars', choices=['mars', 'sasrec']) parser.add_argument('--dataset', type=str, default='ml-1m', choices=['ml-1m', 'synthetic', 'amazon']) parser.add_argument('--amazon_category', type=str, default='Movies_and_TV') parser.add_argument('--embed_dim', type=int, default=64) parser.add_argument('--max_seq_len', type=int, default=512) parser.add_argument('--short_term_len', type=int, default=50) parser.add_argument('--num_memory_tokens', type=int, default=8) parser.add_argument('--num_tadn_layers', type=int, default=3) parser.add_argument('--num_attn_layers', type=int, default=2) parser.add_argument('--num_heads', type=int, default=2) parser.add_argument('--state_dim', type=int, default=64) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--num_negatives', type=int, default=4) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--eval_interval', type=int, default=5) parser.add_argument('--save_dir', type=str, default='./checkpoints') parser.add_argument('--device', type=str, default='auto') parser.add_argument('--push_to_hub', action='store_true') parser.add_argument('--hub_model_id', type=str, default='') args = parser.parse_args() set_seed(args.seed) # Device if args.device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(args.device) print(f"Using device: {device}") # Initialize tracking try: import trackio run_name = f"MARS-{args.model}-{args.dataset}-{args.max_seq_len}" trackio.init( name=run_name, project="mars-seqrec", ) use_trackio = True print(f"Trackio initialized: {run_name}") except Exception as e: print(f"Trackio not available: {e}") use_trackio = False # Load data print(f"\n{'='*60}") print(f"Loading dataset: {args.dataset}") print(f"{'='*60}") if args.dataset == 'ml-1m': sequences = load_movielens_1m(min_interactions=5) elif args.dataset == 'synthetic': sequences = generate_synthetic_data( num_users=10000, num_items=5000, min_seq_len=50, max_seq_len=1000 ) elif args.dataset == 'amazon': from data import load_amazon_reviews sequences = load_amazon_reviews( category=args.amazon_category, min_interactions=20, max_users=50000 ) if not sequences: print("No data loaded! Using synthetic data as fallback.") sequences = generate_synthetic_data() # Process data data = ReindexedData(sequences, max_seq_len=args.max_seq_len) train_loader, val_loader, test_loader = create_dataloaders( data, max_seq_len=args.max_seq_len, batch_size=args.batch_size, num_negatives=args.num_negatives, ) # Save data config os.makedirs(args.save_dir, exist_ok=True) data_config = save_data_config(data, os.path.join(args.save_dir, 'data_config.json')) # Create model print(f"\n{'='*60}") print(f"Creating model: {args.model.upper()}") print(f"{'='*60}") if args.model == 'mars': model = MARS( num_items=data.num_items, embed_dim=args.embed_dim, max_seq_len=args.max_seq_len, short_term_len=args.short_term_len, num_memory_tokens=args.num_memory_tokens, num_tadn_layers=args.num_tadn_layers, num_attn_layers=args.num_attn_layers, num_heads=args.num_heads, state_dim=args.state_dim, dropout=args.dropout, ) else: model = SASRecBaseline( num_items=data.num_items, embed_dim=args.embed_dim, max_seq_len=min(args.max_seq_len, 200), # SASRec limited to 200 num_heads=args.num_heads, num_layers=args.num_attn_layers, dropout=args.dropout, ) model = model.to(device) num_params = count_parameters(model) print(f"Model parameters: {num_params:,}") print(f"Max sequence length: {args.max_seq_len}") # Optimizer optimizer = AdamW( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, ) scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.01) # Training config config = { 'model': args.model, 'dataset': args.dataset, 'num_items': data.num_items, 'embed_dim': args.embed_dim, 'max_seq_len': args.max_seq_len, 'short_term_len': args.short_term_len, 'num_memory_tokens': args.num_memory_tokens, 'num_tadn_layers': args.num_tadn_layers, 'num_attn_layers': args.num_attn_layers, 'num_heads': args.num_heads, 'state_dim': args.state_dim, 'dropout': args.dropout, 'batch_size': args.batch_size, 'lr': args.lr, 'weight_decay': args.weight_decay, 'epochs': args.epochs, 'num_negatives': args.num_negatives, 'num_params': num_params, } with open(os.path.join(args.save_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=2) if use_trackio: trackio.log(config) # Training loop print(f"\n{'='*60}") print(f"Starting training for {args.epochs} epochs") print(f"{'='*60}") best_val_hr10 = 0 best_epoch = 0 results_history = [] for epoch in range(1, args.epochs + 1): # Train train_loss, epoch_time = train_epoch( model, train_loader, optimizer, device, epoch ) scheduler.step() current_lr = scheduler.get_last_lr()[0] print(f"\nEpoch {epoch}/{args.epochs} | Loss: {train_loss:.4f} | " f"LR: {current_lr:.6f} | Time: {epoch_time:.1f}s") if use_trackio: trackio.log({ "train/loss": train_loss, "train/lr": current_lr, "train/epoch_time": epoch_time, "epoch": epoch, }) # Evaluate if epoch % args.eval_interval == 0 or epoch == args.epochs: print(f"\nEvaluating at epoch {epoch}...") metrics = evaluate_model( model, val_loader, data.num_items, device, ks=[5, 10, 20, 50] ) print(f" Val Results:") for k, v in metrics.items(): print(f" {k}: {v:.4f}") if use_trackio: trackio.log({f"val/{k}": v for k, v in metrics.items()}) trackio.log({"epoch": epoch}) # Save best model hr10 = metrics.get('HR@10', 0) if hr10 > best_val_hr10: best_val_hr10 = hr10 best_epoch = epoch checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'config': config, 'metrics': metrics, } torch.save(checkpoint, os.path.join(args.save_dir, 'best_model.pt')) print(f" ✓ New best model! HR@10={hr10:.4f}") results_history.append({ 'epoch': epoch, 'train_loss': train_loss, **metrics }) # Final test evaluation with best model print(f"\n{'='*60}") print(f"Final Test Evaluation (best epoch: {best_epoch})") print(f"{'='*60}") checkpoint = torch.load(os.path.join(args.save_dir, 'best_model.pt'), weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) test_metrics = evaluate_model( model, test_loader, data.num_items, device, ks=[5, 10, 20, 50] ) print(f"\nTest Results:") for k, v in test_metrics.items(): print(f" {k}: {v:.4f}") if use_trackio: trackio.log({f"test/{k}": v for k, v in test_metrics.items()}) # Save final results final_results = { 'model': args.model, 'dataset': args.dataset, 'best_epoch': best_epoch, 'best_val_hr10': best_val_hr10, 'test_metrics': test_metrics, 'config': config, 'history': results_history, } with open(os.path.join(args.save_dir, 'results.json'), 'w') as f: json.dump(final_results, f, indent=2) # Push to Hub if args.push_to_hub and args.hub_model_id: print(f"\nPushing to HF Hub: {args.hub_model_id}") try: from huggingface_hub import HfApi, upload_folder api = HfApi() api.create_repo(args.hub_model_id, exist_ok=True) upload_folder( folder_path=args.save_dir, repo_id=args.hub_model_id, commit_message=f"MARS training - {args.model} on {args.dataset}" ) print(f"✓ Pushed to https://huggingface.co/{args.hub_model_id}") except Exception as e: print(f"Failed to push: {e}") print(f"\n{'='*60}") print(f"Training complete!") print(f"Best Val HR@10: {best_val_hr10:.4f} (epoch {best_epoch})") print(f"Test HR@10: {test_metrics.get('HR@10', 0):.4f}") print(f"{'='*60}") if __name__ == '__main__': main()