FlameF0X commited on
Commit
3364c14
·
verified ·
1 Parent(s): cc0bbd6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -0
app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ from huggingface_hub import hf_hub_download
6
+ from tokenizers import Tokenizer
7
+ import os
8
+
9
+ # ============================================================================
10
+ # 1. MODEL ARCHITECTURE (Must match training code exactly)
11
+ # ============================================================================
12
+
13
+ @torch.jit.script
14
+ def rwkv_linear_attention(B: int, T: int, C: int,
15
+ r: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
16
+ w: torch.Tensor, u: torch.Tensor,
17
+ state_init: torch.Tensor):
18
+ y = torch.zeros_like(v)
19
+ state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device)
20
+ state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device)
21
+ state_pp = state_init.clone()
22
+
23
+ for t in range(T):
24
+ rt, kt, vt = r[:, t], k[:, t], v[:, t]
25
+ ww = u + state_pp
26
+ p = torch.maximum(ww, kt)
27
+ e1 = torch.exp(ww - p)
28
+ e2 = torch.exp(kt - p)
29
+ wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6)
30
+ y[:, t] = wkv
31
+
32
+ ww = w + state_pp
33
+ p = torch.maximum(ww, kt)
34
+ e1 = torch.exp(ww - p)
35
+ e2 = torch.exp(kt - p)
36
+ state_aa = state_aa * e1 + vt * e2
37
+ state_bb = state_bb * e1 + e2
38
+ state_pp = p
39
+
40
+ return y
41
+
42
+ class RWKVTimeMix(nn.Module):
43
+ def __init__(self, d_model):
44
+ super().__init__()
45
+ self.d_model = d_model
46
+ self.time_decay = nn.Parameter(torch.ones(d_model))
47
+ self.time_first = nn.Parameter(torch.ones(d_model))
48
+ self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
49
+ self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
50
+ self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
51
+ self.key = nn.Linear(d_model, d_model, bias=False)
52
+ self.value = nn.Linear(d_model, d_model, bias=False)
53
+ self.receptance = nn.Linear(d_model, d_model, bias=False)
54
+ self.output = nn.Linear(d_model, d_model, bias=False)
55
+ self.time_decay.data.uniform_(-6, -3)
56
+
57
+ def forward(self, x):
58
+ B, T, C = x.size()
59
+ xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
60
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
61
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
62
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
63
+
64
+ k = self.key(xk)
65
+ v = self.value(xv)
66
+ r = torch.sigmoid(self.receptance(xr))
67
+
68
+ w = -torch.exp(self.time_decay)
69
+ u = self.time_first
70
+ state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device)
71
+
72
+ rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init)
73
+ return self.output(r * rwkv)
74
+
75
+ class RWKVChannelMix(nn.Module):
76
+ def __init__(self, d_model, ffn_mult=4):
77
+ super().__init__()
78
+ self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
79
+ self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
80
+ hidden_sz = d_model * ffn_mult
81
+ self.key = nn.Linear(d_model, hidden_sz, bias=False)
82
+ self.receptance = nn.Linear(d_model, d_model, bias=False)
83
+ self.value = nn.Linear(hidden_sz, d_model, bias=False)
84
+
85
+ def forward(self, x):
86
+ B, T, C = x.size()
87
+ xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
88
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
89
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
90
+
91
+ k = torch.square(torch.relu(self.key(xk)))
92
+ kv = self.value(k)
93
+ r = torch.sigmoid(self.receptance(xr))
94
+ return r * kv
95
+
96
+ class BiRWKVBlock(nn.Module):
97
+ def __init__(self, d_model, ffn_mult=4):
98
+ super().__init__()
99
+ self.ln1 = nn.LayerNorm(d_model)
100
+ self.fwd_time_mix = RWKVTimeMix(d_model)
101
+ self.bwd_time_mix = RWKVTimeMix(d_model)
102
+ self.ln2 = nn.LayerNorm(d_model)
103
+ self.channel_mix = RWKVChannelMix(d_model, ffn_mult)
104
+
105
+ def forward(self, x, mask=None):
106
+ x_norm = self.ln1(x)
107
+ x_fwd = self.fwd_time_mix(x_norm)
108
+ x_rev = torch.flip(x_norm, [1])
109
+ x_bwd_rev = self.bwd_time_mix(x_rev)
110
+ x_bwd = torch.flip(x_bwd_rev, [1])
111
+ x = x + x_fwd + x_bwd
112
+ x = x + self.channel_mix(self.ln2(x))
113
+ return x
114
+
115
+ class FullAttention(nn.Module):
116
+ def __init__(self, d_model, n_heads=16):
117
+ super().__init__()
118
+ self.d_model = d_model
119
+ self.n_heads = n_heads
120
+ self.head_dim = d_model // n_heads
121
+ self.qkv = nn.Linear(d_model, d_model * 3)
122
+ self.out_proj = nn.Linear(d_model, d_model)
123
+
124
+ def forward(self, x, mask=None):
125
+ B, T, C = x.shape
126
+ qkv = self.qkv(x)
127
+ q, k, v = qkv.chunk(3, dim=-1)
128
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
129
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
130
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
131
+
132
+ attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
133
+ if mask is not None:
134
+ attn = attn.masked_fill(mask == 0, float('-inf'))
135
+ attn = F.softmax(attn, dim=-1)
136
+ out = attn @ v
137
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
138
+ return self.out_proj(out)
139
+
140
+ class StandardAttentionBlock(nn.Module):
141
+ def __init__(self, d_model, n_heads=16, ffn_mult=4):
142
+ super().__init__()
143
+ self.ln1 = nn.LayerNorm(d_model)
144
+ self.attn = FullAttention(d_model, n_heads)
145
+ self.ln2 = nn.LayerNorm(d_model)
146
+ self.ffn = nn.Sequential(
147
+ nn.Linear(d_model, d_model * ffn_mult),
148
+ nn.GELU(),
149
+ nn.Linear(d_model * ffn_mult, d_model)
150
+ )
151
+
152
+ def forward(self, x, mask=None):
153
+ x = x + self.attn(self.ln1(x), mask)
154
+ x = x + self.ffn(self.ln2(x))
155
+ return x
156
+
157
+ class HybridBertEmbeddings(nn.Module):
158
+ def __init__(self, vocab_size, d_model, max_len=512):
159
+ super().__init__()
160
+ self.word_embeddings = nn.Embedding(vocab_size, d_model)
161
+ self.position_embeddings = nn.Embedding(max_len, d_model)
162
+ self.token_type_embeddings = nn.Embedding(2, d_model)
163
+ self.ln = nn.LayerNorm(d_model)
164
+ self.dropout = nn.Dropout(0.1)
165
+
166
+ def forward(self, input_ids, token_type_ids):
167
+ seq_len = input_ids.size(1)
168
+ pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
169
+ embeddings = (self.word_embeddings(input_ids) +
170
+ self.position_embeddings(pos_ids) +
171
+ self.token_type_embeddings(token_type_ids))
172
+ return self.dropout(self.ln(embeddings))
173
+
174
+ class HybridBertModel(nn.Module):
175
+ def __init__(self, vocab_size, d_model=768, n_rwkv_layers=6, n_attn_layers=6, n_heads=12, max_len=512):
176
+ super().__init__()
177
+ self.embeddings = HybridBertEmbeddings(vocab_size, d_model, max_len)
178
+ self.layers = nn.ModuleList()
179
+ for _ in range(n_rwkv_layers):
180
+ self.layers.append(BiRWKVBlock(d_model, ffn_mult=4))
181
+ for _ in range(n_attn_layers):
182
+ self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads))
183
+
184
+ self.mlm_head = nn.Sequential(
185
+ nn.Linear(d_model, d_model),
186
+ nn.GELU(),
187
+ nn.LayerNorm(d_model),
188
+ nn.Linear(d_model, vocab_size)
189
+ )
190
+ self.pooler_dense = nn.Linear(d_model, d_model)
191
+ self.nsp_head = nn.Linear(d_model, 2)
192
+
193
+ def forward(self, input_ids, segment_ids):
194
+ mask = (input_ids != 1).unsqueeze(1).unsqueeze(2) # 1 is PAD_TOKEN_ID
195
+ x = self.embeddings(input_ids, segment_ids)
196
+ for layer in self.layers:
197
+ x = layer(x, mask)
198
+ prediction_scores = self.mlm_head(x)
199
+ return prediction_scores
200
+
201
+ # ============================================================================
202
+ # 2. INITIALIZATION
203
+ # ============================================================================
204
+
205
+ REPO_ID = "FlameF0X/i3-BERT"
206
+ MODEL_FILENAME = "i3-bert.pt"
207
+ TOKENIZER_FILENAME = "tokenizer_bert.json"
208
+
209
+ print("Downloading model and tokenizer from Hugging Face Hub...")
210
+ try:
211
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
212
+ tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename=TOKENIZER_FILENAME)
213
+ except Exception as e:
214
+ print(f"Error downloading files: {e}")
215
+ print("Ensure 'i3-bert.pt' and 'tokenizer_bert.json' exist in 'FlameF0X/i3-BERT'")
216
+ raise e
217
+
218
+ # Load Tokenizer
219
+ tokenizer = Tokenizer.from_file(tokenizer_path)
220
+ vocab_size = tokenizer.get_vocab_size()
221
+
222
+ # Special Token IDs (based on your training code)
223
+ CLS_ID = tokenizer.token_to_id("<CLS>")
224
+ SEP_ID = tokenizer.token_to_id("<SEP>")
225
+ MASK_ID = tokenizer.token_to_id("<MASK>")
226
+ PAD_ID = tokenizer.token_to_id("<PAD>")
227
+
228
+ # Load Model
229
+ # Config matching the training parameters provided
230
+ config = {
231
+ "d_model": 768,
232
+ "n_rwkv_layers": 4,
233
+ "n_attn_layers": 4,
234
+ "n_heads": 12,
235
+ "seq_len": 128
236
+ }
237
+
238
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
239
+ model = HybridBertModel(
240
+ vocab_size=vocab_size,
241
+ d_model=config['d_model'],
242
+ n_rwkv_layers=config['n_rwkv_layers'],
243
+ n_attn_layers=config['n_attn_layers'],
244
+ n_heads=config['n_heads'],
245
+ max_len=config['seq_len']
246
+ ).to(device)
247
+
248
+ print("Loading state dict...")
249
+ state_dict = torch.load(model_path, map_location=device)
250
+ model.load_state_dict(state_dict)
251
+ model.eval()
252
+ print("Model loaded successfully!")
253
+
254
+ # ============================================================================
255
+ # 3. GRADIO INFERENCE FUNCTION
256
+ # ============================================================================
257
+
258
+ def predict_mask(text):
259
+ if not text:
260
+ return "Please enter text."
261
+
262
+ # Ensure the user provided a <mask> token
263
+ if "<MASK>" not in text:
264
+ return "Please include a <MASK> token in your text to predict."
265
+
266
+ # Tokenize
267
+ encoded = tokenizer.encode(text)
268
+ ids = encoded.ids
269
+
270
+ # Truncate if necessary (keeping space for CLS and SEP)
271
+ max_len = config['seq_len'] - 2
272
+ if len(ids) > max_len:
273
+ ids = ids[:max_len]
274
+
275
+ # Add CLS and SEP
276
+ input_ids = [CLS_ID] + ids + [SEP_ID]
277
+ segment_ids = [0] * len(input_ids) # Single sentence segment
278
+
279
+ # Find MASK indices
280
+ mask_indices = [i for i, token_id in enumerate(input_ids) if token_id == MASK_ID]
281
+
282
+ if not mask_indices:
283
+ return "No <MASK> token found after tokenization."
284
+
285
+ # Convert to Tensor
286
+ input_tensor = torch.tensor([input_ids], device=device)
287
+ segment_tensor = torch.tensor([segment_ids], device=device)
288
+
289
+ # Inference
290
+ with torch.no_grad():
291
+ logits = model(input_tensor, segment_tensor)
292
+
293
+ # Process results for each mask
294
+ results = []
295
+ for idx in mask_indices:
296
+ mask_logits = logits[0, idx, :]
297
+ top_k = torch.topk(mask_logits, 5)
298
+
299
+ candidates = []
300
+ for score, token_id in zip(top_k.values, top_k.indices):
301
+ word = tokenizer.decode([token_id.item()])
302
+ candidates.append(f"{word} ({score.item():.2f})")
303
+
304
+ results.append(f"Mask at pos {idx}: " + ", ".join(candidates))
305
+
306
+ return "\n".join(results)
307
+
308
+ # ============================================================================
309
+ # 4. LAUNCH UI
310
+ # ============================================================================
311
+
312
+ with gr.Blocks() as demo:
313
+ gr.Markdown("# i3-BERT: Hybrid RWKV + Attention Model")
314
+ gr.Markdown("A custom 10M parameter model combining Bi-Directional RWKV and Attention layers.")
315
+ gr.Markdown("Type a sentence with `<MASK>` to see predictions.")
316
+
317
+ with gr.Row():
318
+ inp = gr.Textbox(placeholder="The capital of France is <MASK>.", label="Input Text")
319
+ out = gr.Textbox(label="Predictions")
320
+
321
+ btn = gr.Button("Predict")
322
+ btn.click(fn=predict_mask, inputs=inp, outputs=out)
323
+
324
+ examples = [
325
+ ["The quick brown fox jumps over the <MASK> dog."],
326
+ ["I want to eat a <MASK> for lunch."],
327
+ ["Python is a great programming <MASK>."]
328
+ ]
329
+ gr.Examples(examples, inp)
330
+
331
+ demo.launch()