Text Generation
Transformers
PyTorch
Safetensors
English
i3
i3-architecture
hybrid-model
rwkv-mamba
custom_code
FlameF0X commited on
Commit
885e103
·
verified ·
1 Parent(s): d5710f4

Create i3_model.py

Browse files
Files changed (1) hide show
  1. i3_model.py +230 -0
i3_model.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import json
5
+ import os
6
+
7
+ # ======================================================================
8
+ # RWKV-Mamba Hybrid Recurrence
9
+ # ======================================================================
10
+ class RWKVMambaHybrid(nn.Module):
11
+ def __init__(self, d_model, d_state=64):
12
+ super().__init__()
13
+ self.d_model = d_model
14
+ self.d_state = d_state
15
+ self.w_mix = nn.Parameter(torch.ones(d_model) * 0.5)
16
+ self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
17
+ self.B = nn.Parameter(torch.randn(d_state, d_model) * 0.01)
18
+ self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.01)
19
+ self.D = nn.Parameter(torch.ones(d_model) * 0.1)
20
+
21
+ def forward(self, x):
22
+ B, T, C = x.shape
23
+ h = torch.zeros(B, C, device=x.device)
24
+ s = torch.zeros(B, self.d_state, device=x.device)
25
+ outputs = []
26
+
27
+ for t in range(T):
28
+ x_t = x[:, t, :]
29
+ h = self.w_mix * h + (1 - self.w_mix) * x_t
30
+ s = s @ self.A.T + x_t @ self.B.T
31
+ y_t = s @ self.C.T + h * self.D
32
+ outputs.append(y_t)
33
+
34
+ return torch.stack(outputs, dim=1)
35
+
36
+ # ======================================================================
37
+ # Full Multi-Head Attention
38
+ # ======================================================================
39
+ class FullAttention(nn.Module):
40
+ def __init__(self, d_model, n_heads=16):
41
+ super().__init__()
42
+ self.d_model = d_model
43
+ self.n_heads = n_heads
44
+ self.head_dim = d_model // n_heads
45
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
46
+ self.qkv = nn.Linear(d_model, d_model * 3)
47
+ self.out_proj = nn.Linear(d_model, d_model)
48
+
49
+ def forward(self, x, mask=None):
50
+ B, T, C = x.shape
51
+ qkv = self.qkv(x)
52
+ q, k, v = qkv.chunk(3, dim=-1)
53
+
54
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
55
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
56
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
57
+
58
+ attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
59
+ if mask is not None:
60
+ mask = mask.expand(B, self.n_heads, T, T).bool()
61
+ attn = attn.masked_fill(mask == 0, float('-inf'))
62
+ attn = F.softmax(attn, dim=-1)
63
+ out = attn @ v
64
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
65
+ return self.out_proj(out)
66
+
67
+ # ======================================================================
68
+ # i3 Hybrid Block
69
+ # ======================================================================
70
+ class i3HybridBlock(nn.Module):
71
+ def __init__(self, d_model, d_state=64, ffn_mult=4):
72
+ super().__init__()
73
+ self.ln1 = nn.LayerNorm(d_model)
74
+ self.hybrid = RWKVMambaHybrid(d_model, d_state)
75
+ self.ln2 = nn.LayerNorm(d_model)
76
+ d_ff = d_model * ffn_mult
77
+ self.ffn = nn.Sequential(
78
+ nn.Linear(d_model, d_ff),
79
+ nn.GELU(),
80
+ nn.Linear(d_ff, d_model)
81
+ )
82
+
83
+ def forward(self, x, mask=None):
84
+ x = x + self.hybrid(self.ln1(x))
85
+ x = x + self.ffn(self.ln2(x))
86
+ return x
87
+
88
+ # ======================================================================
89
+ # i3 Attention Block
90
+ # ======================================================================
91
+ class i3AttentionBlock(nn.Module):
92
+ def __init__(self, d_model, n_heads=16, ffn_mult=4):
93
+ super().__init__()
94
+ self.ln1 = nn.LayerNorm(d_model)
95
+ self.attn = FullAttention(d_model, n_heads)
96
+ self.ln2 = nn.LayerNorm(d_model)
97
+ d_ff = d_model * ffn_mult
98
+ self.ffn = nn.Sequential(
99
+ nn.Linear(d_model, d_ff),
100
+ nn.GELU(),
101
+ nn.Linear(d_ff, d_model)
102
+ )
103
+
104
+ def forward(self, x, mask=None):
105
+ x = x + self.attn(self.ln1(x), mask)
106
+ x = x + self.ffn(self.ln2(x))
107
+ return x
108
+
109
+ # ======================================================================
110
+ # Full i3 Model
111
+ # ======================================================================
112
+ class i3Model(nn.Module):
113
+ def __init__(self, vocab_size, d_model=512, n_heads=16, max_seq_len=256, d_state=32):
114
+ super().__init__()
115
+ self.vocab_size = vocab_size
116
+ self.d_model = d_model
117
+ self.max_seq_len = max_seq_len
118
+
119
+ self.embed = nn.Embedding(vocab_size, d_model)
120
+ self.pos_embed = nn.Embedding(max_seq_len, d_model)
121
+
122
+ hybrid_layers = [i3HybridBlock(d_model, d_state=d_state) for _ in range(10)]
123
+ attention_layers = [i3AttentionBlock(d_model, n_heads=n_heads) for _ in range(6)]
124
+ self.layers = nn.ModuleList(hybrid_layers + attention_layers)
125
+
126
+ self.ln_f = nn.LayerNorm(d_model)
127
+ self.head = nn.Linear(d_model, vocab_size)
128
+ self.apply(self._init_weights)
129
+
130
+ def _init_weights(self, module):
131
+ if isinstance(module, (nn.Linear, nn.Embedding)):
132
+ module.weight.data.normal_(0.0, 0.02)
133
+ if isinstance(module, nn.Linear) and module.bias is not None:
134
+ module.bias.data.zero_()
135
+
136
+ def forward(self, idx, targets=None):
137
+ B, T = idx.shape
138
+ pos = torch.arange(T, device=idx.device).unsqueeze(0)
139
+ x = self.embed(idx) + self.pos_embed(pos)
140
+ mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T)
141
+
142
+ for layer in self.layers:
143
+ x = layer(x, mask)
144
+
145
+ x = self.ln_f(x)
146
+ logits = self.head(x)
147
+ loss = None
148
+ if targets is not None:
149
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
150
+ return logits, loss
151
+
152
+ @torch.no_grad()
153
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
154
+ for _ in range(max_new_tokens):
155
+ idx_cond = idx if idx.size(1) <= self.max_seq_len else idx[:, -self.max_seq_len:]
156
+ logits, _ = self(idx_cond)
157
+ logits = logits[:, -1, :] / temperature
158
+ if top_k is not None:
159
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
160
+ logits[logits < v[:, [-1]]] = -float('Inf')
161
+ probs = F.softmax(logits, dim=-1)
162
+ idx_next = torch.multinomial(probs, 1)
163
+ idx = torch.cat((idx, idx_next), dim=1)
164
+ return idx
165
+
166
+ # ======================================================================
167
+ # ChunkTokenizer
168
+ # ======================================================================
169
+ class ChunkTokenizer:
170
+ def __init__(self, vocab_only=False):
171
+ self.chunk_to_idx = {}
172
+ self.idx_to_chunk = {}
173
+ self.vocab_size = 0
174
+ self.unk_token = '<UNK>'
175
+ self.unk_idx = 0
176
+ self.common_trigrams = frozenset({
177
+ 'the', 'and', 'ing', 'ion', 'tio', 'for', 'tha',
178
+ 'ter', 'hat', 'his', 'ere', 'ent', 'her', 'was',
179
+ 'you', 'are', 'not', 'but', 'can', 'all', 'whi',
180
+ 'one', 'our', 'out', 'whe', 'hav', 'thi', 'wit'
181
+ })
182
+
183
+ def _should_use_3char(self, pos, text):
184
+ if pos + 3 > len(text):
185
+ return False
186
+ chunk_3 = text[pos:pos+3]
187
+ if chunk_3 in self.common_trigrams:
188
+ return True
189
+ if pos > 0 and text[pos-1] == ' ':
190
+ return True
191
+ if pos + 3 < len(text) and text[pos+3] == ' ':
192
+ return True
193
+ return False
194
+
195
+ def encode(self, text):
196
+ text = text.lower()
197
+ pos, indices = 0, []
198
+ while pos < len(text):
199
+ chunk_len = 3 if self._should_use_3char(pos, text) else 2
200
+ chunk_len = min(chunk_len, len(text) - pos)
201
+ chunk = text[pos:pos+chunk_len]
202
+ if chunk in self.chunk_to_idx:
203
+ indices.append(self.chunk_to_idx[chunk])
204
+ else:
205
+ indices.append(self.unk_idx)
206
+ pos += chunk_len
207
+ return indices
208
+
209
+ def decode(self, indices):
210
+ return ''.join([self.idx_to_chunk.get(int(i), self.unk_token) for i in indices])
211
+
212
+ def save(self, path):
213
+ data = {
214
+ 'chunk_to_idx': self.chunk_to_idx,
215
+ 'idx_to_chunk': {int(k): v for k, v in self.idx_to_chunk.items()},
216
+ 'vocab_size': self.vocab_size,
217
+ 'unk_token': self.unk_token,
218
+ 'unk_idx': self.unk_idx
219
+ }
220
+ with open(path, 'w') as f:
221
+ json.dump(data, f)
222
+
223
+ def load(self, path):
224
+ with open(path, 'r') as f:
225
+ data = json.load(f)
226
+ self.chunk_to_idx = data['chunk_to_idx']
227
+ self.idx_to_chunk = {int(k): v for k, v in data['idx_to_chunk'].items()}
228
+ self.vocab_size = data['vocab_size']
229
+ self.unk_token = data.get('unk_token', '<UNK>')
230
+ self.unk_idx = data.get('unk_idx', 0)