|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import json |
|
|
import os |
|
|
import gradio as gr |
|
|
from tokenizers import Tokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def rwkv_linear_attention(B: int, T: int, C: int, |
|
|
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
|
|
w: torch.Tensor, u: torch.Tensor, |
|
|
state_init: torch.Tensor): |
|
|
y = torch.zeros_like(v) |
|
|
state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device) |
|
|
state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device) |
|
|
state_pp = state_init.clone() |
|
|
|
|
|
for t in range(T): |
|
|
rt, kt, vt = r[:, t], k[:, t], v[:, t] |
|
|
ww = u + state_pp |
|
|
p = torch.maximum(ww, kt) |
|
|
e1 = torch.exp(ww - p) |
|
|
e2 = torch.exp(kt - p) |
|
|
wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6) |
|
|
y[:, t] = wkv |
|
|
|
|
|
ww = w + state_pp |
|
|
p = torch.maximum(ww, kt) |
|
|
e1 = torch.exp(ww - p) |
|
|
e2 = torch.exp(kt - p) |
|
|
state_aa = state_aa * e1 + vt * e2 |
|
|
state_bb = state_bb * e1 + e2 |
|
|
state_pp = p |
|
|
|
|
|
return y |
|
|
|
|
|
class RWKVTimeMix(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.time_decay = nn.Parameter(torch.ones(d_model)) |
|
|
self.time_first = nn.Parameter(torch.ones(d_model)) |
|
|
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.key = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(d_model, d_model, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.output = nn.Linear(d_model, d_model, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) |
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) |
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) |
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) |
|
|
k = self.key(xk) |
|
|
v = self.value(xv) |
|
|
r = torch.sigmoid(self.receptance(xr)) |
|
|
w = -torch.exp(self.time_decay) |
|
|
u = self.time_first |
|
|
state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device) |
|
|
rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init) |
|
|
return self.output(r * rwkv) |
|
|
|
|
|
class RWKVChannelMix(nn.Module): |
|
|
def __init__(self, d_model, ffn_mult=4): |
|
|
super().__init__() |
|
|
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
hidden_sz = d_model * ffn_mult |
|
|
self.key = nn.Linear(d_model, hidden_sz, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(hidden_sz, d_model, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) |
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) |
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) |
|
|
k = torch.square(torch.relu(self.key(xk))) |
|
|
kv = self.value(k) |
|
|
r = torch.sigmoid(self.receptance(xr)) |
|
|
return r * kv |
|
|
|
|
|
class RWKVBlock(nn.Module): |
|
|
def __init__(self, d_model, ffn_mult=4): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.att = RWKVTimeMix(d_model) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.ffn = RWKVChannelMix(d_model, ffn_mult) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
x = x + self.att(self.ln1(x)) |
|
|
x = x + self.ffn(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class FullAttention(nn.Module): |
|
|
def __init__(self, d_model, n_heads=16): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.head_dim = d_model // n_heads |
|
|
self.qkv = nn.Linear(d_model, d_model * 3) |
|
|
self.out_proj = nn.Linear(d_model, d_model) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, T, C = x.shape |
|
|
qkv = self.qkv(x) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
if mask is not None: |
|
|
mask = mask.to(x.device) |
|
|
attn = attn.masked_fill(mask == 0, float('-inf')) |
|
|
attn = F.softmax(attn, dim=-1) |
|
|
out = attn @ v |
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
return self.out_proj(out) |
|
|
|
|
|
class StandardAttentionBlock(nn.Module): |
|
|
def __init__(self, d_model, n_heads=16, ffn_mult=4): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.attn = FullAttention(d_model, n_heads) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(d_model, d_model * ffn_mult), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model * ffn_mult, d_model) |
|
|
) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
x = x + self.attn(self.ln1(x), mask) |
|
|
x = x + self.ffn(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class i3HybridModel(nn.Module): |
|
|
def __init__(self, vocab_size, d_model=1024, n_heads=16, |
|
|
n_rwkv_layers=10, n_attn_layers=6, max_seq_len=512): |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.max_seq_len = max_seq_len |
|
|
self.embed = nn.Embedding(vocab_size, d_model) |
|
|
self.pos_embed = nn.Embedding(max_seq_len, d_model) |
|
|
self.layers = nn.ModuleList() |
|
|
for _ in range(n_rwkv_layers): |
|
|
self.layers.append(RWKVBlock(d_model, ffn_mult=4)) |
|
|
for _ in range(n_attn_layers): |
|
|
self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads)) |
|
|
self.ln_f = nn.LayerNorm(d_model) |
|
|
self.head = nn.Linear(d_model, vocab_size) |
|
|
|
|
|
def forward(self, idx): |
|
|
B, T = idx.shape |
|
|
if T > self.max_seq_len: |
|
|
idx = idx[:, -self.max_seq_len:] |
|
|
T = self.max_seq_len |
|
|
pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0) |
|
|
x = self.embed(idx) + self.pos_embed(pos) |
|
|
mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T) |
|
|
for layer in self.layers: |
|
|
x = layer(x, mask) |
|
|
x = self.ln_f(x) |
|
|
logits = self.head(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpaceInferenceEngine: |
|
|
def __init__(self, repo_id="FlameF0X/i3-200m-v2"): |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Loading model on {self.device}...") |
|
|
|
|
|
|
|
|
try: |
|
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
|
|
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json") |
|
|
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to download model files from {repo_id}: {e}") |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
|
|
|
self.tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
|
|
|
|
|
|
print("Initializing model architecture...") |
|
|
|
|
|
|
|
|
max_seq_len = self.config.get('seq_len', self.config.get('max_seq_len', 256)) |
|
|
|
|
|
self.model = i3HybridModel( |
|
|
vocab_size=self.config['vocab_size'], |
|
|
d_model=self.config['d_model'], |
|
|
n_heads=self.config.get('n_heads', 12), |
|
|
n_rwkv_layers=self.config['rwkv_layers'], |
|
|
n_attn_layers=self.config['attn_layers'], |
|
|
max_seq_len=max_seq_len |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
print(f"Loading weights...") |
|
|
state_dict = torch.load(weights_path, map_location=self.device) |
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.eval() |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
def generate_stream(self, prompt, max_new_tokens=100, temperature=1.0, top_k=50): |
|
|
|
|
|
input_ids = self.tokenizer.encode(prompt).ids |
|
|
x = torch.tensor([input_ids], dtype=torch.long, device=self.device) |
|
|
|
|
|
|
|
|
generated_text = prompt |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
if x.size(1) > self.model.max_seq_len: |
|
|
x_cond = x[:, -self.model.max_seq_len:] |
|
|
else: |
|
|
x_cond = x |
|
|
|
|
|
|
|
|
logits = self.model(x_cond) |
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
x = torch.cat((x, idx_next), dim=1) |
|
|
|
|
|
|
|
|
new_token_id = idx_next.item() |
|
|
token_str = self.tokenizer.decode([new_token_id]) |
|
|
|
|
|
|
|
|
generated_text += token_str |
|
|
yield generated_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Starting Engine...") |
|
|
engine = SpaceInferenceEngine() |
|
|
|
|
|
def predict(prompt, max_tokens, temperature, top_k): |
|
|
if not prompt.strip(): |
|
|
yield "β οΈ Please enter a prompt to generate text." |
|
|
return |
|
|
|
|
|
|
|
|
for current_text in engine.generate_stream( |
|
|
prompt, |
|
|
max_new_tokens=int(max_tokens), |
|
|
temperature=temperature, |
|
|
top_k=int(top_k) |
|
|
): |
|
|
yield current_text |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
} |
|
|
.main-header { |
|
|
text-align: center; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
gr.HTML(f"<style>{custom_css}</style>") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π i3-200M Text Generation |
|
|
### Powered by RWKV-Hybrid Architecture |
|
|
Generate creative text using the i3-200M language model combining RNN efficiency with Attention precision. |
|
|
""", |
|
|
elem_classes="main-header" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
prompt_input = gr.Textbox( |
|
|
label="βοΈ Enter Your Prompt", |
|
|
placeholder="Once upon a time in a distant galaxy...", |
|
|
lines=4, |
|
|
max_lines=8 |
|
|
) |
|
|
|
|
|
with gr.Accordion("βοΈ Generation Parameters", open=True): |
|
|
with gr.Row(): |
|
|
max_tokens_input = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=512, |
|
|
value=150, |
|
|
step=10, |
|
|
label="Max Tokens", |
|
|
info="Maximum number of tokens to generate" |
|
|
) |
|
|
temp_input = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=0.8, |
|
|
step=0.1, |
|
|
label="Temperature", |
|
|
info="Higher = more creative, Lower = more focused" |
|
|
) |
|
|
|
|
|
topk_input = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=40, |
|
|
step=1, |
|
|
label="Top-k Sampling", |
|
|
info="Number of top tokens to consider" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
generate_btn = gr.Button("π¨ Generate Text", variant="primary", size="lg") |
|
|
clear_btn = gr.ClearButton(components=[prompt_input], value="ποΈ Clear", size="lg") |
|
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output_text = gr.Textbox( |
|
|
label="π Generated Output", |
|
|
lines=12, |
|
|
max_lines=20 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["The history of science is", 150, 0.7, 50], |
|
|
["In a world where technology and nature coexist", 200, 0.9, 40], |
|
|
["The scientist discovered something remarkable", 120, 0.8, 45], |
|
|
], |
|
|
inputs=[prompt_input, max_tokens_input, temp_input, topk_input], |
|
|
label="π‘ Try These Examples" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π§ Developer Info", open=False): |
|
|
total_params = sum(p.numel() for p in engine.model.parameters()) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown(f""" |
|
|
**Model Architecture:** |
|
|
- **Model:** i3-200M Hybrid |
|
|
- **Device:** {engine.device} |
|
|
- **Vocab Size:** {engine.config['vocab_size']:,} |
|
|
- **Parameters:** {total_params:,} ({total_params/1e6:.2f}M) |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown(f""" |
|
|
**Configuration:** |
|
|
- **d_model:** {engine.config['d_model']} |
|
|
- **RWKV Layers:** {engine.config['rwkv_layers']} |
|
|
- **Attention Layers:** {engine.config['attn_layers']} |
|
|
- **Max Seq Len:** {engine.model.max_seq_len} |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
<div style="text-align: center; color: #666;"> |
|
|
<p>Built with β€οΈ using Gradio | Model: FlameF0X/i3-200m-v2</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
predict, |
|
|
inputs=[prompt_input, max_tokens_input, temp_input, topk_input], |
|
|
outputs=[output_text] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue() |
|
|
demo.launch() |