CyberDancer commited on
Commit
3989f8c
Β·
verified Β·
1 Parent(s): 2319f81

MARS v2: Temporal-Gated Linear Attention for SeqRec

Browse files
Files changed (6) hide show
  1. README.md +35 -80
  2. final_results.json +30 -30
  3. marsv2/best_model.pt +3 -0
  4. model_v2.py +411 -0
  5. sasrec/best_model.pt +1 -1
  6. train_v2.py +240 -0
README.md CHANGED
@@ -2,110 +2,65 @@
2
 
3
  An innovative method for **super long sequence modeling** in sequential recommendation.
4
 
5
- ## Key Innovations
6
-
7
- 1. **Temporal-Aware Delta Network (TADN)** β€” O(n) linear complexity recurrent layer with explicit temporal decay gating in the delta rule state update
8
- 2. **Compressive Memory Tokens** β€” Fixed-size learnable memory via cross-attention that acts as information bottleneck (denoising effect)
9
- 3. **Dual-Branch Architecture** β€” Long-term (TADN, O(n)) + Short-term (Causal Self-Attention) with adaptive per-user fusion gate
10
- 4. **Multi-Scale Temporal Encoding** β€” Captures daily/weekly/seasonal patterns via periodic components + log-scaled time deltas
11
-
12
  ## Architecture
13
 
14
  ```
15
- Input: Full user interaction sequence + timestamps
16
- |
17
- v
18
- [Item Embedding + Multi-Scale Temporal Encoding]
19
- |
20
- +---- Long-term Branch (TADN layers, O(n) complexity)
21
- | |
22
- | [Compressive Memory] β†’ fixed-size memory tokens
23
- | |
24
- +---- Short-term Branch (Causal Self-Attention, recent K items)
25
- |
26
- v
27
- [Adaptive Fusion Gate (per-user learned)]
28
- |
29
- v
30
- [Prediction Head] β†’ next item scores
31
  ```
32
 
33
- ## Results on MovieLens-1M
 
 
 
 
 
 
 
34
 
35
  | Model | Params | HR@5 | HR@10 | HR@20 | NDCG@10 | MRR@10 |
36
  |-------|--------|------|-------|-------|---------|--------|
37
- | SASRec (baseline) | 345,664 | 0.0338 | 0.0601 | 0.0922 | 0.0272 | 0.0173 |
38
- | **MARS (ours)** | 426,180 | 0.0182 | 0.0329 | 0.0575 | 0.0156 | 0.0104 |
39
 
40
- ## TADN: Temporal-Aware Delta Network
41
 
42
- The core innovation: a recurrent layer with O(n) complexity that uses a delta rule with temporal gating:
43
 
 
44
  ```
45
- State Update:
46
- S_t = S_{t-1} * (1 - g_t βŠ™ Ξ²_t βŠ— k_t) + Ξ²_t βŠ— v_t βŠ— k_t
47
-
48
- Temporal Gating:
49
- g_t = Ξ± Β· Οƒ(W_g Β· [h_t; Ξ”h_t]) Β· Ο„_t + (1-Ξ±) Β· g_static
50
- Ο„_t = exp(-(t_now - t_behavior) / T_learnable)
51
  ```
 
52
 
