Pavantej commited on
Commit
939db07
·
verified ·
1 Parent(s): 6fc70f4

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +68 -104
app.py CHANGED
@@ -1,67 +1,72 @@
1
  """
2
  Titans + MIRAS Demo: A Brain That Changes Itself While Thinking
3
 
4
- This implements a minimal version of Titans (test-time learning) and MIRAS
5
- (associative memory) using distilgpt2 running on Hugging Face.
6
-
7
- Key features:
8
- - Test-time learning: Memory updates while generating responses
9
- - Retention gate: Surprising events are more memorable
10
- - Persistent memory: Remembers across sessions
11
  """
12
 
13
- import gradio as gr
14
  import torch
15
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
16
 
17
  from miras_memory import MIRASMemory
18
  from projections import KeyProjection, ValueProjection
19
  from memory_store import MemoryStore
20
 
 
 
 
 
21
 
22
  # ========== Configuration ==========
23
  MODEL_NAME = "distilgpt2"
24
  HIDDEN_DIM = 768 # distilgpt2 hidden dimension
25
- MEMORY_DIM = 256 # memory dimension
26
- LEARNING_RATE = 1e-3 # test-time learning rate
27
- MAX_NEW_TOKENS = 50 # max tokens to generate
28
-
29
 
30
  # ========== Initialize Components ==========
31
  print("🧠 Initializing Titans + MIRAS brain...")
32
 
33
- # Base language model
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
  tokenizer.pad_token = tokenizer.eos_token
36
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
37
- model.eval()
38
 
39
- # Memory system
40
- memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
41
  key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
42
  value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
43
 
44
- # Memory persistence
 
 
 
45
  store = MemoryStore(save_dir="memory")
46
  store.load(memory)
47
 
48
  print("✅ Brain initialized!")
49
 
50
 
51
- # ========== Core Logic ==========
52
- def chat(user_input, history):
53
  """
54
- Main chat function that:
55
- 1. Processes input through base LM
56
- 2. Updates memory via test-time learning
57
- 3. Generates response
58
- 4. Returns response + memory stats
 
 
 
59
  """
60
- if not user_input.strip():
61
- return history
62
 
63
  # === Step 1: Extract hidden states from input ===
64
- inputs = tokenizer(user_input, return_tensors="pt", padding=True)
65
 
66
  with torch.no_grad():
