Pavantej commited on
Commit
fbca19f
·
verified ·
1 Parent(s): b75d8ec

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. ESSAY.md +0 -256
  2. README.md +0 -74
  3. app.py +214 -232
  4. memory_store.py +0 -86
  5. miras_memory.py +0 -113
  6. projections.py +0 -83
ESSAY.md CHANGED
@@ -1,258 +1,3 @@
1
- <<<<<<< HEAD
2
- # When Models Learn While Thinking
3
-
4
- ---
5
-
6
- ## 01 / The Frozen Calculator Problem
7
-
8
- Every conversation you've had with ChatGPT, Claude, or any large language model follows the same pattern: the model thinks, predicts, and forgets. The weights that determine its behavior were set months ago, frozen in place after training. When you correct it, when you teach it something new, when you have a breakthrough conversation—none of that changes the model itself.
9
-
10
- This isn't a bug. It's the architecture.
11
-
12
- The model can *simulate* learning through in-context adaptation. It can act like it remembers. But the parameters that define its cognition remain untouched. When the context window ends, so does the illusion.
13
-
14
- This demo breaks that pattern.
15
-
16
- ---
17
-
18
- ## 02 / What This Actually Does
19
-
20
- This is a minimal reimplementation of two recent papers: **Titans** (test-time training) and **MIRAS** (associative memory with attentional bias). Together, they demonstrate something most production LLMs don't do: **learning during inference**.
21
-
22
- The architecture is simple:
23
- - A frozen language model (distilgpt2) generates text
24
- - Hidden states from that model are projected into a memory space
25
- - An associative memory module predicts what it should remember
26
- - The prediction error drives gradient descent
27
- - The memory weights update
28
- - The updated state persists to disk
29
-
30
- Every message you send changes the model's internal representations. Not through prompt engineering. Not through retrieval. Through actual gradient-based optimization—at inference time.
31
-
32
- ---
33
-
34
- ## 03 / The Text Doesn't Matter
35
-
36
- If you interact with this demo, you'll notice the text responses are... not good. Random. Sometimes incoherent. This is intentional.
37
-
38
- The text generator (distilgpt2) is frozen. We're not training it. The responses reflect what a small, untuned model produces when asked to continue arbitrary text. That's not the point.
39
-
40
- **The point is the numbers below each response.**
41
-
42
- Watch the loss. When you send the same message multiple times, the loss decreases. The memory is learning to predict the hidden state patterns associated with that input. When you send something completely different, the loss spikes—the memory is surprised.
43
-
44
- This is test-time learning. The model is changing itself while you use it.
45
-
46
- ---
47
-
48
- ## 04 / What the Stats Mean
49
-
50
- Each response shows four metrics:
51
-
52
- **Loss**: How surprised the memory is. Lower means the pattern is familiar. Higher means it's novel. This is the prediction error that drives learning.
53
-
54
- **Retention**: A multiplier on the learning rate. When loss is high (surprising input), retention is high (2.0x). The memory learns more aggressively from surprising events. This is the retention gate—a simple mechanism inspired by how human memory prioritizes novelty.
55
-
56
- **Updates**: The total number of times the memory has been updated. This persists across sessions. Refresh the page, send another message, and the count continues. The memory doesn't reset.
57
-
58
- **Avg Loss**: The running average of all losses. Over time, as the memory learns recurring patterns, this should trend downward.
59
-
60
- These aren't vanity metrics. They're the observable signature of gradient descent happening during inference.
61
-
62
- ---
63
-
64
- ## 05 / The Two Papers
65
-
66
- **Titans** (2025) introduces test-time training for language models. The core idea: instead of freezing weights after pre-training, allow a subset of parameters to update during inference. This creates a feedback loop—think, predict, update, think differently next time—that doesn't exist in standard LLMs.
67
-
68
- **MIRAS** (2024) reframes attention mechanisms as implicit optimization problems. It shows that dot-product attention, RNNs, and linear transformers are all solving online optimization with a specific loss function (L2). By making the loss function explicit and tunable, you can change the memory behavior. Different losses produce different cognition.
69
-
70
- This demo combines both: Titans' test-time learning with MIRAS's associative memory framework.
71
-
72
- ---
73
-
74
- ## 06 / What's Missing
75
-
76
- This is a minimal reimplementation. Several components from the papers are not included:
77
-
78
- **From Titans**:
79
- - Multi-layer test-time updates (we only update the memory module)
80
- - Task-specific memory partitioning (we use a single shared memory)
81
- - Adaptive learning rate schedules (we use a simple retention gate)
82
-
83
- **From MIRAS**:
84
- - Alternative loss functions (we use L2)
85
- - Multi-head memory (we use a single memory matrix)
86
- - Attention-based retrieval (we use direct key-value mapping)
87
-
88
- The goal was to demonstrate the core mechanism—learning during inference—not to replicate every detail. The full papers contain significantly more sophistication.
89
-
90
- ---
91
-
92
- ## 07 / The Difference from Standard LLMs
93
-
94
- Standard LLMs (ChatGPT, Claude, GPT-4) do this:
95
- ```
96
- Input → Frozen Weights → Output → Forget
97
- ```
98
-
99
- This demo does this:
100
- ```
101
- Input → Frozen LM → Hidden States → Memory (Learning) → Output → Save
102
- ```
103
-
104
- The frozen LM provides the text generation. The memory provides the learning. They're decoupled.
105
-
106
- This matters because:
107
- - **Weights update during use** (not just during training)
108
- - **Memory persists across sessions** (not just within a context window)
109
- - **Learning is explicit** (not simulated through in-context adaptation)
110
- - **The system becomes different** after each interaction
111
-
112
- In-context learning is pattern matching. This is optimization.
113
-
114
- ---
115
-
116
- ## 08 / What Problem This Solves
117
-
118
- The current paradigm for "adaptive" LLMs involves:
119
- - Vector databases for retrieval
120
- - Fine-tuning on user data (expensive, slow)
121
- - Prompt engineering (fragile, context-limited)
122
- - RAG systems (fetch, don't learn)
123
-
124
- None of these change the model itself. They work around the frozen weights.
125
-
126
- Test-time learning makes adaptation a first-class primitive. The model doesn't retrieve your preferences—it encodes them in its parameters. It doesn't simulate learning—it performs learning.
127
-
128
- This opens up:
129
- - **Personalization without fine-tuning** (the model adapts to you as you use it)
130
- - **Continual learning** (the model improves from every interaction)
131
- - **Transparent memory** (you can inspect what it learned)
132
- - **Efficient adaptation** (gradient descent is cheaper than retraining)
133
-
134
- ---
135
-
136
- ## 09 / What This Means for AI's Future
137
-
138
- The industry is converging on a model: train once, deploy frozen, scale through retrieval. This works. But it's not the only path.
139
-
140
- Test-time learning suggests a different trajectory: models that are **living systems**, not static calculators. Systems that don't just respond to you—they change because of you.
141
-
142
- This has implications:
143
- - **Privacy**: Your data updates your local model, not a shared cloud model
144
- - **Efficiency**: Learning happens incrementally, not in massive retraining runs
145
- - **Alignment**: The model adapts to your values through interaction, not through RLHF on aggregate data
146
- - **Transparency**: You can see what the model learned, reset it, or fork it
147
-
148
- The tradeoff is complexity. A model that changes during use is harder to reason about, harder to debug, harder to guarantee. But the benefits—true personalization, continual improvement, user-specific adaptation—may be worth it.
149
-
150
- ---
151
-
152
- ## 10 / The Retention Gate
153
-
154
- One detail worth highlighting: the retention gate.
155
-
156
- When the memory encounters a high-loss input (surprising, novel), it increases the learning rate. When it encounters a low-loss input (familiar, repeated), it decreases the learning rate.
157
-
158
- This is a simple heuristic, but it mirrors how human memory works. We remember surprising events more vividly than routine ones. The retention gate makes the memory selective—it learns more from what it doesn't already know.
159
-
160
- In this demo, retention is always 2.0x because the memory is fresh. Everything is surprising. After hundreds of interactions, you'd see retention vary—0.5x for familiar patterns, 2.0x for novel ones. The memory would become selective.
161
-
162
- ---
163
-
164
- ## 11 / Why the Memory is Shared
165
-
166
- This demo uses a single, shared memory across all users. This is intentional.
167
-
168
- It demonstrates that the memory is not user-specific. It's a collective brain. Every user's input updates the same weight matrix. This makes the learning observable—you can see the loss decrease as the memory encounters repeated patterns from different users.
169
-
170
- In a production system, you'd likely use per-user memory. But for a demo, shared memory makes the learning more visible and the privacy implications simpler (there are none—no user data is stored).
171
-
172
- ---
173
-
174
- ## 12 / The Bandwidth Constraint
175
-
176
- One reason LLMs feel static is that they operate at the wrong bandwidth. The only way to change their behavior is to retrain them—a process that costs millions and takes weeks. Users can't influence the model in real time.
177
-
178
- Test-time learning changes the bandwidth. The model updates with every message. The feedback loop tightens from months to milliseconds.
179
-
180
- This doesn't mean the model becomes smarter. It means the model becomes *responsive*. It adapts to the distribution of inputs it actually sees, not the distribution it was trained on.
181
-
182
- ---
183
-
184
- ## 13 / What You're Actually Watching
185
-
186
- When you interact with this demo, you're not chatting with a model. You're watching a memory module learn to compress hidden state patterns into a 256-dimensional space.
187
-
188
- The text generation is a side effect. The real process is:
189
- - Extract hidden states from the frozen LM
190
- - Project them into memory space
191
- - Predict what the memory should encode
192
- - Compute the error
193
- - Update the weights
194
- - Save the new state
195
-
196
- This happens for every message. The memory is always learning. The loss is always updating. The system is always changing.
197
-
198
- That's the difference. Standard LLMs are frozen calculators. This is a living system.
199
-
200
- ---
201
-
202
- ## 14 / The Horizon
203
-
204
- This demo is a proof of concept. It's not production-ready. It's not optimized. It's not aligned. But it demonstrates a principle: **models can learn while they think**.
205
-
206
- The implications ripple outward:
207
- - What if your AI assistant remembered how you corrected it?
208
- - What if your code completion tool learned your style over time?
209
- - What if your search engine adapted to your information needs?
210
- - What if alignment happened through interaction, not through pre-training?
211
-
212
- These aren't hypotheticals. They're design choices. The architecture exists. The papers are published. The code is open.
213
-
214
- The question is whether we build systems that simulate learning or systems that perform learning.
215
-
216
- This demo chooses the latter.
217
-
218
- ---
219
-
220
- ## 15 / A Note on Hype
221
-
222
- Test-time learning is not a silver bullet. It introduces complexity, instability, and new failure modes. A model that changes during use is harder to trust, harder to audit, harder to guarantee.
223
-
224
- But it's also more adaptive, more personal, more aligned with how humans actually learn.
225
-
226
- The industry will likely converge on hybrid systems: frozen base models with test-time learning in specific modules. The best of both worlds—stability where you need it, adaptation where you want it.
227
-
228
- This demo is a step in that direction. Not the destination. Just a clearer mental model of what's possible.
229
-
230
- ---
231
-
232
- ## 16 / The Core Metaphor
233
-
234
- Standard LLMs are like libraries. They contain vast knowledge, but they don't change when you visit them. You can check out books (retrieve information), but the library itself remains static.
235
-
236
- Test-time learning is like a brain. It changes with every experience. The connections strengthen or weaken based on what you encounter. The system becomes different because of you.
237
-
238
- Both are useful. But they're not the same thing.
239
-
240
- This demo is a brain, not a library.
241
-
242
- ---
243
-
244
- **Papers**:
245
- - Titans: Learning to (Learn at Test Time): RNNs with Expressive Hidden States ([arxiv.org/abs/2501.00663](https://arxiv.org/abs/2501.00663))
246
- - MIRAS: Associative Memory with Attentional Bias ([arxiv.org/abs/2504.13173](https://arxiv.org/abs/2504.13173))
247
-
248
- **Code**: Open source, minimal, educational.
249
- **Memory**: Shared, persistent, observable.
250
- **Learning**: Real, not simulated.
251
-
252
- ---
253
-
254
- *This is not a chatbot. This is a demonstration of what happens when models learn while thinking.*
255
- =======
256
  # When Models Learn While Thinking
257
 
258
  ---
@@ -506,4 +251,3 @@ This demo is a brain, not a library.
506
  ---
507
 
508
  *This is not a chatbot. This is a demonstration of what happens when models learn while thinking.*
509
- >>>>>>> origin/main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # When Models Learn While Thinking
2
 
3
  ---
 
251
  ---
252
 
253
  *This is not a chatbot. This is a demonstration of what happens when models learn while thinking.*
 
README.md CHANGED
@@ -1,76 +1,3 @@
1
- <<<<<<< HEAD
2
- ---
3
- title: Titans Miras Demo
4
- emoji: 🔬
5
- colorFrom: blue
6
- colorTo: purple
7
- sdk: gradio
8
- sdk_version: 4.36.1
9
- app_file: app.py
10
- pinned: false
11
- ---
12
-
13
- # Titans + MIRAS: A Brain That Changes Itself While Thinking
14
-
15
- A minimal but faithful reimplementation of **Titans** (test-time learning) and **MIRAS** (associative memory framework) using open-source models on Hugging Face.
16
-
17
- ## What is this?
18
-
19
- This demo showcases a neural architecture that can **learn and update its memory while generating responses** - a brain that literally changes itself while thinking!
20
-
21
- ### Key Features
22
-
23
- - 🔄 **Test-time learning**: Memory updates during inference (not just training)
24
- - 🎯 **Retention gate**: Surprising/novel inputs are more memorable (inspired by human memory)
25
- - 💾 **Persistent memory**: State is saved across sessions
26
- - 🤖 **Fully OSS**: Uses distilgpt2 and runs entirely on Hugging Face
27
-
28
- ## Architecture
29
-
30
- ```
31
- User Input
32
-
33
- [Base LM: distilgpt2] → Hidden States (768-dim)
34
-
35
- [Key/Value Projections] → Memory Space (256-dim)
36
-
37
- [MIRAS Memory Module] ← Test-time Gradient Updates
38
-
39
- [Text Generation] → Response + Memory Stats
40
- ```
41
-
42
- ### Components
43
-
44
- 1. **Base Language Model**: distilgpt2 (frozen, no training)
45
- 2. **Projection Layers**: Map hidden states to memory space
46
- 3. **MIRAS Memory**: Associative memory with learnable key→value mapping
47
- 4. **Retention Gate**: Adjusts learning rate based on surprise (loss magnitude)
48
- 5. **Memory Store**: Persists memory state to disk
49
-
50
- ## How It Works
51
-
52
- 1. Input text is processed through distilgpt2
53
- 2. Last hidden state is projected to key/value pairs
54
- 3. Memory predicts value from key
55
- 4. Loss (prediction error) indicates surprise
56
- 5. Higher surprise → higher retention → faster learning
57
- 6. Memory updated via gradient descent (1e-3 base LR)
58
- 7. Response generated and memory saved
59
-
60
- ## References
61
-
62
- - **Titans**: [Learning to Memorize at Test Time](https://arxiv.org/abs/2501.00663)
63
- - **MIRAS**: [Framework for Associative Memory with Attentional Bias](https://arxiv.org/abs/2504.13173)
64
-
65
- ## Running Locally
66
-
67
- ```bash
68
- pip install -r requirements.txt
69
- python app.py
70
- ```
71
-
72
- Built with ❤️ exploring the future of adaptive AI systems.
73
- =======
74
  ---
75
  title: Titans Miras Demo
76
  emoji: 🔬
@@ -142,4 +69,3 @@ python app.py
142
  ```
143
 
144
  Built with ❤️ exploring the future of adaptive AI systems.
145
- >>>>>>> origin/main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Titans Miras Demo
3
  emoji: 🔬
 
69
  ```
70
 
71
  Built with ❤️ exploring the future of adaptive AI systems.
 
app.py CHANGED
@@ -12,11 +12,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import gradio as gr
13
 
14
  from miras_memory import MIRASMemory
15
- <<<<<<< HEAD
16
- from projections import KeyProjection, ValueProjection, OutputProjection
17
- =======
18
  from projections import KeyProjection, ValueProjection
19
- >>>>>>> origin/main
20
  from memory_store import MemoryStore
21
 
22
  print("=" * 50)
@@ -30,10 +26,6 @@ HIDDEN_DIM = 768 # distilgpt2 hidden dimension
30
  MEMORY_DIM = 256 # Memory space dimension
31
  LEARNING_RATE = 1e-3 # Base learning rate for test-time updates
32
  MAX_NEW_TOKENS = 50 # Max tokens to generate
33
- <<<<<<< HEAD
34
- MEMORY_ALPHA = 0.1 # Memory influence strength on generation
35
- =======
36
- >>>>>>> origin/main
37
 
38
  # ========== Initialize Components ==========
39
  print("🧠 Initializing Titans + MIRAS brain...")
@@ -47,10 +39,6 @@ model.eval() # Frozen - no training
47
  # Create projection layers
48
  key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
49
  value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
50
- <<<<<<< HEAD
51
- output_proj = OutputProjection(MEMORY_DIM, HIDDEN_DIM) # Map memory back to hidden space
52
- =======
53
- >>>>>>> origin/main
54
 
55
  # Create memory module
56
  memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
@@ -112,44 +100,6 @@ def chat(message, history):
112
  # Update stats
113
  memory.update_stats(loss)
114
 
115
- <<<<<<< HEAD
116
- # === Step 3: Memory-augmented generation (THE KEY CHANGE!) ===
117
- # Instead of model.generate(), we do token-by-token generation
118
- # where memory augments the hidden state before prediction.
119
-
120
- generated_ids = inputs['input_ids'].clone()
121
-
122
- with torch.no_grad():
123
- for _ in range(MAX_NEW_TOKENS):
124
- # Get hidden states from model
125
- outputs = model(generated_ids, output_hidden_states=True)
126
- h_last = outputs.hidden_states[-1][:, -1, :] # (1, hidden_dim)
127
-
128
- # Query memory with projected key
129
- k_gen = key_proj(h_last)
130
- memory_out = memory.query(k_gen) # (1, memory_dim)
131
-
132
- # Augment hidden state with memory output
133
- # h' = h + alpha * output_proj(memory(k))
134
- h_augmented = h_last + MEMORY_ALPHA * output_proj(memory_out)
135
-
136
- # Compute logits with augmented hidden state
137
- logits = model.lm_head(h_augmented) # (1, vocab_size)
138
-
139
- # Sample next token (temperature sampling)
140
- logits = logits / 0.8 # temperature
141
- probs = torch.softmax(logits, dim=-1)
142
- next_token = torch.multinomial(probs, num_samples=1)
143
-
144
- # Stop if EOS
145
- if next_token.item() == tokenizer.eos_token_id:
146
- break
147
-
148
- # Append to sequence
149
- generated_ids = torch.cat([generated_ids, next_token], dim=1)
150
-
151
- response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
152
- =======
153
  # === Step 3: Generate response ===
154
  with torch.no_grad():
155
  output_ids = model.generate(
@@ -162,7 +112,6 @@ def chat(message, history):
162
  )
163
 
164
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
165
- >>>>>>> origin/main
166
 
167
  # Remove the input prompt from response
168
  if response.startswith(message):
@@ -192,202 +141,235 @@ def chat(message, history):
192
  # ========== Gradio Interface ==========
193
  print("🚀 Launching Gradio interface...")
194
 
195
- with gr.Blocks(theme="soft", title="The Brain That Learns While Thinking") as demo:
 
 
 
 
196
 
197
- gr.Markdown("""
198
- # 🧠 The Brain That Learns While Thinking
199
- ### A Living System That Updates Its Neural Weights During Inference
200
 
201
- **What This Does**: Demonstrates test-time learning - the model's memory weights update via gradient descent with every message you send.
202
 
203
- **The Novel Thing**: Standard LLMs (ChatGPT, Claude) freeze their weights after training. This system performs **real gradient descent while you chat**.
204
 
205
- **Quick Test**: Send "hello world" 5 times and watch the **Loss** decrease below each response. That's learning happening in real-time!
206
- """)
207
 
208
- chatbot = gr.Chatbot(
209
- label="Chat & Watch the Memory Learn",
210
- height=500,
211
- )
212
 
213
- msg = gr.Textbox(
214
- label="Your Message",
215
- placeholder="Try: hello world (send it 5 times and watch loss decrease!)",
216
- lines=2,
217
- )
218
 
219
- with gr.Row():
220
- submit = gr.Button("Send", variant="primary")
221
- clear = gr.Button("Clear")
222
-
223
- gr.Examples(
224
- examples=[
225
- "hello world",
226
- "hello world",
227
- "Supercalifragilisticexpialidocious quantum entanglement",
228
- "my name is Pavan",
229
- ],
230
- inputs=msg,
231
- label="Try these (especially repeat 'hello world'!)",
232
- )
233
 
234
- # Built with section
235
- gr.Markdown("""
236
  ---
237
 
238
- **Built with**: [Titans](https://arxiv.org/abs/2501.00663) (test-time training) + [MIRAS](https://arxiv.org/abs/2504.13173) (associative memory)
239
- **📖 Deep Dive**: [Read the full essay](https://huggingface.co/spaces/Pavantej/titans-miras-demo/blob/main/ESSAY.md)
240
- """)
241
-
242
- # Detailed information in accordions
243
- with gr.Accordion("📊 What the Stats Mean", open=False):
244
- gr.Markdown("""
245
- **Loss** (e.g., 7.48 6.61 5.23)
246
- - Prediction error - how surprised the memory is
247
- - **Lower = memory is familiar** with this pattern
248
- - **Decreasing loss = learning is happening!**
249
-
250
- **Retention** (e.g., 2.00x)
251
- - Learning rate multiplier based on surprise
252
- - 2.0x = very surprising, learns aggressively
253
- - 0.5x = familiar, learns slowly
254
-
255
- **Updates** (e.g., 1 → 2 → 3...)
256
- - Total memory updates
257
- - **Persists across page refreshes!**
258
- - Proof that memory is permanent
259
-
260
- **Avg Loss** (e.g., 7.26)
261
- - Running average showing long-term learning progress
262
- """)
263
-
264
- with gr.Accordion("🧪 Interactive Experiments", open=False):
265
- gr.Markdown("""
266
- ### Experiment 1: Watch Loss Decrease
267
- 1. Send "hello world" 5 times
268
- 2. Watch loss: 7.5 → 6.0 → 5.0 → 4.0
269
- 3. **This proves learning!**
270
-
271
- ### Experiment 2: Trigger Surprise
272
- 1. After experiment 1, send something completely different
273
- 2. Watch loss spike back up (4.0 → 9.0+)
274
- 3. **Memory detects novelty!**
275
-
276
- ### Experiment 3: Test Persistence
277
- 1. Note the "Updates" count
278
- 2. Refresh this entire page
279
- 3. Send any message
280
- 4. Updates should continue, not reset!
281
- 5. **Memory survives refresh!**
282
- """)
283
-
284
- <<<<<<< HEAD
285
- with gr.Accordion("✨ Memory-Augmented Generation", open=False):
286
- gr.Markdown("""
287
- **Memory now influences text generation!**
288
-
289
- - At each token generation step, we query the memory
290
- - Memory output is added to the hidden state: `h' = h + α × memory(k)`
291
- - This augmented state determines the next token
292
-
293
- **What to expect:**
294
- - Repeated inputs → more consistent outputs
295
- - As memory learns patterns, it biases generation toward them
296
- - Novel inputs → more random outputs (memory has no prior)
297
- =======
298
- with gr.Accordion("⚠️ Important: Ignore the Text", open=False):
299
- gr.Markdown("""
300
- **The text responses are random** - this is expected!
301
-
302
- - We're NOT training the text generator (distilgpt2 is frozen)
303
- - **Focus on the numbers below each response**
304
- - The magic is in the decreasing loss, not the text
305
-
306
- **Why?** We're demonstrating **memory learning**, not text generation.
307
- >>>>>>> origin/main
308
- """)
309
-
310
- with gr.Accordion("🔬 How It Works", open=False):
311
- gr.Markdown("""
312
- ```
313
- Your Message
314
-
315
- <<<<<<< HEAD
316
- [distilgpt2: FROZEN]
317
-
318
- Hidden States (768-dim)
319
-
320
- [Key Projection] → Memory (256-dim)
321
-
322
- [MIRAS Memory: LEARNING!]
323
-
324
- [Output Projection] → Augmentation (768-dim)
325
-
326
- h' = h + α × memory_output ← THE KEY!
327
-
328
- LM Head → Next Token
329
- ```
330
-
331
- **Key**: Memory modifies hidden states before prediction!
332
- =======
333
- [distilgpt2: FROZEN] ← Not learning
334
-
335
- Hidden States (768-dim)
336
-
337
- [Projections] → Memory (256-dim)
338
-
339
- [MIRAS Memory: LEARNING!]
340
-
341
- Loss = Prediction error
342
-
343
- Gradient Descent → Weights change
344
-
345
- Saved to disk → Persists
346
- ```
347
-
348
- **Key**: We're training the **memory**, not the text generator!
349
- >>>>>>> origin/main
350
- """)
351
-
352
- with gr.Accordion("💡 Why This Matters", open=False):
353
- gr.Markdown("""
354
- **Standard LLMs** (ChatGPT, Claude):
355
- - Weights frozen after training
356
- - "Learning" is just pattern matching
357
- - Forget when context ends
358
-
359
- **This Demo** (Titans + MIRAS):
360
- - Weights update during inference
361
- - Real gradient descent
362
- - Memory persists across sessions
363
-
364
- That decreasing loss? **That's gradient descent during inference.**
365
- That's what ChatGPT doesn't do.
366
- """)
367
 
368
- # Footer
369
- gr.Markdown("""
370
  ---
371
 
372
- **Vibecoded by [Pavan Tej](https://github.com/thepavantejz)**
373
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
- # Event handlers
376
- def user_submit(user_message, history):
377
- return "", history + [[user_message, None]]
378
 
379
- def bot_response(history):
380
- user_message = history[-1][0]
381
- bot_message = chat(user_message, history[:-1])
382
- history[-1][1] = bot_message
383
- return history
384
 
385
- msg.submit(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
386
- bot_response, chatbot, chatbot
387
- )
388
- submit.click(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
389
- bot_response, chatbot, chatbot
390
- )
391
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  demo.launch()
 
12
  import gradio as gr
13
 
14
  from miras_memory import MIRASMemory
 
 
 
15
  from projections import KeyProjection, ValueProjection
 
16
  from memory_store import MemoryStore
17
 
18
  print("=" * 50)
 
26
  MEMORY_DIM = 256 # Memory space dimension
27
  LEARNING_RATE = 1e-3 # Base learning rate for test-time updates
28
  MAX_NEW_TOKENS = 50 # Max tokens to generate
 
 
 
 
29
 
30
  # ========== Initialize Components ==========
31
  print("🧠 Initializing Titans + MIRAS brain...")
 
39
  # Create projection layers
40
  key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
41
  value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
 
 
 
 
42
 
43
  # Create memory module
44
  memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
 
100
  # Update stats
101
  memory.update_stats(loss)
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # === Step 3: Generate response ===
104
  with torch.no_grad():
105
  output_ids = model.generate(
 
112
  )
113
 
114
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
115
 
116
  # Remove the input prompt from response
117
  if response.startswith(message):
 
141
  # ========== Gradio Interface ==========
142
  print("🚀 Launching Gradio interface...")
143
 
144
+ demo = gr.ChatInterface(
145
+ fn=chat,
146
+ title="🧠 The Brain That Learns While Thinking",
147
+ description="""
148
+ # A Living System That Updates Its Weights During Inference
149
 
150
+ **The Novel Thing**: Standard LLMs freeze their weights after training. This system performs gradient descent *while you chat*.
 
 
151
 
152
+ ---
153
 
154
+ ## 🚀 The Revolutionary Difference
155
 
156
+ **Standard LLMs (ChatGPT, Claude, etc.)**: Think Predict **Forget**
157
+ **Titans + MIRAS**: Think → Predict → **Update** → **Remember** → Think Differently
158
 
159
+ ---
 
 
 
160
 
161
+ ### 💡 What Makes This Different?
 
 
 
 
162
 
163
+ | Feature | ChatGPT/Claude/GPT-4 | This Demo (Titans+MIRAS) |
164
+ |---------|---------------------|--------------------------|
165
+ | **Weights during chat** | 🔒 Frozen forever | ✅ Update with every message |
166
+ | **Learning** | ❌ Simulated (in-context only) | ✅ Real (gradient descent) |
167
+ | **Memory** | 📝 Token context only | 🧠 Neural parameters |
168
+ | **Persistence** | ❌ Forgets when context ends | ✅ Saves to disk |
169
+ | **Adaptation** | 🎭 Acts like it learned | 🔬 Actually learns |
 
 
 
 
 
 
 
170
 
 
 
171
  ---
172
 
173
+ ### 🎯 What You're Witnessing
174
+
175
+ **This is NOT a better chatbot** - it's a **learning demonstrator**.
176
+
177
+ 1. **The text responses are random** - that's expected! We're using a small, frozen model (distilgpt2)
178
+ 2. **The MAGIC is in the numbers below** - watch the "Loss" decrease when you repeat inputs!
179
+ 3. **Every message physically changes the brain** - the memory weights update via gradient descent
180
+ 4. **Refresh the page** - the update count continues (memory persists!)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
 
 
182
  ---
183
 
184
+ ### 🧪 How It Works (The Technical Truth)
185
+
186
+ ```
187
+ Your Message
188
+
189
+ [distilgpt2: FROZEN] ← Not learning, just generating
190
+
191
+ Hidden States (768-dim)
192
+
193
+ [Projections] → Memory Space (256-dim)
194
+
195
+ [MIRAS Memory: LEARNING!] ← This is what updates!
196
+
197
+ Loss = How surprised the memory is
198
+
199
+ Gradient Descent → Memory weights change
200
+
201
+ Saved to disk → Persists forever
202
+ ```
203
+
204
+ **Key Insight**: We're training the **memory**, not the text generator!
205
 
206
+ ---
 
 
207
 
208
+ ### 🔬 The Science: Why This Matters
 
 
 
 
209
 
210
+ **Standard LLMs**:
211
+ - Weights frozen after training (costs millions)
212
+ - "Learning" is just pattern matching in context
213
+ - Forget everything when context ends
214
+ - Same model for everyone
215
+
216
+ **Titans + MIRAS**:
217
+ - Weights update during inference (free!)
218
+ - Real optimization via gradient descent
219
+ - Memory persists across sessions
220
+ - Personalizes to each user
221
+
222
+ **This is test-time learning** - the future of adaptive AI.
223
+
224
+ ---
225
+
226
+ ### 📊 What the Stats Mean
227
+
228
+ - **Loss**: How surprised the memory is (lower = more familiar)
229
+ - **Retention**: Learning rate multiplier (2.0x = very surprising, 0.5x = familiar)
230
+ - **Updates**: Total number of memory updates (persists across sessions!)
231
+ - **Avg Loss**: Overall learning progress
232
+
233
+ ---
234
+
235
+ ### 🎮 Try This Experiment
236
+
237
+ 1. **Send "hello world" 5 times** → Watch loss decrease!
238
+ 2. **Send something completely different** → Loss spikes!
239
+ 3. **Refresh the page and send another message** → Update count continues!
240
+
241
+ **That decreasing loss is proof the neural weights are changing!**
242
+
243
+ ---
244
+
245
+ ### 🌟 The Bottom Line
246
+
247
+ **ChatGPT**: A frozen calculator that *simulates* adaptation
248
+ **This Demo**: A living system that *performs* adaptation
249
+
250
+ You're not chatting with a model.
251
+ **You're watching a brain rewire itself in real-time.** 🧠⚡
252
+
253
+ ---
254
+
255
+ ### 🧪 How to Test This (Interactive Experiments)
256
+
257
+ **Don't just chat—run experiments to see the learning happen!**
258
+
259
+ #### Experiment 1: Watch Loss Decrease (Proof of Learning)
260
+ ```
261
+ 1. Send "hello world"
262
+ 2. Send "hello world" again
263
+ 3. Send "hello world" again
264
+ 4. Send "hello world" again
265
+ 5. Send "hello world" again
266
+ ```
267
+ **What to watch**: Loss should decrease each time (7.5 → 6.0 → 5.0 → 4.0)
268
+ **Why it matters**: This proves the memory is learning the pattern!
269
+
270
+ #### Experiment 2: Trigger Surprise (Spike the Loss)
271
+ ```
272
+ 1. Send "hello world" 5 times (loss decreases)
273
+ 2. Then send: "Supercalifragilisticexpialidocious quantum entanglement"
274
+ ```
275
+ **What to watch**: Loss should spike back up (4.0 → 9.0+)
276
+ **Why it matters**: The memory detects novelty—it knows this is different!
277
+
278
+ #### Experiment 3: Test Persistence (Memory Survives)
279
+ ```
280
+ 1. Note the "Updates" count (e.g., 15)
281
+ 2. Refresh this page completely
282
+ 3. Send any message
283
+ 4. Check if Updates = 16 (not reset to 1!)
284
+ ```
285
+ **What to watch**: Update count should continue, not reset
286
+ **Why it matters**: Memory persists to disk—it's not just in RAM!
287
+
288
+ ---
289
+
290
+ ### 📊 What Each Stat Means (Decoder Ring)
291
+
292
+ **Loss** (e.g., 7.48 → 6.61 → 5.23)
293
+ - **What it is**: Prediction error (how surprised the memory is)
294
+ - **Lower = Better**: Memory is familiar with this pattern
295
+ - **Higher = Novel**: Memory hasn't seen this before
296
+ - **Why it matters**: Decreasing loss = learning is happening!
297
+
298
+ **Retention** (e.g., 2.00x)
299
+ - **What it is**: Learning rate multiplier based on surprise
300
+ - **2.0x = Very surprising**: Memory learns aggressively
301
+ - **0.5x = Very familiar**: Memory learns slowly (you won't see this yet)
302
+ - **Why it matters**: The brain learns more from surprising events (like humans!)
303
+
304
+ **Updates** (e.g., 1 → 2 → 3 → 4...)
305
+ - **What it is**: Total number of memory updates
306
+ - **Persists across sessions**: Survives page refreshes
307
+ - **Never resets**: Keeps counting forever
308
+ - **Why it matters**: Proof that memory is persistent, not ephemeral!
309
+
310
+ **Avg Loss** (e.g., 7.26)
311
+ - **What it is**: Running average of all losses
312
+ - **Trends downward**: As memory learns recurring patterns
313
+ - **Reflects overall learning**: Lower = memory is getting smarter
314
+ - **Why it matters**: Shows long-term learning progress!
315
+
316
+ ---
317
+
318
+ ### ⚠️ What to Ignore (Important!)
319
+
320
+ **The text responses are random and bad** - this is expected!
321
+ - We're NOT training the text generator (distilgpt2 is frozen)
322
+ - The responses don't matter—they're a side effect
323
+ - **Focus on the numbers below**, not the text above
324
+ - The magic is in the decreasing loss, not the generated text
325
+
326
+ **Why?** Because we're demonstrating **memory learning**, not text generation.
327
+ Standard LLMs train the text generator. This trains the memory. Different goals.
328
+
329
+ ---
330
+
331
+ ### 🎯 What Success Looks Like
332
+
333
+ ✅ **You're seeing it work if**:
334
+ - Loss decreases when you repeat inputs
335
+ - Loss spikes when you send something new
336
+ - Update count increments with each message
337
+ - Update count persists after page refresh
338
+ - Retention is 2.0x (everything is surprising to fresh memory)
339
+
340
+ ❌ **You're NOT seeing it work if**:
341
+ - Loss stays constant (not learning)
342
+ - Updates reset to 1 after refresh (not persisting)
343
+ - No stats appear below responses
344
+
345
+ ---
346
+
347
+ ### 🔬 Why This Matters (The Big Picture)
348
+
349
+ **Standard LLMs**: Frozen weights → No learning during use
350
+ **This Demo**: Live weights → Learning with every message
351
+
352
+ That decreasing loss you see? **That's gradient descent happening during inference.**
353
+ That's the revolution. That's what ChatGPT doesn't do.
354
+
355
+ You're not just using a model. **You're watching it change.**
356
+
357
+ ---
358
+
359
+ *Built with Titans (test-time training) + MIRAS (associative memory)*
360
+ *Papers: [Titans](https://arxiv.org/abs/2501.00663) | [MIRAS](https://arxiv.org/abs/2504.13173)*
361
+
362
+ **📖 [Read the full essay: "When Models Learn While Thinking"](https://huggingface.co/spaces/Pavantej/titans-miras-demo/blob/main/ESSAY.md)**
363
+ """,
364
+ examples=[
365
+ "hello world",
366
+ "hello world", # Repeat to show learning!
367
+ "Tell me about test-time learning",
368
+ "What is 2+2?",
369
+ "my name is [your name]",
370
+ ],
371
+ cache_examples=False,
372
+ theme="soft",
373
+ )
374
 
375
  demo.launch()
memory_store.py CHANGED
@@ -1,88 +1,3 @@
1
- <<<<<<< HEAD
2
- """
3
- Memory Persistence
4
-
5
- Handles saving and loading memory state to/from disk so the brain
6
- remembers across sessions.
7
- """
8
-
9
- import torch
10
- import json
11
- import os
12
- from pathlib import Path
13
- from datetime import datetime
14
-
15
-
16
- class MemoryStore:
17
- """Manages persistent storage of memory state."""
18
-
19
- def __init__(self, save_dir="memory"):
20
- self.save_dir = Path(save_dir)
21
- self.save_dir.mkdir(exist_ok=True)
22
- self.memory_path = self.save_dir / "memory.pt"
23
- self.metadata_path = self.save_dir / "metadata.json"
24
-
25
- def save(self, memory_module):
26
- """
27
- Save memory state to disk.
28
-
29
- Args:
30
- memory_module: MIRASMemory instance
31
- """
32
- # Save memory weights
33
- torch.save({
34
- 'W': memory_module.W.data,
35
- 'update_count': memory_module.update_count,
36
- 'total_loss': memory_module.total_loss,
37
- }, self.memory_path)
38
-
39
- # Save metadata
40
- metadata = {
41
- 'last_updated': datetime.now().isoformat(),
42
- 'memory_dim': memory_module.memory_dim,
43
- 'updates': memory_module.update_count.item(),
44
- 'avg_loss': (memory_module.total_loss / max(memory_module.update_count, 1)).item(),
45
- }
46
-
47
- with open(self.metadata_path, 'w') as f:
48
- json.dump(metadata, f, indent=2)
49
-
50
- print(f"💾 Memory saved: {memory_module.update_count.item()} updates")
51
-
52
- def load(self, memory_module):
53
- """
54
- Load memory state from disk.
55
-
56
- Args:
57
- memory_module: MIRASMemory instance to load into
58
-
59
- Returns:
60
- bool: True if loaded successfully, False otherwise
61
- """
62
- if not self.memory_path.exists():
63
- print("🆕 No saved memory found. Starting fresh!")
64
- return False
65
-
66
- try:
67
- checkpoint = torch.load(self.memory_path)
68
- memory_module.W.data = checkpoint['W']
69
- memory_module.update_count = checkpoint['update_count']
70
- memory_module.total_loss = checkpoint['total_loss']
71
-
72
- print(f"✅ Memory loaded: {memory_module.update_count.item()} updates")
73
- return True
74
- except Exception as e:
75
- print(f"⚠️ Error loading memory: {e}. Starting fresh!")
76
- return False
77
-
78
- def get_metadata(self):
79
- """Get metadata about saved memory."""
80
- if not self.metadata_path.exists():
81
- return None
82
-
83
- with open(self.metadata_path, 'r') as f:
84
- return json.load(f)
85
- =======
86
  """
87
  Memory Persistence
88
 
@@ -166,4 +81,3 @@ class MemoryStore:
166
 
167
  with open(self.metadata_path, 'r') as f:
168
  return json.load(f)
169
- >>>>>>> origin/main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Memory Persistence
3
 
 
81
 
82
  with open(self.metadata_path, 'r') as f:
83
  return json.load(f)
 
miras_memory.py CHANGED
@@ -1,115 +1,3 @@
1
- <<<<<<< HEAD
2
- """
3
- MIRAS-inspired Associative Memory Module
4
-
5
- Implements an associative memory that learns key-value mappings
6
- through attentional bias objective during test time.
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
-
12
-
13
- class MIRASMemory(nn.Module):
14
- """
15
- Associative memory module inspired by MIRAS framework.
16
-
17
- The memory learns to map keys to values using a simple linear projection
18
- and updates itself during test time via gradient descent.
19
-
20
- Args:
21
- memory_dim: Dimensionality of memory keys/values
22
- init_scale: Scale for random weight initialization
23
- """
24
-
25
- def __init__(self, memory_dim=256, init_scale=0.01):
26
- super().__init__()
27
- self.memory_dim = memory_dim
28
-
29
- # Memory matrix: maps keys to values
30
- # W: (memory_dim, memory_dim)
31
- self.W = nn.Parameter(
32
- torch.randn(memory_dim, memory_dim) * init_scale
33
- )
34
-
35
- # Track number of updates for retention gate
36
- self.register_buffer('update_count', torch.tensor(0))
37
- self.register_buffer('total_loss', torch.tensor(0.0))
38
-
39
- def forward(self, key):
40
- """
41
- Query memory with a key.
42
-
43
- Args:
44
- key: (batch_size, memory_dim) tensor
45
-
46
- Returns:
47
- predicted_value: (batch_size, memory_dim) tensor
48
- """
49
- # Simple linear mapping: pred_v = k @ W
50
- predicted_value = key @ self.W
51
- return predicted_value
52
-
53
- def query(self, key):
54
- """
55
- Query memory without computing gradients (for generation).
56
-
57
- Args:
58
- key: (batch_size, memory_dim) tensor
59
-
60
- Returns:
61
- memory_output: (batch_size, memory_dim) tensor
62
- """
63
- with torch.no_grad():
64
- return self.forward(key)
65
-
66
- def compute_loss(self, key, value):
67
- """
68
- Compute attentional bias loss between predicted and true value.
69
-
70
- Args:
71
- key: (batch_size, memory_dim)
72
- value: (batch_size, memory_dim)
73
-
74
- Returns:
75
- loss: scalar tensor
76
- """
77
- pred = self.forward(key)
78
- loss = ((pred - value) ** 2).mean()
79
- return loss
80
-
81
- def retention_gate(self, loss):
82
- """
83
- Simple retention gate: higher loss = more surprising = more memorable.
84
-
85
- Returns a scaling factor for the learning rate based on surprise.
86
- High loss (surprising) gets higher weight.
87
-
88
- Args:
89
- loss: scalar tensor
90
-
91
- Returns:
92
- retention_factor: scalar in range [0.5, 2.0]
93
- """
94
- # Normalize loss to a retention factor
95
- # If loss is high (surprising), learn more aggressively
96
- retention_factor = torch.clamp(loss / 0.1, 0.5, 2.0)
97
- return retention_factor.item()
98
-
99
- def update_stats(self, loss):
100
- """Track memory statistics."""
101
- self.update_count += 1
102
- self.total_loss += loss.item()
103
-
104
- def get_stats(self):
105
- """Get memory statistics."""
106
- avg_loss = self.total_loss / max(self.update_count, 1)
107
- return {
108
- 'updates': self.update_count.item(),
109
- 'avg_loss': avg_loss.item(),
110
- 'memory_size': self.W.numel()
111
- }
112
- =======
113
  """
114
  MIRAS-inspired Associative Memory Module
115
 
@@ -207,4 +95,3 @@ class MIRASMemory(nn.Module):
207
  'avg_loss': avg_loss.item(),
208
  'memory_size': self.W.numel()
209
  }
210
- >>>>>>> origin/main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  MIRAS-inspired Associative Memory Module
3
 
 
95
  'avg_loss': avg_loss.item(),
96
  'memory_size': self.W.numel()
97
  }
 
projections.py CHANGED
@@ -1,85 +1,3 @@
1
- <<<<<<< HEAD
2
- """
3
- Key and Value Projection Layers
4
-
5
- Maps hidden states from the base language model into memory-compatible
6
- representations for the MIRAS memory module.
7
- """
8
-
9
- import torch.nn as nn
10
-
11
-
12
- class KeyProjection(nn.Module):
13
- """
14
- Projects hidden states to memory keys.
15
-
16
- Args:
17
- hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
18
- memory_dim: Dimension of memory keys (e.g., 256)
19
- """
20
-
21
- def __init__(self, hidden_dim, memory_dim):
22
- super().__init__()
23
- self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
24
-
25
- def forward(self, hidden_state):
26
- """
27
- Args:
28
- hidden_state: (batch_size, hidden_dim)
29
- Returns:
30
- key: (batch_size, memory_dim)
31
- """
32
- return self.projection(hidden_state)
33
-
34
-
35
- class ValueProjection(nn.Module):
36
- """
37
- Projects hidden states to memory values.
38
-
39
- Args:
40
- hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
41
- memory_dim: Dimension of memory values (e.g., 256)
42
- """
43
-
44
- def __init__(self, hidden_dim, memory_dim):
45
- super().__init__()
46
- self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
47
-
48
- def forward(self, hidden_state):
49
- """
50
- Args:
51
- hidden_state: (batch_size, hidden_dim)
52
- Returns:
53
- value: (batch_size, memory_dim)
54
- """
55
- return self.projection(hidden_state)
56
-
57
-
58
- class OutputProjection(nn.Module):
59
- """
60
- Projects memory output back to hidden dimension for augmentation.
61
-
62
- This is the key bridge that allows memory to influence generation:
63
- h' = h + alpha * output_proj(memory(k))
64
-
65
- Args:
66
- memory_dim: Dimension of memory output (e.g., 256)
67
- hidden_dim: Dimension of LM hidden states (e.g., 768)
68
- """
69
-
70
- def __init__(self, memory_dim, hidden_dim):
71
- super().__init__()
72
- self.projection = nn.Linear(memory_dim, hidden_dim, bias=False)
73
-
74
- def forward(self, memory_output):
75
- """
76
- Args:
77
- memory_output: (batch_size, memory_dim)
78
- Returns:
79
- hidden_augmentation: (batch_size, hidden_dim)
80
- """
81
- return self.projection(memory_output)
82
- =======
83
  """
84
  Key and Value Projection Layers
85
 
@@ -134,4 +52,3 @@ class ValueProjection(nn.Module):
134
  value: (batch_size, memory_dim)
135
  """
136
  return self.projection(hidden_state)
137
- >>>>>>> origin/main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Key and Value Projection Layers
3
 
 
52
  value: (batch_size, memory_dim)
53
  """
54
  return self.projection(hidden_state)