Upload folder using huggingface_hub
Browse files- ESSAY.md +0 -256
- README.md +0 -74
- app.py +214 -232
- memory_store.py +0 -86
- miras_memory.py +0 -113
- projections.py +0 -83
ESSAY.md
CHANGED
|
@@ -1,258 +1,3 @@
|
|
| 1 |
-
<<<<<<< HEAD
|
| 2 |
-
# When Models Learn While Thinking
|
| 3 |
-
|
| 4 |
-
---
|
| 5 |
-
|
| 6 |
-
## 01 / The Frozen Calculator Problem
|
| 7 |
-
|
| 8 |
-
Every conversation you've had with ChatGPT, Claude, or any large language model follows the same pattern: the model thinks, predicts, and forgets. The weights that determine its behavior were set months ago, frozen in place after training. When you correct it, when you teach it something new, when you have a breakthrough conversation—none of that changes the model itself.
|
| 9 |
-
|
| 10 |
-
This isn't a bug. It's the architecture.
|
| 11 |
-
|
| 12 |
-
The model can *simulate* learning through in-context adaptation. It can act like it remembers. But the parameters that define its cognition remain untouched. When the context window ends, so does the illusion.
|
| 13 |
-
|
| 14 |
-
This demo breaks that pattern.
|
| 15 |
-
|
| 16 |
-
---
|
| 17 |
-
|
| 18 |
-
## 02 / What This Actually Does
|
| 19 |
-
|
| 20 |
-
This is a minimal reimplementation of two recent papers: **Titans** (test-time training) and **MIRAS** (associative memory with attentional bias). Together, they demonstrate something most production LLMs don't do: **learning during inference**.
|
| 21 |
-
|
| 22 |
-
The architecture is simple:
|
| 23 |
-
- A frozen language model (distilgpt2) generates text
|
| 24 |
-
- Hidden states from that model are projected into a memory space
|
| 25 |
-
- An associative memory module predicts what it should remember
|
| 26 |
-
- The prediction error drives gradient descent
|
| 27 |
-
- The memory weights update
|
| 28 |
-
- The updated state persists to disk
|
| 29 |
-
|
| 30 |
-
Every message you send changes the model's internal representations. Not through prompt engineering. Not through retrieval. Through actual gradient-based optimization—at inference time.
|
| 31 |
-
|
| 32 |
-
---
|
| 33 |
-
|
| 34 |
-
## 03 / The Text Doesn't Matter
|
| 35 |
-
|
| 36 |
-
If you interact with this demo, you'll notice the text responses are... not good. Random. Sometimes incoherent. This is intentional.
|
| 37 |
-
|
| 38 |
-
The text generator (distilgpt2) is frozen. We're not training it. The responses reflect what a small, untuned model produces when asked to continue arbitrary text. That's not the point.
|
| 39 |
-
|
| 40 |
-
**The point is the numbers below each response.**
|
| 41 |
-
|
| 42 |
-
Watch the loss. When you send the same message multiple times, the loss decreases. The memory is learning to predict the hidden state patterns associated with that input. When you send something completely different, the loss spikes—the memory is surprised.
|
| 43 |
-
|
| 44 |
-
This is test-time learning. The model is changing itself while you use it.
|
| 45 |
-
|
| 46 |
-
---
|
| 47 |
-
|
| 48 |
-
## 04 / What the Stats Mean
|
| 49 |
-
|
| 50 |
-
Each response shows four metrics:
|
| 51 |
-
|
| 52 |
-
**Loss**: How surprised the memory is. Lower means the pattern is familiar. Higher means it's novel. This is the prediction error that drives learning.
|
| 53 |
-
|
| 54 |
-
**Retention**: A multiplier on the learning rate. When loss is high (surprising input), retention is high (2.0x). The memory learns more aggressively from surprising events. This is the retention gate—a simple mechanism inspired by how human memory prioritizes novelty.
|
| 55 |
-
|
| 56 |
-
**Updates**: The total number of times the memory has been updated. This persists across sessions. Refresh the page, send another message, and the count continues. The memory doesn't reset.
|
| 57 |
-
|
| 58 |
-
**Avg Loss**: The running average of all losses. Over time, as the memory learns recurring patterns, this should trend downward.
|
| 59 |
-
|
| 60 |
-
These aren't vanity metrics. They're the observable signature of gradient descent happening during inference.
|
| 61 |
-
|
| 62 |
-
---
|
| 63 |
-
|
| 64 |
-
## 05 / The Two Papers
|
| 65 |
-
|
| 66 |
-
**Titans** (2025) introduces test-time training for language models. The core idea: instead of freezing weights after pre-training, allow a subset of parameters to update during inference. This creates a feedback loop—think, predict, update, think differently next time—that doesn't exist in standard LLMs.
|
| 67 |
-
|
| 68 |
-
**MIRAS** (2024) reframes attention mechanisms as implicit optimization problems. It shows that dot-product attention, RNNs, and linear transformers are all solving online optimization with a specific loss function (L2). By making the loss function explicit and tunable, you can change the memory behavior. Different losses produce different cognition.
|
| 69 |
-
|
| 70 |
-
This demo combines both: Titans' test-time learning with MIRAS's associative memory framework.
|
| 71 |
-
|
| 72 |
-
---
|
| 73 |
-
|
| 74 |
-
## 06 / What's Missing
|
| 75 |
-
|
| 76 |
-
This is a minimal reimplementation. Several components from the papers are not included:
|
| 77 |
-
|
| 78 |
-
**From Titans**:
|
| 79 |
-
- Multi-layer test-time updates (we only update the memory module)
|
| 80 |
-
- Task-specific memory partitioning (we use a single shared memory)
|
| 81 |
-
- Adaptive learning rate schedules (we use a simple retention gate)
|
| 82 |
-
|
| 83 |
-
**From MIRAS**:
|
| 84 |
-
- Alternative loss functions (we use L2)
|
| 85 |
-
- Multi-head memory (we use a single memory matrix)
|
| 86 |
-
- Attention-based retrieval (we use direct key-value mapping)
|
| 87 |
-
|
| 88 |
-
The goal was to demonstrate the core mechanism—learning during inference—not to replicate every detail. The full papers contain significantly more sophistication.
|
| 89 |
-
|
| 90 |
-
---
|
| 91 |
-
|
| 92 |
-
## 07 / The Difference from Standard LLMs
|
| 93 |
-
|
| 94 |
-
Standard LLMs (ChatGPT, Claude, GPT-4) do this:
|
| 95 |
-
```
|
| 96 |
-
Input → Frozen Weights → Output → Forget
|
| 97 |
-
```
|
| 98 |
-
|
| 99 |
-
This demo does this:
|
| 100 |
-
```
|
| 101 |
-
Input → Frozen LM → Hidden States → Memory (Learning) → Output → Save
|
| 102 |
-
```
|
| 103 |
-
|
| 104 |
-
The frozen LM provides the text generation. The memory provides the learning. They're decoupled.
|
| 105 |
-
|
| 106 |
-
This matters because:
|
| 107 |
-
- **Weights update during use** (not just during training)
|
| 108 |
-
- **Memory persists across sessions** (not just within a context window)
|
| 109 |
-
- **Learning is explicit** (not simulated through in-context adaptation)
|
| 110 |
-
- **The system becomes different** after each interaction
|
| 111 |
-
|
| 112 |
-
In-context learning is pattern matching. This is optimization.
|
| 113 |
-
|
| 114 |
-
---
|
| 115 |
-
|
| 116 |
-
## 08 / What Problem This Solves
|
| 117 |
-
|
| 118 |
-
The current paradigm for "adaptive" LLMs involves:
|
| 119 |
-
- Vector databases for retrieval
|
| 120 |
-
- Fine-tuning on user data (expensive, slow)
|
| 121 |
-
- Prompt engineering (fragile, context-limited)
|
| 122 |
-
- RAG systems (fetch, don't learn)
|
| 123 |
-
|
| 124 |
-
None of these change the model itself. They work around the frozen weights.
|
| 125 |
-
|
| 126 |
-
Test-time learning makes adaptation a first-class primitive. The model doesn't retrieve your preferences—it encodes them in its parameters. It doesn't simulate learning—it performs learning.
|
| 127 |
-
|
| 128 |
-
This opens up:
|
| 129 |
-
- **Personalization without fine-tuning** (the model adapts to you as you use it)
|
| 130 |
-
- **Continual learning** (the model improves from every interaction)
|
| 131 |
-
- **Transparent memory** (you can inspect what it learned)
|
| 132 |
-
- **Efficient adaptation** (gradient descent is cheaper than retraining)
|
| 133 |
-
|
| 134 |
-
---
|
| 135 |
-
|
| 136 |
-
## 09 / What This Means for AI's Future
|
| 137 |
-
|
| 138 |
-
The industry is converging on a model: train once, deploy frozen, scale through retrieval. This works. But it's not the only path.
|
| 139 |
-
|
| 140 |
-
Test-time learning suggests a different trajectory: models that are **living systems**, not static calculators. Systems that don't just respond to you—they change because of you.
|
| 141 |
-
|
| 142 |
-
This has implications:
|
| 143 |
-
- **Privacy**: Your data updates your local model, not a shared cloud model
|
| 144 |
-
- **Efficiency**: Learning happens incrementally, not in massive retraining runs
|
| 145 |
-
- **Alignment**: The model adapts to your values through interaction, not through RLHF on aggregate data
|
| 146 |
-
- **Transparency**: You can see what the model learned, reset it, or fork it
|
| 147 |
-
|
| 148 |
-
The tradeoff is complexity. A model that changes during use is harder to reason about, harder to debug, harder to guarantee. But the benefits—true personalization, continual improvement, user-specific adaptation—may be worth it.
|
| 149 |
-
|
| 150 |
-
---
|
| 151 |
-
|
| 152 |
-
## 10 / The Retention Gate
|
| 153 |
-
|
| 154 |
-
One detail worth highlighting: the retention gate.
|
| 155 |
-
|
| 156 |
-
When the memory encounters a high-loss input (surprising, novel), it increases the learning rate. When it encounters a low-loss input (familiar, repeated), it decreases the learning rate.
|
| 157 |
-
|
| 158 |
-
This is a simple heuristic, but it mirrors how human memory works. We remember surprising events more vividly than routine ones. The retention gate makes the memory selective—it learns more from what it doesn't already know.
|
| 159 |
-
|
| 160 |
-
In this demo, retention is always 2.0x because the memory is fresh. Everything is surprising. After hundreds of interactions, you'd see retention vary—0.5x for familiar patterns, 2.0x for novel ones. The memory would become selective.
|
| 161 |
-
|
| 162 |
-
---
|
| 163 |
-
|
| 164 |
-
## 11 / Why the Memory is Shared
|
| 165 |
-
|
| 166 |
-
This demo uses a single, shared memory across all users. This is intentional.
|
| 167 |
-
|
| 168 |
-
It demonstrates that the memory is not user-specific. It's a collective brain. Every user's input updates the same weight matrix. This makes the learning observable—you can see the loss decrease as the memory encounters repeated patterns from different users.
|
| 169 |
-
|
| 170 |
-
In a production system, you'd likely use per-user memory. But for a demo, shared memory makes the learning more visible and the privacy implications simpler (there are none—no user data is stored).
|
| 171 |
-
|
| 172 |
-
---
|
| 173 |
-
|
| 174 |
-
## 12 / The Bandwidth Constraint
|
| 175 |
-
|
| 176 |
-
One reason LLMs feel static is that they operate at the wrong bandwidth. The only way to change their behavior is to retrain them—a process that costs millions and takes weeks. Users can't influence the model in real time.
|
| 177 |
-
|
| 178 |
-
Test-time learning changes the bandwidth. The model updates with every message. The feedback loop tightens from months to milliseconds.
|
| 179 |
-
|
| 180 |
-
This doesn't mean the model becomes smarter. It means the model becomes *responsive*. It adapts to the distribution of inputs it actually sees, not the distribution it was trained on.
|
| 181 |
-
|
| 182 |
-
---
|
| 183 |
-
|
| 184 |
-
## 13 / What You're Actually Watching
|
| 185 |
-
|
| 186 |
-
When you interact with this demo, you're not chatting with a model. You're watching a memory module learn to compress hidden state patterns into a 256-dimensional space.
|
| 187 |
-
|
| 188 |
-
The text generation is a side effect. The real process is:
|
| 189 |
-
- Extract hidden states from the frozen LM
|
| 190 |
-
- Project them into memory space
|
| 191 |
-
- Predict what the memory should encode
|
| 192 |
-
- Compute the error
|
| 193 |
-
- Update the weights
|
| 194 |
-
- Save the new state
|
| 195 |
-
|
| 196 |
-
This happens for every message. The memory is always learning. The loss is always updating. The system is always changing.
|
| 197 |
-
|
| 198 |
-
That's the difference. Standard LLMs are frozen calculators. This is a living system.
|
| 199 |
-
|
| 200 |
-
---
|
| 201 |
-
|
| 202 |
-
## 14 / The Horizon
|
| 203 |
-
|
| 204 |
-
This demo is a proof of concept. It's not production-ready. It's not optimized. It's not aligned. But it demonstrates a principle: **models can learn while they think**.
|
| 205 |
-
|
| 206 |
-
The implications ripple outward:
|
| 207 |
-
- What if your AI assistant remembered how you corrected it?
|
| 208 |
-
- What if your code completion tool learned your style over time?
|
| 209 |
-
- What if your search engine adapted to your information needs?
|
| 210 |
-
- What if alignment happened through interaction, not through pre-training?
|
| 211 |
-
|
| 212 |
-
These aren't hypotheticals. They're design choices. The architecture exists. The papers are published. The code is open.
|
| 213 |
-
|
| 214 |
-
The question is whether we build systems that simulate learning or systems that perform learning.
|
| 215 |
-
|
| 216 |
-
This demo chooses the latter.
|
| 217 |
-
|
| 218 |
-
---
|
| 219 |
-
|
| 220 |
-
## 15 / A Note on Hype
|
| 221 |
-
|
| 222 |
-
Test-time learning is not a silver bullet. It introduces complexity, instability, and new failure modes. A model that changes during use is harder to trust, harder to audit, harder to guarantee.
|
| 223 |
-
|
| 224 |
-
But it's also more adaptive, more personal, more aligned with how humans actually learn.
|
| 225 |
-
|
| 226 |
-
The industry will likely converge on hybrid systems: frozen base models with test-time learning in specific modules. The best of both worlds—stability where you need it, adaptation where you want it.
|
| 227 |
-
|
| 228 |
-
This demo is a step in that direction. Not the destination. Just a clearer mental model of what's possible.
|
| 229 |
-
|
| 230 |
-
---
|
| 231 |
-
|
| 232 |
-
## 16 / The Core Metaphor
|
| 233 |
-
|
| 234 |
-
Standard LLMs are like libraries. They contain vast knowledge, but they don't change when you visit them. You can check out books (retrieve information), but the library itself remains static.
|
| 235 |
-
|
| 236 |
-
Test-time learning is like a brain. It changes with every experience. The connections strengthen or weaken based on what you encounter. The system becomes different because of you.
|
| 237 |
-
|
| 238 |
-
Both are useful. But they're not the same thing.
|
| 239 |
-
|
| 240 |
-
This demo is a brain, not a library.
|
| 241 |
-
|
| 242 |
-
---
|
| 243 |
-
|
| 244 |
-
**Papers**:
|
| 245 |
-
- Titans: Learning to (Learn at Test Time): RNNs with Expressive Hidden States ([arxiv.org/abs/2501.00663](https://arxiv.org/abs/2501.00663))
|
| 246 |
-
- MIRAS: Associative Memory with Attentional Bias ([arxiv.org/abs/2504.13173](https://arxiv.org/abs/2504.13173))
|
| 247 |
-
|
| 248 |
-
**Code**: Open source, minimal, educational.
|
| 249 |
-
**Memory**: Shared, persistent, observable.
|
| 250 |
-
**Learning**: Real, not simulated.
|
| 251 |
-
|
| 252 |
-
---
|
| 253 |
-
|
| 254 |
-
*This is not a chatbot. This is a demonstration of what happens when models learn while thinking.*
|
| 255 |
-
=======
|
| 256 |
# When Models Learn While Thinking
|
| 257 |
|
| 258 |
---
|
|
@@ -506,4 +251,3 @@ This demo is a brain, not a library.
|
|
| 506 |
---
|
| 507 |
|
| 508 |
*This is not a chatbot. This is a demonstration of what happens when models learn while thinking.*
|
| 509 |
-
>>>>>>> origin/main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# When Models Learn While Thinking
|
| 2 |
|
| 3 |
---
|
|
|
|
| 251 |
---
|
| 252 |
|
| 253 |
*This is not a chatbot. This is a demonstration of what happens when models learn while thinking.*
|
|
|
README.md
CHANGED
|
@@ -1,76 +1,3 @@
|
|
| 1 |
-
<<<<<<< HEAD
|
| 2 |
-
---
|
| 3 |
-
title: Titans Miras Demo
|
| 4 |
-
emoji: 🔬
|
| 5 |
-
colorFrom: blue
|
| 6 |
-
colorTo: purple
|
| 7 |
-
sdk: gradio
|
| 8 |
-
sdk_version: 4.36.1
|
| 9 |
-
app_file: app.py
|
| 10 |
-
pinned: false
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
# Titans + MIRAS: A Brain That Changes Itself While Thinking
|
| 14 |
-
|
| 15 |
-
A minimal but faithful reimplementation of **Titans** (test-time learning) and **MIRAS** (associative memory framework) using open-source models on Hugging Face.
|
| 16 |
-
|
| 17 |
-
## What is this?
|
| 18 |
-
|
| 19 |
-
This demo showcases a neural architecture that can **learn and update its memory while generating responses** - a brain that literally changes itself while thinking!
|
| 20 |
-
|
| 21 |
-
### Key Features
|
| 22 |
-
|
| 23 |
-
- 🔄 **Test-time learning**: Memory updates during inference (not just training)
|
| 24 |
-
- 🎯 **Retention gate**: Surprising/novel inputs are more memorable (inspired by human memory)
|
| 25 |
-
- 💾 **Persistent memory**: State is saved across sessions
|
| 26 |
-
- 🤖 **Fully OSS**: Uses distilgpt2 and runs entirely on Hugging Face
|
| 27 |
-
|
| 28 |
-
## Architecture
|
| 29 |
-
|
| 30 |
-
```
|
| 31 |
-
User Input
|
| 32 |
-
↓
|
| 33 |
-
[Base LM: distilgpt2] → Hidden States (768-dim)
|
| 34 |
-
↓
|
| 35 |
-
[Key/Value Projections] → Memory Space (256-dim)
|
| 36 |
-
↓
|
| 37 |
-
[MIRAS Memory Module] ← Test-time Gradient Updates
|
| 38 |
-
↓
|
| 39 |
-
[Text Generation] → Response + Memory Stats
|
| 40 |
-
```
|
| 41 |
-
|
| 42 |
-
### Components
|
| 43 |
-
|
| 44 |
-
1. **Base Language Model**: distilgpt2 (frozen, no training)
|
| 45 |
-
2. **Projection Layers**: Map hidden states to memory space
|
| 46 |
-
3. **MIRAS Memory**: Associative memory with learnable key→value mapping
|
| 47 |
-
4. **Retention Gate**: Adjusts learning rate based on surprise (loss magnitude)
|
| 48 |
-
5. **Memory Store**: Persists memory state to disk
|
| 49 |
-
|
| 50 |
-
## How It Works
|
| 51 |
-
|
| 52 |
-
1. Input text is processed through distilgpt2
|
| 53 |
-
2. Last hidden state is projected to key/value pairs
|
| 54 |
-
3. Memory predicts value from key
|
| 55 |
-
4. Loss (prediction error) indicates surprise
|
| 56 |
-
5. Higher surprise → higher retention → faster learning
|
| 57 |
-
6. Memory updated via gradient descent (1e-3 base LR)
|
| 58 |
-
7. Response generated and memory saved
|
| 59 |
-
|
| 60 |
-
## References
|
| 61 |
-
|
| 62 |
-
- **Titans**: [Learning to Memorize at Test Time](https://arxiv.org/abs/2501.00663)
|
| 63 |
-
- **MIRAS**: [Framework for Associative Memory with Attentional Bias](https://arxiv.org/abs/2504.13173)
|
| 64 |
-
|
| 65 |
-
## Running Locally
|
| 66 |
-
|
| 67 |
-
```bash
|
| 68 |
-
pip install -r requirements.txt
|
| 69 |
-
python app.py
|
| 70 |
-
```
|
| 71 |
-
|
| 72 |
-
Built with ❤️ exploring the future of adaptive AI systems.
|
| 73 |
-
=======
|
| 74 |
---
|
| 75 |
title: Titans Miras Demo
|
| 76 |
emoji: 🔬
|
|
@@ -142,4 +69,3 @@ python app.py
|
|
| 142 |
```
|
| 143 |
|
| 144 |
Built with ❤️ exploring the future of adaptive AI systems.
|
| 145 |
-
>>>>>>> origin/main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Titans Miras Demo
|
| 3 |
emoji: 🔬
|
|
|
|
| 69 |
```
|
| 70 |
|
| 71 |
Built with ❤️ exploring the future of adaptive AI systems.
|
|
|
app.py
CHANGED
|
@@ -12,11 +12,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
from miras_memory import MIRASMemory
|
| 15 |
-
<<<<<<< HEAD
|
| 16 |
-
from projections import KeyProjection, ValueProjection, OutputProjection
|
| 17 |
-
=======
|
| 18 |
from projections import KeyProjection, ValueProjection
|
| 19 |
-
>>>>>>> origin/main
|
| 20 |
from memory_store import MemoryStore
|
| 21 |
|
| 22 |
print("=" * 50)
|
|
@@ -30,10 +26,6 @@ HIDDEN_DIM = 768 # distilgpt2 hidden dimension
|
|
| 30 |
MEMORY_DIM = 256 # Memory space dimension
|
| 31 |
LEARNING_RATE = 1e-3 # Base learning rate for test-time updates
|
| 32 |
MAX_NEW_TOKENS = 50 # Max tokens to generate
|
| 33 |
-
<<<<<<< HEAD
|
| 34 |
-
MEMORY_ALPHA = 0.1 # Memory influence strength on generation
|
| 35 |
-
=======
|
| 36 |
-
>>>>>>> origin/main
|
| 37 |
|
| 38 |
# ========== Initialize Components ==========
|
| 39 |
print("🧠 Initializing Titans + MIRAS brain...")
|
|
@@ -47,10 +39,6 @@ model.eval() # Frozen - no training
|
|
| 47 |
# Create projection layers
|
| 48 |
key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
|
| 49 |
value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
|
| 50 |
-
<<<<<<< HEAD
|
| 51 |
-
output_proj = OutputProjection(MEMORY_DIM, HIDDEN_DIM) # Map memory back to hidden space
|
| 52 |
-
=======
|
| 53 |
-
>>>>>>> origin/main
|
| 54 |
|
| 55 |
# Create memory module
|
| 56 |
memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
|
|
@@ -112,44 +100,6 @@ def chat(message, history):
|
|
| 112 |
# Update stats
|
| 113 |
memory.update_stats(loss)
|
| 114 |
|
| 115 |
-
<<<<<<< HEAD
|
| 116 |
-
# === Step 3: Memory-augmented generation (THE KEY CHANGE!) ===
|
| 117 |
-
# Instead of model.generate(), we do token-by-token generation
|
| 118 |
-
# where memory augments the hidden state before prediction.
|
| 119 |
-
|
| 120 |
-
generated_ids = inputs['input_ids'].clone()
|
| 121 |
-
|
| 122 |
-
with torch.no_grad():
|
| 123 |
-
for _ in range(MAX_NEW_TOKENS):
|
| 124 |
-
# Get hidden states from model
|
| 125 |
-
outputs = model(generated_ids, output_hidden_states=True)
|
| 126 |
-
h_last = outputs.hidden_states[-1][:, -1, :] # (1, hidden_dim)
|
| 127 |
-
|
| 128 |
-
# Query memory with projected key
|
| 129 |
-
k_gen = key_proj(h_last)
|
| 130 |
-
memory_out = memory.query(k_gen) # (1, memory_dim)
|
| 131 |
-
|
| 132 |
-
# Augment hidden state with memory output
|
| 133 |
-
# h' = h + alpha * output_proj(memory(k))
|
| 134 |
-
h_augmented = h_last + MEMORY_ALPHA * output_proj(memory_out)
|
| 135 |
-
|
| 136 |
-
# Compute logits with augmented hidden state
|
| 137 |
-
logits = model.lm_head(h_augmented) # (1, vocab_size)
|
| 138 |
-
|
| 139 |
-
# Sample next token (temperature sampling)
|
| 140 |
-
logits = logits / 0.8 # temperature
|
| 141 |
-
probs = torch.softmax(logits, dim=-1)
|
| 142 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 143 |
-
|
| 144 |
-
# Stop if EOS
|
| 145 |
-
if next_token.item() == tokenizer.eos_token_id:
|
| 146 |
-
break
|
| 147 |
-
|
| 148 |
-
# Append to sequence
|
| 149 |
-
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 150 |
-
|
| 151 |
-
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 152 |
-
=======
|
| 153 |
# === Step 3: Generate response ===
|
| 154 |
with torch.no_grad():
|
| 155 |
output_ids = model.generate(
|
|
@@ -162,7 +112,6 @@ def chat(message, history):
|
|
| 162 |
)
|
| 163 |
|
| 164 |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 165 |
-
>>>>>>> origin/main
|
| 166 |
|
| 167 |
# Remove the input prompt from response
|
| 168 |
if response.startswith(message):
|
|
@@ -192,202 +141,235 @@ def chat(message, history):
|
|
| 192 |
# ========== Gradio Interface ==========
|
| 193 |
print("🚀 Launching Gradio interface...")
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
|
| 198 |
-
# 🧠 The Brain That Learns While Thinking
|
| 199 |
-
### A Living System That Updates Its Neural Weights During Inference
|
| 200 |
|
| 201 |
-
|
| 202 |
|
| 203 |
-
|
| 204 |
|
| 205 |
-
**
|
| 206 |
-
|
| 207 |
|
| 208 |
-
|
| 209 |
-
label="Chat & Watch the Memory Learn",
|
| 210 |
-
height=500,
|
| 211 |
-
)
|
| 212 |
|
| 213 |
-
|
| 214 |
-
label="Your Message",
|
| 215 |
-
placeholder="Try: hello world (send it 5 times and watch loss decrease!)",
|
| 216 |
-
lines=2,
|
| 217 |
-
)
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
"hello world",
|
| 227 |
-
"Supercalifragilisticexpialidocious quantum entanglement",
|
| 228 |
-
"my name is Pavan",
|
| 229 |
-
],
|
| 230 |
-
inputs=msg,
|
| 231 |
-
label="Try these (especially repeat 'hello world'!)",
|
| 232 |
-
)
|
| 233 |
|
| 234 |
-
# Built with section
|
| 235 |
-
gr.Markdown("""
|
| 236 |
---
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
- Prediction error - how surprised the memory is
|
| 247 |
-
- **Lower = memory is familiar** with this pattern
|
| 248 |
-
- **Decreasing loss = learning is happening!**
|
| 249 |
-
|
| 250 |
-
**Retention** (e.g., 2.00x)
|
| 251 |
-
- Learning rate multiplier based on surprise
|
| 252 |
-
- 2.0x = very surprising, learns aggressively
|
| 253 |
-
- 0.5x = familiar, learns slowly
|
| 254 |
-
|
| 255 |
-
**Updates** (e.g., 1 → 2 → 3...)
|
| 256 |
-
- Total memory updates
|
| 257 |
-
- **Persists across page refreshes!**
|
| 258 |
-
- Proof that memory is permanent
|
| 259 |
-
|
| 260 |
-
**Avg Loss** (e.g., 7.26)
|
| 261 |
-
- Running average showing long-term learning progress
|
| 262 |
-
""")
|
| 263 |
-
|
| 264 |
-
with gr.Accordion("🧪 Interactive Experiments", open=False):
|
| 265 |
-
gr.Markdown("""
|
| 266 |
-
### Experiment 1: Watch Loss Decrease
|
| 267 |
-
1. Send "hello world" 5 times
|
| 268 |
-
2. Watch loss: 7.5 → 6.0 → 5.0 → 4.0
|
| 269 |
-
3. **This proves learning!**
|
| 270 |
-
|
| 271 |
-
### Experiment 2: Trigger Surprise
|
| 272 |
-
1. After experiment 1, send something completely different
|
| 273 |
-
2. Watch loss spike back up (4.0 → 9.0+)
|
| 274 |
-
3. **Memory detects novelty!**
|
| 275 |
-
|
| 276 |
-
### Experiment 3: Test Persistence
|
| 277 |
-
1. Note the "Updates" count
|
| 278 |
-
2. Refresh this entire page
|
| 279 |
-
3. Send any message
|
| 280 |
-
4. Updates should continue, not reset!
|
| 281 |
-
5. **Memory survives refresh!**
|
| 282 |
-
""")
|
| 283 |
-
|
| 284 |
-
<<<<<<< HEAD
|
| 285 |
-
with gr.Accordion("✨ Memory-Augmented Generation", open=False):
|
| 286 |
-
gr.Markdown("""
|
| 287 |
-
**Memory now influences text generation!**
|
| 288 |
-
|
| 289 |
-
- At each token generation step, we query the memory
|
| 290 |
-
- Memory output is added to the hidden state: `h' = h + α × memory(k)`
|
| 291 |
-
- This augmented state determines the next token
|
| 292 |
-
|
| 293 |
-
**What to expect:**
|
| 294 |
-
- Repeated inputs → more consistent outputs
|
| 295 |
-
- As memory learns patterns, it biases generation toward them
|
| 296 |
-
- Novel inputs → more random outputs (memory has no prior)
|
| 297 |
-
=======
|
| 298 |
-
with gr.Accordion("⚠️ Important: Ignore the Text", open=False):
|
| 299 |
-
gr.Markdown("""
|
| 300 |
-
**The text responses are random** - this is expected!
|
| 301 |
-
|
| 302 |
-
- We're NOT training the text generator (distilgpt2 is frozen)
|
| 303 |
-
- **Focus on the numbers below each response**
|
| 304 |
-
- The magic is in the decreasing loss, not the text
|
| 305 |
-
|
| 306 |
-
**Why?** We're demonstrating **memory learning**, not text generation.
|
| 307 |
-
>>>>>>> origin/main
|
| 308 |
-
""")
|
| 309 |
-
|
| 310 |
-
with gr.Accordion("🔬 How It Works", open=False):
|
| 311 |
-
gr.Markdown("""
|
| 312 |
-
```
|
| 313 |
-
Your Message
|
| 314 |
-
↓
|
| 315 |
-
<<<<<<< HEAD
|
| 316 |
-
[distilgpt2: FROZEN]
|
| 317 |
-
↓
|
| 318 |
-
Hidden States (768-dim)
|
| 319 |
-
↓
|
| 320 |
-
[Key Projection] → Memory (256-dim)
|
| 321 |
-
↓
|
| 322 |
-
[MIRAS Memory: LEARNING!]
|
| 323 |
-
↓
|
| 324 |
-
[Output Projection] → Augmentation (768-dim)
|
| 325 |
-
↓
|
| 326 |
-
h' = h + α × memory_output ← THE KEY!
|
| 327 |
-
↓
|
| 328 |
-
LM Head → Next Token
|
| 329 |
-
```
|
| 330 |
-
|
| 331 |
-
**Key**: Memory modifies hidden states before prediction!
|
| 332 |
-
=======
|
| 333 |
-
[distilgpt2: FROZEN] ← Not learning
|
| 334 |
-
↓
|
| 335 |
-
Hidden States (768-dim)
|
| 336 |
-
↓
|
| 337 |
-
[Projections] → Memory (256-dim)
|
| 338 |
-
↓
|
| 339 |
-
[MIRAS Memory: LEARNING!]
|
| 340 |
-
↓
|
| 341 |
-
Loss = Prediction error
|
| 342 |
-
↓
|
| 343 |
-
Gradient Descent → Weights change
|
| 344 |
-
↓
|
| 345 |
-
Saved to disk → Persists
|
| 346 |
-
```
|
| 347 |
-
|
| 348 |
-
**Key**: We're training the **memory**, not the text generator!
|
| 349 |
-
>>>>>>> origin/main
|
| 350 |
-
""")
|
| 351 |
-
|
| 352 |
-
with gr.Accordion("💡 Why This Matters", open=False):
|
| 353 |
-
gr.Markdown("""
|
| 354 |
-
**Standard LLMs** (ChatGPT, Claude):
|
| 355 |
-
- Weights frozen after training
|
| 356 |
-
- "Learning" is just pattern matching
|
| 357 |
-
- Forget when context ends
|
| 358 |
-
|
| 359 |
-
**This Demo** (Titans + MIRAS):
|
| 360 |
-
- Weights update during inference
|
| 361 |
-
- Real gradient descent
|
| 362 |
-
- Memory persists across sessions
|
| 363 |
-
|
| 364 |
-
That decreasing loss? **That's gradient descent during inference.**
|
| 365 |
-
That's what ChatGPT doesn't do.
|
| 366 |
-
""")
|
| 367 |
|
| 368 |
-
# Footer
|
| 369 |
-
gr.Markdown("""
|
| 370 |
---
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
-
|
| 376 |
-
def user_submit(user_message, history):
|
| 377 |
-
return "", history + [[user_message, None]]
|
| 378 |
|
| 379 |
-
|
| 380 |
-
user_message = history[-1][0]
|
| 381 |
-
bot_message = chat(user_message, history[:-1])
|
| 382 |
-
history[-1][1] = bot_message
|
| 383 |
-
return history
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
demo.launch()
|
|
|
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
from miras_memory import MIRASMemory
|
|
|
|
|
|
|
|
|
|
| 15 |
from projections import KeyProjection, ValueProjection
|
|
|
|
| 16 |
from memory_store import MemoryStore
|
| 17 |
|
| 18 |
print("=" * 50)
|
|
|
|
| 26 |
MEMORY_DIM = 256 # Memory space dimension
|
| 27 |
LEARNING_RATE = 1e-3 # Base learning rate for test-time updates
|
| 28 |
MAX_NEW_TOKENS = 50 # Max tokens to generate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# ========== Initialize Components ==========
|
| 31 |
print("🧠 Initializing Titans + MIRAS brain...")
|
|
|
|
| 39 |
# Create projection layers
|
| 40 |
key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
|
| 41 |
value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Create memory module
|
| 44 |
memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
|
|
|
|
| 100 |
# Update stats
|
| 101 |
memory.update_stats(loss)
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# === Step 3: Generate response ===
|
| 104 |
with torch.no_grad():
|
| 105 |
output_ids = model.generate(
|
|
|
|
| 112 |
)
|
| 113 |
|
| 114 |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
| 115 |
|
| 116 |
# Remove the input prompt from response
|
| 117 |
if response.startswith(message):
|
|
|
|
| 141 |
# ========== Gradio Interface ==========
|
| 142 |
print("🚀 Launching Gradio interface...")
|
| 143 |
|
| 144 |
+
demo = gr.ChatInterface(
|
| 145 |
+
fn=chat,
|
| 146 |
+
title="🧠 The Brain That Learns While Thinking",
|
| 147 |
+
description="""
|
| 148 |
+
# A Living System That Updates Its Weights During Inference
|
| 149 |
|
| 150 |
+
**The Novel Thing**: Standard LLMs freeze their weights after training. This system performs gradient descent *while you chat*.
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
---
|
| 153 |
|
| 154 |
+
## 🚀 The Revolutionary Difference
|
| 155 |
|
| 156 |
+
**Standard LLMs (ChatGPT, Claude, etc.)**: Think → Predict → **Forget**
|
| 157 |
+
**Titans + MIRAS**: Think → Predict → **Update** → **Remember** → Think Differently
|
| 158 |
|
| 159 |
+
---
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
### 💡 What Makes This Different?
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
| Feature | ChatGPT/Claude/GPT-4 | This Demo (Titans+MIRAS) |
|
| 164 |
+
|---------|---------------------|--------------------------|
|
| 165 |
+
| **Weights during chat** | 🔒 Frozen forever | ✅ Update with every message |
|
| 166 |
+
| **Learning** | ❌ Simulated (in-context only) | ✅ Real (gradient descent) |
|
| 167 |
+
| **Memory** | 📝 Token context only | 🧠 Neural parameters |
|
| 168 |
+
| **Persistence** | ❌ Forgets when context ends | ✅ Saves to disk |
|
| 169 |
+
| **Adaptation** | 🎭 Acts like it learned | 🔬 Actually learns |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
|
|
|
|
|
|
| 171 |
---
|
| 172 |
|
| 173 |
+
### 🎯 What You're Witnessing
|
| 174 |
+
|
| 175 |
+
**This is NOT a better chatbot** - it's a **learning demonstrator**.
|
| 176 |
+
|
| 177 |
+
1. **The text responses are random** - that's expected! We're using a small, frozen model (distilgpt2)
|
| 178 |
+
2. **The MAGIC is in the numbers below** - watch the "Loss" decrease when you repeat inputs!
|
| 179 |
+
3. **Every message physically changes the brain** - the memory weights update via gradient descent
|
| 180 |
+
4. **Refresh the page** - the update count continues (memory persists!)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
|
|
|
|
|
|
| 182 |
---
|
| 183 |
|
| 184 |
+
### 🧪 How It Works (The Technical Truth)
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
Your Message
|
| 188 |
+
↓
|
| 189 |
+
[distilgpt2: FROZEN] ← Not learning, just generating
|
| 190 |
+
↓
|
| 191 |
+
Hidden States (768-dim)
|
| 192 |
+
↓
|
| 193 |
+
[Projections] → Memory Space (256-dim)
|
| 194 |
+
↓
|
| 195 |
+
[MIRAS Memory: LEARNING!] ← This is what updates!
|
| 196 |
+
↓
|
| 197 |
+
Loss = How surprised the memory is
|
| 198 |
+
↓
|
| 199 |
+
Gradient Descent → Memory weights change
|
| 200 |
+
↓
|
| 201 |
+
Saved to disk → Persists forever
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
**Key Insight**: We're training the **memory**, not the text generator!
|
| 205 |
|
| 206 |
+
---
|
|
|
|
|
|
|
| 207 |
|
| 208 |
+
### 🔬 The Science: Why This Matters
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
**Standard LLMs**:
|
| 211 |
+
- Weights frozen after training (costs millions)
|
| 212 |
+
- "Learning" is just pattern matching in context
|
| 213 |
+
- Forget everything when context ends
|
| 214 |
+
- Same model for everyone
|
| 215 |
+
|
| 216 |
+
**Titans + MIRAS**:
|
| 217 |
+
- Weights update during inference (free!)
|
| 218 |
+
- Real optimization via gradient descent
|
| 219 |
+
- Memory persists across sessions
|
| 220 |
+
- Personalizes to each user
|
| 221 |
+
|
| 222 |
+
**This is test-time learning** - the future of adaptive AI.
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
### 📊 What the Stats Mean
|
| 227 |
+
|
| 228 |
+
- **Loss**: How surprised the memory is (lower = more familiar)
|
| 229 |
+
- **Retention**: Learning rate multiplier (2.0x = very surprising, 0.5x = familiar)
|
| 230 |
+
- **Updates**: Total number of memory updates (persists across sessions!)
|
| 231 |
+
- **Avg Loss**: Overall learning progress
|
| 232 |
+
|
| 233 |
+
---
|
| 234 |
+
|
| 235 |
+
### 🎮 Try This Experiment
|
| 236 |
+
|
| 237 |
+
1. **Send "hello world" 5 times** → Watch loss decrease!
|
| 238 |
+
2. **Send something completely different** → Loss spikes!
|
| 239 |
+
3. **Refresh the page and send another message** → Update count continues!
|
| 240 |
+
|
| 241 |
+
**That decreasing loss is proof the neural weights are changing!**
|
| 242 |
+
|
| 243 |
+
---
|
| 244 |
+
|
| 245 |
+
### 🌟 The Bottom Line
|
| 246 |
+
|
| 247 |
+
**ChatGPT**: A frozen calculator that *simulates* adaptation
|
| 248 |
+
**This Demo**: A living system that *performs* adaptation
|
| 249 |
+
|
| 250 |
+
You're not chatting with a model.
|
| 251 |
+
**You're watching a brain rewire itself in real-time.** 🧠⚡
|
| 252 |
+
|
| 253 |
+
---
|
| 254 |
+
|
| 255 |
+
### 🧪 How to Test This (Interactive Experiments)
|
| 256 |
+
|
| 257 |
+
**Don't just chat—run experiments to see the learning happen!**
|
| 258 |
+
|
| 259 |
+
#### Experiment 1: Watch Loss Decrease (Proof of Learning)
|
| 260 |
+
```
|
| 261 |
+
1. Send "hello world"
|
| 262 |
+
2. Send "hello world" again
|
| 263 |
+
3. Send "hello world" again
|
| 264 |
+
4. Send "hello world" again
|
| 265 |
+
5. Send "hello world" again
|
| 266 |
+
```
|
| 267 |
+
**What to watch**: Loss should decrease each time (7.5 → 6.0 → 5.0 → 4.0)
|
| 268 |
+
**Why it matters**: This proves the memory is learning the pattern!
|
| 269 |
+
|
| 270 |
+
#### Experiment 2: Trigger Surprise (Spike the Loss)
|
| 271 |
+
```
|
| 272 |
+
1. Send "hello world" 5 times (loss decreases)
|
| 273 |
+
2. Then send: "Supercalifragilisticexpialidocious quantum entanglement"
|
| 274 |
+
```
|
| 275 |
+
**What to watch**: Loss should spike back up (4.0 → 9.0+)
|
| 276 |
+
**Why it matters**: The memory detects novelty—it knows this is different!
|
| 277 |
+
|
| 278 |
+
#### Experiment 3: Test Persistence (Memory Survives)
|
| 279 |
+
```
|
| 280 |
+
1. Note the "Updates" count (e.g., 15)
|
| 281 |
+
2. Refresh this page completely
|
| 282 |
+
3. Send any message
|
| 283 |
+
4. Check if Updates = 16 (not reset to 1!)
|
| 284 |
+
```
|
| 285 |
+
**What to watch**: Update count should continue, not reset
|
| 286 |
+
**Why it matters**: Memory persists to disk—it's not just in RAM!
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
|
| 290 |
+
### 📊 What Each Stat Means (Decoder Ring)
|
| 291 |
+
|
| 292 |
+
**Loss** (e.g., 7.48 → 6.61 → 5.23)
|
| 293 |
+
- **What it is**: Prediction error (how surprised the memory is)
|
| 294 |
+
- **Lower = Better**: Memory is familiar with this pattern
|
| 295 |
+
- **Higher = Novel**: Memory hasn't seen this before
|
| 296 |
+
- **Why it matters**: Decreasing loss = learning is happening!
|
| 297 |
+
|
| 298 |
+
**Retention** (e.g., 2.00x)
|
| 299 |
+
- **What it is**: Learning rate multiplier based on surprise
|
| 300 |
+
- **2.0x = Very surprising**: Memory learns aggressively
|
| 301 |
+
- **0.5x = Very familiar**: Memory learns slowly (you won't see this yet)
|
| 302 |
+
- **Why it matters**: The brain learns more from surprising events (like humans!)
|
| 303 |
+
|
| 304 |
+
**Updates** (e.g., 1 → 2 → 3 → 4...)
|
| 305 |
+
- **What it is**: Total number of memory updates
|
| 306 |
+
- **Persists across sessions**: Survives page refreshes
|
| 307 |
+
- **Never resets**: Keeps counting forever
|
| 308 |
+
- **Why it matters**: Proof that memory is persistent, not ephemeral!
|
| 309 |
+
|
| 310 |
+
**Avg Loss** (e.g., 7.26)
|
| 311 |
+
- **What it is**: Running average of all losses
|
| 312 |
+
- **Trends downward**: As memory learns recurring patterns
|
| 313 |
+
- **Reflects overall learning**: Lower = memory is getting smarter
|
| 314 |
+
- **Why it matters**: Shows long-term learning progress!
|
| 315 |
+
|
| 316 |
+
---
|
| 317 |
+
|
| 318 |
+
### ⚠️ What to Ignore (Important!)
|
| 319 |
+
|
| 320 |
+
**The text responses are random and bad** - this is expected!
|
| 321 |
+
- We're NOT training the text generator (distilgpt2 is frozen)
|
| 322 |
+
- The responses don't matter—they're a side effect
|
| 323 |
+
- **Focus on the numbers below**, not the text above
|
| 324 |
+
- The magic is in the decreasing loss, not the generated text
|
| 325 |
+
|
| 326 |
+
**Why?** Because we're demonstrating **memory learning**, not text generation.
|
| 327 |
+
Standard LLMs train the text generator. This trains the memory. Different goals.
|
| 328 |
+
|
| 329 |
+
---
|
| 330 |
+
|
| 331 |
+
### 🎯 What Success Looks Like
|
| 332 |
+
|
| 333 |
+
✅ **You're seeing it work if**:
|
| 334 |
+
- Loss decreases when you repeat inputs
|
| 335 |
+
- Loss spikes when you send something new
|
| 336 |
+
- Update count increments with each message
|
| 337 |
+
- Update count persists after page refresh
|
| 338 |
+
- Retention is 2.0x (everything is surprising to fresh memory)
|
| 339 |
+
|
| 340 |
+
❌ **You're NOT seeing it work if**:
|
| 341 |
+
- Loss stays constant (not learning)
|
| 342 |
+
- Updates reset to 1 after refresh (not persisting)
|
| 343 |
+
- No stats appear below responses
|
| 344 |
+
|
| 345 |
+
---
|
| 346 |
+
|
| 347 |
+
### 🔬 Why This Matters (The Big Picture)
|
| 348 |
+
|
| 349 |
+
**Standard LLMs**: Frozen weights → No learning during use
|
| 350 |
+
**This Demo**: Live weights → Learning with every message
|
| 351 |
+
|
| 352 |
+
That decreasing loss you see? **That's gradient descent happening during inference.**
|
| 353 |
+
That's the revolution. That's what ChatGPT doesn't do.
|
| 354 |
+
|
| 355 |
+
You're not just using a model. **You're watching it change.**
|
| 356 |
+
|
| 357 |
+
---
|
| 358 |
+
|
| 359 |
+
*Built with Titans (test-time training) + MIRAS (associative memory)*
|
| 360 |
+
*Papers: [Titans](https://arxiv.org/abs/2501.00663) | [MIRAS](https://arxiv.org/abs/2504.13173)*
|
| 361 |
+
|
| 362 |
+
**📖 [Read the full essay: "When Models Learn While Thinking"](https://huggingface.co/spaces/Pavantej/titans-miras-demo/blob/main/ESSAY.md)**
|
| 363 |
+
""",
|
| 364 |
+
examples=[
|
| 365 |
+
"hello world",
|
| 366 |
+
"hello world", # Repeat to show learning!
|
| 367 |
+
"Tell me about test-time learning",
|
| 368 |
+
"What is 2+2?",
|
| 369 |
+
"my name is [your name]",
|
| 370 |
+
],
|
| 371 |
+
cache_examples=False,
|
| 372 |
+
theme="soft",
|
| 373 |
+
)
|
| 374 |
|
| 375 |
demo.launch()
|
memory_store.py
CHANGED
|
@@ -1,88 +1,3 @@
|
|
| 1 |
-
<<<<<<< HEAD
|
| 2 |
-
"""
|
| 3 |
-
Memory Persistence
|
| 4 |
-
|
| 5 |
-
Handles saving and loading memory state to/from disk so the brain
|
| 6 |
-
remembers across sessions.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import json
|
| 11 |
-
import os
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
from datetime import datetime
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class MemoryStore:
|
| 17 |
-
"""Manages persistent storage of memory state."""
|
| 18 |
-
|
| 19 |
-
def __init__(self, save_dir="memory"):
|
| 20 |
-
self.save_dir = Path(save_dir)
|
| 21 |
-
self.save_dir.mkdir(exist_ok=True)
|
| 22 |
-
self.memory_path = self.save_dir / "memory.pt"
|
| 23 |
-
self.metadata_path = self.save_dir / "metadata.json"
|
| 24 |
-
|
| 25 |
-
def save(self, memory_module):
|
| 26 |
-
"""
|
| 27 |
-
Save memory state to disk.
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
memory_module: MIRASMemory instance
|
| 31 |
-
"""
|
| 32 |
-
# Save memory weights
|
| 33 |
-
torch.save({
|
| 34 |
-
'W': memory_module.W.data,
|
| 35 |
-
'update_count': memory_module.update_count,
|
| 36 |
-
'total_loss': memory_module.total_loss,
|
| 37 |
-
}, self.memory_path)
|
| 38 |
-
|
| 39 |
-
# Save metadata
|
| 40 |
-
metadata = {
|
| 41 |
-
'last_updated': datetime.now().isoformat(),
|
| 42 |
-
'memory_dim': memory_module.memory_dim,
|
| 43 |
-
'updates': memory_module.update_count.item(),
|
| 44 |
-
'avg_loss': (memory_module.total_loss / max(memory_module.update_count, 1)).item(),
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
with open(self.metadata_path, 'w') as f:
|
| 48 |
-
json.dump(metadata, f, indent=2)
|
| 49 |
-
|
| 50 |
-
print(f"💾 Memory saved: {memory_module.update_count.item()} updates")
|
| 51 |
-
|
| 52 |
-
def load(self, memory_module):
|
| 53 |
-
"""
|
| 54 |
-
Load memory state from disk.
|
| 55 |
-
|
| 56 |
-
Args:
|
| 57 |
-
memory_module: MIRASMemory instance to load into
|
| 58 |
-
|
| 59 |
-
Returns:
|
| 60 |
-
bool: True if loaded successfully, False otherwise
|
| 61 |
-
"""
|
| 62 |
-
if not self.memory_path.exists():
|
| 63 |
-
print("🆕 No saved memory found. Starting fresh!")
|
| 64 |
-
return False
|
| 65 |
-
|
| 66 |
-
try:
|
| 67 |
-
checkpoint = torch.load(self.memory_path)
|
| 68 |
-
memory_module.W.data = checkpoint['W']
|
| 69 |
-
memory_module.update_count = checkpoint['update_count']
|
| 70 |
-
memory_module.total_loss = checkpoint['total_loss']
|
| 71 |
-
|
| 72 |
-
print(f"✅ Memory loaded: {memory_module.update_count.item()} updates")
|
| 73 |
-
return True
|
| 74 |
-
except Exception as e:
|
| 75 |
-
print(f"⚠️ Error loading memory: {e}. Starting fresh!")
|
| 76 |
-
return False
|
| 77 |
-
|
| 78 |
-
def get_metadata(self):
|
| 79 |
-
"""Get metadata about saved memory."""
|
| 80 |
-
if not self.metadata_path.exists():
|
| 81 |
-
return None
|
| 82 |
-
|
| 83 |
-
with open(self.metadata_path, 'r') as f:
|
| 84 |
-
return json.load(f)
|
| 85 |
-
=======
|
| 86 |
"""
|
| 87 |
Memory Persistence
|
| 88 |
|
|
@@ -166,4 +81,3 @@ class MemoryStore:
|
|
| 166 |
|
| 167 |
with open(self.metadata_path, 'r') as f:
|
| 168 |
return json.load(f)
|
| 169 |
-
>>>>>>> origin/main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Memory Persistence
|
| 3 |
|
|
|
|
| 81 |
|
| 82 |
with open(self.metadata_path, 'r') as f:
|
| 83 |
return json.load(f)
|
|
|
miras_memory.py
CHANGED
|
@@ -1,115 +1,3 @@
|
|
| 1 |
-
<<<<<<< HEAD
|
| 2 |
-
"""
|
| 3 |
-
MIRAS-inspired Associative Memory Module
|
| 4 |
-
|
| 5 |
-
Implements an associative memory that learns key-value mappings
|
| 6 |
-
through attentional bias objective during test time.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class MIRASMemory(nn.Module):
|
| 14 |
-
"""
|
| 15 |
-
Associative memory module inspired by MIRAS framework.
|
| 16 |
-
|
| 17 |
-
The memory learns to map keys to values using a simple linear projection
|
| 18 |
-
and updates itself during test time via gradient descent.
|
| 19 |
-
|
| 20 |
-
Args:
|
| 21 |
-
memory_dim: Dimensionality of memory keys/values
|
| 22 |
-
init_scale: Scale for random weight initialization
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, memory_dim=256, init_scale=0.01):
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.memory_dim = memory_dim
|
| 28 |
-
|
| 29 |
-
# Memory matrix: maps keys to values
|
| 30 |
-
# W: (memory_dim, memory_dim)
|
| 31 |
-
self.W = nn.Parameter(
|
| 32 |
-
torch.randn(memory_dim, memory_dim) * init_scale
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
# Track number of updates for retention gate
|
| 36 |
-
self.register_buffer('update_count', torch.tensor(0))
|
| 37 |
-
self.register_buffer('total_loss', torch.tensor(0.0))
|
| 38 |
-
|
| 39 |
-
def forward(self, key):
|
| 40 |
-
"""
|
| 41 |
-
Query memory with a key.
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
key: (batch_size, memory_dim) tensor
|
| 45 |
-
|
| 46 |
-
Returns:
|
| 47 |
-
predicted_value: (batch_size, memory_dim) tensor
|
| 48 |
-
"""
|
| 49 |
-
# Simple linear mapping: pred_v = k @ W
|
| 50 |
-
predicted_value = key @ self.W
|
| 51 |
-
return predicted_value
|
| 52 |
-
|
| 53 |
-
def query(self, key):
|
| 54 |
-
"""
|
| 55 |
-
Query memory without computing gradients (for generation).
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
key: (batch_size, memory_dim) tensor
|
| 59 |
-
|
| 60 |
-
Returns:
|
| 61 |
-
memory_output: (batch_size, memory_dim) tensor
|
| 62 |
-
"""
|
| 63 |
-
with torch.no_grad():
|
| 64 |
-
return self.forward(key)
|
| 65 |
-
|
| 66 |
-
def compute_loss(self, key, value):
|
| 67 |
-
"""
|
| 68 |
-
Compute attentional bias loss between predicted and true value.
|
| 69 |
-
|
| 70 |
-
Args:
|
| 71 |
-
key: (batch_size, memory_dim)
|
| 72 |
-
value: (batch_size, memory_dim)
|
| 73 |
-
|
| 74 |
-
Returns:
|
| 75 |
-
loss: scalar tensor
|
| 76 |
-
"""
|
| 77 |
-
pred = self.forward(key)
|
| 78 |
-
loss = ((pred - value) ** 2).mean()
|
| 79 |
-
return loss
|
| 80 |
-
|
| 81 |
-
def retention_gate(self, loss):
|
| 82 |
-
"""
|
| 83 |
-
Simple retention gate: higher loss = more surprising = more memorable.
|
| 84 |
-
|
| 85 |
-
Returns a scaling factor for the learning rate based on surprise.
|
| 86 |
-
High loss (surprising) gets higher weight.
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
loss: scalar tensor
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
retention_factor: scalar in range [0.5, 2.0]
|
| 93 |
-
"""
|
| 94 |
-
# Normalize loss to a retention factor
|
| 95 |
-
# If loss is high (surprising), learn more aggressively
|
| 96 |
-
retention_factor = torch.clamp(loss / 0.1, 0.5, 2.0)
|
| 97 |
-
return retention_factor.item()
|
| 98 |
-
|
| 99 |
-
def update_stats(self, loss):
|
| 100 |
-
"""Track memory statistics."""
|
| 101 |
-
self.update_count += 1
|
| 102 |
-
self.total_loss += loss.item()
|
| 103 |
-
|
| 104 |
-
def get_stats(self):
|
| 105 |
-
"""Get memory statistics."""
|
| 106 |
-
avg_loss = self.total_loss / max(self.update_count, 1)
|
| 107 |
-
return {
|
| 108 |
-
'updates': self.update_count.item(),
|
| 109 |
-
'avg_loss': avg_loss.item(),
|
| 110 |
-
'memory_size': self.W.numel()
|
| 111 |
-
}
|
| 112 |
-
=======
|
| 113 |
"""
|
| 114 |
MIRAS-inspired Associative Memory Module
|
| 115 |
|
|
@@ -207,4 +95,3 @@ class MIRASMemory(nn.Module):
|
|
| 207 |
'avg_loss': avg_loss.item(),
|
| 208 |
'memory_size': self.W.numel()
|
| 209 |
}
|
| 210 |
-
>>>>>>> origin/main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
MIRAS-inspired Associative Memory Module
|
| 3 |
|
|
|
|
| 95 |
'avg_loss': avg_loss.item(),
|
| 96 |
'memory_size': self.W.numel()
|
| 97 |
}
|
|
|
projections.py
CHANGED
|
@@ -1,85 +1,3 @@
|
|
| 1 |
-
<<<<<<< HEAD
|
| 2 |
-
"""
|
| 3 |
-
Key and Value Projection Layers
|
| 4 |
-
|
| 5 |
-
Maps hidden states from the base language model into memory-compatible
|
| 6 |
-
representations for the MIRAS memory module.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class KeyProjection(nn.Module):
|
| 13 |
-
"""
|
| 14 |
-
Projects hidden states to memory keys.
|
| 15 |
-
|
| 16 |
-
Args:
|
| 17 |
-
hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
|
| 18 |
-
memory_dim: Dimension of memory keys (e.g., 256)
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
def __init__(self, hidden_dim, memory_dim):
|
| 22 |
-
super().__init__()
|
| 23 |
-
self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
|
| 24 |
-
|
| 25 |
-
def forward(self, hidden_state):
|
| 26 |
-
"""
|
| 27 |
-
Args:
|
| 28 |
-
hidden_state: (batch_size, hidden_dim)
|
| 29 |
-
Returns:
|
| 30 |
-
key: (batch_size, memory_dim)
|
| 31 |
-
"""
|
| 32 |
-
return self.projection(hidden_state)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class ValueProjection(nn.Module):
|
| 36 |
-
"""
|
| 37 |
-
Projects hidden states to memory values.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
|
| 41 |
-
memory_dim: Dimension of memory values (e.g., 256)
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
def __init__(self, hidden_dim, memory_dim):
|
| 45 |
-
super().__init__()
|
| 46 |
-
self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
|
| 47 |
-
|
| 48 |
-
def forward(self, hidden_state):
|
| 49 |
-
"""
|
| 50 |
-
Args:
|
| 51 |
-
hidden_state: (batch_size, hidden_dim)
|
| 52 |
-
Returns:
|
| 53 |
-
value: (batch_size, memory_dim)
|
| 54 |
-
"""
|
| 55 |
-
return self.projection(hidden_state)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class OutputProjection(nn.Module):
|
| 59 |
-
"""
|
| 60 |
-
Projects memory output back to hidden dimension for augmentation.
|
| 61 |
-
|
| 62 |
-
This is the key bridge that allows memory to influence generation:
|
| 63 |
-
h' = h + alpha * output_proj(memory(k))
|
| 64 |
-
|
| 65 |
-
Args:
|
| 66 |
-
memory_dim: Dimension of memory output (e.g., 256)
|
| 67 |
-
hidden_dim: Dimension of LM hidden states (e.g., 768)
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
def __init__(self, memory_dim, hidden_dim):
|
| 71 |
-
super().__init__()
|
| 72 |
-
self.projection = nn.Linear(memory_dim, hidden_dim, bias=False)
|
| 73 |
-
|
| 74 |
-
def forward(self, memory_output):
|
| 75 |
-
"""
|
| 76 |
-
Args:
|
| 77 |
-
memory_output: (batch_size, memory_dim)
|
| 78 |
-
Returns:
|
| 79 |
-
hidden_augmentation: (batch_size, hidden_dim)
|
| 80 |
-
"""
|
| 81 |
-
return self.projection(memory_output)
|
| 82 |
-
=======
|
| 83 |
"""
|
| 84 |
Key and Value Projection Layers
|
| 85 |
|
|
@@ -134,4 +52,3 @@ class ValueProjection(nn.Module):
|
|
| 134 |
value: (batch_size, memory_dim)
|
| 135 |
"""
|
| 136 |
return self.projection(hidden_state)
|
| 137 |
-
>>>>>>> origin/main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Key and Value Projection Layers
|
| 3 |
|
|
|
|
| 52 |
value: (batch_size, memory_dim)
|
| 53 |
"""
|
| 54 |
return self.projection(hidden_state)
|
|
|