""" MARS v2 Training Script — Improved architecture with linear attention. """ import os, sys, time, json, random import numpy as np import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR random.seed(42); np.random.seed(42); torch.manual_seed(42) device = torch.device('cpu') print(f"Device: {device}") from model_v2 import MARSv2, SASRecBaseline from data import load_movielens_1m, ReindexedData, create_dataloaders from evaluate import evaluate_model, print_comparison try: import trackio trackio.init(name="MARSv2-SeqRec-ML1M", project="mars-seqrec") use_trackio = True print("Trackio initialized") except Exception as e: use_trackio = False # Load data print("\nLoading MovieLens-1M...") sequences = load_movielens_1m(min_interactions=5) seq_lens = [len(v['item_ids']) for v in sequences.values()] print(f"{len(sequences)} users, seq mean={np.mean(seq_lens):.1f}, max={np.max(seq_lens)}") def train_model(model_name, model, config, device): print(f"\n{'='*60}\nTraining: {model_name.upper()}\nParams: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\n{'='*60}") data = ReindexedData(sequences, max_seq_len=config['max_seq_len']) train_loader, val_loader, test_loader = create_dataloaders( data, max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], num_negatives=config['num_negatives'], num_workers=2) optimizer = AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) # Warmup + cosine schedule total_steps = config['epochs'] * len(train_loader) warmup_steps = min(500, total_steps // 10) def lr_lambda(step): if step < warmup_steps: return step / warmup_steps progress = (step - warmup_steps) / (total_steps - warmup_steps) return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) import math scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) best_hr10, best_epoch, best_state = 0, 0, None for epoch in range(1, config['epochs'] + 1): model.train() total_loss, n = 0, 0 t0 = time.time() for batch in train_loader: batch = {k: v.to(device) for k, v in batch.items()} optimizer.zero_grad() loss = model(batch) if torch.isnan(loss): print(f"WARNING: NaN loss at epoch {epoch}!") continue loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() total_loss += loss.item() n += 1 avg_loss = total_loss / max(n, 1) ep_time = time.time() - t0 print(f"Epoch {epoch:3d}/{config['epochs']} | Loss: {avg_loss:.4f} | Time: {ep_time:.1f}s") if use_trackio: trackio.log({f"{model_name}/train_loss": avg_loss, "epoch": epoch}) if epoch % config['eval_interval'] == 0 or epoch == config['epochs']: metrics = evaluate_model(model, val_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True) print(f" Val | HR@10={metrics['HR@10']:.4f} NDCG@10={metrics['NDCG@10']:.4f} MRR@10={metrics['MRR@10']:.4f}") if use_trackio: trackio.log({f"{model_name}/val_{k}": v for k, v in metrics.items() if k != 'eval_time'}) if metrics['HR@10'] > best_hr10: best_hr10 = metrics['HR@10'] best_epoch = epoch best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} print(f" ✓ New best! HR@10={best_hr10:.4f}") if best_state: model.load_state_dict(best_state) test_metrics = evaluate_model(model, test_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True) print(f"\nTest ({model_name}, best ep {best_epoch}):") for k, v in sorted(test_metrics.items()): if k != 'eval_time': print(f" {k}: {v:.4f}") save_dir = f'./checkpoints/{model_name}' os.makedirs(save_dir, exist_ok=True) torch.save({'model_state_dict': best_state or model.state_dict(), 'config': config, 'test_metrics': test_metrics, 'best_epoch': best_epoch, 'num_items': data.num_items}, os.path.join(save_dir, 'best_model.pt')) return test_metrics, sum(p.numel() for p in model.parameters()) # Configs SASREC_CFG = {'max_seq_len': 128, 'batch_size': 128, 'lr': 1e-3, 'weight_decay': 0.0, 'epochs': 25, 'num_negatives': 4, 'eval_interval': 5} MARS_CFG = {'max_seq_len': 128, 'batch_size': 64, 'lr': 5e-4, 'weight_decay': 0.01, 'epochs': 25, 'num_negatives': 4, 'eval_interval': 5} # Precompute data for num_items data_tmp = ReindexedData(sequences, max_seq_len=128) num_items = data_tmp.num_items # Models sasrec = SASRecBaseline(num_items=num_items, embed_dim=64, max_seq_len=128, num_heads=2, num_layers=2, dropout=0.1) marsv2 = MARSv2(num_items=num_items, embed_dim=64, max_seq_len=128, short_term_len=30, num_memory_tokens=8, num_long_layers=3, num_short_layers=2, num_heads=2, dropout=0.1) # Train sasrec_results, sasrec_params = train_model('sasrec', sasrec, SASREC_CFG, device) mars_results, mars_params = train_model('marsv2', marsv2, MARS_CFG, device) # Compare print_comparison(mars_results, sasrec_results, ks=[5, 10, 20, 50]) # Save final = { 'marsv2': {'metrics': mars_results, 'config': MARS_CFG, 'params': mars_params}, 'sasrec': {'metrics': sasrec_results, 'config': SASREC_CFG, 'params': sasrec_params}, 'dataset': 'MovieLens-1M', } os.makedirs('./checkpoints', exist_ok=True) with open('./checkpoints/final_results.json', 'w') as f: json.dump(final, f, indent=2, default=str) # Push to Hub try: from huggingface_hub import HfApi, upload_folder import shutil hub_id = 'CyberDancer/MARS-SeqRec' api = HfApi() api.create_repo(hub_id, exist_ok=True) for f in ['model.py', 'model_v2.py', 'data.py', 'evaluate.py', 'train.py', 'train_gpu.py', 'train_v2.py']: if os.path.exists(f'/app/{f}'): shutil.copy(f'/app/{f}', f'./checkpoints/{f}') readme = f"""# MARS: Multi-scale Adaptive Recurrence with State compression An innovative method for **super long sequence modeling** in sequential recommendation. ## Architecture ``` Input: User interaction sequence + timestamps │ ├── Long-term Branch (Temporal-Gated Linear Attention, O(n)) │ │ │ [Compressive Memory] → fixed-size memory tokens │ │ ├── Short-term Branch (Causal Self-Attention, last K items) │ └── Adaptive Fusion Gate → User Embedding → Next Item Prediction ``` ## Key Innovations 1. **Temporal-Gated Linear Attention** — O(n) complexity via kernel trick (ELU+1 feature map) with learned temporal decay weighting per attention head 2. **Compressive Memory Tokens** — Cross-attention bottleneck compresses full history into M fixed tokens 3. **Dual-Branch with Adaptive Fusion** — Per-user gating balances long-term preferences and short-term intent 4. **Multi-Scale Temporal Encoding** — Log-scaled time deltas + periodic components for daily/weekly patterns ## Results on MovieLens-1M (Full Ranking, 3706 items) | Model | Params | HR@5 | HR@10 | HR@20 | NDCG@10 | MRR@10 | |-------|--------|------|-------|-------|---------|--------| | SASRec | {sasrec_params:,} | {sasrec_results.get('HR@5',0):.4f} | {sasrec_results.get('HR@10',0):.4f} | {sasrec_results.get('HR@20',0):.4f} | {sasrec_results.get('NDCG@10',0):.4f} | {sasrec_results.get('MRR@10',0):.4f} | | **MARS v2** | {mars_params:,} | {mars_results.get('HR@5',0):.4f} | {mars_results.get('HR@10',0):.4f} | {mars_results.get('HR@20',0):.4f} | {mars_results.get('NDCG@10',0):.4f} | {mars_results.get('MRR@10',0):.4f} | ## Core Method: Temporal-Gated Linear Attention Standard linear attention: `Attn(Q,K,V) = φ(Q)(φ(K)^T V) / φ(Q)φ(K)^T 1` Our enhancement adds temporal gating: ``` K_gated = K ⊙ σ(W_decay · log(1 + Δt/3600)) ``` where `Δt` is the inter-action time gap and `W_decay` is learned per attention head. This gives O(n) complexity while explicitly modeling temporal dynamics — recent interactions get higher attention weight, with the decay rate learned per head. ## Based On - **HyTRec** (2602.18283) — Temporal-aware dual-branch architecture - **Rec2PM** (2602.11605) — Compressive memory as information bottleneck - **Linear Transformers** (Katharopoulos et al.) — Kernel-based linear attention - **SASRec** (1808.09781) — Self-attentive sequential recommendation baseline ## Usage ```python from model_v2 import MARSv2 model = MARSv2( num_items=10000, embed_dim=64, max_seq_len=2048, # Handles very long sequences at O(n) cost short_term_len=50, num_memory_tokens=8, num_long_layers=3, num_short_layers=2, ) ``` """ with open('./checkpoints/README.md', 'w') as f: f.write(readme) upload_folder(folder_path='./checkpoints', repo_id=hub_id, commit_message="MARS v2: Temporal-Gated Linear Attention for SeqRec") print(f"\n✓ Pushed to https://huggingface.co/{hub_id}") except Exception as e: print(f"Hub push: {e}") print("\nDone!")