Pavantej's picture
Upload folder using huggingface_hub
b6df69a verified
"""
Titans + MIRAS Demo: A Brain That Changes Itself While Thinking
This application demonstrates test-time learning using:
- Titans: Test-time training framework
- MIRAS: Associative memory with retention gate
"""
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from miras_memory import MIRASMemory
from projections import KeyProjection, ValueProjection, OutputProjection
from memory_store import MemoryStore
print("=" * 50)
print("===== Application Startup at", __import__('datetime').datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "=====")
print("=" * 50)
print()
# ========== Configuration ==========
MODEL_NAME = "distilgpt2"
HIDDEN_DIM = 768 # distilgpt2 hidden dimension
MEMORY_DIM = 256 # Memory space dimension
LEARNING_RATE = 0.01 # Increased learning rate for faster adaptation
MAX_NEW_TOKENS = 50 # Max tokens to generate
MEMORY_ALPHA = 1.0 # Increased from 0.1 - stronger memory influence
NUM_TRAIN_STEPS = 5 # Multiple gradient steps per input for better learning
# ========== Initialize Components ==========
print("🧠 Initializing Titans + MIRAS brain...")
# Load base language model (frozen)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.eval() # Frozen - no training
# Create projection layers
key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
output_proj = OutputProjection(MEMORY_DIM, HIDDEN_DIM) # Map memory back to hidden dim
# Create memory module
memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
# Load persistent memory
store = MemoryStore(save_dir="memory")
store.load(memory)
print("✅ Brain initialized!")
# ========== Chat Function ==========
def chat(message, history):
"""
Main chat function for gr.ChatInterface.
Args:
message: str - user's current message
history: list of dicts with 'role' and 'content' keys
Returns:
str - assistant's response with memory stats
"""
if not message.strip():
return "Please enter a message."
# === Step 1: Extract hidden states from input ===
inputs = tokenizer(message, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(
**inputs,
output_hidden_states=True
)
# Get last hidden state of the last token
h_last = outputs.hidden_states[-1][:, -1, :] # (1, hidden_dim)
# === Step 2: Test-time memory learning with LANGUAGE MODELING loss ===
# Key insight: Train memory to help predict next tokens, not just map k→v
# Get ALL hidden states (not just last token) for training
all_hidden = outputs.hidden_states[-1] # (1, seq_len, hidden_dim)
seq_len = all_hidden.shape[1]
if seq_len > 1:
# We have context - train on predicting each next token
# Run multiple training steps for faster learning
for train_step in range(NUM_TRAIN_STEPS):
with torch.enable_grad():
total_lm_loss = 0.0
# For each position (except last), predict next token
for pos in range(seq_len - 1):
h_pos = all_hidden[:, pos, :] # Hidden at position pos
# Project to memory space
k = key_proj(h_pos)
# Query memory and augment hidden state
memory_out = memory(k)
h_augmented = h_pos + MEMORY_ALPHA * output_proj(memory_out)
# Compute logits for next token
logits = model.lm_head(h_augmented) # (1, vocab_size)
# Target is the actual next token
target = inputs['input_ids'][:, pos + 1]
# Cross-entropy loss
lm_loss = nn.functional.cross_entropy(logits, target)
total_lm_loss = total_lm_loss + lm_loss
# Average loss over positions
memory_loss = total_lm_loss / (seq_len - 1)
# Get retention factor
retention = memory.retention_gate(memory_loss)
effective_lr = LEARNING_RATE * retention
# Backprop and update
memory_loss.backward()
with torch.no_grad():
# Update memory
if memory.W.grad is not None:
memory.W -= effective_lr * memory.W.grad
memory.W.grad.zero_()
# Update output projection
if output_proj.projection.weight.grad is not None:
output_proj.projection.weight -= effective_lr * output_proj.projection.weight.grad
output_proj.projection.weight.grad.zero_()
# Update stats after all training steps (use final loss)
with torch.no_grad():
memory.update_stats(memory_loss)
else:
# Single token - just compute MSE for stats
with torch.no_grad():
k = key_proj(h_last)
v = value_proj(h_last)
memory_pred = memory(k)
memory_loss = ((memory_pred - v) ** 2).mean()
retention = 1.0
memory.update_stats(memory_loss)
# === Step 3: Memory-augmented generation ===
# Token-by-token generation where memory influences hidden states
# Key insight: h' = h + alpha * output_proj(memory(k))
generated_ids = inputs['input_ids'].clone()
with torch.no_grad():
for _ in range(MAX_NEW_TOKENS):
# Forward pass to get hidden states
outputs = model(generated_ids, output_hidden_states=True)
h_last = outputs.hidden_states[-1][:, -1, :] # (1, hidden_dim)
# Query memory with projected key
k_gen = key_proj(h_last)
memory_out = memory.query(k_gen) # (1, memory_dim)
# Augment hidden state: h' = h + alpha * output_proj(memory(k))
h_augmented = h_last + MEMORY_ALPHA * output_proj(memory_out)
# Compute logits with augmented hidden state
logits = model.lm_head(h_augmented) # (1, vocab_size)
# Temperature sampling
logits = logits / 0.8
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Stop on EOS
if next_token.item() == tokenizer.eos_token_id:
break
# Append to sequence
generated_ids = torch.cat([generated_ids, next_token], dim=1)
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Remove the input prompt from response
if response.startswith(message):
response = response[len(message):].strip()
if not response:
response = "..."
# === Step 4: Save memory ===
store.save(memory)
# === Step 5: Format output with memory stats ===
stats = memory.get_stats()
memory_info = (
f"\n\n---\n"
f"**🧠 Memory Update**\n"
f"- Loss: {memory_loss.item():.4f} (lower = better prediction)\n"
f"- Retention: {retention:.2f}x (surprise factor)\n"
f"- Total Updates: {stats['updates']}\n"
f"- Avg Loss: {stats['avg_loss']:.4f}"
)
return response + memory_info
# ========== Gradio Interface ==========
print("🚀 Launching Gradio interface...")
demo = gr.ChatInterface(
fn=chat,
title="🧠 The Brain That Learns While Thinking",
description="""
# A Living System That Updates Its Weights During Inference
**The Novel Thing**: Standard LLMs freeze their weights after training. This system performs gradient descent *while you chat*.
---
## 🚀 The Revolutionary Difference
**Standard LLMs (ChatGPT, Claude, etc.)**: Think → Predict → **Forget**
**Titans + MIRAS**: Think → Predict → **Update** → **Remember** → Think Differently
---
### 💡 What Makes This Different?
| Feature | ChatGPT/Claude/GPT-4 | This Demo (Titans+MIRAS) |
|---------|---------------------|--------------------------|
| **Weights during chat** | 🔒 Frozen forever | ✅ Update with every message |
| **Learning** | ❌ Simulated (in-context only) | ✅ Real (gradient descent) |
| **Memory** | 📝 Token context only | 🧠 Neural parameters |
| **Persistence** | ❌ Forgets when context ends | ✅ Saves to disk |
| **Adaptation** | 🎭 Acts like it learned | 🔬 Actually learns |
---
### 🎯 What You're Witnessing
**This is NOT a better chatbot** - it's a **learning demonstrator**.
1. **The text responses are random** - that's expected! We're using a small, frozen model (distilgpt2)
2. **The MAGIC is in the numbers below** - watch the "Loss" decrease when you repeat inputs!
3. **Every message physically changes the brain** - the memory weights update via gradient descent
4. **Refresh the page** - the update count continues (memory persists!)
---
### 🧪 How It Works (The Technical Truth)
```
Your Message
[distilgpt2: FROZEN] ← Not learning, just generating
Hidden States (768-dim)
[Projections] → Memory Space (256-dim)
[MIRAS Memory: LEARNING!] ← This is what updates!
Loss = How surprised the memory is
Gradient Descent → Memory weights change
Saved to disk → Persists forever
```
**Key Insight**: We're training the **memory**, not the text generator!
---
### 🔬 The Science: Why This Matters
**Standard LLMs**:
- Weights frozen after training (costs millions)
- "Learning" is just pattern matching in context
- Forget everything when context ends
- Same model for everyone
**Titans + MIRAS**:
- Weights update during inference (free!)
- Real optimization via gradient descent
- Memory persists across sessions
- Personalizes to each user
**This is test-time learning** - the future of adaptive AI.
---
### 📊 What the Stats Mean
- **Loss**: How surprised the memory is (lower = more familiar)
- **Retention**: Learning rate multiplier (2.0x = very surprising, 0.5x = familiar)
- **Updates**: Total number of memory updates (persists across sessions!)
- **Avg Loss**: Overall learning progress
---
### 🎮 Try This Experiment
1. **Send "hello world" 5 times** → Watch loss decrease!
2. **Send something completely different** → Loss spikes!
3. **Refresh the page and send another message** → Update count continues!
**That decreasing loss is proof the neural weights are changing!**
---
### 🌟 The Bottom Line
**ChatGPT**: A frozen calculator that *simulates* adaptation
**This Demo**: A living system that *performs* adaptation
You're not chatting with a model.
**You're watching a brain rewire itself in real-time.** 🧠⚡
---
### 🧪 How to Test This (Interactive Experiments)
**Don't just chat—run experiments to see the learning happen!**
#### Experiment 1: Watch Loss Decrease (Proof of Learning)
```
1. Send "hello world"
2. Send "hello world" again
3. Send "hello world" again
4. Send "hello world" again
5. Send "hello world" again
```
**What to watch**: Loss should decrease each time (7.5 → 6.0 → 5.0 → 4.0)
**Why it matters**: This proves the memory is learning the pattern!
#### Experiment 2: Trigger Surprise (Spike the Loss)
```
1. Send "hello world" 5 times (loss decreases)
2. Then send: "Supercalifragilisticexpialidocious quantum entanglement"
```
**What to watch**: Loss should spike back up (4.0 → 9.0+)
**Why it matters**: The memory detects novelty—it knows this is different!
#### Experiment 3: Test Persistence (Memory Survives)
```
1. Note the "Updates" count (e.g., 15)
2. Refresh this page completely
3. Send any message
4. Check if Updates = 16 (not reset to 1!)
```
**What to watch**: Update count should continue, not reset
**Why it matters**: Memory persists to disk—it's not just in RAM!
---
### 📊 What Each Stat Means (Decoder Ring)
**Loss** (e.g., 7.48 → 6.61 → 5.23)
- **What it is**: Prediction error (how surprised the memory is)
- **Lower = Better**: Memory is familiar with this pattern
- **Higher = Novel**: Memory hasn't seen this before
- **Why it matters**: Decreasing loss = learning is happening!
**Retention** (e.g., 2.00x)
- **What it is**: Learning rate multiplier based on surprise
- **2.0x = Very surprising**: Memory learns aggressively
- **0.5x = Very familiar**: Memory learns slowly (you won't see this yet)
- **Why it matters**: The brain learns more from surprising events (like humans!)
**Updates** (e.g., 1 → 2 → 3 → 4...)
- **What it is**: Total number of memory updates
- **Persists across sessions**: Survives page refreshes
- **Never resets**: Keeps counting forever
- **Why it matters**: Proof that memory is persistent, not ephemeral!
**Avg Loss** (e.g., 7.26)
- **What it is**: Running average of all losses
- **Trends downward**: As memory learns recurring patterns
- **Reflects overall learning**: Lower = memory is getting smarter
- **Why it matters**: Shows long-term learning progress!
---
### ⚠️ What to Ignore (Important!)
**The text responses are random and bad** - this is expected!
- We're NOT training the text generator (distilgpt2 is frozen)
- The responses don't matter—they're a side effect
- **Focus on the numbers below**, not the text above
- The magic is in the decreasing loss, not the generated text
**Why?** Because we're demonstrating **memory learning**, not text generation.
Standard LLMs train the text generator. This trains the memory. Different goals.
---
### 🎯 What Success Looks Like
✅ **You're seeing it work if**:
- Loss decreases when you repeat inputs
- Loss spikes when you send something new
- Update count increments with each message
- Update count persists after page refresh
- Retention is 2.0x (everything is surprising to fresh memory)
❌ **You're NOT seeing it work if**:
- Loss stays constant (not learning)
- Updates reset to 1 after refresh (not persisting)
- No stats appear below responses
---
### 🔬 Why This Matters (The Big Picture)
**Standard LLMs**: Frozen weights → No learning during use
**This Demo**: Live weights → Learning with every message
That decreasing loss you see? **That's gradient descent happening during inference.**
That's the revolution. That's what ChatGPT doesn't do.
You're not just using a model. **You're watching it change.**
---
*Built with Titans (test-time training) + MIRAS (associative memory)*
*Papers: [Titans](https://arxiv.org/abs/2501.00663) | [MIRAS](https://arxiv.org/abs/2504.13173)*
**📖 [Read the full essay: "When Models Learn While Thinking"](https://huggingface.co/spaces/Pavantej/titans-miras-demo/blob/main/ESSAY.md)**
""",
examples=[
"hello world",
"hello world", # Repeat to show learning!
"Tell me about test-time learning",
"What is 2+2?",
"my name is [your name]",
],
cache_examples=False,
theme="soft",
)
demo.launch()