av-codes commited on
Commit
1ee2101
·
verified ·
1 Parent(s): 90dd792

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +75 -0
  2. config.json +77 -0
  3. model.pt +3 -0
  4. modeling_miras.py +214 -0
README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIRAS Language Model
2
+
3
+ A character-level language model trained on Shakespeare using the MIRAS (Memory-Integrated Recurrent Attention System) architecture.
4
+
5
+ ## Model Details
6
+ - **Embedding dimension**: 384
7
+ - **Layers**: 4
8
+ - **Block size**: 128
9
+ - **Memory type**: deep
10
+ - **Attentional bias**: l2
11
+ - **Retention**: l2
12
+ - **Vocabulary size**: 65
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ pip install torch huggingface_hub
18
+ ```
19
+
20
+ ## Usage
21
+
22
+ ### Quick Start
23
+
24
+ ```python
25
+ from huggingface_hub import hf_hub_download
26
+ import torch
27
+
28
+ # Download files
29
+ for f in ["modeling_miras.py", "model.pt", "config.json"]:
30
+ hf_hub_download(repo_id="av-codes/miras-shakespeare", filename=f, local_dir="./miras")
31
+
32
+ # Import and load
33
+ import sys
34
+ sys.path.insert(0, "./miras")
35
+ from modeling_miras import load_miras_model
36
+
37
+ model, encode, decode, config = load_miras_model("./miras")
38
+ model.eval()
39
+
40
+ # Generate text
41
+ context = torch.zeros((1, 1), dtype=torch.long)
42
+ output = model.generate(context, max_new_tokens=200, temperature=0.8)
43
+ print(decode(output[0].tolist()))
44
+ ```
45
+
46
+ ### Using the Helper Function
47
+
48
+ ```python
49
+ from modeling_miras import load_miras_model
50
+
51
+ # Load directly from Hub
52
+ model, encode, decode, config = load_miras_model("av-codes/miras-shakespeare")
53
+
54
+ # Generate
55
+ import torch
56
+ context = torch.zeros((1, 1), dtype=torch.long)
57
+ generated = model.generate(context, max_new_tokens=100)
58
+ print(decode(generated[0].tolist()))
59
+ ```
60
+
61
+ ## Files
62
+
63
+ - `model.pt` - Model weights and architecture config
64
+ - `config.json` - Full configuration including vocabulary
65
+ - `modeling_miras.py` - Complete model architecture code
66
+
67
+ ## Training
68
+ Trained for 5000 iterations on the TinyShakespeare dataset.
69
+
70
+ ## Architecture
71
+
72
+ MIRAS uses a novel memory-based attention mechanism with configurable:
73
+ - **Memory type**: `linear` (matrix memory) or `deep` (MLP memory)
74
+ - **Attentional bias**: `l2`, `lp`, or `huber` loss functions
75
+ - **Retention**: `l2`, `kl`, or `elastic` weight update rules
config.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "miras",
3
+ "vocab_size": 65,
4
+ "d_model": 384,
5
+ "n_layers": 4,
6
+ "block_size": 128,
7
+ "memory_type": "deep",
8
+ "attentional_bias": "l2",
9
+ "retention": "l2",
10
+ "chars": [
11
+ "\n",
12
+ " ",
13
+ "!",
14
+ "$",
15
+ "&",
16
+ "'",
17
+ ",",
18
+ "-",
19
+ ".",
20
+ "3",
21
+ ":",
22
+ ";",
23
+ "?",
24
+ "A",
25
+ "B",
26
+ "C",
27
+ "D",
28
+ "E",
29
+ "F",
30
+ "G",
31
+ "H",
32
+ "I",
33
+ "J",
34
+ "K",
35
+ "L",
36
+ "M",
37
+ "N",
38
+ "O",
39
+ "P",
40
+ "Q",
41
+ "R",
42
+ "S",
43
+ "T",
44
+ "U",
45
+ "V",
46
+ "W",
47
+ "X",
48
+ "Y",
49
+ "Z",
50
+ "a",
51
+ "b",
52
+ "c",
53
+ "d",
54
+ "e",
55
+ "f",
56
+ "g",
57
+ "h",
58
+ "i",
59
+ "j",
60
+ "k",
61
+ "l",
62
+ "m",
63
+ "n",
64
+ "o",
65
+ "p",
66
+ "q",
67
+ "r",
68
+ "s",
69
+ "t",
70
+ "u",
71
+ "v",
72
+ "w",
73
+ "x",
74
+ "y",
75
+ "z"
76
+ ]
77
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42e0315925c3efca2ab74185c640a64ba0d460e873cdf77a74a5dbccb8a021cf
3
+ size 45215151
modeling_miras.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MIRAS Language Model - Custom Architecture"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ def l2_loss(pred, target):
9
+ return 0.5 * ((pred - target) ** 2).sum(dim=-1)
10
+
11
+ def lp_loss(pred, target, p=3):
12
+ return (torch.abs(pred - target) ** p).sum(dim=-1)
13
+
14
+ def huber_loss(pred, target, delta):
15
+ diff = pred - target
16
+ abs_diff = torch.abs(diff)
17
+ return torch.where(abs_diff <= delta, 0.5 * diff ** 2, delta * (abs_diff - 0.5 * delta)).sum(dim=-1)
18
+
19
+ def l2_retention_update(W, grad, alpha, eta):
20
+ return alpha * W - eta * grad
21
+
22
+ def kl_retention_update(log_W, grad, alpha, eta, c=1.0):
23
+ log_W_new = alpha * log_W - eta * grad
24
+ return log_W_new, c * F.softmax(log_W_new, dim=-1)
25
+
26
+ def elastic_net_update(W, grad, lambda_decay, zeta_lr, gamma_l1):
27
+ z = lambda_decay * W - zeta_lr * grad
28
+ return torch.sign(z) * F.relu(torch.abs(z) - gamma_l1)
29
+
30
+
31
+ class KeyValueProjection(nn.Module):
32
+ def __init__(self, d_in, d_out):
33
+ super().__init__()
34
+ self.W_K = nn.Linear(d_in, d_out, bias=False)
35
+ self.W_V = nn.Linear(d_in, d_out, bias=False)
36
+ self.W_Q = nn.Linear(d_in, d_out, bias=False)
37
+
38
+ def forward(self, x):
39
+ return self.W_K(x), self.W_V(x), self.W_Q(x)
40
+
41
+
42
+ class MIRASLayer(nn.Module):
43
+ def __init__(self, d, memory_type='deep', attentional_bias='l2', retention='l2', expansion=4, p=3, q=4):
44
+ super().__init__()
45
+ self.d, self.memory_type, self.attentional_bias, self.retention = d, memory_type, attentional_bias, retention
46
+ self.p, self.q = p, q
47
+ self.kv_proj = KeyValueProjection(d, d)
48
+
49
+ if memory_type == 'linear':
50
+ self.register_buffer('M_init', torch.zeros(d, d))
51
+ else:
52
+ self.W1_init = nn.Parameter(torch.randn(d, d * expansion) * 0.02)
53
+ self.W2_init = nn.Parameter(torch.randn(d * expansion, d) * 0.02)
54
+ self.ln = nn.LayerNorm(d)
55
+
56
+ if attentional_bias == 'huber':
57
+ self.delta_proj = nn.Linear(d, 1)
58
+
59
+ self.alpha = nn.Parameter(torch.ones(1) * 0.9)
60
+ self.eta = nn.Parameter(torch.ones(1) * 0.1)
61
+ if retention == 'kl':
62
+ self.c = nn.Parameter(torch.ones(1))
63
+ if retention == 'elastic':
64
+ self.gamma = nn.Parameter(torch.ones(1) * 0.01)
65
+
66
+ def memory_forward_deep(self, x, W1, W2):
67
+ h = F.gelu(x @ W2.transpose(-2, -1))
68
+ return x + self.ln(h @ W1.transpose(-2, -1))
69
+
70
+ def get_loss(self, pred, target, x_t=None):
71
+ if self.attentional_bias == 'l2':
72
+ return l2_loss(pred, target).sum()
73
+ elif self.attentional_bias == 'lp':
74
+ return lp_loss(pred, target, self.p).sum()
75
+ else:
76
+ return huber_loss(pred, target, F.softplus(self.delta_proj(x_t))).sum()
77
+
78
+ def apply_retention(self, W, grad, log_W=None):
79
+ alpha, eta = torch.sigmoid(self.alpha), F.softplus(self.eta)
80
+ if self.retention == 'l2':
81
+ return l2_retention_update(W, grad, alpha, eta), None
82
+ elif self.retention == 'kl':
83
+ log_W = log_W if log_W is not None else torch.log(W.clamp(min=1e-10))
84
+ log_W_new, W_new = kl_retention_update(log_W, grad, alpha, eta, self.c)
85
+ return W_new, log_W_new
86
+ else:
87
+ return elastic_net_update(W, grad, alpha, eta, self.gamma), None
88
+
89
+ def forward(self, x):
90
+ k, v, q = self.kv_proj(x)
91
+ B, T, D = k.shape
92
+ outputs = []
93
+
94
+ with torch.enable_grad():
95
+ if self.memory_type == 'linear':
96
+ M = self.M_init.unsqueeze(0).expand(B, -1, -1).contiguous()
97
+ for t in range(T):
98
+ k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]
99
+ M_leaf = M.detach().requires_grad_(True)
100
+ pred = torch.einsum('bde,be->bd', M_leaf, k_t)
101
+ loss = self.get_loss(pred, v_t, x[:, t] if self.attentional_bias == 'huber' else None)
102
+ grad = torch.autograd.grad(loss, M_leaf)[0]
103
+ M, _ = self.apply_retention(M, grad)
104
+ outputs.append(torch.einsum('bde,be->bd', M, q_t))
105
+ else:
106
+ W1 = self.W1_init.unsqueeze(0).expand(B, -1, -1).contiguous()
107
+ W2 = self.W2_init.unsqueeze(0).expand(B, -1, -1).contiguous()
108
+ log_W1, log_W2 = None, None
109
+ if self.retention == 'kl':
110
+ W1, W2 = F.softmax(W1, dim=-1), F.softmax(W2, dim=-1)
111
+ log_W1, log_W2 = torch.log(W1.clamp(min=1e-10)), torch.log(W2.clamp(min=1e-10))
112
+
113
+ for t in range(T):
114
+ k_t, v_t, q_t = k[:, t], v[:, t], q[:, t]
115
+ W1_leaf, W2_leaf = W1.detach().requires_grad_(True), W2.detach().requires_grad_(True)
116
+ pred = self.memory_forward_deep(k_t.unsqueeze(1), W1_leaf, W2_leaf).squeeze(1)
117
+ loss = self.get_loss(pred, v_t, x[:, t] if self.attentional_bias == 'huber' else None)
118
+ grad1, grad2 = torch.autograd.grad(loss, [W1_leaf, W2_leaf])
119
+ W1, log_W1 = self.apply_retention(W1, grad1, log_W1)
120
+ W2, log_W2 = self.apply_retention(W2, grad2, log_W2)
121
+ outputs.append(self.memory_forward_deep(q_t.unsqueeze(1), W1.detach(), W2.detach()).squeeze(1))
122
+
123
+ return torch.stack(outputs, dim=1)
124
+
125
+
126
+ class MIRASBlock(nn.Module):
127
+ def __init__(self, d_model, memory_type, attentional_bias, retention, ffn_mult=4):
128
+ super().__init__()
129
+ self.ln1 = nn.LayerNorm(d_model)
130
+ self.memory = MIRASLayer(d_model, memory_type, attentional_bias, retention)
131
+ self.ln2 = nn.LayerNorm(d_model)
132
+ self.ffn = nn.Sequential(nn.Linear(d_model, d_model * ffn_mult), nn.GELU(), nn.Linear(d_model * ffn_mult, d_model))
133
+
134
+ def forward(self, x):
135
+ x = x + self.memory(self.ln1(x))
136
+ return x + self.ffn(self.ln2(x))
137
+
138
+
139
+ class MIRASLanguageModel(nn.Module):
140
+ def __init__(self, vocab_size, d_model, n_layers, memory_type='deep', attentional_bias='l2', retention='l2', block_size=128):
141
+ super().__init__()
142
+ self.block_size = block_size
143
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
144
+ self.position_embedding = nn.Embedding(block_size, d_model)
145
+ self.layers = nn.ModuleList([MIRASBlock(d_model, memory_type, attentional_bias, retention) for _ in range(n_layers)])
146
+ self.ln_f = nn.LayerNorm(d_model)
147
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
148
+ self.token_embedding.weight = self.lm_head.weight
149
+ self.apply(self._init_weights)
150
+
151
+ def _init_weights(self, m):
152
+ if isinstance(m, nn.Linear):
153
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
154
+ if m.bias is not None:
155
+ torch.nn.init.zeros_(m.bias)
156
+ elif isinstance(m, nn.Embedding):
157
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
158
+
159
+ def forward(self, idx, targets=None):
160
+ B, T = idx.shape
161
+ x = self.token_embedding(idx) + self.position_embedding(torch.arange(T, device=idx.device))
162
+ for layer in self.layers:
163
+ x = layer(x)
164
+ logits = self.lm_head(self.ln_f(x))
165
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None
166
+ return logits, loss
167
+
168
+ @torch.no_grad()
169
+ def generate(self, idx, max_new_tokens, temperature=1.0):
170
+ for _ in range(max_new_tokens):
171
+ logits, _ = self(idx[:, -self.block_size:])
172
+ probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)
173
+ idx = torch.cat((idx, torch.multinomial(probs, num_samples=1)), dim=1)
174
+ return idx
175
+
176
+
177
+ def load_miras_model(repo_id_or_path, device='cpu'):
178
+ """Load a MIRAS model from HuggingFace Hub or local path."""
179
+ import json
180
+ from pathlib import Path
181
+
182
+ if Path(repo_id_or_path).exists():
183
+ base_path = Path(repo_id_or_path)
184
+ config_path = base_path / "config.json"
185
+ model_path = base_path / "model.pt"
186
+ else:
187
+ from huggingface_hub import hf_hub_download
188
+ config_path = hf_hub_download(repo_id=repo_id_or_path, filename="config.json")
189
+ model_path = hf_hub_download(repo_id=repo_id_or_path, filename="model.pt")
190
+
191
+ with open(config_path) as f:
192
+ config = json.load(f)
193
+
194
+ model = MIRASLanguageModel(
195
+ vocab_size=config['vocab_size'],
196
+ d_model=config['d_model'],
197
+ n_layers=config['n_layers'],
198
+ memory_type=config['memory_type'],
199
+ attentional_bias=config['attentional_bias'],
200
+ retention=config['retention'],
201
+ block_size=config['block_size'],
202
+ )
203
+
204
+ checkpoint = torch.load(model_path, map_location=device)
205
+ model.load_state_dict(checkpoint['model_state_dict'])
206
+ model.to(device)
207
+ model.eval()
208
+
209
+ stoi = {ch: i for i, ch in enumerate(config['chars'])}
210
+ itos = {i: ch for i, ch in enumerate(config['chars'])}
211
+ encode = lambda s: [stoi[c] for c in s]
212
+ decode = lambda l: ''.join([itos[i] for i in l])
213
+
214
+ return model, encode, decode, config