53
- Key properties:
54
- - **O(n) complexity** for training and O(1) per-step for inference
55
- - **Explicit temporal modeling** via learnable exponential decay in the gate
56
- - **Selective memory** via input-dependent gating (inspired by HyTRec's TADN)
57
- - **Change detection** via Ξ”h_t = h_t - h_{t-1} in the gate input
58
-
59
- ## Compressive Memory
60
-
61
- Cross-attention memory queries compress the full TADN-encoded history into M fixed tokens:
62
- - Acts as information bottleneck (denoising, per Rec2PM theory)
63
- - Memory size is constant regardless of sequence length
64
- - Enables processing of arbitrarily long histories
65
 
66
- ## Files
67
 
68
- - `model.py` β€” Full MARS architecture + SASRec baseline
69
- - `data.py` β€” Data pipeline (MovieLens-1M, Amazon Reviews, synthetic)
70
- - `evaluate.py` β€” Evaluation (HR@K, NDCG@K, MRR@K)
71
- - `train.py` β€” CLI training script
72
- - `train_gpu.py` β€” GPU training with both models + comparison
73
-
74
- ## Based on Research
75
-
76
- Combines ideas from:
77
- - **HyTRec** (arxiv:2602.18283) β€” Temporal-Aware Delta Network concept
78
- - **Rec2PM** (arxiv:2602.11605) β€” Compressive memory as information bottleneck
79
- - **SIGMA** (arxiv:2408.11451) β€” Bidirectional gating for recommendation SSMs
80
- - **HSTU** (arxiv:2402.17152) β€” Generative Recommenders at scale
81
- - **SASRec** (arxiv:1808.09781) β€” Self-Attentive Sequential Recommendation baseline
82
 
83
  ## Usage
84
 
85
  ```python
86
- from model import MARS
87
 
88
- model = MARS(
89
  num_items=10000,
90
  embed_dim=64,
91
- max_seq_len=2048, # Can handle very long sequences
92
  short_term_len=50,
93
  num_memory_tokens=8,
94
- num_tadn_layers=3,
95
- num_attn_layers=2,
96
  )
97
-
98
- # Training
99
- batch = {
100
- 'item_ids': item_ids, # (B, T) padded sequences
101
- 'timestamps': timestamps, # (B, T) timestamps in seconds
102
- 'mask': mask, # (B, T) boolean mask
103
- 'positive_ids': pos_ids, # (B,) next items
104
- 'negative_ids': neg_ids, # (B, num_neg) negative items
105
- }
106
- loss = model(batch)
107
-
108
- # Inference
109
- model.eval()
110
- user_emb = model(batch) # (B, embed_dim)
111
  ```
 
2
 
3
  An innovative method for **super long sequence modeling** in sequential recommendation.
4
 
 
 
 
 
 
 
 
5
  ## Architecture
6
 
7
  ```
8
+ Input: User interaction sequence + timestamps
9
+ β”‚
10
+ β”œβ”€β”€ Long-term Branch (Temporal-Gated Linear Attention, O(n))
11
+ β”‚ β”‚
12
+ β”‚ [Compressive Memory] β†’ fixed-size memory tokens
13
+ β”‚ β”‚
14
+ β”œβ”€β”€ Short-term Branch (Causal Self-Attention, last K items)
15
+ β”‚
16
+ └── Adaptive Fusion Gate β†’ User Embedding β†’ Next Item Prediction
 
 
 
 
 
 
 
17
  ```
18
 
19
+ ## Key Innovations
20
+
21
+ 1. **Temporal-Gated Linear Attention** β€” O(n) complexity via kernel trick (ELU+1 feature map) with learned temporal decay weighting per attention head
22
+ 2. **Compressive Memory Tokens** β€” Cross-attention bottleneck compresses full history into M fixed tokens
23
+ 3. **Dual-Branch with Adaptive Fusion** β€” Per-user gating balances long-term preferences and short-term intent
24
+ 4. **Multi-Scale Temporal Encoding** β€” Log-scaled time deltas + periodic components for daily/weekly patterns
25
+
26
+ ## Results on MovieLens-1M (Full Ranking, 3706 items)
27
 
28
  | Model | Params | HR@5 | HR@10 | HR@20 | NDCG@10 | MRR@10 |
29
  |-------|--------|------|-------|-------|---------|--------|
30
+ | SASRec | 345,664 | 0.0338 | 0.0594 | 0.0995 | 0.0266 | 0.0166 |
31
+ | **MARS v2** | 567,628 | 0.0253 | 0.0414 | 0.0656 | 0.0201 | 0.0136 |
32
 
33
+ ## Core Method: Temporal-Gated Linear Attention
34
 
35
+ Standard linear attention: `Attn(Q,K,V) = Ο†(Q)(Ο†(K)^T V) / Ο†(Q)Ο†(K)^T 1`
36
 
37
+ Our enhancement adds temporal gating:
38
  ```
39
+ K_gated = K βŠ™ Οƒ(W_decay Β· log(1 + Ξ”t/3600))
 
 
 
 
 
40
  ```
41
+ where `Ξ”t` is the inter-action time gap and `W_decay` is learned per attention head.
42
 
43
+ This gives O(n) complexity while explicitly modeling temporal dynamics β€” recent interactions get higher attention weight, with the decay rate learned per head.
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ ## Based On
46
 
47
+ - **HyTRec** (2602.18283) β€” Temporal-aware dual-branch architecture
48
+ - **Rec2PM** (2602.11605) β€” Compressive memory as information bottleneck
49
+ - **Linear Transformers** (Katharopoulos et al.) β€” Kernel-based linear attention
50
+ - **SASRec** (1808.09781) β€” Self-attentive sequential recommendation baseline
 
 
 
 
 
 
 
 
 
 
51
 
52
  ## Usage
53
 
54
  ```python
55
+ from model_v2 import MARSv2
56
 
57
+ model = MARSv2(
58
  num_items=10000,
59
  embed_dim=64,
60
+ max_seq_len=2048, # Handles very long sequences at O(n) cost
61
  short_term_len=50,
62
  num_memory_tokens=8,
63
+ num_long_layers=3,
64
+ num_short_layers=2,
65
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ```
final_results.json CHANGED
@@ -1,53 +1,53 @@
1
  {
2
- "mars": {
3
  "metrics": {
4
- "HR@5": 0.018211920529801324,
5
- "NDCG@5": 0.010866539168206853,
6
- "MRR@5": 0.00847682119205298,
7
- "HR@10": 0.03294701986754967,
8
- "NDCG@10": 0.015587167767389802,
9
- "MRR@10": 0.010399716177861874,
10
- "HR@20": 0.057450331125827814,
11
- "NDCG@20": 0.021729804814576637,
12
- "MRR@20": 0.012058079955347656,
13
- "HR@50": 0.10943708609271523,
14
- "NDCG@50": 0.03200992960974106,
15
- "MRR@50": 0.013693084857615116,
16
- "eval_time": 21.104490041732788
17
  },
18
  "config": {
19
  "max_seq_len": 128,
20
  "batch_size": 64,
21
- "lr": 0.001,
22
  "weight_decay": 0.01,
23
- "epochs": 20,
24
  "num_negatives": 4,
25
  "eval_interval": 5
26
  },
27
- "params": 426180
28
  },
29
  "sasrec": {
30
  "metrics": {
31
  "HR@5": 0.03377483443708609,
32
- "NDCG@5": 0.0187428358177548,
33
- "MRR@5": 0.013846578366445915,
34
- "HR@10": 0.06009933774834437,
35
- "NDCG@10": 0.027174775884652287,
36
- "MRR@10": 0.017280432040365813,
37
- "HR@20": 0.09221854304635761,
38
- "NDCG@20": 0.035199293168162935,
39
- "MRR@20": 0.019431891083746364,
40
- "HR@50": 0.1566225165562914,
41
- "NDCG@50": 0.047938563477553944,
42
- "MRR@50": 0.02145968348171206,
43
- "eval_time": 6.248645305633545
44
  },
45
  "config": {
46
  "max_seq_len": 128,
47
  "batch_size": 128,
48
  "lr": 0.001,
49
  "weight_decay": 0.0,
50
- "epochs": 20,
51
  "num_negatives": 4,
52
  "eval_interval": 5
53
  },
 
1
  {
2
+ "marsv2": {
3
  "metrics": {
4
+ "HR@5": 0.02533112582781457,
5
+ "NDCG@5": 0.014835237558963535,
6
+ "MRR@5": 0.011410044150110373,
7
+ "HR@10": 0.041390728476821195,
8
+ "NDCG@10": 0.020070716381011464,
9
+ "MRR@10": 0.013596657205928729,
10
+ "HR@20": 0.06556291390728476,
11
+ "NDCG@20": 0.026056864980031683,
12
+ "MRR@20": 0.015173197924560101,
13
+ "HR@50": 0.12350993377483444,
14
+ "NDCG@50": 0.03741163215681034,
15
+ "MRR@50": 0.01693633649883963,
16
+ "eval_time": 8.468570232391357
17
  },
18
  "config": {
19
  "max_seq_len": 128,
20
  "batch_size": 64,
21
+ "lr": 0.0005,
22
  "weight_decay": 0.01,
23
+ "epochs": 25,
24
  "num_negatives": 4,
25
  "eval_interval": 5
26
  },
27
+ "params": 567628
28
  },
29
  "sasrec": {
30
  "metrics": {
31
  "HR@5": 0.03377483443708609,
32
+ "NDCG@5": 0.018333244425315455,
33
+ "MRR@5": 0.013275386313465785,
34
+ "HR@10": 0.05943708609271523,
35
+ "NDCG@10": 0.02657590673542354,
36
+ "MRR@10": 0.016644591611479027,
37
+ "HR@20": 0.09950331125827815,
38
+ "NDCG@20": 0.03672212773625359,
39
+ "MRR@20": 0.01943707237075238,
40
+ "HR@50": 0.16622516556291392,
41
+ "NDCG@50": 0.04983449691479723,
42
+ "MRR@50": 0.021489499293137433,
43
+ "eval_time": 6.591589450836182
44
  },
45
  "config": {
46
  "max_seq_len": 128,
47
  "batch_size": 128,
48
  "lr": 0.001,
49
  "weight_decay": 0.0,
50
+ "epochs": 25,
51
  "num_negatives": 4,
52
  "eval_interval": 5
53
  },
marsv2/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82835f47af21ef06936ff1d287a89456d6901bb60560c3c977a7762d9fd57704
3
+ size 2306047
model_v2.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MARS v2: Simplified and stabilized architecture.
3
+
4
+ Key changes from v1:
5
+ 1. Replace unstable delta-rule state with temporal-gated linear attention
6
+ 2. Simpler but more robust long-term branch
7
+ 3. FFN layers for capacity
8
+ """
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from typing import Optional, Dict
15
+
16
+
17
+ class TemporalEncoding(nn.Module):
18
+ """Multi-scale temporal encoding."""
19
+
20
+ def __init__(self, embed_dim: int, max_periods: int = 4):
21
+ super().__init__()
22
+ self.time_delta_proj = nn.Linear(1, embed_dim)
23
+ periods = [3600, 86400, 604800, 2592000][:max_periods]
24
+ self.register_buffer('periods', torch.tensor(periods, dtype=torch.float32))
25
+ self.periodic_proj = nn.Linear(max_periods * 2, embed_dim)
26
+ self.layernorm = nn.LayerNorm(embed_dim)
27
+
28
+ def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
29
+ B, T = timestamps.shape
30
+ time_deltas = torch.zeros_like(timestamps)
31
+ time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1]
32
+ time_deltas = time_deltas.clamp(min=0)
33
+ log_deltas = torch.log1p(time_deltas).unsqueeze(-1)
34
+ delta_emb = self.time_delta_proj(log_deltas)
35
+
36
+ ts_expanded = timestamps.unsqueeze(-1)
37
+ periods = self.periods.view(1, 1, -1)
38
+ angles = 2 * math.pi * ts_expanded / periods
39
+ periodic_features = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
40
+ periodic_emb = self.periodic_proj(periodic_features)
41
+
42
+ return self.layernorm(delta_emb + periodic_emb)
43
+
44
+
45
+ class TemporalGatedLinearAttention(nn.Module):
46
+ """
47
+ Temporal-Gated Linear Attention: O(n) attention with temporal decay.
48
+
49
+ Uses the kernel trick: softmax(QK^T)V β‰ˆ Ο†(Q) * (Ο†(K)^T * V)
50
+ where Ο† is ELU + 1, making it O(n*dΒ²) instead of O(nΒ²*d).
51
+
52
+ Added temporal gating: each step's contribution is weighted by
53
+ a learnable temporal decay function.
54
+ """
55
+
56
+ def __init__(self, embed_dim: int, num_heads: int = 2, dropout: float = 0.1):
57
+ super().__init__()
58
+ self.embed_dim = embed_dim
59
+ self.num_heads = num_heads
60
+ self.head_dim = embed_dim // num_heads
61
+
62
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
66
+
67
+ # Temporal decay: learned per head
68
+ self.decay_proj = nn.Linear(1, num_heads) # log-delta β†’ per-head decay weight
69
+
70
+ self.norm = nn.LayerNorm(embed_dim)
71
+ self.dropout = nn.Dropout(dropout)
72
+
73
+ # FFN
74
+ self.ffn = nn.Sequential(
75
+ nn.LayerNorm(embed_dim),
76
+ nn.Linear(embed_dim, embed_dim * 4),
77
+ nn.GELU(),
78
+ nn.Dropout(dropout),
79
+ nn.Linear(embed_dim * 4, embed_dim),
80
+ nn.Dropout(dropout),
81
+ )
82
+
83
+ def _feature_map(self, x):
84
+ """ELU + 1 feature map for linear attention."""
85
+ return F.elu(x) + 1
86
+
87
+ def forward(self, x, timestamps=None, mask=None):
88
+ B, T, D = x.shape
89
+ H = self.num_heads
90
+ d = self.head_dim
91
+
92
+ # Project and reshape
93
+ q = self._feature_map(self.q_proj(x)).view(B, T, H, d)
94
+ k = self._feature_map(self.k_proj(x)).view(B, T, H, d)
95
+ v = self.v_proj(x).view(B, T, H, d)
96
+
97
+ # Temporal decay weights
98
+ if timestamps is not None:
99
+ time_deltas = torch.zeros_like(timestamps)
100
+ time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1]
101
+ time_deltas = time_deltas.clamp(min=0)
102
+ log_deltas = torch.log1p(time_deltas / 3600.0).unsqueeze(-1) # (B, T, 1)
103
+ decay_weights = torch.sigmoid(self.decay_proj(log_deltas)) # (B, T, H)
104
+ # Weight keys by temporal decay
105
+ k = k * decay_weights.unsqueeze(-1) # (B, T, H, d)
106
+
107
+ # Mask padding
108
+ if mask is not None:
109
+ mask_expanded = mask.unsqueeze(-1).unsqueeze(-1).float() # (B, T, 1, 1)
110
+ k = k * mask_expanded
111
+ v = v * mask_expanded
112
+
113
+ # Linear attention: O(n*dΒ²)
114
+ # Causal version using cumulative sum
115
+ # KV = cumsum(k βŠ— v) β†’ (B, T, H, d, d) β€” too expensive
116
+ # Instead, use the simpler cumulative state approach:
117
+
118
+ # Non-causal linear attention (bidirectional for long-term modeling)
119
+ # attn = Ο†(Q)(Ο†(K)^T V) / Ο†(Q)(Ο†(K)^T 1)
120
+ kv = torch.einsum('bthd,bthe->bhde', k, v) # (B, H, d, d)
121
+ k_sum = k.sum(dim=1) # (B, H, d)
122
+
123
+ # Output: q @ kv / (q @ k_sum)
124
+ numerator = torch.einsum('bthd,bhde->bthe', q, kv) # (B, T, H, d)
125
+ denominator = torch.einsum('bthd,bhd->bth', q, k_sum).unsqueeze(-1) # (B, T, H, 1)
126
+
127
+ attn_out = numerator / (denominator + 1e-6)
128
+ attn_out = attn_out.reshape(B, T, D)
129
+ attn_out = self.out_proj(self.dropout(attn_out))
130
+
131
+ # Residual + LayerNorm
132
+ x = self.norm(x + attn_out)
133
+
134
+ # FFN with residual
135
+ x = x + self.ffn(x)
136
+
137
+ return x
138
+
139
+
140
+ class CompressiveMemory(nn.Module):
141
+ """Cross-attention memory compression."""
142
+
143
+ def __init__(self, embed_dim: int, num_memory_tokens: int = 8, num_heads: int = 2):
144
+ super().__init__()
145
+ self.memory_queries = nn.Parameter(torch.randn(num_memory_tokens, embed_dim) * 0.02)
146
+ self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=0.1)
147
+ self.ffn = nn.Sequential(
148
+ nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(0.1),
149
+ nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(0.1),
150
+ )
151
+ self.norm1 = nn.LayerNorm(embed_dim)
152
+ self.norm2 = nn.LayerNorm(embed_dim)
153
+
154
+ def forward(self, sequence, mask=None):
155
+ B = sequence.shape[0]
156
+ queries = self.memory_queries.unsqueeze(0).expand(B, -1, -1)
157
+ key_padding_mask = ~mask if mask is not None else None
158
+ attn_out, _ = self.cross_attn(queries, sequence, sequence, key_padding_mask=key_padding_mask)
159
+ memory = self.norm1(queries + attn_out)
160
+ memory = self.norm2(memory + self.ffn(memory))
161
+ return memory
162
+
163
+
164
+ class AdaptiveFusionGate(nn.Module):
165
+ """Learned fusion of long-term and short-term signals."""
166
+
167
+ def __init__(self, embed_dim: int):
168
+ super().__init__()
169
+ self.gate = nn.Sequential(
170
+ nn.Linear(embed_dim * 3, embed_dim),
171
+ nn.GELU(),
172
+ nn.Linear(embed_dim, embed_dim),
173
+ nn.Sigmoid()
174
+ )
175
+
176
+ def forward(self, long_term, short_term, memory):
177
+ g = self.gate(torch.cat([long_term, short_term, memory], dim=-1))
178
+ return g * long_term + (1 - g) * short_term
179
+
180
+
181
+ class MARSv2(nn.Module):
182
+ """
183
+ MARS v2: Multi-scale Adaptive Recurrence with State compression
184
+
185
+ Uses temporal-gated linear attention (O(n)) for long-term branch
186
+ and standard causal self-attention for short-term branch.
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ num_items: int,
192
+ embed_dim: int = 64,
193
+ max_seq_len: int = 512,
194
+ short_term_len: int = 50,
195
+ num_memory_tokens: int = 8,
196
+ num_long_layers: int = 3,
197
+ num_short_layers: int = 2,
198
+ num_heads: int = 2,
199
+ dropout: float = 0.1,
200
+ ):
201
+ super().__init__()
202
+ self.num_items = num_items
203
+ self.embed_dim = embed_dim
204
+ self.max_seq_len = max_seq_len
205
+ self.short_term_len = short_term_len
206
+
207
+ self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
208
+ self.temporal_encoding = TemporalEncoding(embed_dim)
209
+ self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
210
+ self.input_norm = nn.LayerNorm(embed_dim)
211
+ self.input_dropout = nn.Dropout(dropout)
212
+
213
+ # Long-term branch: temporal-gated linear attention (O(n))
214
+ self.long_layers = nn.ModuleList([
215
+ TemporalGatedLinearAttention(embed_dim, num_heads, dropout)
216
+ for _ in range(num_long_layers)
217
+ ])
218
+
219
+ # Compressive memory
220
+ self.compressive_memory = CompressiveMemory(embed_dim, num_memory_tokens, num_heads)
221
+
222
+ # Short-term branch: standard causal attention
223
+ encoder_layer = nn.TransformerEncoderLayer(
224
+ d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4,
225
+ dropout=dropout, activation='gelu', batch_first=True, norm_first=True
226
+ )
227
+ self.short_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_short_layers)
228
+
229
+ # Fusion
230
+ self.fusion_gate = AdaptiveFusionGate(embed_dim)
231
+ self.output_norm = nn.LayerNorm(embed_dim)
232
+ self.output_proj = nn.Linear(embed_dim, embed_dim)
233
+
234
+ self._init_weights()
235
+
236
+ def _init_weights(self):
237
+ for name, param in self.named_parameters():
238
+ if 'weight' in name and param.dim() >= 2:
239
+ nn.init.trunc_normal_(param, std=0.02)
240
+ elif 'bias' in name:
241
+ nn.init.zeros_(param)
242
+ nn.init.zeros_(self.item_embedding.weight[0])
243
+
244
+ @property
245
+ def item_embeddings(self):
246
+ return self.item_embedding
247
+
248
+ def encode(self, item_ids, timestamps=None, mask=None):
249
+ B, T = item_ids.shape
250
+ if mask is None:
251
+ mask = (item_ids != 0)
252
+
253
+ # Embeddings
254
+ item_emb = self.item_embedding(item_ids)
255
+ if timestamps is not None:
256
+ item_emb = item_emb + self.temporal_encoding(timestamps.float())
257
+
258
+ positions = torch.arange(T, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1)
259
+ item_emb = self.input_norm(item_emb + self.position_embedding(positions))
260
+ item_emb = self.input_dropout(item_emb)
261
+
262
+ # Long-term branch
263
+ long_repr = item_emb
264
+ for layer in self.long_layers:
265
+ long_repr = layer(long_repr, timestamps, mask)
266
+
267
+ # Memory compression
268
+ memory = self.compressive_memory(long_repr, mask)
269
+ memory_summary = memory.mean(dim=1)
270
+
271
+ # Last valid long-term
272
+ lengths = mask.sum(dim=1).long()
273
+ long_last = long_repr[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)]
274
+
275
+ # Short-term branch: extract last K valid items
276
+ K = min(self.short_term_len, T)
277
+ short_ids_list, short_ts_list, short_mask_list = [], [], []
278
+
279
+ for b in range(B):
280
+ sl = lengths[b].item()
281
+ actual_k = min(K, sl)
282
+ start = max(0, sl - K)
283
+ ids = item_ids[b, start:sl]
284
+ pad = K - actual_k
285
+ if pad > 0:
286
+ ids = torch.cat([ids, torch.zeros(pad, dtype=ids.dtype, device=ids.device)])
287
+ short_ids_list.append(ids)
288
+
289
+ if timestamps is not None:
290
+ ts = timestamps[b, start:sl]
291
+ if pad > 0:
292
+ ts = torch.cat([ts, torch.zeros(pad, dtype=ts.dtype, device=ts.device)])
293
+ short_ts_list.append(ts)
294
+
295
+ m = torch.zeros(K, dtype=torch.bool, device=item_ids.device)
296
+ m[:actual_k] = True
297
+ short_mask_list.append(m)
298
+
299
+ short_ids = torch.stack(short_ids_list)
300
+ short_mask = torch.stack(short_mask_list)
301
+
302
+ short_emb = self.item_embedding(short_ids)
303
+ if timestamps is not None:
304
+ short_ts = torch.stack(short_ts_list)
305
+ short_emb = short_emb + self.temporal_encoding(short_ts.float())
306
+
307
+ short_pos = torch.arange(K, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1)
308
+ short_emb = self.input_norm(short_emb + self.position_embedding(short_pos))
309
+
310
+ causal_mask = torch.triu(torch.ones(K, K, device=item_ids.device, dtype=torch.bool), diagonal=1)
311
+ short_repr = self.short_encoder(short_emb, mask=causal_mask, src_key_padding_mask=~short_mask)
312
+
313
+ short_lengths = short_mask.sum(dim=1).long()
314
+ short_last = short_repr[torch.arange(B, device=item_ids.device), (short_lengths - 1).clamp(min=0)]
315
+
316
+ # Fusion
317
+ user_emb = self.fusion_gate(long_last, short_last, memory_summary)
318
+ return self.output_proj(self.output_norm(user_emb))
319
+
320
+ def forward(self, batch):
321
+ if self.training:
322
+ item_ids = batch['item_ids']
323
+ timestamps = batch.get('timestamps')
324
+ mask = batch.get('mask')
325
+ pos_ids = batch['positive_ids']
326
+ neg_ids = batch['negative_ids']
327
+
328
+ user_emb = self.encode(item_ids, timestamps, mask)
329
+ pos_emb = self.item_embedding(pos_ids)
330
+ neg_emb = self.item_embedding(neg_ids)
331
+
332
+ pos_scores = (user_emb * pos_emb).sum(dim=-1)
333
+ neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb)
334
+
335
+ loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
336
+ loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
337
+ return loss_pos + loss_neg
338
+ else:
339
+ return self.encode(batch['item_ids'], batch.get('timestamps'), batch.get('mask'))
340
+
341
+
342
+ class SASRecBaseline(nn.Module):
343
+ """SASRec baseline."""
344
+
345
+ def __init__(self, num_items, embed_dim=64, max_seq_len=200, num_heads=2, num_layers=2, dropout=0.1):
346
+ super().__init__()
347
+ self.num_items = num_items
348
+ self.embed_dim = embed_dim
349
+ self.max_seq_len = max_seq_len
350
+
351
+ self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
352
+ self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
353
+ self.input_norm = nn.LayerNorm(embed_dim)
354
+ self.input_dropout = nn.Dropout(dropout)
355
+
356
+ encoder_layer = nn.TransformerEncoderLayer(
357
+ d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4,
358
+ dropout=dropout, activation='gelu', batch_first=True, norm_first=True
359
+ )
360
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
361
+ self.output_norm = nn.LayerNorm(embed_dim)
362
+ self._init_weights()
363
+
364
+ def _init_weights(self):
365
+ for name, param in self.named_parameters():
366
+ if 'weight' in name and param.dim() >= 2:
367
+ nn.init.trunc_normal_(param, std=0.02)
368
+ elif 'bias' in name:
369
+ nn.init.zeros_(param)
370
+ nn.init.zeros_(self.item_embedding.weight[0])
371
+
372
+ @property
373
+ def item_embeddings(self):
374
+ return self.item_embedding
375
+
376
+ def encode(self, item_ids, timestamps=None, mask=None):
377
+ B, T = item_ids.shape
378
+ if mask is None:
379
+ mask = (item_ids != 0)
380
+
381
+ item_emb = self.item_embedding(item_ids)
382
+ positions = torch.arange(T, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1)
383
+ item_emb = self.input_norm(item_emb + self.position_embedding(positions))
384
+ item_emb = self.input_dropout(item_emb)
385
+
386
+ causal_mask = torch.triu(torch.ones(T, T, device=item_ids.device, dtype=torch.bool), diagonal=1)
387
+ output = self.encoder(item_emb, mask=causal_mask, src_key_padding_mask=~mask)
388
+
389
+ lengths = mask.sum(dim=1).long()
390
+ user_emb = output[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)]
391
+ return self.output_norm(user_emb)
392
+
393
+ def forward(self, batch):
394
+ if self.training:
395
+ item_ids = batch['item_ids']
396
+ mask = batch.get('mask')
397
+ pos_ids = batch['positive_ids']
398
+ neg_ids = batch['negative_ids']
399
+
400
+ user_emb = self.encode(item_ids, mask=mask)
401
+ pos_emb = self.item_embedding(pos_ids)
402
+ neg_emb = self.item_embedding(neg_ids)
403
+
404
+ pos_scores = (user_emb * pos_emb).sum(dim=-1)
405
+ neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb)
406
+
407
+ loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
408
+ loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
409
+ return loss_pos + loss_neg
410
+ else:
411
+ return self.encode(batch['item_ids'], mask=batch.get('mask'))
sasrec/best_model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:00b40a57b7d6f4c3b3047f2279539cf2191512a3109a45838267f175db6ec4a4
3
  size 1393845
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4aa1e3c48943ea04823362e4b3a5c567984d095f3489286d82fd7e24e0f8e9cc
3
  size 1393845
