FlameF0X commited on
Commit
2401a0f
·
verified ·
1 Parent(s): 03760bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -77
app.py CHANGED
@@ -1,83 +1,230 @@
1
- import gradio as gr
2
  import torch
 
 
 
3
  import json
4
- from safetensors.torch import load_file as safe_load
5
- from huggingface_hub import hf_hub_download
6
- from app_classes import i3Model, ChunkTokenizer # Make sure your classes file is importable
7
-
8
- # ------------------------------
9
- # Hugging Face Repo & Files
10
- # ------------------------------
11
- REPO_ID = "FlameF0X/i3-80m" # Replace with your HF repo
12
-
13
- print("Downloading model files from Hugging Face...")
14
- model_file = hf_hub_download(REPO_ID, "model.safetensors")
15
- vocab_file = hf_hub_download(REPO_ID, "chunk_vocab_combined.json")
16
- config_file = hf_hub_download(REPO_ID, "config.json")
17
-
18
- # ------------------------------
19
- # Load Config
20
- # ------------------------------
21
- with open(config_file, "r") as f:
22
- config = json.load(f)
23
-
24
- # ------------------------------
25
- # Load Tokenizer
26
- # ------------------------------
27
- tokenizer = ChunkTokenizer()
28
- tokenizer.load(vocab_file)
29
-
30
- # ------------------------------
31
- # Initialize Model
32
- # ------------------------------
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
- model = i3Model(vocab_size=tokenizer.vocab_size,
35
- d_model=config.get("d_model", 512),
36
- n_heads=config.get("n_heads", 16),
37
- max_seq_len=config.get("max_seq_len", 512),
38
- d_state=config.get("d_state", 32)).to(device)
39
-
40
- # Load weights
41
- state_dict = safe_load(model_file, device=device)
42
- model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  model.eval()
44
 
45
- # ------------------------------
46
- # Generation Function
47
- # ------------------------------
48
- def generate_text(prompt, max_tokens=100, temperature=1.0, top_k=40):
49
- idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(device)
50
- with torch.no_grad():
51
- out_idx = model.generate(idx, max_new_tokens=int(max_tokens),
52
- temperature=float(temperature),
53
- top_k=int(top_k))
54
- return tokenizer.decode(out_idx[0].cpu())
55
-
56
- # ------------------------------
57
- # Gradio UI
58
- # ------------------------------
59
  with gr.Blocks() as demo:
60
- gr.Markdown("## i3 Model Text Generator")
61
-
62
  with gr.Row():
63
- prompt_input = gr.Textbox(label="Prompt", placeholder="Type your text here...", lines=3)
64
- generate_btn = gr.Button("Generate")
65
-
66
- output_box = gr.Textbox(label="Generated Text", lines=10)
67
-
68
- with gr.Accordion("Dev Panel", open=False):
69
- max_tokens_input = gr.Slider(10, 500, value=100, label="Max Tokens")
70
- temperature_input = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature")
71
- top_k_input = gr.Slider(1, tokenizer.vocab_size, value=40, step=1, label="Top-k Sampling")
72
-
73
- # Connect button
74
- generate_btn.click(
75
- generate_text,
76
- inputs=[prompt_input, max_tokens_input, temperature_input, top_k_input],
77
- outputs=[output_box]
78
- )
79
-
80
- # ------------------------------
81
- # Launch App
82
- # ------------------------------
83
- demo.launch(share=True)
 
1
+ import os
2
  import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
  import json
