Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -24,9 +24,10 @@ print()
|
|
| 24 |
MODEL_NAME = "distilgpt2"
|
| 25 |
HIDDEN_DIM = 768 # distilgpt2 hidden dimension
|
| 26 |
MEMORY_DIM = 256 # Memory space dimension
|
| 27 |
-
LEARNING_RATE =
|
| 28 |
MAX_NEW_TOKENS = 50 # Max tokens to generate
|
| 29 |
-
MEMORY_ALPHA =
|
|
|
|
| 30 |
|
| 31 |
# ========== Initialize Components ==========
|
| 32 |
print("🧠 Initializing Titans + MIRAS brain...")
|
|
@@ -88,53 +89,56 @@ def chat(message, history):
|
|
| 88 |
|
| 89 |
if seq_len > 1:
|
| 90 |
# We have context - train on predicting each next token
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
for pos in range(seq_len - 1):
|
| 96 |
-
h_pos = all_hidden[:, pos, :] # Hidden at position pos
|
| 97 |
-
|
| 98 |
-
# Project to memory space
|
| 99 |
-
k = key_proj(h_pos)
|
| 100 |
-
|
| 101 |
-
# Query memory and augment hidden state
|
| 102 |
-
memory_out = memory(k)
|
| 103 |
-
h_augmented = h_pos + MEMORY_ALPHA * output_proj(memory_out)
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
# Average loss over positions
|
| 116 |
-
memory_loss = total_lm_loss / (seq_len - 1)
|
| 117 |
-
|
| 118 |
-
# Get retention factor
|
| 119 |
-
retention = memory.retention_gate(memory_loss)
|
| 120 |
-
effective_lr = LEARNING_RATE * retention
|
| 121 |
-
|
| 122 |
-
# Backprop and update
|
| 123 |
-
memory_loss.backward()
|
| 124 |
-
|
| 125 |
-
with torch.no_grad():
|
| 126 |
-
# Update memory
|
| 127 |
-
if memory.W.grad is not None:
|
| 128 |
-
memory.W -= effective_lr * memory.W.grad
|
| 129 |
-
memory.W.grad.zero_()
|
| 130 |
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
-
output_proj.projection.weight -= effective_lr * output_proj.projection.weight.grad
|
| 134 |
-
output_proj.projection.weight.grad.zero_()
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
else:
|
| 139 |
# Single token - just compute MSE for stats
|
| 140 |
with torch.no_grad():
|
|
|
|
| 24 |
MODEL_NAME = "distilgpt2"
|
| 25 |
HIDDEN_DIM = 768 # distilgpt2 hidden dimension
|
| 26 |
MEMORY_DIM = 256 # Memory space dimension
|
| 27 |
+
LEARNING_RATE = 0.01 # Increased learning rate for faster adaptation
|
| 28 |
MAX_NEW_TOKENS = 50 # Max tokens to generate
|
| 29 |
+
MEMORY_ALPHA = 1.0 # Increased from 0.1 - stronger memory influence
|
| 30 |
+
NUM_TRAIN_STEPS = 5 # Multiple gradient steps per input for better learning
|
| 31 |
|
| 32 |
# ========== Initialize Components ==========
|
| 33 |
print("🧠 Initializing Titans + MIRAS brain...")
|
|
|
|
| 89 |
|
| 90 |
if seq_len > 1:
|
| 91 |
# We have context - train on predicting each next token
|
| 92 |
+
# Run multiple training steps for faster learning
|
| 93 |
+
for train_step in range(NUM_TRAIN_STEPS):
|
| 94 |
+
with torch.enable_grad():
|
| 95 |
+
total_lm_loss = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
# For each position (except last), predict next token
|
| 98 |
+
for pos in range(seq_len - 1):
|
| 99 |
+
h_pos = all_hidden[:, pos, :] # Hidden at position pos
|
| 100 |
+
|
| 101 |
+
# Project to memory space
|
| 102 |
+
k = key_proj(h_pos)
|
| 103 |
+
|
| 104 |
+
# Query memory and augment hidden state
|
| 105 |
+
memory_out = memory(k)
|
| 106 |
+
h_augmented = h_pos + MEMORY_ALPHA * output_proj(memory_out)
|
| 107 |
+
|
| 108 |
+
# Compute logits for next token
|
| 109 |
+
logits = model.lm_head(h_augmented) # (1, vocab_size)
|
| 110 |
+
|
| 111 |
+
# Target is the actual next token
|
| 112 |
+
target = inputs['input_ids'][:, pos + 1]
|
| 113 |
+
|
| 114 |
+
# Cross-entropy loss
|
| 115 |
+
lm_loss = nn.functional.cross_entropy(logits, target)
|
| 116 |
+
total_lm_loss = total_lm_loss + lm_loss
|
| 117 |
|
| 118 |
+
# Average loss over positions
|
| 119 |
+
memory_loss = total_lm_loss / (seq_len - 1)
|
| 120 |
|
| 121 |
+
# Get retention factor
|
| 122 |
+
retention = memory.retention_gate(memory_loss)
|
| 123 |
+
effective_lr = LEARNING_RATE * retention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
# Backprop and update
|
| 126 |
+
memory_loss.backward()
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
# Update memory
|
| 130 |
+
if memory.W.grad is not None:
|
| 131 |
+
memory.W -= effective_lr * memory.W.grad
|
| 132 |
+
memory.W.grad.zero_()
|
| 133 |
+
|
| 134 |
+
# Update output projection
|
| 135 |
+
if output_proj.projection.weight.grad is not None:
|
| 136 |
+
output_proj.projection.weight -= effective_lr * output_proj.projection.weight.grad
|
| 137 |
+
output_proj.projection.weight.grad.zero_()
|
| 138 |
+
|
| 139 |
+
# Update stats after all training steps (use final loss)
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
memory.update_stats(memory_loss)
|
| 142 |
else:
|
| 143 |
# Single token - just compute MSE for stats
|
| 144 |
with torch.no_grad():
|