train_v2.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MARS v2 Training Script β€” Improved architecture with linear attention.
3
+ """
4
+
5
+ import os, sys, time, json, random
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import CosineAnnealingLR
11
+
12
+ random.seed(42); np.random.seed(42); torch.manual_seed(42)
13
+ device = torch.device('cpu')
14
+ print(f"Device: {device}")
15
+
16
+ from model_v2 import MARSv2, SASRecBaseline
17
+ from data import load_movielens_1m, ReindexedData, create_dataloaders
18
+ from evaluate import evaluate_model, print_comparison
19
+
20
+ try:
21
+ import trackio
22
+ trackio.init(name="MARSv2-SeqRec-ML1M", project="mars-seqrec")
23
+ use_trackio = True
24
+ print("Trackio initialized")
25
+ except Exception as e:
26
+ use_trackio = False
27
+
28
+ # Load data
29
+ print("\nLoading MovieLens-1M...")
30
+ sequences = load_movielens_1m(min_interactions=5)
31
+ seq_lens = [len(v['item_ids']) for v in sequences.values()]
32
+ print(f"{len(sequences)} users, seq mean={np.mean(seq_lens):.1f}, max={np.max(seq_lens)}")
33
+
34
+
35
+ def train_model(model_name, model, config, device):
36
+ print(f"\n{'='*60}\nTraining: {model_name.upper()}\nParams: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\n{'='*60}")
37
+
38
+ data = ReindexedData(sequences, max_seq_len=config['max_seq_len'])
39
+ train_loader, val_loader, test_loader = create_dataloaders(
40
+ data, max_seq_len=config['max_seq_len'], batch_size=config['batch_size'],
41
+ num_negatives=config['num_negatives'], num_workers=2)
42
+
43
+ optimizer = AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
44
+
45
+ # Warmup + cosine schedule
46
+ total_steps = config['epochs'] * len(train_loader)
47
+ warmup_steps = min(500, total_steps // 10)
48
+
49
+ def lr_lambda(step):
50
+ if step < warmup_steps:
51
+ return step / warmup_steps
52
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
53
+ return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))
54
+
55
+ import math
56
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
57
+
58
+ best_hr10, best_epoch, best_state = 0, 0, None
59
+
60
+ for epoch in range(1, config['epochs'] + 1):
61
+ model.train()
62
+ total_loss, n = 0, 0
63
+ t0 = time.time()
64
+
65
+ for batch in train_loader:
66
+ batch = {k: v.to(device) for k, v in batch.items()}
67
+ optimizer.zero_grad()
68
+ loss = model(batch)
69
+
70
+ if torch.isnan(loss):
71
+ print(f"WARNING: NaN loss at epoch {epoch}!")
72
+ continue
73
+
74
+ loss.backward()
75
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
76
+ optimizer.step()
77
+ scheduler.step()
78
+ total_loss += loss.item()
79
+ n += 1
80
+
81
+ avg_loss = total_loss / max(n, 1)
82
+ ep_time = time.time() - t0
83
+ print(f"Epoch {epoch:3d}/{config['epochs']} | Loss: {avg_loss:.4f} | Time: {ep_time:.1f}s")
84
+
85
+ if use_trackio:
86
+ trackio.log({f"{model_name}/train_loss": avg_loss, "epoch": epoch})
87
+
88
+ if epoch % config['eval_interval'] == 0 or epoch == config['epochs']:
89
+ metrics = evaluate_model(model, val_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True)
90
+ print(f" Val | HR@10={metrics['HR@10']:.4f} NDCG@10={metrics['NDCG@10']:.4f} MRR@10={metrics['MRR@10']:.4f}")
91
+
92
+ if use_trackio:
93
+ trackio.log({f"{model_name}/val_{k}": v for k, v in metrics.items() if k != 'eval_time'})
94
+
95
+ if metrics['HR@10'] > best_hr10:
96
+ best_hr10 = metrics['HR@10']
97
+ best_epoch = epoch
98
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
99
+ print(f" βœ“ New best! HR@10={best_hr10:.4f}")
100
+
101
+ if best_state:
102
+ model.load_state_dict(best_state)
103
+
104
+ test_metrics = evaluate_model(model, test_loader, data.num_items, device, ks=[5, 10, 20, 50], full_ranking=True)
105
+ print(f"\nTest ({model_name}, best ep {best_epoch}):")
106
+ for k, v in sorted(test_metrics.items()):
107
+ if k != 'eval_time': print(f" {k}: {v:.4f}")
108
+
109
+ save_dir = f'./checkpoints/{model_name}'
110
+ os.makedirs(save_dir, exist_ok=True)
111
+ torch.save({'model_state_dict': best_state or model.state_dict(), 'config': config,
112
+ 'test_metrics': test_metrics, 'best_epoch': best_epoch, 'num_items': data.num_items},
113
+ os.path.join(save_dir, 'best_model.pt'))
114
+
115
+ return test_metrics, sum(p.numel() for p in model.parameters())
116
+
117
+
118
+ # Configs
119
+ SASREC_CFG = {'max_seq_len': 128, 'batch_size': 128, 'lr': 1e-3, 'weight_decay': 0.0,
120
+ 'epochs': 25, 'num_negatives': 4, 'eval_interval': 5}
121
+ MARS_CFG = {'max_seq_len': 128, 'batch_size': 64, 'lr': 5e-4, 'weight_decay': 0.01,
122
+ 'epochs': 25, 'num_negatives': 4, 'eval_interval': 5}
123
+
124
+ # Precompute data for num_items
125
+ data_tmp = ReindexedData(sequences, max_seq_len=128)
126
+ num_items = data_tmp.num_items
127
+
128
+ # Models
129
+ sasrec = SASRecBaseline(num_items=num_items, embed_dim=64, max_seq_len=128, num_heads=2, num_layers=2, dropout=0.1)
130
+ marsv2 = MARSv2(num_items=num_items, embed_dim=64, max_seq_len=128, short_term_len=30,
131
+ num_memory_tokens=8, num_long_layers=3, num_short_layers=2, num_heads=2, dropout=0.1)
132
+
133
+ # Train
134
+ sasrec_results, sasrec_params = train_model('sasrec', sasrec, SASREC_CFG, device)
135
+ mars_results, mars_params = train_model('marsv2', marsv2, MARS_CFG, device)
136
+
137
+ # Compare
138
+ print_comparison(mars_results, sasrec_results, ks=[5, 10, 20, 50])
139
+
140
+ # Save
141
+ final = {
142
+ 'marsv2': {'metrics': mars_results, 'config': MARS_CFG, 'params': mars_params},
143
+ 'sasrec': {'metrics': sasrec_results, 'config': SASREC_CFG, 'params': sasrec_params},
144
+ 'dataset': 'MovieLens-1M',
145
+ }
146
+ os.makedirs('./checkpoints', exist_ok=True)
147
+ with open('./checkpoints/final_results.json', 'w') as f:
148
+ json.dump(final, f, indent=2, default=str)
149
+
150
+ # Push to Hub
151
+ try:
152
+ from huggingface_hub import HfApi, upload_folder
153
+ import shutil
154
+
155
+ hub_id = 'CyberDancer/MARS-SeqRec'
156
+ api = HfApi()
157
+ api.create_repo(hub_id, exist_ok=True)
158
+
159
+ for f in ['model.py', 'model_v2.py', 'data.py', 'evaluate.py', 'train.py', 'train_gpu.py', 'train_v2.py']:
160
+ if os.path.exists(f'/app/{f}'):
161
+ shutil.copy(f'/app/{f}', f'./checkpoints/{f}')
162
+
163
+ readme = f"""# MARS: Multi-scale Adaptive Recurrence with State compression
164
+
165
+ An innovative method for **super long sequence modeling** in sequential recommendation.
166
+
167
+ ## Architecture
168
+
169
+ ```
170
+ Input: User interaction sequence + timestamps
171
+ β”‚
172
+ β”œβ”€β”€ Long-term Branch (Temporal-Gated Linear Attention, O(n))
173
+ β”‚ β”‚
174
+ β”‚ [Compressive Memory] β†’ fixed-size memory tokens
175
+ β”‚ β”‚
176
+ β”œβ”€β”€ Short-term Branch (Causal Self-Attention, last K items)
177
+ β”‚
178
+ └── Adaptive Fusion Gate β†’ User Embedding β†’ Next Item Prediction
179
+ ```
180
+
181
+ ## Key Innovations
182
+
183
+ 1. **Temporal-Gated Linear Attention** β€” O(n) complexity via kernel trick (ELU+1 feature map) with learned temporal decay weighting per attention head
184
+ 2. **Compressive Memory Tokens** β€” Cross-attention bottleneck compresses full history into M fixed tokens
185
+ 3. **Dual-Branch with Adaptive Fusion** β€” Per-user gating balances long-term preferences and short-term intent
186
+ 4. **Multi-Scale Temporal Encoding** β€” Log-scaled time deltas + periodic components for daily/weekly patterns
187
+
188
+ ## Results on MovieLens-1M (Full Ranking, 3706 items)
189
+
190
+ | Model | Params | HR@5 | HR@10 | HR@20 | NDCG@10 | MRR@10 |
191
+ |-------|--------|------|-------|-------|---------|--------|
192
+ | SASRec | {sasrec_params:,} | {sasrec_results.get('HR@5',0):.4f} | {sasrec_results.get('HR@10',0):.4f} | {sasrec_results.get('HR@20',0):.4f} | {sasrec_results.get('NDCG@10',0):.4f} | {sasrec_results.get('MRR@10',0):.4f} |
193
+ | **MARS v2** | {mars_params:,} | {mars_results.get('HR@5',0):.4f} | {mars_results.get('HR@10',0):.4f} | {mars_results.get('HR@20',0):.4f} | {mars_results.get('NDCG@10',0):.4f} | {mars_results.get('MRR@10',0):.4f} |
194
+
195
+ ## Core Method: Temporal-Gated Linear Attention
196
+
197
+ Standard linear attention: `Attn(Q,K,V) = Ο†(Q)(Ο†(K)^T V) / Ο†(Q)Ο†(K)^T 1`
198
+
199
+ Our enhancement adds temporal gating:
200
+ ```
201
+ K_gated = K βŠ™ Οƒ(W_decay Β· log(1 + Ξ”t/3600))
202
+ ```
203
+ where `Ξ”t` is the inter-action time gap and `W_decay` is learned per attention head.
204
+
205
+ This gives O(n) complexity while explicitly modeling temporal dynamics β€” recent interactions get higher attention weight, with the decay rate learned per head.
206
+
207
+ ## Based On
208
+
209
+ - **HyTRec** (2602.18283) β€” Temporal-aware dual-branch architecture
210
+ - **Rec2PM** (2602.11605) β€” Compressive memory as information bottleneck
211
+ - **Linear Transformers** (Katharopoulos et al.) β€” Kernel-based linear attention
212
+ - **SASRec** (1808.09781) β€” Self-attentive sequential recommendation baseline
213
+
214
+ ## Usage
215
+
216
+ ```python
217
+ from model_v2 import MARSv2
218
+
219
+ model = MARSv2(
220
+ num_items=10000,
221
+ embed_dim=64,
222
+ max_seq_len=2048, # Handles very long sequences at O(n) cost
223
+ short_term_len=50,
224
+ num_memory_tokens=8,
225
+ num_long_layers=3,
226
+ num_short_layers=2,
227
+ )
228
+ ```
229
+ """
230
+
231
+ with open('./checkpoints/README.md', 'w') as f:
232
+ f.write(readme)
233
+
234
+ upload_folder(folder_path='./checkpoints', repo_id=hub_id,
235
+ commit_message="MARS v2: Temporal-Gated Linear Attention for SeqRec")
236
+ print(f"\nβœ“ Pushed to https://huggingface.co/{hub_id}")
237
+ except Exception as e:
238
+ print(f"Hub push: {e}")
239
+
240
+ print("\nDone!")