7
+ import requests
8
+
9
+ # ============================================================
10
+ # ==================== MODEL + TOKENIZER =====================
11
+ # ============================================================
12
+
13
+ class RWKVMambaHybrid(nn.Module):
14
+ def __init__(self, d_model, d_state=64):
15
+ super().__init__()
16
+ self.d_model = d_model
17
+ self.d_state = d_state
18
+ self.w_mix = nn.Parameter(torch.ones(d_model) * 0.5)
19
+ self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
20
+ self.B = nn.Parameter(torch.randn(d_state, d_model) * 0.01)
21
+ self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.01)
22
+ self.D = nn.Parameter(torch.ones(d_model) * 0.1)
23
+
24
+ def forward(self, x):
25
+ B, T, C = x.shape
26
+ h = torch.zeros(B, C, device=x.device)
27
+ s = torch.zeros(B, self.d_state, device=x.device)
28
+ outputs = []
29
+ for t in range(T):
30
+ x_t = x[:, t, :]
31
+ h = self.w_mix * h + (1 - self.w_mix) * x_t
32
+ s = s @ self.A.T + x_t @ self.B.T
33
+ y_t = s @ self.C.T + h * self.D
34
+ outputs.append(y_t)
35
+ return torch.stack(outputs, dim=1)
36
+
37
+ class FullAttention(nn.Module):
38
+ def __init__(self, d_model, n_heads=16):
39
+ super().__init__()
40
+ self.d_model = d_model
41
+ self.n_heads = n_heads
42
+ self.head_dim = d_model // n_heads
43
+ self.qkv = nn.Linear(d_model, d_model*3)
44
+ self.out_proj = nn.Linear(d_model, d_model)
45
+
46
+ def forward(self, x, mask=None):
47
+ B, T, C = x.shape
48
+ qkv = self.qkv(x)
49
+ q, k, v = qkv.chunk(3, dim=-1)
50
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1,2)
51
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1,2)
52
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1,2)
53
+ attn = (q @ k.transpose(-2,-1)) / (self.head_dim**0.5)
54
+ if mask is not None:
55
+ mask = mask.expand(B, self.n_heads, T, T).bool()
56
+ attn = attn.masked_fill(mask==0, float('-inf'))
57
+ attn = F.softmax(attn, dim=-1)
58
+ out = attn @ v
59
+ out = out.transpose(1,2).contiguous().view(B,T,C)
60
+ return self.out_proj(out)
61
+
62
+ class i3HybridBlock(nn.Module):
63
+ def __init__(self, d_model, d_state=64, ffn_mult=4):
64
+ super().__init__()
65
+ self.ln1 = nn.LayerNorm(d_model)
66
+ self.hybrid = RWKVMambaHybrid(d_model, d_state)
67
+ self.ln2 = nn.LayerNorm(d_model)
68
+ d_ff = d_model * ffn_mult
69
+ self.ffn = nn.Sequential(nn.Linear(d_model,d_ff), nn.GELU(), nn.Linear(d_ff,d_model))
70
+
71
+ def forward(self, x, mask=None):
72
+ x = x + self.hybrid(self.ln1(x))
73
+ x = x + self.ffn(self.ln2(x))
74
+ return x
75
+
76
+ class i3AttentionBlock(nn.Module):
77
+ def __init__(self, d_model, n_heads=16, ffn_mult=4):
78
+ super().__init__()
79
+ self.ln1 = nn.LayerNorm(d_model)
80
+ self.attn = FullAttention(d_model,n_heads)
81
+ self.ln2 = nn.LayerNorm(d_model)
82
+ d_ff = d_model * ffn_mult
83
+ self.ffn = nn.Sequential(nn.Linear(d_model,d_ff), nn.GELU(), nn.Linear(d_ff,d_model))
84
+
85
+ def forward(self, x, mask=None):
86
+ x = x + self.attn(self.ln1(x), mask)
87
+ x = x + self.ffn(self.ln2(x))
88
+ return x
89
+
90
+ class i3Model(nn.Module):
91
+ def __init__(self, vocab_size, d_model=512, n_heads=16, max_seq_len=256, d_state=32):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.d_model = d_model
95
+ self.max_seq_len = max_seq_len
96
+ self.embed = nn.Embedding(vocab_size,d_model)
97
+ self.pos_embed = nn.Embedding(max_seq_len,d_model)
98
+ hybrid_layers = [i3HybridBlock(d_model,d_state) for _ in range(10)]
99
+ attention_layers = [i3AttentionBlock(d_model,n_heads) for _ in range(6)]
100
+ self.layers = nn.ModuleList(hybrid_layers + attention_layers)
101
+ self.ln_f = nn.LayerNorm(d_model)
102
+ self.head = nn.Linear(d_model,vocab_size)
103
+ self.apply(self._init_weights)
104
+
105
+ def _init_weights(self,module):
106
+ if isinstance(module,(nn.Linear,nn.Embedding)):
107
+ module.weight.data.normal_(0,0.02)
108
+ if isinstance(module,nn.Linear) and module.bias is not None:
109
+ module.bias.data.zero_()
110
+
111
+ def forward(self, idx, targets=None):
112
+ B,T = idx.shape
113
+ pos = torch.arange(0,T,device=idx.device).unsqueeze(0)
114
+ x = self.embed(idx)+self.pos_embed(pos)
115
+ mask = torch.tril(torch.ones(T,T,device=idx.device)).view(1,1,T,T)
116
+ for layer in self.layers:
117
+ x = layer(x,mask)
118
+ x = self.ln_f(x)
119
+ logits = self.head(x)
120
+ loss=None
121
+ if targets is not None:
122
+ loss = F.cross_entropy(logits.view(-1,logits.size(-1)), targets.view(-1))
123
+ return logits, loss
124
+
125
+ @torch.no_grad()
126
+ def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
127
+ for _ in range(max_new_tokens):
128
+ idx_cond = idx if idx.size(1)<=self.max_seq_len else idx[:,-self.max_seq_len:]
129
+ logits,_ = self(idx_cond)
130
+ logits = logits[:,-1,:]/temperature
131
+ if top_k is not None:
132
+ v,_ = torch.topk(logits,min(top_k,logits.size(-1)))
133
+ logits[logits<v[:,[-1]]]=-float('Inf')
134
+ probs = F.softmax(logits,dim=-1)
135
+ idx_next = torch.multinomial(probs,1)
136
+ idx = torch.cat((idx,idx_next),dim=1)
137
+ return idx
138
+
139
+ class ChunkTokenizer:
140
+ def __init__(self, vocab_path=None):
141
+ self.chunk_to_idx={}
142
+ self.idx_to_chunk={}
143
+ self.unk_token='<UNK>'
144
+ self.unk_idx=0
145
+ if vocab_path and os.path.exists(vocab_path):
146
+ with open(vocab_path,'r') as f:
147
+ data=json.load(f)
148
+ self.chunk_to_idx=data['chunk_to_idx']
149
+ self.idx_to_chunk={int(k):v for k,v in data['idx_to_chunk'].items()}
150
+ self.vocab_size=data['vocab_size']
151
+ else:
152
+ # minimal fallback vocab
153
+ self.chunk_to_idx={'<UNK>':0,'a':1,'b':2,'c':3,'d':4,'e':5,'f':6,'g':7,'h':8,'i':9,'j':10,'k':11,'l':12,'m':13,'n':14,'o':15,'p':16,'q':17,'r':18,'s':19,'t':20,'u':21,'v':22,'w':23,'x':24,'y':25,'z':26,' ':27}
154
+ self.idx_to_chunk={v:k for k,v in self.chunk_to_idx.items()}
155
+ self.vocab_size=len(self.chunk_to_idx)
156
+
157
+ def encode(self,text):
158
+ text=text.lower()
159
+ idxs=[]
160
+ pos=0
161
+ while pos<len(text):
162
+ chunk=text[pos:pos+3] if pos+3<=len(text) else text[pos:]
163
+ if chunk in self.chunk_to_idx:
164
+ idxs.append(self.chunk_to_idx[chunk])
165
+ pos+=len(chunk)
166
+ else:
167
+ idxs.append(self.unk_idx)
168
+ pos+=1
169
+ return idxs
170
+
171
+ def decode(self,indices):
172
+ return ''.join([self.idx_to_chunk.get(int(i),self.unk_token) for i in indices])
173
+
174
+ # ============================================================
175
+ # ===================== LOAD MODEL ===========================
176
+ # ============================================================
177
+
178
+ MODEL_NAME = "your-hf-username/i3-80m" # Replace with HF repo ID
179
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
180
+
181
+ vocab_file = "chunk_vocab_combined.json"
182
+ tokenizer = ChunkTokenizer(vocab_file)
183
+ vocab_size = tokenizer.vocab_size
184
+
185
+ model = i3Model(vocab_size=vocab_size)
186
+ # load local safetensors or pytorch_model.bin if exists
187
+ if os.path.exists("model.safetensors"):
188
+ from safetensors.torch import load_file
189
+ state_dict = load_file("model.safetensors")
190
+ model.load_state_dict(state_dict)
191
+ elif os.path.exists("pytorch_model.bin"):
192
+ state_dict = torch.load("pytorch_model.bin", map_location=DEVICE)
193
+ model.load_state_dict(state_dict)
194
+ else:
195
+ # download from HF
196
+ url_bin = f"https://huggingface.co/{MODEL_NAME}/resolve/main/pytorch_model.bin"
197
+ r = requests.get(url_bin)
198
+ with open("pytorch_model.bin",'wb') as f:
199
+ f.write(r.content)
200
+ state_dict = torch.load("pytorch_model.bin", map_location=DEVICE)
201
+ model.load_state_dict(state_dict)
202
+
203
+ model.to(DEVICE)
204
  model.eval()
