i3-4096ctx / app.py
FlameF0X's picture
Update app.py
ca534e3 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import json
import os
# ============================================================================
# MODEL ARCHITECTURE (from training code)
# ============================================================================
@torch.jit.script
def rwkv_linear_attention(B: int, T: int, C: int,
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor, u: torch.Tensor,
state_init: torch.Tensor):
y = torch.zeros_like(v)
state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_pp = state_init.clone()
for t in range(T):
rt, kt, vt = r[:, t], k[:, t], v[:, t]
ww = u + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6)
y[:, t] = wkv
ww = w + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
state_aa = state_aa * e1 + vt * e2
state_bb = state_bb * e1 + e2
state_pp = p
return y
class RWKVTimeMix(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.time_decay = nn.Parameter(torch.ones(d_model))
self.time_first = nn.Parameter(torch.ones(d_model))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
self.time_decay.data.uniform_(-6, -3)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = torch.sigmoid(self.receptance(xr))
w = -torch.exp(self.time_decay)
u = self.time_first
state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device)
rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init)
return self.output(r * rwkv)
class RWKVChannelMix(nn.Module):
def __init__(self, d_model, ffn_mult=4):
super().__init__()
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
hidden_sz = d_model * ffn_mult
self.key = nn.Linear(d_model, hidden_sz, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(hidden_sz, d_model, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = torch.square(torch.relu(self.key(xk)))
kv = self.value(k)
r = torch.sigmoid(self.receptance(xr))
return r * kv
class RWKVBlock(nn.Module):
def __init__(self, d_model, ffn_mult=4):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.att = RWKVTimeMix(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = RWKVChannelMix(d_model, ffn_mult)
def forward(self, x, mask=None):
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class FullAttention(nn.Module):
def __init__(self, d_model, n_heads=16):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
mask = mask.to(x.device)
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
class StandardAttentionBlock(nn.Module):
def __init__(self, d_model, n_heads=16, ffn_mult=4):
super().__init__()
if d_model % n_heads != 0:
for h in [16, 12, 10, 8, 6, 4, 2]:
if d_model % h == 0:
n_heads = h
break
self.ln1 = nn.LayerNorm(d_model)
self.attn = FullAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * ffn_mult),
nn.GELU(),
nn.Linear(d_model * ffn_mult, d_model)
)
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
x = x + self.ffn(self.ln2(x))
return x
class LatentContextCompressor(nn.Module):
def __init__(self, d_model, compression_ratio=4, num_latent_tokens=32, n_heads=None):
super().__init__()
self.d_model = d_model
self.compression_ratio = compression_ratio
self.num_latent_tokens = num_latent_tokens
if n_heads is None:
for h in [16, 12, 10, 8, 6, 4, 2, 1]:
if d_model % h == 0:
n_heads = h
break
self.n_heads = n_heads
self.latent_queries = nn.Parameter(torch.randn(1, num_latent_tokens, d_model))
self.compress_attn = nn.MultiheadAttention(
embed_dim=d_model, num_heads=n_heads, batch_first=True
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model)
)
def forward(self, x):
B, T, C = x.shape
queries = self.latent_queries.expand(B, -1, -1)
compressed, _ = self.compress_attn(
query=self.ln1(queries), key=x, value=x, need_weights=False
)
compressed = queries + compressed
compressed = compressed + self.ffn(self.ln2(compressed))
return compressed
class i3HybridModelWithCompression(nn.Module):
def __init__(self, vocab_size, d_model=1024, n_heads=16, n_rwkv_layers=10,
n_attn_layers=6, kernel_size=512, max_latent_context=2048,
compression_ratio=4, num_latent_tokens=32, enable_compression=True):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.kernel_size = kernel_size
self.max_latent_context = max_latent_context
self.enable_compression = enable_compression
self.num_latent_tokens = num_latent_tokens
if d_model % n_heads != 0:
for h in [16, 12, 10, 8, 6, 4, 2]:
if d_model % h == 0:
n_heads = h
break
self.n_heads = n_heads
self.max_compressed_chunks = max_latent_context // kernel_size
self.embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(max(kernel_size, max_latent_context), d_model)
if enable_compression:
self.compressor = LatentContextCompressor(
d_model=d_model, compression_ratio=compression_ratio,
num_latent_tokens=num_latent_tokens, n_heads=n_heads
)
self.layers = nn.ModuleList()
for _ in range(n_rwkv_layers):
self.layers.append(RWKVBlock(d_model, ffn_mult=4))
for _ in range(n_attn_layers):
self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads))
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def forward(self, idx, targets=None, compressed_history=None):
B, T = idx.shape
if T > self.kernel_size:
idx = idx[:, -self.kernel_size:]
if targets is not None:
targets = targets[:, -self.kernel_size:]
T = self.kernel_size
pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)
x = self.embed(idx) + self.pos_embed(pos)
if self.enable_compression and compressed_history is not None:
history_len = compressed_history.size(1)
total_len = history_len + T
full_pos = torch.arange(0, total_len, dtype=torch.long, device=idx.device).unsqueeze(0)
x_with_pos = self.embed(idx) + self.pos_embed(full_pos[:, history_len:])
x = torch.cat([compressed_history, x_with_pos], dim=1)
T = total_len
mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T)
for layer in self.layers:
x = layer(x, mask)
x = self.ln_f(x)
if self.enable_compression and compressed_history is not None:
history_len = compressed_history.size(1)
logits = self.head(x[:, history_len:])
else:
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
new_compressed = None
if self.enable_compression:
if compressed_history is not None:
current_tokens = x[:, history_len:]
else:
current_tokens = x
new_compressed = self.compressor(current_tokens)
return logits, loss, new_compressed
@torch.no_grad()
def generate_stream(self, idx, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9):
"""Generator that yields tokens one at a time for streaming."""
compressed_history = None
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.kernel_size else idx[:, -self.kernel_size:]
logits, _, new_compressed = self(idx_cond, compressed_history=compressed_history)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
if top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs[indices_to_remove] = 0
probs = probs / probs.sum(dim=-1, keepdim=True)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
if self.enable_compression and new_compressed is not None:
if compressed_history is None:
compressed_history = new_compressed
else:
compressed_history = torch.cat([compressed_history, new_compressed], dim=1)
max_history_tokens = self.max_compressed_chunks * self.num_latent_tokens
if compressed_history.size(1) > max_history_tokens:
compressed_history = compressed_history[:, -max_history_tokens:]
yield idx_next.item()
# ============================================================================
# MODEL LOADING
# ============================================================================
class ModelLoader:
def __init__(self, repo_id="i3-lab/i3-4096ctx"):
self.repo_id = repo_id
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.tokenizer = None
def load(self):
print(f"Loading model from {self.repo_id}...")
# Download files
config_path = hf_hub_download(repo_id=self.repo_id, filename="config.json")
model_path = hf_hub_download(repo_id=self.repo_id, filename="pytorch_model.bin")
tokenizer_path = hf_hub_download(repo_id=self.repo_id, filename="tokenizer.json")
# Load config
with open(config_path, 'r') as f:
config = json.load(f)
# Load tokenizer
self.tokenizer = Tokenizer.from_file(tokenizer_path)
# Create model
self.model = i3HybridModelWithCompression(
vocab_size=config['vocab_size'],
d_model=config['d_model'],
n_heads=8, # Adjust based on your config
n_rwkv_layers=config.get('rwkv_layers', 12),
n_attn_layers=config.get('attn_layers', 2),
kernel_size=config.get('kernel_size', 512),
max_latent_context=config.get('max_latent_context', 4096),
num_latent_tokens=32,
enable_compression=config.get('compression_enabled', True)
)
# Load weights
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
print(f"Model loaded successfully on {self.device}")
return self.model, self.tokenizer
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
# Initialize model
loader = ModelLoader()
model, tokenizer = loader.load()
def generate_text(prompt, temperature, top_k, top_p, max_tokens):
"""Generate text completion with streaming."""
# Encode the prompt
input_ids = tokenizer.encode(prompt).ids
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=loader.device)
# Start with the prompt
output_text = prompt
# Generate with streaming
for token_id in model.generate_stream(
input_tensor,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p
):
token_text = tokenizer.decode([token_id])
output_text += token_text
yield output_text
# Example prompts
examples = [
["The future of artificial intelligence is", 0.8, 50, 0.9, 200],
["In a world where technology has advanced beyond our wildest dreams,", 0.9, 40, 0.95, 300],
["The key principles of quantum mechanics include", 0.7, 50, 0.9, 250],
["Once upon a time in a distant galaxy,", 1.0, 50, 0.95, 200],
["The most important factors in climate change are", 0.7, 50, 0.9, 200],
]
# Create Gradio interface
with gr.Blocks(title="i3-4096ctx Text Completion", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🚀 i3-4096ctx Language Model - Text Completion
A hybrid RWKV-Attention pre-trained model with latent context compression, supporting up to 4096 tokens of context.
**Note**: This is a pre-trained base model, not an instruction-tuned chat model. It performs **text completion** - give it a prompt and it will continue the text.
""")
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here... The model will continue from where you leave off.",
lines=5
)
output_text = gr.Textbox(
label="Generated Text",
lines=15,
interactive=False
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary", scale=2)
clear_btn = gr.Button("Clear", scale=1)
with gr.Column(scale=1):
gr.Markdown("### Generation Settings")
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher = more creative, random"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k",
info="Sample from top k tokens"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus)",
info="Cumulative probability cutoff"
)
max_tokens = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=10,
label="Max new tokens",
info="Maximum length to generate"
)
gr.Markdown("""
### Model Info
- **Type**: Pre-trained base model
- **Architecture**: Hybrid RWKV-Attention
- **Context**: 4096 tokens (compressed)
- **Kernel**: 512 tokens direct
- **Compression**: 32 latent tokens/chunk
### Tips for Better Results
- Start with a clear, specific prompt
- Lower temperature (0.5-0.8) for factual text
- Higher temperature (0.9-1.2) for creative writing
- Adjust top-k and top-p for diversity control
""")
gr.Markdown("### Example Prompts")
gr.Examples(
examples=examples,
inputs=[prompt_input, temperature, top_k, top_p, max_tokens],
outputs=output_text,
fn=generate_text,
cache_examples=False
)
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, temperature, top_k, top_p, max_tokens],
outputs=output_text
)
clear_btn.click(
fn=lambda: ("", ""),
inputs=None,
outputs=[prompt_input, output_text]
)
# Launch
if __name__ == "__main__":
demo.queue()
demo.launch()