| """
|
| Titans + MIRAS Demo: A Brain That Changes Itself While Thinking
|
|
|
| This application demonstrates test-time learning using:
|
| - Titans: Test-time training framework
|
| - MIRAS: Associative memory with retention gate
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import AutoTokenizer, AutoModelForCausalLM
|
| import gradio as gr
|
|
|
| from miras_memory import MIRASMemory
|
| from projections import KeyProjection, ValueProjection, OutputProjection
|
| from memory_store import MemoryStore
|
|
|
| print("=" * 50)
|
| print("===== Application Startup at", __import__('datetime').datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "=====")
|
| print("=" * 50)
|
| print()
|
|
|
|
|
| MODEL_NAME = "distilgpt2"
|
| HIDDEN_DIM = 768
|
| MEMORY_DIM = 256
|
| LEARNING_RATE = 0.01
|
| MAX_NEW_TOKENS = 50
|
| MEMORY_ALPHA = 1.0
|
| NUM_TRAIN_STEPS = 5
|
|
|
|
|
| print("🧠 Initializing Titans + MIRAS brain...")
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| tokenizer.pad_token = tokenizer.eos_token
|
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
| model.eval()
|
|
|
|
|
| key_proj = KeyProjection(HIDDEN_DIM, MEMORY_DIM)
|
| value_proj = ValueProjection(HIDDEN_DIM, MEMORY_DIM)
|
| output_proj = OutputProjection(MEMORY_DIM, HIDDEN_DIM)
|
|
|
|
|
| memory = MIRASMemory(memory_dim=MEMORY_DIM, init_scale=0.01)
|
|
|
|
|
| store = MemoryStore(save_dir="memory")
|
| store.load(memory)
|
|
|
| print("✅ Brain initialized!")
|
|
|
|
|
|
|
| def chat(message, history):
|
| """
|
| Main chat function for gr.ChatInterface.
|
|
|
| Args:
|
| message: str - user's current message
|
| history: list of dicts with 'role' and 'content' keys
|
|
|
| Returns:
|
| str - assistant's response with memory stats
|
| """
|
| if not message.strip():
|
| return "Please enter a message."
|
|
|
|
|
| inputs = tokenizer(message, return_tensors="pt", padding=True)
|
|
|
| with torch.no_grad():
|
| outputs = model(
|
| **inputs,
|
| output_hidden_states=True
|
| )
|
|
|
|
|
| h_last = outputs.hidden_states[-1][:, -1, :]
|
|
|
|
|
|
|
|
|
|
|
| all_hidden = outputs.hidden_states[-1]
|
| seq_len = all_hidden.shape[1]
|
|
|
| if seq_len > 1:
|
|
|
|
|
| for train_step in range(NUM_TRAIN_STEPS):
|
| with torch.enable_grad():
|
| total_lm_loss = 0.0
|
|
|
|
|
| for pos in range(seq_len - 1):
|
| h_pos = all_hidden[:, pos, :]
|
|
|
|
|
| k = key_proj(h_pos)
|
|
|
|
|
| memory_out = memory(k)
|
| h_augmented = h_pos + MEMORY_ALPHA * output_proj(memory_out)
|
|
|
|
|
| logits = model.lm_head(h_augmented)
|
|
|
|
|
| target = inputs['input_ids'][:, pos + 1]
|
|
|
|
|
| lm_loss = nn.functional.cross_entropy(logits, target)
|
| total_lm_loss = total_lm_loss + lm_loss
|
|
|
|
|
| memory_loss = total_lm_loss / (seq_len - 1)
|
|
|
|
|
| retention = memory.retention_gate(memory_loss)
|
| effective_lr = LEARNING_RATE * retention
|
|
|
|
|
| memory_loss.backward()
|
|
|
| with torch.no_grad():
|
|
|
| if memory.W.grad is not None:
|
| memory.W -= effective_lr * memory.W.grad
|
| memory.W.grad.zero_()
|
|
|
|
|
| if output_proj.projection.weight.grad is not None:
|
| output_proj.projection.weight -= effective_lr * output_proj.projection.weight.grad
|
| output_proj.projection.weight.grad.zero_()
|
|
|
|
|
| with torch.no_grad():
|
| memory.update_stats(memory_loss)
|
| else:
|
|
|
| with torch.no_grad():
|
| k = key_proj(h_last)
|
| v = value_proj(h_last)
|
| memory_pred = memory(k)
|
| memory_loss = ((memory_pred - v) ** 2).mean()
|
| retention = 1.0
|
| memory.update_stats(memory_loss)
|
|
|
|
|
|
|
|
|
|
|
| generated_ids = inputs['input_ids'].clone()
|
|
|
| with torch.no_grad():
|
| for _ in range(MAX_NEW_TOKENS):
|
|
|
| outputs = model(generated_ids, output_hidden_states=True)
|
| h_last = outputs.hidden_states[-1][:, -1, :]
|
|
|
|
|
| k_gen = key_proj(h_last)
|
| memory_out = memory.query(k_gen)
|
|
|
|
|
| h_augmented = h_last + MEMORY_ALPHA * output_proj(memory_out)
|
|
|
|
|
| logits = model.lm_head(h_augmented)
|
|
|
|
|
| logits = logits / 0.8
|
| probs = torch.softmax(logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
| if next_token.item() == tokenizer.eos_token_id:
|
| break
|
|
|
|
|
| generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
|
|
| response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
|
|
|
| if response.startswith(message):
|
| response = response[len(message):].strip()
|
|
|
| if not response:
|
| response = "..."
|
|
|
|
|
| store.save(memory)
|
|
|
|
|
| stats = memory.get_stats()
|
|
|
| memory_info = (
|
| f"\n\n---\n"
|
| f"**🧠 Memory Update**\n"
|
| f"- Loss: {memory_loss.item():.4f} (lower = better prediction)\n"
|
| f"- Retention: {retention:.2f}x (surprise factor)\n"
|
| f"- Total Updates: {stats['updates']}\n"
|
| f"- Avg Loss: {stats['avg_loss']:.4f}"
|
| )
|
|
|
| return response + memory_info
|
|
|
|
|
|
|
| print("🚀 Launching Gradio interface...")
|
|
|
| demo = gr.ChatInterface(
|
| fn=chat,
|
| title="🧠 The Brain That Learns While Thinking",
|
| description="""
|
| # A Living System That Updates Its Weights During Inference
|
|
|
| **The Novel Thing**: Standard LLMs freeze their weights after training. This system performs gradient descent *while you chat*.
|
|
|
| ---
|
|
|
| ## 🚀 The Revolutionary Difference
|
|
|
| **Standard LLMs (ChatGPT, Claude, etc.)**: Think → Predict → **Forget**
|
| **Titans + MIRAS**: Think → Predict → **Update** → **Remember** → Think Differently
|
|
|
| ---
|
|
|
| ### 💡 What Makes This Different?
|
|
|
| | Feature | ChatGPT/Claude/GPT-4 | This Demo (Titans+MIRAS) |
|
| |---------|---------------------|--------------------------|
|
| | **Weights during chat** | 🔒 Frozen forever | ✅ Update with every message |
|
| | **Learning** | ❌ Simulated (in-context only) | ✅ Real (gradient descent) |
|
| | **Memory** | 📝 Token context only | 🧠 Neural parameters |
|
| | **Persistence** | ❌ Forgets when context ends | ✅ Saves to disk |
|
| | **Adaptation** | 🎭 Acts like it learned | 🔬 Actually learns |
|
|
|
| ---
|
|
|
| ### 🎯 What You're Witnessing
|
|
|
| **This is NOT a better chatbot** - it's a **learning demonstrator**.
|
|
|
| 1. **The text responses are random** - that's expected! We're using a small, frozen model (distilgpt2)
|
| 2. **The MAGIC is in the numbers below** - watch the "Loss" decrease when you repeat inputs!
|
| 3. **Every message physically changes the brain** - the memory weights update via gradient descent
|
| 4. **Refresh the page** - the update count continues (memory persists!)
|
|
|
| ---
|
|
|
| ### 🧪 How It Works (The Technical Truth)
|
|
|
| ```
|
| Your Message
|
| ↓
|
| [distilgpt2: FROZEN] ← Not learning, just generating
|
| ↓
|
| Hidden States (768-dim)
|
| ↓
|
| [Projections] → Memory Space (256-dim)
|
| ↓
|
| [MIRAS Memory: LEARNING!] ← This is what updates!
|
| ↓
|
| Loss = How surprised the memory is
|
| ↓
|
| Gradient Descent → Memory weights change
|
| ↓
|
| Saved to disk → Persists forever
|
| ```
|
|
|
| **Key Insight**: We're training the **memory**, not the text generator!
|
|
|
| ---
|
|
|
| ### 🔬 The Science: Why This Matters
|
|
|
| **Standard LLMs**:
|
| - Weights frozen after training (costs millions)
|
| - "Learning" is just pattern matching in context
|
| - Forget everything when context ends
|
| - Same model for everyone
|
|
|
| **Titans + MIRAS**:
|
| - Weights update during inference (free!)
|
| - Real optimization via gradient descent
|
| - Memory persists across sessions
|
| - Personalizes to each user
|
|
|
| **This is test-time learning** - the future of adaptive AI.
|
|
|
| ---
|
|
|
| ### 📊 What the Stats Mean
|
|
|
| - **Loss**: How surprised the memory is (lower = more familiar)
|
| - **Retention**: Learning rate multiplier (2.0x = very surprising, 0.5x = familiar)
|
| - **Updates**: Total number of memory updates (persists across sessions!)
|
| - **Avg Loss**: Overall learning progress
|
|
|
| ---
|
|
|
| ### 🎮 Try This Experiment
|
|
|
| 1. **Send "hello world" 5 times** → Watch loss decrease!
|
| 2. **Send something completely different** → Loss spikes!
|
| 3. **Refresh the page and send another message** → Update count continues!
|
|
|
| **That decreasing loss is proof the neural weights are changing!**
|
|
|
| ---
|
|
|
| ### 🌟 The Bottom Line
|
|
|
| **ChatGPT**: A frozen calculator that *simulates* adaptation
|
| **This Demo**: A living system that *performs* adaptation
|
|
|
| You're not chatting with a model.
|
| **You're watching a brain rewire itself in real-time.** 🧠⚡
|
|
|
| ---
|
|
|
| ### 🧪 How to Test This (Interactive Experiments)
|
|
|
| **Don't just chat—run experiments to see the learning happen!**
|
|
|
| #### Experiment 1: Watch Loss Decrease (Proof of Learning)
|
| ```
|
| 1. Send "hello world"
|
| 2. Send "hello world" again
|
| 3. Send "hello world" again
|
| 4. Send "hello world" again
|
| 5. Send "hello world" again
|
| ```
|
| **What to watch**: Loss should decrease each time (7.5 → 6.0 → 5.0 → 4.0)
|
| **Why it matters**: This proves the memory is learning the pattern!
|
|
|
| #### Experiment 2: Trigger Surprise (Spike the Loss)
|
| ```
|
| 1. Send "hello world" 5 times (loss decreases)
|
| 2. Then send: "Supercalifragilisticexpialidocious quantum entanglement"
|
| ```
|
| **What to watch**: Loss should spike back up (4.0 → 9.0+)
|
| **Why it matters**: The memory detects novelty—it knows this is different!
|
|
|
| #### Experiment 3: Test Persistence (Memory Survives)
|
| ```
|
| 1. Note the "Updates" count (e.g., 15)
|
| 2. Refresh this page completely
|
| 3. Send any message
|
| 4. Check if Updates = 16 (not reset to 1!)
|
| ```
|
| **What to watch**: Update count should continue, not reset
|
| **Why it matters**: Memory persists to disk—it's not just in RAM!
|
|
|
| ---
|
|
|
| ### 📊 What Each Stat Means (Decoder Ring)
|
|
|
| **Loss** (e.g., 7.48 → 6.61 → 5.23)
|
| - **What it is**: Prediction error (how surprised the memory is)
|
| - **Lower = Better**: Memory is familiar with this pattern
|
| - **Higher = Novel**: Memory hasn't seen this before
|
| - **Why it matters**: Decreasing loss = learning is happening!
|
|
|
| **Retention** (e.g., 2.00x)
|
| - **What it is**: Learning rate multiplier based on surprise
|
| - **2.0x = Very surprising**: Memory learns aggressively
|
| - **0.5x = Very familiar**: Memory learns slowly (you won't see this yet)
|
| - **Why it matters**: The brain learns more from surprising events (like humans!)
|
|
|
| **Updates** (e.g., 1 → 2 → 3 → 4...)
|
| - **What it is**: Total number of memory updates
|
| - **Persists across sessions**: Survives page refreshes
|
| - **Never resets**: Keeps counting forever
|
| - **Why it matters**: Proof that memory is persistent, not ephemeral!
|
|
|
| **Avg Loss** (e.g., 7.26)
|
| - **What it is**: Running average of all losses
|
| - **Trends downward**: As memory learns recurring patterns
|
| - **Reflects overall learning**: Lower = memory is getting smarter
|
| - **Why it matters**: Shows long-term learning progress!
|
|
|
| ---
|
|
|
| ### ⚠️ What to Ignore (Important!)
|
|
|
| **The text responses are random and bad** - this is expected!
|
| - We're NOT training the text generator (distilgpt2 is frozen)
|
| - The responses don't matter—they're a side effect
|
| - **Focus on the numbers below**, not the text above
|
| - The magic is in the decreasing loss, not the generated text
|
|
|
| **Why?** Because we're demonstrating **memory learning**, not text generation.
|
| Standard LLMs train the text generator. This trains the memory. Different goals.
|
|
|
| ---
|
|
|
| ### 🎯 What Success Looks Like
|
|
|
| ✅ **You're seeing it work if**:
|
| - Loss decreases when you repeat inputs
|
| - Loss spikes when you send something new
|
| - Update count increments with each message
|
| - Update count persists after page refresh
|
| - Retention is 2.0x (everything is surprising to fresh memory)
|
|
|
| ❌ **You're NOT seeing it work if**:
|
| - Loss stays constant (not learning)
|
| - Updates reset to 1 after refresh (not persisting)
|
| - No stats appear below responses
|
|
|
| ---
|
|
|
| ### 🔬 Why This Matters (The Big Picture)
|
|
|
| **Standard LLMs**: Frozen weights → No learning during use
|
| **This Demo**: Live weights → Learning with every message
|
|
|
| That decreasing loss you see? **That's gradient descent happening during inference.**
|
| That's the revolution. That's what ChatGPT doesn't do.
|
|
|
| You're not just using a model. **You're watching it change.**
|
|
|
| ---
|
|
|
| *Built with Titans (test-time training) + MIRAS (associative memory)*
|
| *Papers: [Titans](https://arxiv.org/abs/2501.00663) | [MIRAS](https://arxiv.org/abs/2504.13173)*
|
|
|
| **📖 [Read the full essay: "When Models Learn While Thinking"](https://huggingface.co/spaces/Pavantej/titans-miras-demo/blob/main/ESSAY.md)**
|
| """,
|
| examples=[
|
| "hello world",
|
| "hello world",
|
| "Tell me about test-time learning",
|
| "What is 2+2?",
|
| "my name is [your name]",
|
| ],
|
| cache_examples=False,
|
| theme="soft",
|
| )
|
|
|
| demo.launch()
|
|
|