67
  outputs = model(
@@ -109,8 +114,8 @@ def chat(user_input, history):
109
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
110
 
111
  # Remove the input prompt from response
112
- if response.startswith(user_input):
113
- response = response[len(user_input):].strip()
114
 
115
  if not response:
116
  response = "..."
@@ -118,89 +123,48 @@ def chat(user_input, history):
118
  # === Step 4: Save memory ===
119
  store.save(memory)
120
 
121
- # === Step 5: Format output ===
122
  stats = memory.get_stats()
123
 
124
  memory_info = (
125
- f"**Memory Update**: Loss={loss.item():.4f} | "
126
- f"Retention={retention:.2f}x | "
127
- f"Updates={stats['updates']} | "
128
- f"Avg Loss={stats['avg_loss']:.4f}"
 
 
129
  )
130
 
131
- # Build response with memory stats
132
- bot_message = f"{response}\n\n---\n*{memory_info}*"
133
-
134
- # Update history with simple tuple format (Gradio 4.x compatible)
135
- history = history + [[user_input, bot_message]]
136
-
137
- return history
138
-
139
-
140
- def clear_conversation():
141
- """Clear the conversation but keep memory."""
142
- return []
143
-
144
 
145
 
146
  # ========== Gradio Interface ==========
147
- with gr.Blocks(title="Titans + MIRAS: Self-Modifying Brain") as demo:
148
- gr.Markdown("""
149
- # 🧠 Titans + MIRAS: A Brain That Changes Itself While Thinking
150
-
151
- This is a minimal implementation of **Titans** (test-time learning) and **MIRAS** (associative memory).
152
-
153
- **What makes this special:**
154
- - 🔄 **Test-time learning**: The memory updates with every interaction
155
- - 🎯 **Retention gate**: Surprising inputs are more memorable
156
- - 💾 **Persistent memory**: Remembers across sessions
157
 
158
  **How it works:**
159
- 1. Your input is processed through distilgpt2
160
- 2. Hidden states are projected to memory key/value space
161
- 3. Memory learns via gradient descent (learning rate adjusted by surprise)
162
- 4. Model generates a response
163
- 5. Memory is saved to disk
164
-
165
- *Watch the memory loss decrease as it learns from your conversations!*
166
- """)
167
-
168
- chatbot = gr.Chatbot(
169
- label="Conversation",
170
- height=400
171
- )
172
-
173
- with gr.Row():
174
- msg = gr.Textbox(
175
- label="Your Message",
176
- placeholder="Type your message here...",
177
- scale=4,
178
- )
179
- submit = gr.Button("Send", scale=1, variant="primary")
180
-
181
- with gr.Row():
182
- clear = gr.Button("Clear Conversation (Keep Memory)")
183
-
184
- gr.Markdown("""
185
- ### 📊 Memory Stats
186
- - **Loss**: How well memory predicts values (lower = better)
187
- - **Retention**: Learning rate multiplier (higher for surprising inputs)
188
- - **Updates**: Total number of memory updates
189
- - **Avg Loss**: Average loss across all updates
190
-
191
- ### 📚 References
192
- - **Titans**: [arxiv.org/abs/2501.00663](https://arxiv.org/abs/2501.00663)
193
- - **MIRAS**: [arxiv.org/abs/2504.13173](https://arxiv.org/abs/2504.13173)
194
- """)
195
-
196
- # Event handlers
197
- msg.submit(chat, [msg, chatbot], [chatbot]).then(
198
- lambda: "", None, msg
199
- )
200
- submit.click(chat, [msg, chatbot], [chatbot]).then(
201
- lambda: "", None, msg
202
- )
203
- clear.click(clear_conversation, None, [chatbot])
204
 
205
- print("🚀 Launching Gradio interface...")
206
  demo.launch()
 
1
  """
2
  Titans + MIRAS Demo: A Brain That Changes Itself While Thinking
3
 
4
+ This application demonstrates test-time learning using:
5
+ - Titans: Test-time training framework
6
+ - MIRAS: Associative memory with retention gate
 
 
 
 
7
  """
8
 
 
9
  import torch
10
+ import torch.nn as nn
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ import gradio as gr
13
 
14
  from miras_memory import MIRASMemory
15
  from projections import KeyProjection, ValueProjection
16
  from memory_store import MemoryStore
17
 
18
+ print("=" * 50)
19
+ print("===== Application Startup at", __import__('datetime').datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "=====")
20
+ print("=" * 50)
21
+ print()
22
 
23
  # ========== Configuration ==========
24
  MODEL_NAME = "distilgpt2"
25
  HIDDEN_DIM = 768 # distilgpt2 hidden dimension
26
+ MEMORY_DIM = 256 # Memory space dimension
27
+ LEARNING_RATE = 1e-3 # Base learning rate for test-time updates
28
+ MAX_NEW_TOKENS = 50 # Max tokens to generate
 
29
 
30
  # ========== Initialize Components ==========
31
  print("🧠 Initializing Titans + MIRAS brain...")
32
 
33
+ # Load base language model (frozen)
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
  tokenizer.pad_token = tokenizer.eos_token
36
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
37
+ model.eval() # Frozen - no training
38
 
39
+ # Create projection layers
 
40
  key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
41
  value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
42
 
43
+ # Create memory module
44
+ memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
45
+
46
+ # Load persistent memory
47
  store = MemoryStore(save_dir="memory")
48
  store.load(memory)
49
 
50
  print("✅ Brain initialized!")
51
 
52
 
53
+ # ========== Chat Function ==========
54
+ def chat(message, history):
55
  """
56
+ Main chat function for gr.ChatInterface.
57
+
58
+ Args:
59
+ message: str - user's current message
60
+ history: list of dicts with 'role' and 'content' keys
61
+
62
+ Returns:
63
+ str - assistant's response with memory stats
64
  """
65
+ if not message.strip():
66
+ return "Please enter a message."
67
 
68
  # === Step 1: Extract hidden states from input ===
69
+ inputs = tokenizer(message, return_tensors="pt", padding=True)
70
 
71
  with torch.no_grad():
72
  outputs = model(
 
114
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
115
 
116
  # Remove the input prompt from response
117
+ if response.startswith(message):
118
+ response = response[len(message):].strip()
119
 
120
  if not response:
121
  response = "..."
 
123
  # === Step 4: Save memory ===
124
  store.save(memory)
125
 
126
+ # === Step 5: Format output with memory stats ===
127
  stats = memory.get_stats()
128
 
129
  memory_info = (
130
+ f"\n\n---\n"
131
+ f"**🧠 Memory Update**\n"
132
+ f"- Loss: {loss.item():.4f} (lower = better prediction)\n"
133
+ f"- Retention: {retention:.2f}x (surprise factor)\n"
134
+ f"- Total Updates: {stats['updates']}\n"
135
+ f"- Avg Loss: {stats['avg_loss']:.4f}"
136
  )
137
 
138
+ return response + memory_info
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
 
141
  # ========== Gradio Interface ==========
142
+ print("🚀 Launching Gradio interface...")
143
+
144
+ demo = gr.ChatInterface(
145
+ fn=chat,
146
+ title="🧠 Titans + MIRAS: A Brain That Changes Itself While Thinking",
147
+ description="""
148
+ This chatbot uses **test-time learning** - it updates its memory with every message!
 
 
 
149
 
150
  **How it works:**
151
+ 1. Your message is processed through distilgpt2
152
+ 2. Memory predicts what it should remember
153
+ 3. Prediction error (loss) indicates surprise
154
+ 4. Higher surprise stronger memory formation
155
+ 5. Memory weights update via gradient descent
156
+ 6. Response generated and memory saved to disk
157
+
158
+ **Watch the stats below each response to see the brain learning!**
159
+ """,
160
+ examples=[
161
+ "Hello! What can you do?",
162
+ "Tell me about test-time learning",
163
+ "What is 2+2?",
164
+ "Repeat this exact phrase: The quick brown fox",
165
+ ],
166
+ cache_examples=False,
167
+ theme="soft",
168
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
 
170
  demo.launch()