MARS-SeqRec / train.py
CyberDancer's picture
MARS: Multi-scale Adaptive Recurrence with State compression
2319f81 verified
"""
Training script for MARS: Multi-scale Adaptive Recurrence with State compression
Trains both MARS and SASRec baseline for comparison.
Uses MovieLens-1M dataset (avg 164 interactions/user — ideal for long-sequence testing).
Usage:
python train.py --model mars --max_seq_len 512 --epochs 50
python train.py --model sasrec --max_seq_len 200 --epochs 50
"""
import os
import sys
import time
import json
import argparse
import random
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from model import MARS, SASRecBaseline
from data import (
load_movielens_1m,
generate_synthetic_data,
ReindexedData,
create_dataloaders,
save_data_config,
)
from evaluate import evaluate_model, compute_metrics_full
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def train_epoch(model, train_loader, optimizer, device, epoch, log_interval=50):
model.train()
total_loss = 0
num_batches = 0
start_time = time.time()
for batch_idx, batch in enumerate(train_loader):
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
loss = model(batch)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
if (batch_idx + 1) % log_interval == 0:
avg_loss = total_loss / num_batches
elapsed = time.time() - start_time
print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | "
f"Loss: {avg_loss:.4f} | Time: {elapsed:.1f}s")
avg_loss = total_loss / num_batches
epoch_time = time.time() - start_time
return avg_loss, epoch_time
def main():
parser = argparse.ArgumentParser(description='MARS Training')
parser.add_argument('--model', type=str, default='mars', choices=['mars', 'sasrec'])
parser.add_argument('--dataset', type=str, default='ml-1m',
choices=['ml-1m', 'synthetic', 'amazon'])
parser.add_argument('--amazon_category', type=str, default='Movies_and_TV')
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--max_seq_len', type=int, default=512)
parser.add_argument('--short_term_len', type=int, default=50)
parser.add_argument('--num_memory_tokens', type=int, default=8)
parser.add_argument('--num_tadn_layers', type=int, default=3)
parser.add_argument('--num_attn_layers', type=int, default=2)
parser.add_argument('--num_heads', type=int, default=2)
parser.add_argument('--state_dim', type=int, default=64)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--num_negatives', type=int, default=4)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--eval_interval', type=int, default=5)
parser.add_argument('--save_dir', type=str, default='./checkpoints')
parser.add_argument('--device', type=str, default='auto')
parser.add_argument('--push_to_hub', action='store_true')
parser.add_argument('--hub_model_id', type=str, default='')
args = parser.parse_args()
set_seed(args.seed)
# Device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# Initialize tracking
try:
import trackio
run_name = f"MARS-{args.model}-{args.dataset}-{args.max_seq_len}"
trackio.init(
name=run_name,
project="mars-seqrec",
)
use_trackio = True
print(f"Trackio initialized: {run_name}")
except Exception as e:
print(f"Trackio not available: {e}")
use_trackio = False
# Load data
print(f"\n{'='*60}")
print(f"Loading dataset: {args.dataset}")
print(f"{'='*60}")
if args.dataset == 'ml-1m':
sequences = load_movielens_1m(min_interactions=5)
elif args.dataset == 'synthetic':
sequences = generate_synthetic_data(
num_users=10000, num_items=5000,
min_seq_len=50, max_seq_len=1000
)
elif args.dataset == 'amazon':
from data import load_amazon_reviews
sequences = load_amazon_reviews(
category=args.amazon_category,
min_interactions=20,
max_users=50000
)
if not sequences:
print("No data loaded! Using synthetic data as fallback.")
sequences = generate_synthetic_data()
# Process data
data = ReindexedData(sequences, max_seq_len=args.max_seq_len)
train_loader, val_loader, test_loader = create_dataloaders(
data, max_seq_len=args.max_seq_len,
batch_size=args.batch_size,
num_negatives=args.num_negatives,
)
# Save data config
os.makedirs(args.save_dir, exist_ok=True)
data_config = save_data_config(data, os.path.join(args.save_dir, 'data_config.json'))
# Create model
print(f"\n{'='*60}")
print(f"Creating model: {args.model.upper()}")
print(f"{'='*60}")
if args.model == 'mars':
model = MARS(
num_items=data.num_items,
embed_dim=args.embed_dim,
max_seq_len=args.max_seq_len,
short_term_len=args.short_term_len,
num_memory_tokens=args.num_memory_tokens,
num_tadn_layers=args.num_tadn_layers,
num_attn_layers=args.num_attn_layers,
num_heads=args.num_heads,
state_dim=args.state_dim,
dropout=args.dropout,
)
else:
model = SASRecBaseline(
num_items=data.num_items,
embed_dim=args.embed_dim,
max_seq_len=min(args.max_seq_len, 200), # SASRec limited to 200
num_heads=args.num_heads,
num_layers=args.num_attn_layers,
dropout=args.dropout,
)
model = model.to(device)
num_params = count_parameters(model)
print(f"Model parameters: {num_params:,}")
print(f"Max sequence length: {args.max_seq_len}")
# Optimizer
optimizer = AdamW(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.01)
# Training config
config = {
'model': args.model,
'dataset': args.dataset,
'num_items': data.num_items,
'embed_dim': args.embed_dim,
'max_seq_len': args.max_seq_len,
'short_term_len': args.short_term_len,
'num_memory_tokens': args.num_memory_tokens,
'num_tadn_layers': args.num_tadn_layers,
'num_attn_layers': args.num_attn_layers,
'num_heads': args.num_heads,
'state_dim': args.state_dim,
'dropout': args.dropout,
'batch_size': args.batch_size,
'lr': args.lr,
'weight_decay': args.weight_decay,
'epochs': args.epochs,
'num_negatives': args.num_negatives,
'num_params': num_params,
}
with open(os.path.join(args.save_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=2)
if use_trackio:
trackio.log(config)
# Training loop
print(f"\n{'='*60}")
print(f"Starting training for {args.epochs} epochs")
print(f"{'='*60}")
best_val_hr10 = 0
best_epoch = 0
results_history = []
for epoch in range(1, args.epochs + 1):
# Train
train_loss, epoch_time = train_epoch(
model, train_loader, optimizer, device, epoch
)
scheduler.step()
current_lr = scheduler.get_last_lr()[0]
print(f"\nEpoch {epoch}/{args.epochs} | Loss: {train_loss:.4f} | "
f"LR: {current_lr:.6f} | Time: {epoch_time:.1f}s")
if use_trackio:
trackio.log({
"train/loss": train_loss,
"train/lr": current_lr,
"train/epoch_time": epoch_time,
"epoch": epoch,
})
# Evaluate
if epoch % args.eval_interval == 0 or epoch == args.epochs:
print(f"\nEvaluating at epoch {epoch}...")
metrics = evaluate_model(
model, val_loader, data.num_items, device,
ks=[5, 10, 20, 50]
)
print(f" Val Results:")
for k, v in metrics.items():
print(f" {k}: {v:.4f}")
if use_trackio:
trackio.log({f"val/{k}": v for k, v in metrics.items()})
trackio.log({"epoch": epoch})
# Save best model
hr10 = metrics.get('HR@10', 0)
if hr10 > best_val_hr10:
best_val_hr10 = hr10
best_epoch = epoch
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config,
'metrics': metrics,
}
torch.save(checkpoint, os.path.join(args.save_dir, 'best_model.pt'))
print(f" ✓ New best model! HR@10={hr10:.4f}")
results_history.append({
'epoch': epoch,
'train_loss': train_loss,
**metrics
})
# Final test evaluation with best model
print(f"\n{'='*60}")
print(f"Final Test Evaluation (best epoch: {best_epoch})")
print(f"{'='*60}")
checkpoint = torch.load(os.path.join(args.save_dir, 'best_model.pt'), weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
test_metrics = evaluate_model(
model, test_loader, data.num_items, device,
ks=[5, 10, 20, 50]
)
print(f"\nTest Results:")
for k, v in test_metrics.items():
print(f" {k}: {v:.4f}")
if use_trackio:
trackio.log({f"test/{k}": v for k, v in test_metrics.items()})
# Save final results
final_results = {
'model': args.model,
'dataset': args.dataset,
'best_epoch': best_epoch,
'best_val_hr10': best_val_hr10,
'test_metrics': test_metrics,
'config': config,
'history': results_history,
}
with open(os.path.join(args.save_dir, 'results.json'), 'w') as f:
json.dump(final_results, f, indent=2)
# Push to Hub
if args.push_to_hub and args.hub_model_id:
print(f"\nPushing to HF Hub: {args.hub_model_id}")
try:
from huggingface_hub import HfApi, upload_folder
api = HfApi()
api.create_repo(args.hub_model_id, exist_ok=True)
upload_folder(
folder_path=args.save_dir,
repo_id=args.hub_model_id,
commit_message=f"MARS training - {args.model} on {args.dataset}"
)
print(f"✓ Pushed to https://huggingface.co/{args.hub_model_id}")
except Exception as e:
print(f"Failed to push: {e}")
print(f"\n{'='*60}")
print(f"Training complete!")
print(f"Best Val HR@10: {best_val_hr10:.4f} (epoch {best_epoch})")
print(f"Test HR@10: {test_metrics.get('HR@10', 0):.4f}")
print(f"{'='*60}")
if __name__ == '__main__':
main()