FlameF0X commited on
Commit
14b19ae
·
verified ·
1 Parent(s): be24077

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import json
5
+ import os
6
+ import gradio as gr
7
+ from tokenizers import Tokenizer
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # ============================================================================
11
+ # 1. MODEL ARCHITECTURE
12
+ # (Copied from inference.py to support custom weight loading)
13
+ # ============================================================================
14
+
15
+ @torch.jit.script
16
+ def rwkv_linear_attention(B: int, T: int, C: int,
17
+ r: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
18
+ w: torch.Tensor, u: torch.Tensor,
19
+ state_init: torch.Tensor):
20
+ y = torch.zeros_like(v)
21
+ state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device)
22
+ state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device)
23
+ state_pp = state_init.clone()
24
+
25
+ for t in range(T):
26
+ rt, kt, vt = r[:, t], k[:, t], v[:, t]
27
+ ww = u + state_pp
28
+ p = torch.maximum(ww, kt)
29
+ e1 = torch.exp(ww - p)
30
+ e2 = torch.exp(kt - p)
31
+ wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6)
32
+ y[:, t] = wkv
33
+
34
+ ww = w + state_pp
35
+ p = torch.maximum(ww, kt)
36
+ e1 = torch.exp(ww - p)
37
+ e2 = torch.exp(kt - p)
38
+ state_aa = state_aa * e1 + vt * e2
39
+ state_bb = state_bb * e1 + e2
40
+ state_pp = p
41
+
42
+ return y
43
+
44
+ class RWKVTimeMix(nn.Module):
45
+ def __init__(self, d_model):
46
+ super().__init__()
47
+ self.d_model = d_model
48
+ self.time_decay = nn.Parameter(torch.ones(d_model))
49
+ self.time_first = nn.Parameter(torch.ones(d_model))
50
+ self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
51
+ self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
52
+ self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
53
+ self.key = nn.Linear(d_model, d_model, bias=False)
54
+ self.value = nn.Linear(d_model, d_model, bias=False)
55
+ self.receptance = nn.Linear(d_model, d_model, bias=False)
56
+ self.output = nn.Linear(d_model, d_model, bias=False)
57
+
58
+ def forward(self, x):
59
+ B, T, C = x.size()
60
+ xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
61
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
62
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
63
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
64
+ k = self.key(xk)
65
+ v = self.value(xv)
66
+ r = torch.sigmoid(self.receptance(xr))
67
+ w = -torch.exp(self.time_decay)
68
+ u = self.time_first
69
+ state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device)
70
+ rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init)
71
+ return self.output(r * rwkv)
72
+
73
+ class RWKVChannelMix(nn.Module):
74
+ def __init__(self, d_model, ffn_mult=4):
75
+ super().__init__()
76
+ self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
77
+ self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
78
+ hidden_sz = d_model * ffn_mult
79
+ self.key = nn.Linear(d_model, hidden_sz, bias=False)
80
+ self.receptance = nn.Linear(d_model, d_model, bias=False)
81
+ self.value = nn.Linear(hidden_sz, d_model, bias=False)
82
+
83
+ def forward(self, x):
84
+ B, T, C = x.size()
85
+ xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
86
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
87
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
88
+ k = torch.square(torch.relu(self.key(xk)))
89
+ kv = self.value(k)
90
+ r = torch.sigmoid(self.receptance(xr))
91
+ return r * kv
92
+
93
+ class RWKVBlock(nn.Module):
94
+ def __init__(self, d_model, ffn_mult=4):
95
+ super().__init__()
96
+ self.ln1 = nn.LayerNorm(d_model)
97
+ self.att = RWKVTimeMix(d_model)
98
+ self.ln2 = nn.LayerNorm(d_model)
99
+ self.ffn = RWKVChannelMix(d_model, ffn_mult)
100
+
101
+ def forward(self, x, mask=None):
102
+ x = x + self.att(self.ln1(x))
103
+ x = x + self.ffn(self.ln2(x))
104
+ return x
105
+
106
+ class FullAttention(nn.Module):
107
+ def __init__(self, d_model, n_heads=16):
108
+ super().__init__()
109
+ self.d_model = d_model
110
+ self.n_heads = n_heads
111
+ self.head_dim = d_model // n_heads
112
+ self.qkv = nn.Linear(d_model, d_model * 3)
113
+ self.out_proj = nn.Linear(d_model, d_model)
114
+
115
+ def forward(self, x, mask=None):
116
+ B, T, C = x.shape
117
+ qkv = self.qkv(x)
118
+ q, k, v = qkv.chunk(3, dim=-1)
119
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
120
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
121
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
122
+ attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
123
+ if mask is not None:
124
+ mask = mask.to(x.device)
125
+ attn = attn.masked_fill(mask == 0, float('-inf'))
126
+ attn = F.softmax(attn, dim=-1)
127
+ out = attn @ v
128
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
129
+ return self.out_proj(out)
130
+
131
+ class StandardAttentionBlock(nn.Module):
132
+ def __init__(self, d_model, n_heads=16, ffn_mult=4):
133
+ super().__init__()
134
+ self.ln1 = nn.LayerNorm(d_model)
135
+ self.attn = FullAttention(d_model, n_heads)
136
+ self.ln2 = nn.LayerNorm(d_model)
137
+ self.ffn = nn.Sequential(
138
+ nn.Linear(d_model, d_model * ffn_mult),
139
+ nn.GELU(),
140
+ nn.Linear(d_model * ffn_mult, d_model)
141
+ )
142
+
143
+ def forward(self, x, mask=None):
144
+ x = x + self.attn(self.ln1(x), mask)
145
+ x = x + self.ffn(self.ln2(x))
146
+ return x
147
+
148
+ class i3HybridModel(nn.Module):
149
+ def __init__(self, vocab_size, d_model=1024, n_heads=16,
150
+ n_rwkv_layers=10, n_attn_layers=6, max_seq_len=512):
151
+ super().__init__()
152
+ self.vocab_size = vocab_size
153
+ self.d_model = d_model
154
+ self.max_seq_len = max_seq_len
155
+ self.embed = nn.Embedding(vocab_size, d_model)
156
+ self.pos_embed = nn.Embedding(max_seq_len, d_model)
157
+ self.layers = nn.ModuleList()
158
+ for _ in range(n_rwkv_layers):
159
+ self.layers.append(RWKVBlock(d_model, ffn_mult=4))
160
+ for _ in range(n_attn_layers):
161
+ self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads))
162
+ self.ln_f = nn.LayerNorm(d_model)
163
+ self.head = nn.Linear(d_model, vocab_size)
164
+
165
+ def forward(self, idx):
166
+ B, T = idx.shape
167
+ if T > self.max_seq_len:
168
+ idx = idx[:, -self.max_seq_len:]
169
+ T = self.max_seq_len
170
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)
171
+ x = self.embed(idx) + self.pos_embed(pos)
172
+ mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T)
173
+ for layer in self.layers:
174
+ x = layer(x, mask)
175
+ x = self.ln_f(x)
176
+ logits = self.head(x)
177
+ return logits
178
+
179
+ # ============================================================================
180
+ # 2. SPACE INFERENCE ENGINE
181
+ # ============================================================================
182
+
183
+ class SpaceInferenceEngine:
184
+ def __init__(self, repo_id="FlameF0X/i3-200m-v2"):
185
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
186
+ print(f"Loading model on {self.device}...")
187
+
188
+ # Download files from Hugging Face Hub
189
+ try:
190
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
191
+ tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json")
192
+ weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
193
+ except Exception as e:
194
+ raise ValueError(f"Failed to download model files from {repo_id}: {e}")
195
+
196
+ # Load Config
197
+ with open(config_path, 'r') as f:
198
+ self.config = json.load(f)
199
+
200
+ # Load Tokenizer
201
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
202
+
203
+ # Initialize Model
204
+ print("Initializing model architecture...")
205
+
206
+ # Use config for seq_len, fallback to 256
207
+ max_seq_len = self.config.get('seq_len', self.config.get('max_seq_len', 256))
208
+
209
+ self.model = i3HybridModel(
210
+ vocab_size=self.config['vocab_size'],
211
+ d_model=self.config['d_model'],
212
+ n_heads=self.config.get('n_heads', 12),
213
+ n_rwkv_layers=self.config['rwkv_layers'],
214
+ n_attn_layers=self.config['attn_layers'],
215
+ max_seq_len=max_seq_len
216
+ ).to(self.device)
217
+
218
+ # Load Weights
219
+ print(f"Loading weights...")
220
+ state_dict = torch.load(weights_path, map_location=self.device)
221
+ self.model.load_state_dict(state_dict)
222
+ self.model.eval()
223
+ print("Model loaded successfully.")
224
+
225
+ def generate_stream(self, prompt, max_new_tokens=100, temperature=1.0, top_k=50):
226
+ # Encode
227
+ input_ids = self.tokenizer.encode(prompt).ids
228
+ x = torch.tensor([input_ids], dtype=torch.long, device=self.device)
229
+
230
+ # For display purposes, we keep the original prompt + new tokens
231
+ generated_text = prompt
232
+
233
+ with torch.no_grad():
234
+ for _ in range(max_new_tokens):
235
+ # Context window handling
236
+ if x.size(1) > self.model.max_seq_len:
237
+ x_cond = x[:, -self.model.max_seq_len:]
238
+ else:
239
+ x_cond = x
240
+
241
+ # Forward pass
242
+ logits = self.model(x_cond)
243
+ logits = logits[:, -1, :] / temperature
244
+
245
+ # Top-K Sampling
246
+ if top_k is not None:
247
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
248
+ logits[logits < v[:, [-1]]] = -float('Inf')
249
+
250
+ # Probability distribution
251
+ probs = F.softmax(logits, dim=-1)
252
+
253
+ # Sample next token
254
+ idx_next = torch.multinomial(probs, num_samples=1)
255
+
256
+ # Append to sequence
257
+ x = torch.cat((x, idx_next), dim=1)
258
+
259
+ # Decode the new token
260
+ new_token_id = idx_next.item()
261
+ token_str = self.tokenizer.decode([new_token_id])
262
+
263
+ # Update text and yield for streaming
264
+ generated_text += token_str
265
+ yield generated_text
266
+
267
+ # Optional: Stop generation if needed
268
+ # if new_token_id == self.tokenizer.token_to_id("<EOS>"): break
269
+
270
+ # ============================================================================
271
+ # 3. GRADIO INTERFACE
272
+ # ============================================================================
273
+
274
+ # Initialize engine globally
275
+ print("Starting Engine...")
276
+ engine = SpaceInferenceEngine()
277
+
278
+ def predict(prompt, max_tokens, temperature, top_k):
279
+ if not prompt:
280
+ return "Please enter a prompt."
281
+
282
+ # Use the generator for streaming
283
+ for current_text in engine.generate_stream(
284
+ prompt,
285
+ max_new_tokens=int(max_tokens),
286
+ temperature=temperature,
287
+ top_k=int(top_k)
288
+ ):
289
+ yield current_text
290
+
291
+ # Custom CSS for a cleaner look
292
+ custom_css = """
293
+ #component-0 {max_width: 800px; margin: auto;}
294
+ """
295
+
296
+ with gr.Interface(
297
+ fn=predict,
298
+ inputs=[
299
+ gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Input Prompt"),
300
+ gr.Slider(minimum=10, maximum=512, value=150, step=10, label="Max New Tokens"),
301
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
302
+ gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-K"),
303
+ ],
304
+ outputs=gr.Textbox(lines=10, label="Generated Output"),
305
+ title="i3-200m-v2 (RWKV-Hybrid)",
306
+ description="A 200M parameter hybrid model combining RWKV (RNN) and Standard Attention layers.",
307
+ css=custom_css,
308
+ examples=[
309
+ ["The history of science is"],
310
+ ["Once upon a time in a digital world,"],
311
+ ["The quick brown fox jumps over"]
312
+ ],
313
+ cache_examples=False
314
+ ) as demo:
315
+ demo.queue() # Enable queuing for streaming
316
+ demo.launch()