Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| import os | |
| # ============================================================================ | |
| # MODEL ARCHITECTURE (from training code) | |
| # ============================================================================ | |
| def rwkv_linear_attention(B: int, T: int, C: int, | |
| r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
| w: torch.Tensor, u: torch.Tensor, | |
| state_init: torch.Tensor): | |
| y = torch.zeros_like(v) | |
| state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device) | |
| state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device) | |
| state_pp = state_init.clone() | |
| for t in range(T): | |
| rt, kt, vt = r[:, t], k[:, t], v[:, t] | |
| ww = u + state_pp | |
| p = torch.maximum(ww, kt) | |
| e1 = torch.exp(ww - p) | |
| e2 = torch.exp(kt - p) | |
| wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6) | |
| y[:, t] = wkv | |
| ww = w + state_pp | |
| p = torch.maximum(ww, kt) | |
| e1 = torch.exp(ww - p) | |
| e2 = torch.exp(kt - p) | |
| state_aa = state_aa * e1 + vt * e2 | |
| state_bb = state_bb * e1 + e2 | |
| state_pp = p | |
| return y | |
| class RWKVTimeMix(nn.Module): | |
| def __init__(self, d_model): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.time_decay = nn.Parameter(torch.ones(d_model)) | |
| self.time_first = nn.Parameter(torch.ones(d_model)) | |
| self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) | |
| self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) | |
| self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) | |
| self.key = nn.Linear(d_model, d_model, bias=False) | |
| self.value = nn.Linear(d_model, d_model, bias=False) | |
| self.receptance = nn.Linear(d_model, d_model, bias=False) | |
| self.output = nn.Linear(d_model, d_model, bias=False) | |
| self.time_decay.data.uniform_(-6, -3) | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) | |
| xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) | |
| xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| k = self.key(xk) | |
| v = self.value(xv) | |
| r = torch.sigmoid(self.receptance(xr)) | |
| w = -torch.exp(self.time_decay) | |
| u = self.time_first | |
| state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device) | |
| rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init) | |
| return self.output(r * rwkv) | |
| class RWKVChannelMix(nn.Module): | |
| def __init__(self, d_model, ffn_mult=4): | |
| super().__init__() | |
| self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) | |
| self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) | |
| hidden_sz = d_model * ffn_mult | |
| self.key = nn.Linear(d_model, hidden_sz, bias=False) | |
| self.receptance = nn.Linear(d_model, d_model, bias=False) | |
| self.value = nn.Linear(hidden_sz, d_model, bias=False) | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) | |
| xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| k = torch.square(torch.relu(self.key(xk))) | |
| kv = self.value(k) | |
| r = torch.sigmoid(self.receptance(xr)) | |
| return r * kv | |
| class RWKVBlock(nn.Module): | |
| def __init__(self, d_model, ffn_mult=4): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.att = RWKVTimeMix(d_model) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.ffn = RWKVChannelMix(d_model, ffn_mult) | |
| def forward(self, x, mask=None): | |
| x = x + self.att(self.ln1(x)) | |
| x = x + self.ffn(self.ln2(x)) | |
| return x | |
| class FullAttention(nn.Module): | |
| def __init__(self, d_model, n_heads=16): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| self.qkv = nn.Linear(d_model, d_model * 3) | |
| self.out_proj = nn.Linear(d_model, d_model) | |
| def forward(self, x, mask=None): | |
| B, T, C = x.shape | |
| qkv = self.qkv(x) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) | |
| if mask is not None: | |
| mask = mask.to(x.device) | |
| attn = attn.masked_fill(mask == 0, float('-inf')) | |
| attn = F.softmax(attn, dim=-1) | |
| out = attn @ v | |
| out = out.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.out_proj(out) | |
| class StandardAttentionBlock(nn.Module): | |
| def __init__(self, d_model, n_heads=16, ffn_mult=4): | |
| super().__init__() | |
| if d_model % n_heads != 0: | |
| for h in [16, 12, 10, 8, 6, 4, 2]: | |
| if d_model % h == 0: | |
| n_heads = h | |
| break | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.attn = FullAttention(d_model, n_heads) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, d_model * ffn_mult), | |
| nn.GELU(), | |
| nn.Linear(d_model * ffn_mult, d_model) | |
| ) | |
| def forward(self, x, mask=None): | |
| x = x + self.attn(self.ln1(x), mask) | |
| x = x + self.ffn(self.ln2(x)) | |
| return x | |
| class LatentContextCompressor(nn.Module): | |
| def __init__(self, d_model, compression_ratio=4, num_latent_tokens=32, n_heads=None): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.compression_ratio = compression_ratio | |
| self.num_latent_tokens = num_latent_tokens | |
| if n_heads is None: | |
| for h in [16, 12, 10, 8, 6, 4, 2, 1]: | |
| if d_model % h == 0: | |
| n_heads = h | |
| break | |
| self.n_heads = n_heads | |
| self.latent_queries = nn.Parameter(torch.randn(1, num_latent_tokens, d_model)) | |
| self.compress_attn = nn.MultiheadAttention( | |
| embed_dim=d_model, num_heads=n_heads, batch_first=True | |
| ) | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, d_model * 2), | |
| nn.GELU(), | |
| nn.Linear(d_model * 2, d_model) | |
| ) | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| queries = self.latent_queries.expand(B, -1, -1) | |
| compressed, _ = self.compress_attn( | |
| query=self.ln1(queries), key=x, value=x, need_weights=False | |
| ) | |
| compressed = queries + compressed | |
| compressed = compressed + self.ffn(self.ln2(compressed)) | |
| return compressed | |
| class i3HybridModelWithCompression(nn.Module): | |
| def __init__(self, vocab_size, d_model=1024, n_heads=16, n_rwkv_layers=10, | |
| n_attn_layers=6, kernel_size=512, max_latent_context=2048, | |
| compression_ratio=4, num_latent_tokens=32, enable_compression=True): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.kernel_size = kernel_size | |
| self.max_latent_context = max_latent_context | |
| self.enable_compression = enable_compression | |
| self.num_latent_tokens = num_latent_tokens | |
| if d_model % n_heads != 0: | |
| for h in [16, 12, 10, 8, 6, 4, 2]: | |
| if d_model % h == 0: | |
| n_heads = h | |
| break | |
| self.n_heads = n_heads | |
| self.max_compressed_chunks = max_latent_context // kernel_size | |
| self.embed = nn.Embedding(vocab_size, d_model) | |
| self.pos_embed = nn.Embedding(max(kernel_size, max_latent_context), d_model) | |
| if enable_compression: | |
| self.compressor = LatentContextCompressor( | |
| d_model=d_model, compression_ratio=compression_ratio, | |
| num_latent_tokens=num_latent_tokens, n_heads=n_heads | |
| ) | |
| self.layers = nn.ModuleList() | |
| for _ in range(n_rwkv_layers): | |
| self.layers.append(RWKVBlock(d_model, ffn_mult=4)) | |
| for _ in range(n_attn_layers): | |
| self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads)) | |
| self.ln_f = nn.LayerNorm(d_model) | |
| self.head = nn.Linear(d_model, vocab_size) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| def forward(self, idx, targets=None, compressed_history=None): | |
| B, T = idx.shape | |
| if T > self.kernel_size: | |
| idx = idx[:, -self.kernel_size:] | |
| if targets is not None: | |
| targets = targets[:, -self.kernel_size:] | |
| T = self.kernel_size | |
| pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0) | |
| x = self.embed(idx) + self.pos_embed(pos) | |
| if self.enable_compression and compressed_history is not None: | |
| history_len = compressed_history.size(1) | |
| total_len = history_len + T | |
| full_pos = torch.arange(0, total_len, dtype=torch.long, device=idx.device).unsqueeze(0) | |
| x_with_pos = self.embed(idx) + self.pos_embed(full_pos[:, history_len:]) | |
| x = torch.cat([compressed_history, x_with_pos], dim=1) | |
| T = total_len | |
| mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T) | |
| for layer in self.layers: | |
| x = layer(x, mask) | |
| x = self.ln_f(x) | |
| if self.enable_compression and compressed_history is not None: | |
| history_len = compressed_history.size(1) | |
| logits = self.head(x[:, history_len:]) | |
| else: | |
| logits = self.head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) | |
| new_compressed = None | |
| if self.enable_compression: | |
| if compressed_history is not None: | |
| current_tokens = x[:, history_len:] | |
| else: | |
| current_tokens = x | |
| new_compressed = self.compressor(current_tokens) | |
| return logits, loss, new_compressed | |
| def generate_stream(self, idx, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9): | |
| """Generator that yields tokens one at a time for streaming.""" | |
| compressed_history = None | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx if idx.size(1) <= self.kernel_size else idx[:, -self.kernel_size:] | |
| logits, _, new_compressed = self(idx_cond, compressed_history=compressed_history) | |
| logits = logits[:, -1, :] / temperature | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| probs = F.softmax(logits, dim=-1) | |
| if top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| probs[indices_to_remove] = 0 | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| if self.enable_compression and new_compressed is not None: | |
| if compressed_history is None: | |
| compressed_history = new_compressed | |
| else: | |
| compressed_history = torch.cat([compressed_history, new_compressed], dim=1) | |
| max_history_tokens = self.max_compressed_chunks * self.num_latent_tokens | |
| if compressed_history.size(1) > max_history_tokens: | |
| compressed_history = compressed_history[:, -max_history_tokens:] | |
| yield idx_next.item() | |
| # ============================================================================ | |
| # MODEL LOADING | |
| # ============================================================================ | |
| class ModelLoader: | |
| def __init__(self, repo_id="i3-lab/i3-4096ctx"): | |
| self.repo_id = repo_id | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model = None | |
| self.tokenizer = None | |
| def load(self): | |
| print(f"Loading model from {self.repo_id}...") | |
| # Download files | |
| config_path = hf_hub_download(repo_id=self.repo_id, filename="config.json") | |
| model_path = hf_hub_download(repo_id=self.repo_id, filename="pytorch_model.bin") | |
| tokenizer_path = hf_hub_download(repo_id=self.repo_id, filename="tokenizer.json") | |
| # Load config | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| # Load tokenizer | |
| self.tokenizer = Tokenizer.from_file(tokenizer_path) | |
| # Create model | |
| self.model = i3HybridModelWithCompression( | |
| vocab_size=config['vocab_size'], | |
| d_model=config['d_model'], | |
| n_heads=8, # Adjust based on your config | |
| n_rwkv_layers=config.get('rwkv_layers', 12), | |
| n_attn_layers=config.get('attn_layers', 2), | |
| kernel_size=config.get('kernel_size', 512), | |
| max_latent_context=config.get('max_latent_context', 4096), | |
| num_latent_tokens=32, | |
| enable_compression=config.get('compression_enabled', True) | |
| ) | |
| # Load weights | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print(f"Model loaded successfully on {self.device}") | |
| return self.model, self.tokenizer | |
| # ============================================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================================ | |
| # Initialize model | |
| loader = ModelLoader() | |
| model, tokenizer = loader.load() | |
| def generate_text(prompt, temperature, top_k, top_p, max_tokens): | |
| """Generate text completion with streaming.""" | |
| # Encode the prompt | |
| input_ids = tokenizer.encode(prompt).ids | |
| input_tensor = torch.tensor([input_ids], dtype=torch.long, device=loader.device) | |
| # Start with the prompt | |
| output_text = prompt | |
| # Generate with streaming | |
| for token_id in model.generate_stream( | |
| input_tensor, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ): | |
| token_text = tokenizer.decode([token_id]) | |
| output_text += token_text | |
| yield output_text | |
| # Example prompts | |
| examples = [ | |
| ["The future of artificial intelligence is", 0.8, 50, 0.9, 200], | |
| ["In a world where technology has advanced beyond our wildest dreams,", 0.9, 40, 0.95, 300], | |
| ["The key principles of quantum mechanics include", 0.7, 50, 0.9, 250], | |
| ["Once upon a time in a distant galaxy,", 1.0, 50, 0.95, 200], | |
| ["The most important factors in climate change are", 0.7, 50, 0.9, 200], | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="i3-4096ctx Text Completion", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🚀 i3-4096ctx Language Model - Text Completion | |
| A hybrid RWKV-Attention pre-trained model with latent context compression, supporting up to 4096 tokens of context. | |
| **Note**: This is a pre-trained base model, not an instruction-tuned chat model. It performs **text completion** - give it a prompt and it will continue the text. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here... The model will continue from where you leave off.", | |
| lines=5 | |
| ) | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=15, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate", variant="primary", scale=2) | |
| clear_btn = gr.Button("Clear", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generation Settings") | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative, random" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top-k", | |
| info="Sample from top k tokens" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p (nucleus)", | |
| info="Cumulative probability cutoff" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=500, | |
| value=200, | |
| step=10, | |
| label="Max new tokens", | |
| info="Maximum length to generate" | |
| ) | |
| gr.Markdown(""" | |
| ### Model Info | |
| - **Type**: Pre-trained base model | |
| - **Architecture**: Hybrid RWKV-Attention | |
| - **Context**: 4096 tokens (compressed) | |
| - **Kernel**: 512 tokens direct | |
| - **Compression**: 32 latent tokens/chunk | |
| ### Tips for Better Results | |
| - Start with a clear, specific prompt | |
| - Lower temperature (0.5-0.8) for factual text | |
| - Higher temperature (0.9-1.2) for creative writing | |
| - Adjust top-k and top-p for diversity control | |
| """) | |
| gr.Markdown("### Example Prompts") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[prompt_input, temperature, top_k, top_p, max_tokens], | |
| outputs=output_text, | |
| fn=generate_text, | |
| cache_examples=False | |
| ) | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt_input, temperature, top_k, top_p, max_tokens], | |
| outputs=output_text | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", ""), | |
| inputs=None, | |
| outputs=[prompt_input, output_text] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |