Pavantej commited on
Commit
afa8aff
·
verified ·
1 Parent(s): 75f8e79

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 2501.00663v1.pdf filter=lfs diff=lfs merge=lfs -text
37
+ 2504.13173v1.pdf filter=lfs diff=lfs merge=lfs -text
2501.00663v1.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a65e4a7d02784df1a040b487127e6dd09fff4474e5caf94d93263af3d50cfbc2
3
+ size 3657065
2504.13173v1.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faceb861d46d65fd7098cbdb97aa400081c7b5cb7e048ac8d010f01537915ab2
3
+ size 1987057
README.md CHANGED
@@ -1,12 +1,72 @@
1
- ---
2
- title: Titans Miras Demo
3
- emoji: 🔥
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 6.2.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Titans Miras Demo
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 6.2.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # 🧠 Titans + MIRAS: A Brain That Changes Itself While Thinking
13
+
14
+ A minimal but faithful reimplementation of **Titans** (test-time learning) and **MIRAS** (associative memory framework) using open-source models on Hugging Face.
15
+
16
+ ## What is this?
17
+
18
+ This demo showcases a neural architecture that can **learn and update its memory while generating responses** - a brain that literally changes itself while thinking!
19
+
20
+ ### Key Features
21
+
22
+ - 🔄 **Test-time learning**: Memory updates during inference (not just training)
23
+ - 🎯 **Retention gate**: Surprising/novel inputs are more memorable (inspired by human memory)
24
+ - 💾 **Persistent memory**: State is saved across sessions
25
+ - 🤖 **Fully OSS**: Uses distilgpt2 and runs entirely on Hugging Face
26
+
27
+ ## Architecture
28
+
29
+ ```
30
+ User Input
31
+
32
+ [Base LM: distilgpt2] → Hidden States (768-dim)
33
+
34
+ [Key/Value Projections] → Memory Space (256-dim)
35
+
36
+ [MIRAS Memory Module] ← Test-time Gradient Updates
37
+
38
+ [Text Generation] → Response + Memory Stats
39
+ ```
40
+
41
+ ### Components
42
+
43
+ 1. **Base Language Model**: distilgpt2 (frozen, no training)
44
+ 2. **Projection Layers**: Map hidden states to memory space
45
+ 3. **MIRAS Memory**: Associative memory with learnable key→value mapping
46
+ 4. **Retention Gate**: Adjusts learning rate based on surprise (loss magnitude)
47
+ 5. **Memory Store**: Persists memory state to disk
48
+
49
+ ## How It Works
50
+
51
+ 1. Input text is processed through distilgpt2
52
+ 2. Last hidden state is projected to key/value pairs
53
+ 3. Memory predicts value from key
54
+ 4. Loss (prediction error) indicates surprise
55
+ 5. Higher surprise → higher retention → faster learning
56
+ 6. Memory updated via gradient descent (1e-3 base LR)
57
+ 7. Response generated and memory saved
58
+
59
+ ## References
60
+
61
+ - **Titans**: [Learning to Memorize at Test Time](https://arxiv.org/abs/2501.00663)
62
+ - **MIRAS**: [Framework for Associative Memory with Attentional Bias](https://arxiv.org/abs/2504.13173)
63
+
64
+ ## Running Locally
65
+
66
+ ```bash
67
+ pip install -r requirements.txt
68
+ python app.py
69
+ ```
70
+
71
+ Built with ❤️ exploring the future of adaptive AI systems.
72
+
app.py CHANGED
@@ -1,40 +1,204 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
- model_name = "distilgpt2"
6
-
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
- model.eval()
10
-
11
- def chat(text):
12
- inputs = tokenizer(text, return_tensors="pt")
13
- outputs = model(
14
- **inputs,
15
- output_hidden_states=True
16
- )
17
-
18
- h_last = outputs.hidden_states[-1][:, -1]
19
-
20
- k = key_proj(h_last)
21
- v = value_proj(h_last)
22
-
23
- pred = memory(k)
24
- loss = ((pred - v) ** 2).mean()
25
-
26
- loss.backward()
27
-
28
- with torch.no_grad():
29
- memory.W -= 1e-2 * memory.W.grad
30
- memory.W.grad.zero_()
31
-
32
- return f"Loss: {loss.item():.4f}"
33
-
34
- gr.Interface(
35
- fn=chat,
36
- inputs="text",
37
- outputs="text",
38
- title="Base LM (no memory yet)"
39
- ).launch()
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Titans + MIRAS Demo: A Brain That Changes Itself While Thinking
3
+
4
+ This implements a minimal version of Titans (test-time learning) and MIRAS
5
+ (associative memory) using distilgpt2 running on Hugging Face.
6
+
7
+ Key features:
8
+ - Test-time learning: Memory updates while generating responses
9
+ - Retention gate: Surprising events are more memorable
10
+ - Persistent memory: Remembers across sessions
11
+ """
12
+
13
+ import gradio as gr
14
+ import torch
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+ from miras_memory import MIRASMemory
18
+ from projections import KeyProjection, ValueProjection
19
+ from memory_store import MemoryStore
20
+
21
+
22
+ # ========== Configuration ==========
23
+ MODEL_NAME = "distilgpt2"
24
+ HIDDEN_DIM = 768 # distilgpt2 hidden dimension
25
+ MEMORY_DIM = 256 # memory dimension
26
+ LEARNING_RATE = 1e-3 # test-time learning rate
27
+ MAX_NEW_TOKENS = 50 # max tokens to generate
28
+
29
+
30
+ # ========== Initialize Components ==========
31
+ print("🧠 Initializing Titans + MIRAS brain...")
32
+
33
+ # Base language model
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
37
+ model.eval()
38
+
39
+ # Memory system
40
+ memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
41
+ key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
42
+ value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
43
+
44
+ # Memory persistence
45
+ store = MemoryStore(save_dir="memory")
46
+ store.load(memory)
47
+
48
+ print("✅ Brain initialized!")
49
+
50
+
51
+ # ========== Core Logic ==========
52
+ def chat(user_input, conversation_history):
53
+ """
54
+ Main chat function that:
55
+ 1. Processes input through base LM
56
+ 2. Updates memory via test-time learning
57
+ 3. Generates response
58
+ 4. Returns response + memory stats
59
+ """
60
+ if not user_input.strip():
61
+ return conversation_history, conversation_history
62
+
63
+ # === Step 1: Extract hidden states from input ===
64
+ inputs = tokenizer(user_input, return_tensors="pt", padding=True)
65
+
66
+ with torch.no_grad():
67
+ outputs = model(
68
+ **inputs,
69
+ output_hidden_states=True
70
+ )
71
+
72
+ # Get last hidden state of the last token
73
+ h_last = outputs.hidden_states[-1][:, -1, :] # (1, hidden_dim)
74
+
75
+ # === Step 2: Test-time memory learning ===
76
+ with torch.enable_grad():
77
+ # Project to key/value space
78
+ k = key_proj(h_last)
79
+ v = value_proj(h_last)
80
+
81
+ # Compute memory loss
82
+ loss = memory.compute_loss(k, v)
83
+
84
+ # Get retention factor (higher for surprising events)
85
+ retention = memory.retention_gate(loss)
86
+ effective_lr = LEARNING_RATE * retention
87
+
88
+ # Backprop and update memory
89
+ loss.backward()
90
+
91
+ with torch.no_grad():
92
+ memory.W -= effective_lr * memory.W.grad
93
+ memory.W.grad.zero_()
94
+
95
+ # Update stats
96
+ memory.update_stats(loss)
97
+
98
+ # === Step 3: Generate response ===
99
+ with torch.no_grad():
100
+ output_ids = model.generate(
101
+ inputs['input_ids'],
102
+ max_new_tokens=MAX_NEW_TOKENS,
103
+ do_sample=True,
104
+ temperature=0.8,
105
+ top_p=0.9,
106
+ pad_token_id=tokenizer.eos_token_id,
107
+ )
108
+
109
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
110
+
111
+ # Remove the input prompt from response
112
+ if response.startswith(user_input):
113
+ response = response[len(user_input):].strip()
114
+
115
+ # === Step 4: Save memory ===
116
+ store.save(memory)
117
+
118
+ # === Step 5: Format output ===
119
+ stats = memory.get_stats()
120
+
121
+ memory_info = (
122
+ f"**Memory Update**: Loss={loss.item():.4f} | "
123
+ f"Retention={retention:.2f}x | "
124
+ f"Updates={stats['updates']} | "
125
+ f"Avg Loss={stats['avg_loss']:.4f}"
126
+ )
127
+
128
+ # Build conversation
129
+ bot_message = f"{response}\n\n---\n*{memory_info}*"
130
+
131
+ # Update conversation history
132
+ conversation_history.append((user_input, bot_message))
133
+
134
+ return conversation_history, conversation_history
135
+
136
+
137
+ def clear_conversation():
138
+ """Clear the conversation but keep memory."""
139
+ return [], []
140
+
141
+
142
+ # ========== Gradio Interface ==========
143
+ with gr.Blocks(title="Titans + MIRAS: Self-Modifying Brain") as demo:
144
+ gr.Markdown("""
145
+ # 🧠 Titans + MIRAS: A Brain That Changes Itself While Thinking
146
+
147
+ This is a minimal implementation of **Titans** (test-time learning) and **MIRAS** (associative memory).
148
+
149
+ **What makes this special:**
150
+ - 🔄 **Test-time learning**: The memory updates with every interaction
151
+ - 🎯 **Retention gate**: Surprising inputs are more memorable
152
+ - 💾 **Persistent memory**: Remembers across sessions
153
+
154
+ **How it works:**
155
+ 1. Your input is processed through distilgpt2
156
+ 2. Hidden states are projected to memory key/value space
157
+ 3. Memory learns via gradient descent (learning rate adjusted by surprise)
158
+ 4. Model generates a response
159
+ 5. Memory is saved to disk
160
+
161
+ *Watch the memory loss decrease as it learns from your conversations!*
162
+ """)
163
+
164
+ chatbot = gr.Chatbot(
165
+ label="Conversation",
166
+ height=400,
167
+ )
168
+
169
+ state = gr.State([])
170
+
171
+ with gr.Row():
172
+ msg = gr.Textbox(
173
+ label="Your Message",
174
+ placeholder="Type your message here...",
175
+ scale=4,
176
+ )
177
+ submit = gr.Button("Send", scale=1, variant="primary")
178
+
179
+ with gr.Row():
180
+ clear = gr.Button("Clear Conversation (Keep Memory)")
181
+
182
+ gr.Markdown("""
183
+ ### 📊 Memory Stats
184
+ - **Loss**: How well memory predicts values (lower = better)
185
+ - **Retention**: Learning rate multiplier (higher for surprising inputs)
186
+ - **Updates**: Total number of memory updates
187
+ - **Avg Loss**: Average loss across all updates
188
+
189
+ ### 📚 References
190
+ - **Titans**: [arxiv.org/abs/2501.00663](https://arxiv.org/abs/2501.00663)
191
+ - **MIRAS**: [arxiv.org/abs/2504.13173](https://arxiv.org/abs/2504.13173)
192
+ """)
193
+
194
+ # Event handlers
195
+ msg.submit(chat, [msg, state], [chatbot, state]).then(
196
+ lambda: "", None, msg
197
+ )
198
+ submit.click(chat, [msg, state], [chatbot, state]).then(
199
+ lambda: "", None, msg
200
+ )
201
+ clear.click(clear_conversation, None, [chatbot, state])
202
+
203
+ print("🚀 Launching Gradio interface...")
204
+ demo.launch()
memory_store.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory Persistence
3
+
4
+ Handles saving and loading memory state to/from disk so the brain
5
+ remembers across sessions.
6
+ """
7
+
8
+ import torch
9
+ import json
10
+ import os
11
+ from pathlib import Path
12
+ from datetime import datetime
13
+
14
+
15
+ class MemoryStore:
16
+ """Manages persistent storage of memory state."""
17
+
18
+ def __init__(self, save_dir="memory"):
19
+ self.save_dir = Path(save_dir)
20
+ self.save_dir.mkdir(exist_ok=True)
21
+ self.memory_path = self.save_dir / "memory.pt"
22
+ self.metadata_path = self.save_dir / "metadata.json"
23
+
24
+ def save(self, memory_module):
25
+ """
26
+ Save memory state to disk.
27
+
28
+ Args:
29
+ memory_module: MIRASMemory instance
30
+ """
31
+ # Save memory weights
32
+ torch.save({
33
+ 'W': memory_module.W.data,
34
+ 'update_count': memory_module.update_count,
35
+ 'total_loss': memory_module.total_loss,
36
+ }, self.memory_path)
37
+
38
+ # Save metadata
39
+ metadata = {
40
+ 'last_updated': datetime.now().isoformat(),
41
+ 'memory_dim': memory_module.memory_dim,
42
+ 'updates': memory_module.update_count.item(),
43
+ 'avg_loss': (memory_module.total_loss / max(memory_module.update_count, 1)).item(),
44
+ }
45
+
46
+ with open(self.metadata_path, 'w') as f:
47
+ json.dump(metadata, f, indent=2)
48
+
49
+ print(f"💾 Memory saved: {memory_module.update_count.item()} updates")
50
+
51
+ def load(self, memory_module):
52
+ """
53
+ Load memory state from disk.
54
+
55
+ Args:
56
+ memory_module: MIRASMemory instance to load into
57
+
58
+ Returns:
59
+ bool: True if loaded successfully, False otherwise
60
+ """
61
+ if not self.memory_path.exists():
62
+ print("🆕 No saved memory found. Starting fresh!")
63
+ return False
64
+
65
+ try:
66
+ checkpoint = torch.load(self.memory_path)
67
+ memory_module.W.data = checkpoint['W']
68
+ memory_module.update_count = checkpoint['update_count']
69
+ memory_module.total_loss = checkpoint['total_loss']
70
+
71
+ print(f"✅ Memory loaded: {memory_module.update_count.item()} updates")
72
+ return True
73
+ except Exception as e:
74
+ print(f"⚠️ Error loading memory: {e}. Starting fresh!")
75
+ return False
76
+
77
+ def get_metadata(self):
78
+ """Get metadata about saved memory."""
79
+ if not self.metadata_path.exists():
80
+ return None
81
+
82
+ with open(self.metadata_path, 'r') as f:
83
+ return json.load(f)
memory_test/memory.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:112699910bd87e5a20fb5ea40d87869fe3f3f987d70d6f45c2ec6b1cf8fca32a
3
+ size 264152
memory_test/metadata.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "last_updated": "2025-12-20T19:57:11.523657",
3
+ "memory_dim": 256,
4
+ "updates": 0,
5
+ "avg_loss": 0.0
6
+ }
miras_memory.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIRAS-inspired Associative Memory Module
3
+
4
+ Implements an associative memory that learns key-value mappings
5
+ through attentional bias objective during test time.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class MIRASMemory(nn.Module):
13
+ """
14
+ Associative memory module inspired by MIRAS framework.
15
+
16
+ The memory learns to map keys to values using a simple linear projection
17
+ and updates itself during test time via gradient descent.
18
+
19
+ Args:
20
+ memory_dim: Dimensionality of memory keys/values
21
+ init_scale: Scale for random weight initialization
22
+ """
23
+
24
+ def __init__(self, memory_dim=256, init_scale=0.01):
25
+ super().__init__()
26
+ self.memory_dim = memory_dim
27
+
28
+ # Memory matrix: maps keys to values
29
+ # W: (memory_dim, memory_dim)
30
+ self.W = nn.Parameter(
31
+ torch.randn(memory_dim, memory_dim) * init_scale
32
+ )
33
+
34
+ # Track number of updates for retention gate
35
+ self.register_buffer('update_count', torch.tensor(0))
36
+ self.register_buffer('total_loss', torch.tensor(0.0))
37
+
38
+ def forward(self, key):
39
+ """
40
+ Query memory with a key.
41
+
42
+ Args:
43
+ key: (batch_size, memory_dim) tensor
44
+
45
+ Returns:
46
+ predicted_value: (batch_size, memory_dim) tensor
47
+ """
48
+ # Simple linear mapping: pred_v = k @ W
49
+ predicted_value = key @ self.W
50
+ return predicted_value
51
+
52
+ def compute_loss(self, key, value):
53
+ """
54
+ Compute attentional bias loss between predicted and true value.
55
+
56
+ Args:
57
+ key: (batch_size, memory_dim)
58
+ value: (batch_size, memory_dim)
59
+
60
+ Returns:
61
+ loss: scalar tensor
62
+ """
63
+ pred = self.forward(key)
64
+ loss = ((pred - value) ** 2).mean()
65
+ return loss
66
+
67
+ def retention_gate(self, loss):
68
+ """
69
+ Simple retention gate: higher loss = more surprising = more memorable.
70
+
71
+ Returns a scaling factor for the learning rate based on surprise.
72
+ High loss (surprising) gets higher weight.
73
+
74
+ Args:
75
+ loss: scalar tensor
76
+
77
+ Returns:
78
+ retention_factor: scalar in range [0.5, 2.0]
79
+ """
80
+ # Normalize loss to a retention factor
81
+ # If loss is high (surprising), learn more aggressively
82
+ retention_factor = torch.clamp(loss / 0.1, 0.5, 2.0)
83
+ return retention_factor.item()
84
+
85
+ def update_stats(self, loss):
86
+ """Track memory statistics."""
87
+ self.update_count += 1
88
+ self.total_loss += loss.item()
89
+
90
+ def get_stats(self):
91
+ """Get memory statistics."""
92
+ avg_loss = self.total_loss / max(self.update_count, 1)
93
+ return {
94
+ 'updates': self.update_count.item(),
95
+ 'avg_loss': avg_loss.item(),
96
+ 'memory_size': self.W.numel()
97
+ }
projections.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Key and Value Projection Layers
3
+
4
+ Maps hidden states from the base language model into memory-compatible
5
+ representations for the MIRAS memory module.
6
+ """
7
+
8
+ import torch.nn as nn
9
+
10
+
11
+ class KeyProjection(nn.Module):
12
+ """
13
+ Projects hidden states to memory keys.
14
+
15
+ Args:
16
+ hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
17
+ memory_dim: Dimension of memory keys (e.g., 256)
18
+ """
19
+
20
+ def __init__(self, hidden_dim, memory_dim):
21
+ super().__init__()
22
+ self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
23
+
24
+ def forward(self, hidden_state):
25
+ """
26
+ Args:
27
+ hidden_state: (batch_size, hidden_dim)
28
+ Returns:
29
+ key: (batch_size, memory_dim)
30
+ """
31
+ return self.projection(hidden_state)
32
+
33
+
34
+ class ValueProjection(nn.Module):
35
+ """
36
+ Projects hidden states to memory values.
37
+
38
+ Args:
39
+ hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
40
+ memory_dim: Dimension of memory values (e.g., 256)
41
+ """
42
+
43
+ def __init__(self, hidden_dim, memory_dim):
44
+ super().__init__()
45
+ self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
46
+
47
+ def forward(self, hidden_state):
48
+ """
49
+ Args:
50
+ hidden_state: (batch_size, hidden_dim)
51
+ Returns:
52
+ value: (batch_size, memory_dim)
53
+ """
54
+ return self.projection(hidden_state)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- torch
2
- transformers
3
- gradio
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ numpy
test_components.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick test script to verify Titans+MIRAS components
3
+ """
4
+
5
+ import torch
6
+ from miras_memory import MIRASMemory
7
+ from projections import KeyProjection, ValueProjection
8
+ from memory_store import MemoryStore
9
+
10
+ print("=" * 50)
11
+ print("Testing Titans + MIRAS Components")
12
+ print("=" * 50)
13
+
14
+ # Test 1: Memory Module
15
+ print("\n✓ Test 1: Memory Module")
16
+ memory = MIRASMemory(memory_dim=256, init_scale=0.01)
17
+ key_test = torch.randn(1, 256)
18
+ value_test = torch.randn(1, 256)
19
+
20
+ pred = memory(key_test)
21
+ print(f" - Forward pass: {pred.shape}")
22
+
23
+ loss = memory.compute_loss(key_test, value_test)
24
+ print(f" - Loss computation: {loss.item():.4f}")
25
+
26
+ retention = memory.retention_gate(loss)
27
+ print(f" - Retention gate: {retention:.2f}x")
28
+
29
+ stats = memory.get_stats()
30
+ print(f" - Stats: {stats}")
31
+
32
+ # Test 2: Projections
33
+ print("\n✓ Test 2: Projection Layers")
34
+ key_proj = KeyProjection(768, 256)
35
+ value_proj = ValueProjection(768, 256)
36
+
37
+ hidden = torch.randn(1, 768)
38
+ k = key_proj(hidden)
39
+ v = value_proj(hidden)
40
+ print(f" - Key projection: {k.shape}")
41
+ print(f" - Value projection: {v.shape}")
42
+
43
+ # Test 3: Memory Store
44
+ print("\n✓ Test 3: Memory Persistence")
45
+ store = MemoryStore(save_dir="memory_test")
46
+
47
+ # Save
48
+ store.save(memory)
49
+ print(f" - Memory saved")
50
+
51
+ # Create new memory and load
52
+ memory2 = MIRASMemory(memory_dim=256, init_scale=0.01)
53
+ loaded = store.load(memory2)
54
+ print(f" - Memory loaded: {loaded}")
55
+
56
+ # Test 4: Full Pipeline
57
+ print("\n✓ Test 4: Full Test-Time Learning Pipeline")
58
+ memory3 = MIRASMemory(memory_dim=256, init_scale=0.01)
59
+
60
+ for i in range(5):
61
+ # Simulate learning
62
+ k = torch.randn(1, 256)
63
+ v = torch.randn(1, 256)
64
+
65
+ loss = memory3.compute_loss(k, v)
66
+ retention = memory3.retention_gate(loss)
67
+ lr = 1e-3 * retention
68
+
69
+ loss.backward()
70
+ with torch.no_grad():
71
+ memory3.W -= lr * memory3.W.grad
72
+ memory3.W.grad.zero_()
73
+ memory3.update_stats(loss)
74
+
75
+ stats = memory3.get_stats()
76
+ print(f" - Step {i+1}: Loss={loss.item():.4f}, Retention={retention:.2f}x, Avg={stats['avg_loss']:.4f}")
77
+
78
+ print("\n" + "=" * 50)
79
+ print("✅ ALL TESTS PASSED!")
80
+ print("=" * 50)