| """ |
| MARS v3: Complete rebuild for beating SASRec. |
| |
| Key fixes from research: |
| 1. Full softmax cross-entropy loss (not BCE with few negatives) |
| 2. DuoRec-style dropout contrastive learning |
| 3. FMLP-inspired frequency-domain filtering in long-term branch |
| 4. Proper max_seq_len=200 for ML-1M (avg 165 interactions) |
| 5. Proper leave-one-out evaluation protocol with full ranking |
| |
| Architecture: MARS v3 = FMLP filter (long-term, O(n log n)) |
| + Causal Attention (short-term) |
| + Compressive Memory + Adaptive Fusion |
| + DuoRec contrastive regularization |
| """ |
|
|
| import math, os, random, time, json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from collections import defaultdict |
| from typing import Dict, List, Tuple, Optional |
|
|
|
|
| |
| |
| |
|
|
| def download_movielens_1m(data_dir='./data/ml-1m'): |
| import urllib.request, zipfile |
| os.makedirs(data_dir, exist_ok=True) |
| ratings_path = os.path.join(data_dir, 'ratings.dat') |
| if not os.path.exists(ratings_path): |
| url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip' |
| zip_path = os.path.join(data_dir, 'ml-1m.zip') |
| print(f"Downloading ML-1M...") |
| urllib.request.urlretrieve(url, zip_path) |
| with zipfile.ZipFile(zip_path, 'r') as z: |
| z.extractall(data_dir) |
| inner = os.path.join(data_dir, 'ml-1m') |
| if os.path.exists(inner): |
| for f in os.listdir(inner): |
| os.rename(os.path.join(inner, f), os.path.join(data_dir, f)) |
| os.rmdir(inner) |
| os.remove(zip_path) |
| return ratings_path |
|
|
|
|
| def load_and_process_ml1m(max_seq_len=200, min_interactions=5): |
| """Load ML-1M with proper preprocessing: all ratings as implicit, 5-core filter.""" |
| ratings_path = download_movielens_1m() |
| |
| user_items = defaultdict(list) |
| with open(ratings_path, 'r') as f: |
| for line in f: |
| parts = line.strip().split('::') |
| uid, iid, rating, ts = int(parts[0]), int(parts[1]), float(parts[2]), int(parts[3]) |
| user_items[uid].append((iid, ts)) |
| |
| |
| for uid in user_items: |
| user_items[uid].sort(key=lambda x: x[1]) |
| |
| |
| for _ in range(3): |
| item_counts = defaultdict(int) |
| for uid, items in user_items.items(): |
| for iid, _ in items: |
| item_counts[iid] += 1 |
| valid_items = {iid for iid, c in item_counts.items() if c >= min_interactions} |
| |
| new_user_items = {} |
| for uid, items in user_items.items(): |
| filtered = [(iid, ts) for iid, ts in items if iid in valid_items] |
| if len(filtered) >= min_interactions: |
| new_user_items[uid] = filtered |
| user_items = new_user_items |
| |
| |
| all_items = set() |
| for items in user_items.values(): |
| all_items.update(iid for iid, _ in items) |
| item2idx = {iid: idx+1 for idx, iid in enumerate(sorted(all_items))} |
| num_items = len(item2idx) |
| |
| |
| train_seqs, val_seqs, test_seqs = [], [], [] |
| for uid, items in user_items.items(): |
| seq = [item2idx[iid] for iid, _ in items] |
| if len(seq) < 3: |
| continue |
| |
| seq = seq[-(max_seq_len + 2):] |
| |
| train_seqs.append({'items': seq[:-2], 'target': seq[-2]}) |
| val_seqs.append({'items': seq[:-1], 'target': seq[-1]}) |
| test_seqs.append({'items': seq[:-1], 'target': seq[-1]}) |
| |
| print(f"ML-1M: {len(user_items)} users, {num_items} items") |
| print(f"Train: {len(train_seqs)}, Val: {len(val_seqs)}, Test: {len(test_seqs)}") |
| seq_lens = [len(d['items']) for d in train_seqs] |
| print(f"Seq len: mean={np.mean(seq_lens):.0f}, p50={np.median(seq_lens):.0f}, " |
| f"p90={np.percentile(seq_lens, 90):.0f}, max={max(seq_lens)}") |
| |
| return train_seqs, val_seqs, test_seqs, num_items |
|
|
|
|
| class SeqRecDataset(Dataset): |
| """Minimal dataset: just pads sequences, no negative sampling (CE loss handles it).""" |
| def __init__(self, data, max_seq_len): |
| self.data = data |
| self.max_seq_len = max_seq_len |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| d = self.data[idx] |
| items = d['items'][-self.max_seq_len:] |
| target = d['target'] |
| L = len(items) |
| pad = self.max_seq_len - L |
| return { |
| 'input_ids': torch.tensor(items + [0]*pad, dtype=torch.long), |
| 'lengths': torch.tensor(L, dtype=torch.long), |
| 'target': torch.tensor(target, dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class FilterLayer(nn.Module): |
| """FMLP-Rec FFT filter: learnable frequency-domain filtering, O(n log n). |
| Replaces attention for long-term modeling. Denoises by filtering |
| high-frequency noise in the interaction sequence.""" |
| |
| def __init__(self, max_seq_len, hidden_size, dropout=0.1): |
| super().__init__() |
| self.complex_weight = nn.Parameter( |
| torch.randn(1, max_seq_len // 2 + 1, hidden_size, 2) * 0.02 |
| ) |
| self.dropout = nn.Dropout(dropout) |
| self.norm = nn.LayerNorm(hidden_size) |
| |
| def forward(self, x): |
| |
| freq = torch.fft.rfft(x, dim=1, norm='ortho') |
| weight = torch.view_as_complex(self.complex_weight) |
| |
| freq = freq * weight[:, :freq.shape[1], :] |
| out = torch.fft.irfft(freq, n=x.shape[1], dim=1, norm='ortho') |
| return self.norm(self.dropout(out) + x) |
|
|
|
|
| class FMLPBlock(nn.Module): |
| """Filter + FFN block.""" |
| def __init__(self, max_seq_len, hidden_size, inner_size, dropout=0.1): |
| super().__init__() |
| self.filter = FilterLayer(max_seq_len, hidden_size, dropout) |
| self.ffn = nn.Sequential( |
| nn.LayerNorm(hidden_size), |
| nn.Linear(hidden_size, inner_size), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(inner_size, hidden_size), |
| nn.Dropout(dropout), |
| ) |
| self.norm = nn.LayerNorm(hidden_size) |
| |
| def forward(self, x): |
| x = self.filter(x) |
| return self.norm(x + self.ffn(x)) |
|
|
|
|
| class CompressiveMemory(nn.Module): |
| """Cross-attention memory compression (from MARS v1/v2).""" |
| def __init__(self, hidden_size, num_tokens=8, num_heads=2, dropout=0.1): |
| super().__init__() |
| self.queries = nn.Parameter(torch.randn(num_tokens, hidden_size) * 0.02) |
| self.attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True) |
| self.norm = nn.LayerNorm(hidden_size) |
| |
| def forward(self, seq, mask=None): |
| B = seq.shape[0] |
| q = self.queries.unsqueeze(0).expand(B, -1, -1) |
| kpm = ~mask if mask is not None else None |
| out, _ = self.attn(q, seq, seq, key_padding_mask=kpm) |
| return self.norm(q + out).mean(dim=1) |
|
|
|
|
| class MARSv3(nn.Module): |
| """ |
| MARS v3: FMLP filter (long-term) + Causal Attention (short-term) |
| + Memory compression + Adaptive fusion + CE loss + CL loss |
| """ |
| def __init__(self, num_items, hidden_size=64, max_seq_len=200, |
| n_filter_layers=2, n_attn_layers=1, n_heads=2, |
| inner_size=256, short_len=50, n_memory=8, dropout=0.2): |
| super().__init__() |
| self.num_items = num_items |
| self.hidden_size = hidden_size |
| self.max_seq_len = max_seq_len |
| self.short_len = short_len |
| |
| self.item_emb = nn.Embedding(num_items + 1, hidden_size, padding_idx=0) |
| self.pos_emb = nn.Embedding(max_seq_len, hidden_size) |
| self.emb_dropout = nn.Dropout(dropout) |
| self.emb_norm = nn.LayerNorm(hidden_size) |
| |
| |
| self.filter_blocks = nn.ModuleList([ |
| FMLPBlock(max_seq_len, hidden_size, inner_size, dropout) |
| for _ in range(n_filter_layers) |
| ]) |
| |
| |
| self.memory = CompressiveMemory(hidden_size, n_memory, n_heads, dropout) |
| |
| |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_size, nhead=n_heads, dim_feedforward=inner_size, |
| dropout=dropout, activation='gelu', batch_first=True, norm_first=True) |
| self.short_encoder = nn.TransformerEncoder(enc_layer, num_layers=n_attn_layers) |
| |
| |
| self.gate = nn.Sequential( |
| nn.Linear(hidden_size * 3, hidden_size), nn.GELU(), |
| nn.Linear(hidden_size, hidden_size), nn.Sigmoid()) |
| |
| self.output_norm = nn.LayerNorm(hidden_size) |
| self._init_weights() |
| |
| def _init_weights(self): |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.trunc_normal_(p, std=0.02) |
| nn.init.zeros_(self.item_emb.weight[0]) |
| |
| def _embed(self, input_ids, lengths): |
| B, T = input_ids.shape |
| x = self.item_emb(input_ids) |
| pos = torch.arange(T, device=input_ids.device).unsqueeze(0).clamp(max=self.max_seq_len-1) |
| x = self.emb_norm(self.emb_dropout(x + self.pos_emb(pos))) |
| mask = torch.arange(T, device=input_ids.device).unsqueeze(0) < lengths.unsqueeze(1) |
| return x, mask |
| |
| def encode(self, input_ids, lengths): |
| """Encode sequence β user representation (B, D).""" |
| B, T = input_ids.shape |
| x, mask = self._embed(input_ids, lengths) |
| |
| |
| long_x = x |
| for block in self.filter_blocks: |
| long_x = long_x * mask.unsqueeze(-1).float() |
| long_x = block(long_x) |
| |
| |
| mem = self.memory(long_x, mask) |
| |
| |
| long_last = long_x[torch.arange(B, device=x.device), (lengths - 1).clamp(min=0)] |
| |
| |
| K = min(self.short_len, T) |
| short_ids = [] |
| short_masks = [] |
| for b in range(B): |
| sl = lengths[b].item() |
| k = min(K, sl) |
| start = max(0, sl - K) |
| ids = input_ids[b, start:sl] |
| pad = K - k |
| if pad > 0: |
| ids = torch.cat([ids, torch.zeros(pad, dtype=ids.dtype, device=ids.device)]) |
| short_ids.append(ids) |
| m = torch.zeros(K, dtype=torch.bool, device=x.device) |
| m[:k] = True |
| short_masks.append(m) |
| |
| short_ids = torch.stack(short_ids) |
| short_masks = torch.stack(short_masks) |
| short_x = self.item_emb(short_ids) + self.pos_emb( |
| torch.arange(K, device=x.device).unsqueeze(0).clamp(max=self.max_seq_len-1)) |
| short_x = self.emb_norm(self.emb_dropout(short_x)) |
| |
| causal = torch.triu(torch.ones(K, K, device=x.device, dtype=torch.bool), diagonal=1) |
| short_out = self.short_encoder(short_x, mask=causal, src_key_padding_mask=~short_masks) |
| short_lens = short_masks.sum(1).long() |
| short_last = short_out[torch.arange(B, device=x.device), (short_lens - 1).clamp(min=0)] |
| |
| |
| g = self.gate(torch.cat([long_last, short_last, mem], dim=-1)) |
| user = g * long_last + (1 - g) * short_last |
| return self.output_norm(user) |
| |
| def forward(self, input_ids, lengths, targets=None, cl_lambda=0.1): |
| """ |
| Full softmax CE loss + DuoRec dropout contrastive loss. |
| """ |
| |
| user1 = self.encode(input_ids, lengths) |
| |
| |
| all_item_embs = self.item_emb.weight[1:] |
| logits = user1 @ all_item_embs.t() |
| |
| if targets is not None: |
| |
| ce_loss = F.cross_entropy(logits, targets - 1) |
| |
| |
| if self.training and cl_lambda > 0: |
| user2 = self.encode(input_ids, lengths) |
| cl_loss = self._contrastive_loss(user1, user2) |
| return ce_loss + cl_lambda * cl_loss, logits |
| |
| return ce_loss, logits |
| |
| return logits |
| |
| def _contrastive_loss(self, h1, h2, temperature=0.1): |
| """InfoNCE between two dropout views of same sequences.""" |
| h1 = F.normalize(h1, dim=-1) |
| h2 = F.normalize(h2, dim=-1) |
| logits = h1 @ h2.t() / temperature |
| labels = torch.arange(h1.shape[0], device=h1.device) |
| return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) / 2 |
|
|
|
|
| class SASRecV3(nn.Module): |
| """SASRec with proper CE loss (fair baseline).""" |
| def __init__(self, num_items, hidden_size=64, max_seq_len=200, |
| n_layers=2, n_heads=2, inner_size=256, dropout=0.2): |
| super().__init__() |
| self.num_items = num_items |
| self.hidden_size = hidden_size |
| self.max_seq_len = max_seq_len |
| |
| self.item_emb = nn.Embedding(num_items + 1, hidden_size, padding_idx=0) |
| self.pos_emb = nn.Embedding(max_seq_len, hidden_size) |
| self.emb_dropout = nn.Dropout(dropout) |
| self.emb_norm = nn.LayerNorm(hidden_size) |
| |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_size, nhead=n_heads, dim_feedforward=inner_size, |
| dropout=dropout, activation='gelu', batch_first=True, norm_first=True) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) |
| self.output_norm = nn.LayerNorm(hidden_size) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| for p in self.parameters(): |
| if p.dim() > 1: nn.init.trunc_normal_(p, std=0.02) |
| nn.init.zeros_(self.item_emb.weight[0]) |
| |
| def encode(self, input_ids, lengths): |
| B, T = input_ids.shape |
| x = self.item_emb(input_ids) |
| pos = torch.arange(T, device=input_ids.device).unsqueeze(0).clamp(max=self.max_seq_len-1) |
| x = self.emb_norm(self.emb_dropout(x + self.pos_emb(pos))) |
| |
| mask = torch.arange(T, device=input_ids.device).unsqueeze(0) < lengths.unsqueeze(1) |
| causal = torch.triu(torch.ones(T, T, device=input_ids.device, dtype=torch.bool), diagonal=1) |
| out = self.encoder(x, mask=causal, src_key_padding_mask=~mask) |
| |
| user = out[torch.arange(B, device=input_ids.device), (lengths - 1).clamp(min=0)] |
| return self.output_norm(user) |
| |
| def forward(self, input_ids, lengths, targets=None): |
| user = self.encode(input_ids, lengths) |
| logits = user @ self.item_emb.weight[1:].t() |
| if targets is not None: |
| loss = F.cross_entropy(logits, targets - 1) |
| return loss, logits |
| return logits |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(model, loader, num_items, device, ks=[5, 10, 20, 50]): |
| model.eval() |
| metrics = {f'{m}@{k}': [] for k in ks for m in ['HR', 'NDCG', 'MRR']} |
| |
| for batch in loader: |
| ids = batch['input_ids'].to(device) |
| lens = batch['lengths'].to(device) |
| tgt = batch['target'].to(device) |
| |
| if hasattr(model, '_contrastive_loss'): |
| logits = model(ids, lens)[1] if model.training else model(ids, lens) |
| else: |
| logits = model(ids, lens)[1] if model.training else model(ids, lens) |
| |
| |
| user = model.encode(ids, lens) |
| logits = user @ model.item_emb.weight[1:].t() |
| |
| gt_idx = tgt - 1 |
| gt_scores = logits[torch.arange(logits.shape[0], device=device), gt_idx] |
| ranks = (logits > gt_scores.unsqueeze(1)).sum(dim=1) + 1 |
| |
| for k in ks: |
| hit = (ranks <= k).float() |
| ndcg = torch.where(ranks <= k, 1.0 / torch.log2(ranks.float() + 1), torch.zeros_like(ranks.float())) |
| mrr = torch.where(ranks <= k, 1.0 / ranks.float(), torch.zeros_like(ranks.float())) |
| metrics[f'HR@{k}'].extend(hit.cpu().tolist()) |
| metrics[f'NDCG@{k}'].extend(ndcg.cpu().tolist()) |
| metrics[f'MRR@{k}'].extend(mrr.cpu().tolist()) |
| |
| return {k: np.mean(v) for k, v in metrics.items()} |
|
|
|
|
| |
| |
| |
|
|
| def train_model(name, model, train_data, val_data, test_data, num_items, config, device): |
| print(f"\n{'='*60}\n{name} | {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params\n{'='*60}") |
| |
| model = model.to(device) |
| MSL = config['max_seq_len'] |
| BS = config['batch_size'] |
| |
| train_loader = DataLoader(SeqRecDataset(train_data, MSL), batch_size=BS, |
| shuffle=True, num_workers=2, drop_last=True, pin_memory=True) |
| val_loader = DataLoader(SeqRecDataset(val_data, MSL), batch_size=BS*2, |
| num_workers=2, pin_memory=True) |
| test_loader = DataLoader(SeqRecDataset(test_data, MSL), batch_size=BS*2, |
| num_workers=2, pin_memory=True) |
| |
| optimizer = AdamW(model.parameters(), lr=config['lr'], weight_decay=config['wd']) |
| total_steps = config['epochs'] * len(train_loader) |
| warmup = min(500, total_steps // 10) |
| |
| def lr_fn(step): |
| if step < warmup: return step / max(warmup, 1) |
| p = (step - warmup) / max(total_steps - warmup, 1) |
| return max(0.01, 0.5 * (1 + math.cos(math.pi * p))) |
| |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fn) |
| |
| best_hr10, best_ep, best_state = 0, 0, None |
| patience, no_imp = config.get('patience', 8), 0 |
| |
| for epoch in range(1, config['epochs'] + 1): |
| model.train() |
| total_loss, n = 0, 0 |
| t0 = time.time() |
| |
| for batch in train_loader: |
| ids = batch['input_ids'].to(device) |
| lens = batch['lengths'].to(device) |
| tgt = batch['target'].to(device) |
| |
| optimizer.zero_grad() |
| |
| if hasattr(model, '_contrastive_loss'): |
| loss, _ = model(ids, lens, tgt, cl_lambda=config.get('cl_lambda', 0.1)) |
| else: |
| loss, _ = model(ids, lens, tgt) |
| |
| if torch.isnan(loss): |
| continue |
| |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) |
| optimizer.step() |
| scheduler.step() |
| total_loss += loss.item() |
| n += 1 |
| |
| avg_loss = total_loss / max(n, 1) |
| print(f"Ep {epoch:3d}/{config['epochs']} | Loss: {avg_loss:.4f} | {time.time()-t0:.0f}s", end='') |
| |
| try: |
| if use_trackio: |
| trackio.log({f"{name}/loss": avg_loss, "epoch": epoch}) |
| except: |
| pass |
| |
| |
| if epoch % config.get('eval_every', 3) == 0 or epoch <= 3 or epoch == config['epochs']: |
| m = evaluate(model, val_loader, num_items, device, ks=[5, 10, 20]) |
| print(f" | HR@10={m['HR@10']:.4f} NDCG@10={m['NDCG@10']:.4f}", end='') |
| try: |
| if use_trackio: |
| trackio.log({f"{name}/{k}": v for k, v in m.items()}) |
| except: |
| pass |
| |
| if m['HR@10'] > best_hr10: |
| best_hr10 = m['HR@10'] |
| best_ep = epoch |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| no_imp = 0 |
| print(f" β BEST", end='') |
| else: |
| no_imp += 1 |
| if no_imp >= patience: |
| print(f"\n Early stop at ep {epoch}") |
| break |
| print() |
| |
| |
| if best_state: |
| model.load_state_dict(best_state) |
| model = model.to(device) |
| |
| test_m = evaluate(model, test_loader, num_items, device, ks=[5, 10, 20, 50]) |
| print(f"\nTest ({name}, best ep {best_ep}):") |
| for k in sorted(test_m): print(f" {k}: {test_m[k]:.4f}") |
| |
| return test_m, best_state |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| random.seed(42); np.random.seed(42); torch.manual_seed(42) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| |
| try: |
| import trackio |
| trackio.init(name="MARSv3-vs-SASRec", project="mars-seqrec") |
| use_trackio = True |
| except: |
| use_trackio = False |
| |
| |
| MSL = 200 |
| train, val, test, num_items = load_and_process_ml1m(max_seq_len=MSL) |
| |
| |
| sasrec = SASRecV3(num_items, hidden_size=64, max_seq_len=MSL, n_layers=2, |
| n_heads=2, inner_size=256, dropout=0.2) |
| sasrec_cfg = {'max_seq_len': MSL, 'batch_size': 256, 'lr': 1e-3, 'wd': 0.0, |
| 'epochs': 50, 'patience': 8, 'eval_every': 2} |
| |
| sasrec_results, sasrec_state = train_model( |
| 'SASRec', sasrec, train, val, test, num_items, sasrec_cfg, device) |
| |
| |
| mars = MARSv3(num_items, hidden_size=64, max_seq_len=MSL, |
| n_filter_layers=2, n_attn_layers=1, n_heads=2, |
| inner_size=256, short_len=50, n_memory=8, dropout=0.2) |
| mars_cfg = {'max_seq_len': MSL, 'batch_size': 256, 'lr': 1e-3, 'wd': 0.0, |
| 'epochs': 50, 'patience': 8, 'eval_every': 2, 'cl_lambda': 0.1} |
| |
| mars_results, mars_state = train_model( |
| 'MARSv3', mars, train, val, test, num_items, mars_cfg, device) |
| |
| |
| print(f"\n{'='*70}") |
| print(f"{'Metric':<12} | {'SASRec':>8} | {'MARS v3':>8} | {'Delta':>8} | {'%':>8}") |
| print(f"{'-'*70}") |
| for k in sorted(sasrec_results): |
| s, m = sasrec_results[k], mars_results[k] |
| d = m - s |
| pct = d / max(s, 1e-8) * 100 |
| mark = 'β' if d > 0 else 'β' |
| print(f"{k:<12} | {s:>8.4f} | {m:>8.4f} | {d:>+8.4f} | {mark}{abs(pct):>6.1f}%") |
| print(f"{'='*70}") |
| |
| |
| os.makedirs('./checkpoints', exist_ok=True) |
| results = {'sasrec': sasrec_results, 'marsv3': mars_results, |
| 'sasrec_params': sum(p.numel() for p in sasrec.parameters()), |
| 'mars_params': sum(p.numel() for p in mars.parameters())} |
| with open('./checkpoints/results_v3.json', 'w') as f: |
| json.dump(results, f, indent=2, default=str) |
| |
| torch.save({'sasrec': sasrec_state, 'marsv3': mars_state, 'num_items': num_items, |
| 'results': results}, './checkpoints/models_v3.pt') |
| |
| |
| try: |
| from huggingface_hub import HfApi, upload_folder |
| import shutil |
| hub_id = 'CyberDancer/MARS-SeqRec' |
| api = HfApi() |
| api.create_repo(hub_id, exist_ok=True) |
| shutil.copy('/app/mars_v3.py', './checkpoints/mars_v3.py') |
| |
| sp = results['sasrec_params'] |
| mp = results['mars_params'] |
| readme = f"""# MARS v3: Multi-scale Adaptive Recurrence with State compression |
| |
| ## Architecture |
| ``` |
| Long-term Branch: FMLP Filter (FFT β learnable filter β IFFT, O(n log n)) |
| β |
| [Compressive Memory] β fixed-size bottleneck |
| β |
| Short-term Branch: Causal Self-Attention (last K items) |
| β |
| [Adaptive Fusion Gate] |
| β |
| Training: Full Softmax CE + DuoRec Dropout Contrastive Loss |
| ``` |
| |
| ## Results on MovieLens-1M (Full Ranking, {num_items} items) |
| |
| | Model | Params | HR@5 | HR@10 | HR@20 | NDCG@10 | MRR@10 | |
| |-------|--------|------|-------|-------|---------|--------| |
| | SASRec+CE | {sp:,} | {sasrec_results.get('HR@5',0):.4f} | {sasrec_results.get('HR@10',0):.4f} | {sasrec_results.get('HR@20',0):.4f} | {sasrec_results.get('NDCG@10',0):.4f} | {sasrec_results.get('MRR@10',0):.4f} | |
| | **MARS v3** | {mp:,} | {mars_results.get('HR@5',0):.4f} | {mars_results.get('HR@10',0):.4f} | {mars_results.get('HR@20',0):.4f} | {mars_results.get('NDCG@10',0):.4f} | {mars_results.get('MRR@10',0):.4f} | |
| |
| ## Key Innovations |
| 1. **FMLP Filter (long-term)**: FFT-based learnable frequency filter denoises user history at O(n log n) |
| 2. **Compressive Memory**: Cross-attention bottleneck β constant-size summary of arbitrarily long history |
| 3. **DuoRec Contrastive Learning**: Two dropout-augmented views of same sequence β InfoNCE regularization |
| 4. **Full Softmax CE**: Scores against ALL items, not sampled negatives β critical for quality |
| 5. **Adaptive Fusion Gate**: Per-user learned balance of long-term preferences vs short-term intent |
| """ |
| with open('./checkpoints/README.md', 'w') as f: |
| f.write(readme) |
| |
| upload_folder(folder_path='./checkpoints', repo_id=hub_id, |
| commit_message="MARS v3: CE loss + contrastive learning + FMLP filters") |
| print(f"β Pushed to https://huggingface.co/{hub_id}") |
| except Exception as e: |
| print(f"Hub: {e}") |
|
|