i3-200m / app.py
FlameF0X's picture
Update app.py
cd310f6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os
import gradio as gr
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
# ============================================================================
# 1. MODEL ARCHITECTURE
# (Copied from inference.py to support custom weight loading)
# ============================================================================
@torch.jit.script
def rwkv_linear_attention(B: int, T: int, C: int,
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor, u: torch.Tensor,
state_init: torch.Tensor):
y = torch.zeros_like(v)
state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_pp = state_init.clone()
for t in range(T):
rt, kt, vt = r[:, t], k[:, t], v[:, t]
ww = u + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6)
y[:, t] = wkv
ww = w + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
state_aa = state_aa * e1 + vt * e2
state_bb = state_bb * e1 + e2
state_pp = p
return y
class RWKVTimeMix(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.time_decay = nn.Parameter(torch.ones(d_model))
self.time_first = nn.Parameter(torch.ones(d_model))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = torch.sigmoid(self.receptance(xr))
w = -torch.exp(self.time_decay)
u = self.time_first
state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device)
rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init)
return self.output(r * rwkv)
class RWKVChannelMix(nn.Module):
def __init__(self, d_model, ffn_mult=4):
super().__init__()
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
hidden_sz = d_model * ffn_mult
self.key = nn.Linear(d_model, hidden_sz, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(hidden_sz, d_model, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = torch.square(torch.relu(self.key(xk)))
kv = self.value(k)
r = torch.sigmoid(self.receptance(xr))
return r * kv
class RWKVBlock(nn.Module):
def __init__(self, d_model, ffn_mult=4):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.att = RWKVTimeMix(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = RWKVChannelMix(d_model, ffn_mult)
def forward(self, x, mask=None):
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class FullAttention(nn.Module):
def __init__(self, d_model, n_heads=16):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
mask = mask.to(x.device)
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
class StandardAttentionBlock(nn.Module):
def __init__(self, d_model, n_heads=16, ffn_mult=4):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = FullAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * ffn_mult),
nn.GELU(),
nn.Linear(d_model * ffn_mult, d_model)
)
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
x = x + self.ffn(self.ln2(x))
return x
class i3HybridModel(nn.Module):
def __init__(self, vocab_size, d_model=1024, n_heads=16,
n_rwkv_layers=10, n_attn_layers=6, max_seq_len=512):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_seq_len = max_seq_len
self.embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(max_seq_len, d_model)
self.layers = nn.ModuleList()
for _ in range(n_rwkv_layers):
self.layers.append(RWKVBlock(d_model, ffn_mult=4))
for _ in range(n_attn_layers):
self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads))
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, idx):
B, T = idx.shape
if T > self.max_seq_len:
idx = idx[:, -self.max_seq_len:]
T = self.max_seq_len
pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)
x = self.embed(idx) + self.pos_embed(pos)
mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T)
for layer in self.layers:
x = layer(x, mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
# ============================================================================
# 2. SPACE INFERENCE ENGINE
# ============================================================================
class SpaceInferenceEngine:
def __init__(self, repo_id="FlameF0X/i3-200m-v2"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Loading model on {self.device}...")
# Download files from Hugging Face Hub
try:
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json")
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
except Exception as e:
raise ValueError(f"Failed to download model files from {repo_id}: {e}")
# Load Config
with open(config_path, 'r') as f:
self.config = json.load(f)
# Load Tokenizer
self.tokenizer = Tokenizer.from_file(tokenizer_path)
# Initialize Model
print("Initializing model architecture...")
# Use config for seq_len, fallback to 256
max_seq_len = self.config.get('seq_len', self.config.get('max_seq_len', 256))
self.model = i3HybridModel(
vocab_size=self.config['vocab_size'],
d_model=self.config['d_model'],
n_heads=self.config.get('n_heads', 12),
n_rwkv_layers=self.config['rwkv_layers'],
n_attn_layers=self.config['attn_layers'],
max_seq_len=max_seq_len
).to(self.device)
# Load Weights
print(f"Loading weights...")
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
print("Model loaded successfully.")
def generate_stream(self, prompt, max_new_tokens=100, temperature=1.0, top_k=50):
# Encode
input_ids = self.tokenizer.encode(prompt).ids
x = torch.tensor([input_ids], dtype=torch.long, device=self.device)
# For display purposes, we keep the original prompt + new tokens
generated_text = prompt
with torch.no_grad():
for _ in range(max_new_tokens):
# Context window handling
if x.size(1) > self.model.max_seq_len:
x_cond = x[:, -self.model.max_seq_len:]
else:
x_cond = x
# Forward pass
logits = self.model(x_cond)
logits = logits[:, -1, :] / temperature
# Top-K Sampling
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Probability distribution
probs = F.softmax(logits, dim=-1)
# Sample next token
idx_next = torch.multinomial(probs, num_samples=1)
# Append to sequence
x = torch.cat((x, idx_next), dim=1)
# Decode the new token
new_token_id = idx_next.item()
token_str = self.tokenizer.decode([new_token_id])
# Update text and yield for streaming
generated_text += token_str
yield generated_text
# ============================================================================
# 3. GRADIO INTERFACE (UI Upgrade)
# ============================================================================
# Initialize engine globally
print("Starting Engine...")
engine = SpaceInferenceEngine()
def predict(prompt, max_tokens, temperature, top_k):
if not prompt.strip():
yield "⚠️ Please enter a prompt to generate text."
return
# Use the generator for streaming
for current_text in engine.generate_stream(
prompt,
max_new_tokens=int(max_tokens),
temperature=temperature,
top_k=int(top_k)
):
yield current_text
# Custom CSS
custom_css = """
.gradio-container {
max-width: 1200px !important;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
with gr.Blocks() as demo:
# Inject CSS via HTML component to avoid Blocks() keyword argument error
gr.HTML(f"<style>{custom_css}</style>")
# Header
with gr.Row():
gr.Markdown(
"""
# πŸš€ i3-200M Text Generation
### Powered by RWKV-Hybrid Architecture
Generate creative text using the i3-200M language model combining RNN efficiency with Attention precision.
""",
elem_classes="main-header"
)
# Main Generation Area
with gr.Row():
# Left Column: Inputs
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="✍️ Enter Your Prompt",
placeholder="Once upon a time in a distant galaxy...",
lines=4,
max_lines=8
)
with gr.Accordion("βš™οΈ Generation Parameters", open=True):
with gr.Row():
max_tokens_input = gr.Slider(
minimum=10,
maximum=512,
value=150,
step=10,
label="Max Tokens",
info="Maximum number of tokens to generate"
)
temp_input = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher = more creative, Lower = more focused"
)
topk_input = gr.Slider(
minimum=1,
maximum=100,
value=40,
step=1,
label="Top-k Sampling",
info="Number of top tokens to consider"
)
with gr.Row():
generate_btn = gr.Button("🎨 Generate Text", variant="primary", size="lg")
clear_btn = gr.ClearButton(components=[prompt_input], value="πŸ—‘οΈ Clear", size="lg")
# Right Column: Output
with gr.Column(scale=2):
output_text = gr.Textbox(
label="πŸ“ Generated Output",
lines=12,
max_lines=20
)
# Examples Section
with gr.Row():
gr.Examples(
examples=[
["The history of science is", 150, 0.7, 50],
["In a world where technology and nature coexist", 200, 0.9, 40],
["The scientist discovered something remarkable", 120, 0.8, 45],
],
inputs=[prompt_input, max_tokens_input, temp_input, topk_input],
label="πŸ’‘ Try These Examples"
)
# Developer Panel
with gr.Accordion("πŸ”§ Developer Info", open=False):
total_params = sum(p.numel() for p in engine.model.parameters())
with gr.Row():
with gr.Column():
gr.Markdown(f"""
**Model Architecture:**
- **Model:** i3-200M Hybrid
- **Device:** {engine.device}
- **Vocab Size:** {engine.config['vocab_size']:,}
- **Parameters:** {total_params:,} ({total_params/1e6:.2f}M)
""")
with gr.Column():
gr.Markdown(f"""
**Configuration:**
- **d_model:** {engine.config['d_model']}
- **RWKV Layers:** {engine.config['rwkv_layers']}
- **Attention Layers:** {engine.config['attn_layers']}
- **Max Seq Len:** {engine.model.max_seq_len}
""")
# Footer
gr.Markdown(
"""
---
<div style="text-align: center; color: #666;">
<p>Built with ❀️ using Gradio | Model: FlameF0X/i3-200m-v2</p>
</div>
"""
)
# Connect UI
generate_btn.click(
predict,
inputs=[prompt_input, max_tokens_input, temp_input, topk_input],
outputs=[output_text]
)
if __name__ == "__main__":
demo.queue()
demo.launch()