import torch import torch.nn as nn import torch.nn.functional as F import json import os import gradio as gr from tokenizers import Tokenizer from huggingface_hub import hf_hub_download # ============================================================================ # 1. MODEL ARCHITECTURE # (Copied from inference.py to support custom weight loading) # ============================================================================ @torch.jit.script def rwkv_linear_attention(B: int, T: int, C: int, r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, u: torch.Tensor, state_init: torch.Tensor): y = torch.zeros_like(v) state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device) state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device) state_pp = state_init.clone() for t in range(T): rt, kt, vt = r[:, t], k[:, t], v[:, t] ww = u + state_pp p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6) y[:, t] = wkv ww = w + state_pp p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) state_aa = state_aa * e1 + vt * e2 state_bb = state_bb * e1 + e2 state_pp = p return y class RWKVTimeMix(nn.Module): def __init__(self, d_model): super().__init__() self.d_model = d_model self.time_decay = nn.Parameter(torch.ones(d_model)) self.time_first = nn.Parameter(torch.ones(d_model)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) self.key = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(d_model, d_model, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.output = nn.Linear(d_model, d_model, bias=False) def forward(self, x): B, T, C = x.size() xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) v = self.value(xv) r = torch.sigmoid(self.receptance(xr)) w = -torch.exp(self.time_decay) u = self.time_first state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device) rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init) return self.output(r * rwkv) class RWKVChannelMix(nn.Module): def __init__(self, d_model, ffn_mult=4): super().__init__() self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) hidden_sz = d_model * ffn_mult self.key = nn.Linear(d_model, hidden_sz, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(hidden_sz, d_model, bias=False) def forward(self, x): B, T, C = x.size() xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = torch.square(torch.relu(self.key(xk))) kv = self.value(k) r = torch.sigmoid(self.receptance(xr)) return r * kv class RWKVBlock(nn.Module): def __init__(self, d_model, ffn_mult=4): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.att = RWKVTimeMix(d_model) self.ln2 = nn.LayerNorm(d_model) self.ffn = RWKVChannelMix(d_model, ffn_mult) def forward(self, x, mask=None): x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) return x class FullAttention(nn.Module): def __init__(self, d_model, n_heads=16): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.qkv = nn.Linear(d_model, d_model * 3) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x, mask=None): B, T, C = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) if mask is not None: mask = mask.to(x.device) attn = attn.masked_fill(mask == 0, float('-inf')) attn = F.softmax(attn, dim=-1) out = attn @ v out = out.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(out) class StandardAttentionBlock(nn.Module): def __init__(self, d_model, n_heads=16, ffn_mult=4): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = FullAttention(d_model, n_heads) self.ln2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, d_model * ffn_mult), nn.GELU(), nn.Linear(d_model * ffn_mult, d_model) ) def forward(self, x, mask=None): x = x + self.attn(self.ln1(x), mask) x = x + self.ffn(self.ln2(x)) return x class i3HybridModel(nn.Module): def __init__(self, vocab_size, d_model=1024, n_heads=16, n_rwkv_layers=10, n_attn_layers=6, max_seq_len=512): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.max_seq_len = max_seq_len self.embed = nn.Embedding(vocab_size, d_model) self.pos_embed = nn.Embedding(max_seq_len, d_model) self.layers = nn.ModuleList() for _ in range(n_rwkv_layers): self.layers.append(RWKVBlock(d_model, ffn_mult=4)) for _ in range(n_attn_layers): self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads)) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size) def forward(self, idx): B, T = idx.shape if T > self.max_seq_len: idx = idx[:, -self.max_seq_len:] T = self.max_seq_len pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0) x = self.embed(idx) + self.pos_embed(pos) mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T) for layer in self.layers: x = layer(x, mask) x = self.ln_f(x) logits = self.head(x) return logits # ============================================================================ # 2. SPACE INFERENCE ENGINE # ============================================================================ class SpaceInferenceEngine: def __init__(self, repo_id="FlameF0X/i3-200m-v2"): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Loading model on {self.device}...") # Download files from Hugging Face Hub try: config_path = hf_hub_download(repo_id=repo_id, filename="config.json") tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json") weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") except Exception as e: raise ValueError(f"Failed to download model files from {repo_id}: {e}") # Load Config with open(config_path, 'r') as f: self.config = json.load(f) # Load Tokenizer self.tokenizer = Tokenizer.from_file(tokenizer_path) # Initialize Model print("Initializing model architecture...") # Use config for seq_len, fallback to 256 max_seq_len = self.config.get('seq_len', self.config.get('max_seq_len', 256)) self.model = i3HybridModel( vocab_size=self.config['vocab_size'], d_model=self.config['d_model'], n_heads=self.config.get('n_heads', 12), n_rwkv_layers=self.config['rwkv_layers'], n_attn_layers=self.config['attn_layers'], max_seq_len=max_seq_len ).to(self.device) # Load Weights print(f"Loading weights...") state_dict = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.eval() print("Model loaded successfully.") def generate_stream(self, prompt, max_new_tokens=100, temperature=1.0, top_k=50): # Encode input_ids = self.tokenizer.encode(prompt).ids x = torch.tensor([input_ids], dtype=torch.long, device=self.device) # For display purposes, we keep the original prompt + new tokens generated_text = prompt with torch.no_grad(): for _ in range(max_new_tokens): # Context window handling if x.size(1) > self.model.max_seq_len: x_cond = x[:, -self.model.max_seq_len:] else: x_cond = x # Forward pass logits = self.model(x_cond) logits = logits[:, -1, :] / temperature # Top-K Sampling if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # Probability distribution probs = F.softmax(logits, dim=-1) # Sample next token idx_next = torch.multinomial(probs, num_samples=1) # Append to sequence x = torch.cat((x, idx_next), dim=1) # Decode the new token new_token_id = idx_next.item() token_str = self.tokenizer.decode([new_token_id]) # Update text and yield for streaming generated_text += token_str yield generated_text # ============================================================================ # 3. GRADIO INTERFACE (UI Upgrade) # ============================================================================ # Initialize engine globally print("Starting Engine...") engine = SpaceInferenceEngine() def predict(prompt, max_tokens, temperature, top_k): if not prompt.strip(): yield "⚠️ Please enter a prompt to generate text." return # Use the generator for streaming for current_text in engine.generate_stream( prompt, max_new_tokens=int(max_tokens), temperature=temperature, top_k=int(top_k) ): yield current_text # Custom CSS custom_css = """ .gradio-container { max-width: 1200px !important; } .main-header { text-align: center; margin-bottom: 2rem; } """ with gr.Blocks() as demo: # Inject CSS via HTML component to avoid Blocks() keyword argument error gr.HTML(f"") # Header with gr.Row(): gr.Markdown( """ # 🚀 i3-200M Text Generation ### Powered by RWKV-Hybrid Architecture Generate creative text using the i3-200M language model combining RNN efficiency with Attention precision. """, elem_classes="main-header" ) # Main Generation Area with gr.Row(): # Left Column: Inputs with gr.Column(scale=2): prompt_input = gr.Textbox( label="✍️ Enter Your Prompt", placeholder="Once upon a time in a distant galaxy...", lines=4, max_lines=8 ) with gr.Accordion("⚙️ Generation Parameters", open=True): with gr.Row(): max_tokens_input = gr.Slider( minimum=10, maximum=512, value=150, step=10, label="Max Tokens", info="Maximum number of tokens to generate" ) temp_input = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", info="Higher = more creative, Lower = more focused" ) topk_input = gr.Slider( minimum=1, maximum=100, value=40, step=1, label="Top-k Sampling", info="Number of top tokens to consider" ) with gr.Row(): generate_btn = gr.Button("🎨 Generate Text", variant="primary", size="lg") clear_btn = gr.ClearButton(components=[prompt_input], value="🗑️ Clear", size="lg") # Right Column: Output with gr.Column(scale=2): output_text = gr.Textbox( label="📝 Generated Output", lines=12, max_lines=20 ) # Examples Section with gr.Row(): gr.Examples( examples=[ ["The history of science is", 150, 0.7, 50], ["In a world where technology and nature coexist", 200, 0.9, 40], ["The scientist discovered something remarkable", 120, 0.8, 45], ], inputs=[prompt_input, max_tokens_input, temp_input, topk_input], label="💡 Try These Examples" ) # Developer Panel with gr.Accordion("🔧 Developer Info", open=False): total_params = sum(p.numel() for p in engine.model.parameters()) with gr.Row(): with gr.Column(): gr.Markdown(f""" **Model Architecture:** - **Model:** i3-200M Hybrid - **Device:** {engine.device} - **Vocab Size:** {engine.config['vocab_size']:,} - **Parameters:** {total_params:,} ({total_params/1e6:.2f}M) """) with gr.Column(): gr.Markdown(f""" **Configuration:** - **d_model:** {engine.config['d_model']} - **RWKV Layers:** {engine.config['rwkv_layers']} - **Attention Layers:** {engine.config['attn_layers']} - **Max Seq Len:** {engine.model.max_seq_len} """) # Footer gr.Markdown( """ ---

Built with ❤️ using Gradio | Model: FlameF0X/i3-200m-v2

""" ) # Connect UI generate_btn.click( predict, inputs=[prompt_input, max_tokens_input, temp_input, topk_input], outputs=[output_text] ) if __name__ == "__main__": demo.queue() demo.launch()