MARS-SeqRec / train_v2.py
CyberDancer's picture
MARS v2: Temporal-Gated Linear Attention for SeqRec
3989f8c verified
"""
MARS v2 Training Script β€” Improved architecture with linear attention.
"""
import os, sys, time, json, random
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
random.seed(42); np.random.seed(42); torch.manual_seed(42)
device = torch.device('cpu')
print(f"Device: {device}")
from model_v2 import MARSv2, SASRecBaseline
from data import load_movielens_1m, ReindexedData, create_dataloaders
from evaluate import evaluate_model, print_comparison
try:
import trackio
trackio.init(name="MARSv2-SeqRec-ML1M", project="mars-seqrec")
use_trackio = True
print("Trackio initialized")
except Exception as e:
use_trackio = False
# Load data
print("\nLoading MovieLens-1M...")
sequences = load_movielens_1m(min_interactions=5)
seq_lens = [len(v['item_ids']) for v in sequences.values()]
print(f"{len(sequences)} users, seq mean={np.mean(seq_lens):.1f}, max={np.max(seq_lens)}")
def train_model(model_name, model, config, device):
print(f"\n{'='*60}\nTraining: {model_name.upper()}\nParams: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\n{'='*60}")
data = ReindexedData(sequences, max_seq_len=config['max_seq_len'])
train_loader, val_loader, test_loader = create_dataloaders(
data, max_seq_len=config['max_seq_len'], batch_size=config['batch_size'],
num_negatives=config['num_negatives'], num_workers=2)
optimizer = AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
# Warmup + cosine schedule
total_steps = config['epochs'] * len(train_loader)
warmup_steps = min(500, total_steps // 10)
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))
import math
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
best_hr10, best_epoch, best_state = 0, 0, None
for epoch in range(1, config['epochs'] + 1):
model.train()
total_loss, n = 0, 0
t0 = time.time()
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
loss = model(batch)
if torch.isnan(loss):
print(f"WARNING: NaN loss at epoch {epoch}!")
continue
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
n += 1
avg_loss = total_loss / max(n, 1)
ep_time = time.time() - t0
print(f"Epoch {epoch:3d}/{config['epochs']} | Loss: {avg_loss:.4f} | Time: {ep_time:.1f}s")
if use_trackio:
trackio.log({f"{model_name}/train_loss": avg_loss, "epoch": epoch})
if epoch % config['eval_interval'] == 0 or epoch == config['epochs']:
metrics = evaluate_model(model, val_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True)
print(f" Val | HR@10={metrics['HR@10']:.4f} NDCG@10={metrics['NDCG@10']:.4f} MRR@10={metrics['MRR@10']:.4f}")
if use_trackio:
trackio.log({f"{model_name}/val_{k}": v for k, v in metrics.items() if k != 'eval_time'})
if metrics['HR@10'] > best_hr10:
best_hr10 = metrics['HR@10']
best_epoch = epoch
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
print(f" βœ“ New best! HR@10={best_hr10:.4f}")
if best_state:
model.load_state_dict(best_state)
test_metrics = evaluate_model(model, test_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True)
print(f"\nTest ({model_name}, best ep {best_epoch}):")
for k, v in sorted(test_metrics.items()):
if k != 'eval_time': print(f" {k}: {v:.4f}")
save_dir = f'./checkpoints/{model_name}'
os.makedirs(save_dir, exist_ok=True)
torch.save({'model_state_dict': best_state or model.state_dict(), 'config': config,
'test_metrics': test_metrics, 'best_epoch': best_epoch, 'num_items': data.num_items},
os.path.join(save_dir, 'best_model.pt'))
return test_metrics, sum(p.numel() for p in model.parameters())
# Configs
SASREC_CFG = {'max_seq_len': 128, 'batch_size': 128, 'lr': 1e-3, 'weight_decay': 0.0,
'epochs': 25, 'num_negatives': 4, 'eval_interval': 5}
MARS_CFG = {'max_seq_len': 128, 'batch_size': 64, 'lr': 5e-4, 'weight_decay': 0.01,
'epochs': 25, 'num_negatives': 4, 'eval_interval': 5}
# Precompute data for num_items
data_tmp = ReindexedData(sequences, max_seq_len=128)
num_items = data_tmp.num_items
# Models
sasrec = SASRecBaseline(num_items=num_items, embed_dim=64, max_seq_len=128, num_heads=2, num_layers=2, dropout=0.1)
marsv2 = MARSv2(num_items=num_items, embed_dim=64, max_seq_len=128, short_term_len=30,
num_memory_tokens=8, num_long_layers=3, num_short_layers=2, num_heads=2, dropout=0.1)
# Train
sasrec_results, sasrec_params = train_model('sasrec', sasrec, SASREC_CFG, device)
mars_results, mars_params = train_model('marsv2', marsv2, MARS_CFG, device)
# Compare
print_comparison(mars_results, sasrec_results, ks=[5, 10, 20, 50])
# Save
final = {
'marsv2': {'metrics': mars_results, 'config': MARS_CFG, 'params': mars_params},
'sasrec': {'metrics': sasrec_results, 'config': SASREC_CFG, 'params': sasrec_params},
'dataset': 'MovieLens-1M',
}
os.makedirs('./checkpoints', exist_ok=True)
with open('./checkpoints/final_results.json', 'w') as f:
json.dump(final, f, indent=2, default=str)
# Push to Hub
try:
from huggingface_hub import HfApi, upload_folder
import shutil
hub_id = 'CyberDancer/MARS-SeqRec'
api = HfApi()
api.create_repo(hub_id, exist_ok=True)
for f in ['model.py', 'model_v2.py', 'data.py', 'evaluate.py', 'train.py', 'train_gpu.py', 'train_v2.py']:
if os.path.exists(f'/app/{f}'):
shutil.copy(f'/app/{f}', f'./checkpoints/{f}')
readme = f"""# MARS: Multi-scale Adaptive Recurrence with State compression
An innovative method for **super long sequence modeling** in sequential recommendation.
## Architecture
```
Input: User interaction sequence + timestamps
β”‚
β”œβ”€β”€ Long-term Branch (Temporal-Gated Linear Attention, O(n))
β”‚ β”‚
β”‚ [Compressive Memory] β†’ fixed-size memory tokens
β”‚ β”‚
β”œβ”€β”€ Short-term Branch (Causal Self-Attention, last K items)
β”‚
└── Adaptive Fusion Gate β†’ User Embedding β†’ Next Item Prediction
```
## Key Innovations
1. **Temporal-Gated Linear Attention** β€” O(n) complexity via kernel trick (ELU+1 feature map) with learned temporal decay weighting per attention head
2. **Compressive Memory Tokens** β€” Cross-attention bottleneck compresses full history into M fixed tokens
3. **Dual-Branch with Adaptive Fusion** β€” Per-user gating balances long-term preferences and short-term intent
4. **Multi-Scale Temporal Encoding** β€” Log-scaled time deltas + periodic components for daily/weekly patterns
## Results on MovieLens-1M (Full Ranking, 3706 items)
| Model | Params | HR@5 | HR@10 | HR@20 | NDCG@10 | MRR@10 |
|-------|--------|------|-------|-------|---------|--------|
| SASRec | {sasrec_params:,} | {sasrec_results.get('HR@5',0):.4f} | {sasrec_results.get('HR@10',0):.4f} | {sasrec_results.get('HR@20',0):.4f} | {sasrec_results.get('NDCG@10',0):.4f} | {sasrec_results.get('MRR@10',0):.4f} |
| **MARS v2** | {mars_params:,} | {mars_results.get('HR@5',0):.4f} | {mars_results.get('HR@10',0):.4f} | {mars_results.get('HR@20',0):.4f} | {mars_results.get('NDCG@10',0):.4f} | {mars_results.get('MRR@10',0):.4f} |
## Core Method: Temporal-Gated Linear Attention
Standard linear attention: `Attn(Q,K,V) = Ο†(Q)(Ο†(K)^T V) / Ο†(Q)Ο†(K)^T 1`
Our enhancement adds temporal gating:
```
K_gated = K βŠ™ Οƒ(W_decay Β· log(1 + Ξ”t/3600))
```
where `Ξ”t` is the inter-action time gap and `W_decay` is learned per attention head.
This gives O(n) complexity while explicitly modeling temporal dynamics β€” recent interactions get higher attention weight, with the decay rate learned per head.
## Based On
- **HyTRec** (2602.18283) β€” Temporal-aware dual-branch architecture
- **Rec2PM** (2602.11605) β€” Compressive memory as information bottleneck
- **Linear Transformers** (Katharopoulos et al.) β€” Kernel-based linear attention
- **SASRec** (1808.09781) β€” Self-attentive sequential recommendation baseline
## Usage
```python
from model_v2 import MARSv2
model = MARSv2(
num_items=10000,
embed_dim=64,
max_seq_len=2048, # Handles very long sequences at O(n) cost
short_term_len=50,
num_memory_tokens=8,
num_long_layers=3,
num_short_layers=2,
)
```
"""
with open('./checkpoints/README.md', 'w') as f:
f.write(readme)
upload_folder(folder_path='./checkpoints', repo_id=hub_id,
commit_message="MARS v2: Temporal-Gated Linear Attention for SeqRec")
print(f"\nβœ“ Pushed to https://huggingface.co/{hub_id}")
except Exception as e:
print(f"Hub push: {e}")
print("\nDone!")