Pavantej commited on
Commit
b6df69a
·
verified ·
1 Parent(s): dd93f43

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +48 -44
app.py CHANGED
@@ -24,9 +24,10 @@ print()
24
  MODEL_NAME = "distilgpt2"
25
  HIDDEN_DIM = 768 # distilgpt2 hidden dimension
26
  MEMORY_DIM = 256 # Memory space dimension
27
- LEARNING_RATE = 1e-3 # Base learning rate for test-time updates
28
  MAX_NEW_TOKENS = 50 # Max tokens to generate
29
- MEMORY_ALPHA = 0.1 # Memory influence strength on generation
 
30
 
31
  # ========== Initialize Components ==========
32
  print("🧠 Initializing Titans + MIRAS brain...")
@@ -88,53 +89,56 @@ def chat(message, history):
88
 
89
  if seq_len > 1:
90
  # We have context - train on predicting each next token
91
- with torch.enable_grad():
92
- total_lm_loss = 0.0
93
-
94
- # For each position (except last), predict next token
95
- for pos in range(seq_len - 1):
96
- h_pos = all_hidden[:, pos, :] # Hidden at position pos
97
-
98
- # Project to memory space
99
- k = key_proj(h_pos)
100
-
101
- # Query memory and augment hidden state
102
- memory_out = memory(k)
103
- h_augmented = h_pos + MEMORY_ALPHA * output_proj(memory_out)
104
 
105
- # Compute logits for next token
106
- logits = model.lm_head(h_augmented) # (1, vocab_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # Target is the actual next token
109
- target = inputs['input_ids'][:, pos + 1]
110
 
111
- # Cross-entropy loss
112
- lm_loss = nn.functional.cross_entropy(logits, target)
113
- total_lm_loss = total_lm_loss + lm_loss
114
-
115
- # Average loss over positions
116
- memory_loss = total_lm_loss / (seq_len - 1)
117
-
118
- # Get retention factor
119
- retention = memory.retention_gate(memory_loss)
120
- effective_lr = LEARNING_RATE * retention
121
-
122
- # Backprop and update
123
- memory_loss.backward()
124
-
125
- with torch.no_grad():
126
- # Update memory
127
- if memory.W.grad is not None:
128
- memory.W -= effective_lr * memory.W.grad
129
- memory.W.grad.zero_()
130
 
131
- # Update output projection
132
- if output_proj.projection.weight.grad is not None:
133
- output_proj.projection.weight -= effective_lr * output_proj.projection.weight.grad
134
- output_proj.projection.weight.grad.zero_()
135
 
136
- # Update stats
137
- memory.update_stats(memory_loss)
 
 
 
 
 
 
 
 
 
 
 
 
138
  else:
139
  # Single token - just compute MSE for stats
140
  with torch.no_grad():
 
24
  MODEL_NAME = "distilgpt2"
25
  HIDDEN_DIM = 768 # distilgpt2 hidden dimension
26
  MEMORY_DIM = 256 # Memory space dimension
27
+ LEARNING_RATE = 0.01 # Increased learning rate for faster adaptation
28
  MAX_NEW_TOKENS = 50 # Max tokens to generate
29
+ MEMORY_ALPHA = 1.0 # Increased from 0.1 - stronger memory influence
30
+ NUM_TRAIN_STEPS = 5 # Multiple gradient steps per input for better learning
31
 
32
  # ========== Initialize Components ==========
33
  print("🧠 Initializing Titans + MIRAS brain...")
 
89
 
90
  if seq_len > 1:
91
  # We have context - train on predicting each next token
92
+ # Run multiple training steps for faster learning
93
+ for train_step in range(NUM_TRAIN_STEPS):
94
+ with torch.enable_grad():
95
+ total_lm_loss = 0.0
 
 
 
 
 
 
 
 
 
96
 
97
+ # For each position (except last), predict next token
98
+ for pos in range(seq_len - 1):
99
+ h_pos = all_hidden[:, pos, :] # Hidden at position pos
100
+
101
+ # Project to memory space
102
+ k = key_proj(h_pos)
103
+
104
+ # Query memory and augment hidden state
105
+ memory_out = memory(k)
106
+ h_augmented = h_pos + MEMORY_ALPHA * output_proj(memory_out)
107
+
108
+ # Compute logits for next token
109
+ logits = model.lm_head(h_augmented) # (1, vocab_size)
110
+
111
+ # Target is the actual next token
112
+ target = inputs['input_ids'][:, pos + 1]
113
+
114
+ # Cross-entropy loss
115
+ lm_loss = nn.functional.cross_entropy(logits, target)
116
+ total_lm_loss = total_lm_loss + lm_loss
117
 
118
+ # Average loss over positions
119
+ memory_loss = total_lm_loss / (seq_len - 1)
120
 
121
+ # Get retention factor
122
+ retention = memory.retention_gate(memory_loss)
123
+ effective_lr = LEARNING_RATE * retention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ # Backprop and update
126
+ memory_loss.backward()
 
 
127
 
128
+ with torch.no_grad():
129
+ # Update memory
130
+ if memory.W.grad is not None:
131
+ memory.W -= effective_lr * memory.W.grad
132
+ memory.W.grad.zero_()
133
+
134
+ # Update output projection
135
+ if output_proj.projection.weight.grad is not None:
136
+ output_proj.projection.weight -= effective_lr * output_proj.projection.weight.grad
137
+ output_proj.projection.weight.grad.zero_()
138
+
139
+ # Update stats after all training steps (use final loss)
140
+ with torch.no_grad():
141
+ memory.update_stats(memory_loss)
142
  else:
143
  # Single token - just compute MSE for stats
144
  with torch.no_grad():