| """ |
| MARS v2 Training Script β Improved architecture with linear attention. |
| """ |
|
|
| import os, sys, time, json, random |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
| random.seed(42); np.random.seed(42); torch.manual_seed(42) |
| device = torch.device('cpu') |
| print(f"Device: {device}") |
|
|
| from model_v2 import MARSv2, SASRecBaseline |
| from data import load_movielens_1m, ReindexedData, create_dataloaders |
| from evaluate import evaluate_model, print_comparison |
|
|
| try: |
| import trackio |
| trackio.init(name="MARSv2-SeqRec-ML1M", project="mars-seqrec") |
| use_trackio = True |
| print("Trackio initialized") |
| except Exception as e: |
| use_trackio = False |
|
|
| |
| print("\nLoading MovieLens-1M...") |
| sequences = load_movielens_1m(min_interactions=5) |
| seq_lens = [len(v['item_ids']) for v in sequences.values()] |
| print(f"{len(sequences)} users, seq mean={np.mean(seq_lens):.1f}, max={np.max(seq_lens)}") |
|
|
|
|
| def train_model(model_name, model, config, device): |
| print(f"\n{'='*60}\nTraining: {model_name.upper()}\nParams: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\n{'='*60}") |
| |
| data = ReindexedData(sequences, max_seq_len=config['max_seq_len']) |
| train_loader, val_loader, test_loader = create_dataloaders( |
| data, max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], |
| num_negatives=config['num_negatives'], num_workers=2) |
| |
| optimizer = AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) |
| |
| |
| total_steps = config['epochs'] * len(train_loader) |
| warmup_steps = min(500, total_steps // 10) |
| |
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / warmup_steps |
| progress = (step - warmup_steps) / (total_steps - warmup_steps) |
| return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) |
| |
| import math |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| |
| best_hr10, best_epoch, best_state = 0, 0, None |
| |
| for epoch in range(1, config['epochs'] + 1): |
| model.train() |
| total_loss, n = 0, 0 |
| t0 = time.time() |
| |
| for batch in train_loader: |
| batch = {k: v.to(device) for k, v in batch.items()} |
| optimizer.zero_grad() |
| loss = model(batch) |
| |
| if torch.isnan(loss): |
| print(f"WARNING: NaN loss at epoch {epoch}!") |
| continue |
| |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| total_loss += loss.item() |
| n += 1 |
| |
| avg_loss = total_loss / max(n, 1) |
| ep_time = time.time() - t0 |
| print(f"Epoch {epoch:3d}/{config['epochs']} | Loss: {avg_loss:.4f} | Time: {ep_time:.1f}s") |
| |
| if use_trackio: |
| trackio.log({f"{model_name}/train_loss": avg_loss, "epoch": epoch}) |
| |
| if epoch % config['eval_interval'] == 0 or epoch == config['epochs']: |
| metrics = evaluate_model(model, val_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True) |
| print(f" Val | HR@10={metrics['HR@10']:.4f} NDCG@10={metrics['NDCG@10']:.4f} MRR@10={metrics['MRR@10']:.4f}") |
| |
| if use_trackio: |
| trackio.log({f"{model_name}/val_{k}": v for k, v in metrics.items() if k != 'eval_time'}) |
| |
| if metrics['HR@10'] > best_hr10: |
| best_hr10 = metrics['HR@10'] |
| best_epoch = epoch |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| print(f" β New best! HR@10={best_hr10:.4f}") |
| |
| if best_state: |
| model.load_state_dict(best_state) |
| |
| test_metrics = evaluate_model(model, test_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True) |
| print(f"\nTest ({model_name}, best ep {best_epoch}):") |
| for k, v in sorted(test_metrics.items()): |
| if k != 'eval_time': print(f" {k}: {v:.4f}") |
| |
| save_dir = f'./checkpoints/{model_name}' |
| os.makedirs(save_dir, exist_ok=True) |
| torch.save({'model_state_dict': best_state or model.state_dict(), 'config': config, |
| 'test_metrics': test_metrics, 'best_epoch': best_epoch, 'num_items': data.num_items}, |
| os.path.join(save_dir, 'best_model.pt')) |
| |
| return test_metrics, sum(p.numel() for p in model.parameters()) |
|
|
|
|
| |
| SASREC_CFG = {'max_seq_len': 128, 'batch_size': 128, 'lr': 1e-3, 'weight_decay': 0.0, |
| 'epochs': 25, 'num_negatives': 4, 'eval_interval': 5} |
| MARS_CFG = {'max_seq_len': 128, 'batch_size': 64, 'lr': 5e-4, 'weight_decay': 0.01, |
| 'epochs': 25, 'num_negatives': 4, 'eval_interval': 5} |
|
|
| |
| data_tmp = ReindexedData(sequences, max_seq_len=128) |
| num_items = data_tmp.num_items |
|
|
| |
| sasrec = SASRecBaseline(num_items=num_items, embed_dim=64, max_seq_len=128, num_heads=2, num_layers=2, dropout=0.1) |
| marsv2 = MARSv2(num_items=num_items, embed_dim=64, max_seq_len=128, short_term_len=30, |
| num_memory_tokens=8, num_long_layers=3, num_short_layers=2, num_heads=2, dropout=0.1) |
|
|
| |
| sasrec_results, sasrec_params = train_model('sasrec', sasrec, SASREC_CFG, device) |
| mars_results, mars_params = train_model('marsv2', marsv2, MARS_CFG, device) |
|
|
| |
| print_comparison(mars_results, sasrec_results, ks=[5, 10, 20, 50]) |
|
|
| |
| final = { |
| 'marsv2': {'metrics': mars_results, 'config': MARS_CFG, 'params': mars_params}, |
| 'sasrec': {'metrics': sasrec_results, 'config': SASREC_CFG, 'params': sasrec_params}, |
| 'dataset': 'MovieLens-1M', |
| } |
| os.makedirs('./checkpoints', exist_ok=True) |
| with open('./checkpoints/final_results.json', 'w') as f: |
| json.dump(final, f, indent=2, default=str) |
|
|
| |
| try: |
| from huggingface_hub import HfApi, upload_folder |
| import shutil |
| |
| hub_id = 'CyberDancer/MARS-SeqRec' |
| api = HfApi() |
| api.create_repo(hub_id, exist_ok=True) |
| |
| for f in ['model.py', 'model_v2.py', 'data.py', 'evaluate.py', 'train.py', 'train_gpu.py', 'train_v2.py']: |
| if os.path.exists(f'/app/{f}'): |
| shutil.copy(f'/app/{f}', f'./checkpoints/{f}') |
| |
| readme = f"""# MARS: Multi-scale Adaptive Recurrence with State compression |
| |
| An innovative method for **super long sequence modeling** in sequential recommendation. |
| |
| ## Architecture |
| |
| ``` |
| Input: User interaction sequence + timestamps |
| β |
| βββ Long-term Branch (Temporal-Gated Linear Attention, O(n)) |
| β β |
| β [Compressive Memory] β fixed-size memory tokens |
| β β |
| βββ Short-term Branch (Causal Self-Attention, last K items) |
| β |
| βββ Adaptive Fusion Gate β User Embedding β Next Item Prediction |
| ``` |
| |
| ## Key Innovations |
| |
| 1. **Temporal-Gated Linear Attention** β O(n) complexity via kernel trick (ELU+1 feature map) with learned temporal decay weighting per attention head |
| 2. **Compressive Memory Tokens** β Cross-attention bottleneck compresses full history into M fixed tokens |
| 3. **Dual-Branch with Adaptive Fusion** β Per-user gating balances long-term preferences and short-term intent |
| 4. **Multi-Scale Temporal Encoding** β Log-scaled time deltas + periodic components for daily/weekly patterns |
| |
| ## Results on MovieLens-1M (Full Ranking, 3706 items) |
| |
| | Model | Params | HR@5 | HR@10 | HR@20 | NDCG@10 | MRR@10 | |
| |-------|--------|------|-------|-------|---------|--------| |
| | SASRec | {sasrec_params:,} | {sasrec_results.get('HR@5',0):.4f} | {sasrec_results.get('HR@10',0):.4f} | {sasrec_results.get('HR@20',0):.4f} | {sasrec_results.get('NDCG@10',0):.4f} | {sasrec_results.get('MRR@10',0):.4f} | |
| | **MARS v2** | {mars_params:,} | {mars_results.get('HR@5',0):.4f} | {mars_results.get('HR@10',0):.4f} | {mars_results.get('HR@20',0):.4f} | {mars_results.get('NDCG@10',0):.4f} | {mars_results.get('MRR@10',0):.4f} | |
| |
| ## Core Method: Temporal-Gated Linear Attention |
| |
| Standard linear attention: `Attn(Q,K,V) = Ο(Q)(Ο(K)^T V) / Ο(Q)Ο(K)^T 1` |
| |
| Our enhancement adds temporal gating: |
| ``` |
| K_gated = K β Ο(W_decay Β· log(1 + Ξt/3600)) |
| ``` |
| where `Ξt` is the inter-action time gap and `W_decay` is learned per attention head. |
| |
| This gives O(n) complexity while explicitly modeling temporal dynamics β recent interactions get higher attention weight, with the decay rate learned per head. |
| |
| ## Based On |
| |
| - **HyTRec** (2602.18283) β Temporal-aware dual-branch architecture |
| - **Rec2PM** (2602.11605) β Compressive memory as information bottleneck |
| - **Linear Transformers** (Katharopoulos et al.) β Kernel-based linear attention |
| - **SASRec** (1808.09781) β Self-attentive sequential recommendation baseline |
| |
| ## Usage |
| |
| ```python |
| from model_v2 import MARSv2 |
| |
| model = MARSv2( |
| num_items=10000, |
| embed_dim=64, |
| max_seq_len=2048, # Handles very long sequences at O(n) cost |
| short_term_len=50, |
| num_memory_tokens=8, |
| num_long_layers=3, |
| num_short_layers=2, |
| ) |
| ``` |
| """ |
| |
| with open('./checkpoints/README.md', 'w') as f: |
| f.write(readme) |
| |
| upload_folder(folder_path='./checkpoints', repo_id=hub_id, |
| commit_message="MARS v2: Temporal-Gated Linear Attention for SeqRec") |
| print(f"\nβ Pushed to https://huggingface.co/{hub_id}") |
| except Exception as e: |
| print(f"Hub push: {e}") |
|
|
| print("\nDone!") |
|
|