205
 
206
+ # ============================================================
207
+ # ===================== GRADIO UI ============================
208
+ # ============================================================
209
+
210
+ def generate_text(prompt, max_tokens, temperature, top_k):
211
+ idx = torch.tensor([tokenizer.encode(prompt)],dtype=torch.long).to(DEVICE)
212
+ out_idx = model.generate(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k)
213
+ text = tokenizer.decode(out_idx[0].cpu())
214
+ return text
215
+
 
 
 
 
216
  with gr.Blocks() as demo:
217
+ gr.Markdown("### i3-80M Model Demo")
 
218
  with gr.Row():
219
+ with gr.Column(scale=3):
220
+ prompt = gr.Textbox(label="Prompt", lines=3)
221
+ generate_btn = gr.Button("Generate")
222
+ output = gr.Textbox(label="Generated Text", lines=10)
223
+ with gr.Column(scale=1):
224
+ gr.Markdown("#### Dev Panel")
225
+ max_tokens = gr.Slider(10,512,value=100,step=1,label="Max Tokens")
226
+ temperature = gr.Slider(0.1,2.0,value=0.8,step=0.05,label="Temperature")
227
+ top_k = gr.Slider(1,100,value=40,step=1,label="Top-k")
228
+ generate_btn.click(generate_text, inputs=[prompt,max_tokens,temperature,top_k], outputs=output)
229
+
230
+ demo.launch()