import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from huggingface_hub import hf_hub_download from tokenizers import Tokenizer import os # ============================================================================ # 1. MODEL ARCHITECTURE (Must match training code exactly) # ============================================================================ @torch.jit.script def rwkv_linear_attention(B: int, T: int, C: int, r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, u: torch.Tensor, state_init: torch.Tensor): y = torch.zeros_like(v) state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device) state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device) state_pp = state_init.clone() for t in range(T): rt, kt, vt = r[:, t], k[:, t], v[:, t] ww = u + state_pp p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6) y[:, t] = wkv ww = w + state_pp p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) state_aa = state_aa * e1 + vt * e2 state_bb = state_bb * e1 + e2 state_pp = p return y class RWKVTimeMix(nn.Module): def __init__(self, d_model): super().__init__() self.d_model = d_model self.time_decay = nn.Parameter(torch.ones(d_model)) self.time_first = nn.Parameter(torch.ones(d_model)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) self.key = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(d_model, d_model, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.output = nn.Linear(d_model, d_model, bias=False) self.time_decay.data.uniform_(-6, -3) def forward(self, x): B, T, C = x.size() xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) v = self.value(xv) r = torch.sigmoid(self.receptance(xr)) w = -torch.exp(self.time_decay) u = self.time_first state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device) rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init) return self.output(r * rwkv) class RWKVChannelMix(nn.Module): def __init__(self, d_model, ffn_mult=4): super().__init__() self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) hidden_sz = d_model * ffn_mult self.key = nn.Linear(d_model, hidden_sz, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(hidden_sz, d_model, bias=False) def forward(self, x): B, T, C = x.size() xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = torch.square(torch.relu(self.key(xk))) kv = self.value(k) r = torch.sigmoid(self.receptance(xr)) return r * kv class BiRWKVBlock(nn.Module): def __init__(self, d_model, ffn_mult=4): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.fwd_time_mix = RWKVTimeMix(d_model) self.bwd_time_mix = RWKVTimeMix(d_model) self.ln2 = nn.LayerNorm(d_model) self.channel_mix = RWKVChannelMix(d_model, ffn_mult) def forward(self, x, mask=None): x_norm = self.ln1(x) x_fwd = self.fwd_time_mix(x_norm) x_rev = torch.flip(x_norm, [1]) x_bwd_rev = self.bwd_time_mix(x_rev) x_bwd = torch.flip(x_bwd_rev, [1]) x = x + x_fwd + x_bwd x = x + self.channel_mix(self.ln2(x)) return x class FullAttention(nn.Module): def __init__(self, d_model, n_heads=16): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.qkv = nn.Linear(d_model, d_model * 3) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x, mask=None): B, T, C = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) if mask is not None: attn = attn.masked_fill(mask == 0, float('-inf')) attn = F.softmax(attn, dim=-1) out = attn @ v out = out.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(out) class StandardAttentionBlock(nn.Module): def __init__(self, d_model, n_heads=16, ffn_mult=4): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = FullAttention(d_model, n_heads) self.ln2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, d_model * ffn_mult), nn.GELU(), nn.Linear(d_model * ffn_mult, d_model) ) def forward(self, x, mask=None): x = x + self.attn(self.ln1(x), mask) x = x + self.ffn(self.ln2(x)) return x class HybridBertEmbeddings(nn.Module): def __init__(self, vocab_size, d_model, max_len=512): super().__init__() self.word_embeddings = nn.Embedding(vocab_size, d_model) self.position_embeddings = nn.Embedding(max_len, d_model) self.token_type_embeddings = nn.Embedding(2, d_model) self.ln = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.1) def forward(self, input_ids, token_type_ids): seq_len = input_ids.size(1) pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) embeddings = (self.word_embeddings(input_ids) + self.position_embeddings(pos_ids) + self.token_type_embeddings(token_type_ids)) return self.dropout(self.ln(embeddings)) class HybridBertModel(nn.Module): def __init__(self, vocab_size, d_model=768, n_rwkv_layers=6, n_attn_layers=6, n_heads=12, max_len=512): super().__init__() self.embeddings = HybridBertEmbeddings(vocab_size, d_model, max_len) self.layers = nn.ModuleList() for _ in range(n_rwkv_layers): self.layers.append(BiRWKVBlock(d_model, ffn_mult=4)) for _ in range(n_attn_layers): self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads)) self.mlm_head = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.LayerNorm(d_model), nn.Linear(d_model, vocab_size) ) self.pooler_dense = nn.Linear(d_model, d_model) self.nsp_head = nn.Linear(d_model, 2) def forward(self, input_ids, segment_ids): mask = (input_ids != 1).unsqueeze(1).unsqueeze(2) # 1 is PAD_TOKEN_ID x = self.embeddings(input_ids, segment_ids) for layer in self.layers: x = layer(x, mask) prediction_scores = self.mlm_head(x) return prediction_scores # ============================================================================ # 2. INITIALIZATION # ============================================================================ REPO_ID = "i3-lab/i3-BERT-v2" MODEL_FILENAME = "i3-bert.pt" TOKENIZER_FILENAME = "tokenizer_bert.json" print("Downloading model and tokenizer from Hugging Face Hub...") try: model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename=TOKENIZER_FILENAME) except Exception as e: print(f"Error downloading files: {e}") print("Ensure 'i3-bert.pt' and 'tokenizer_bert.json' exist in 'i3-lab/i3-BERT-v2'") raise e # Load Tokenizer tokenizer = Tokenizer.from_file(tokenizer_path) vocab_size = tokenizer.get_vocab_size() # Special Token IDs (based on your training code) CLS_ID = tokenizer.token_to_id("") SEP_ID = tokenizer.token_to_id("") MASK_ID = tokenizer.token_to_id("") PAD_ID = tokenizer.token_to_id("") # Load Model # Config matching the training parameters provided config = { "d_model": 768, "n_rwkv_layers": 4, "n_attn_layers": 4, "n_heads": 12, "seq_len": 128 } device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = HybridBertModel( vocab_size=vocab_size, d_model=config['d_model'], n_rwkv_layers=config['n_rwkv_layers'], n_attn_layers=config['n_attn_layers'], n_heads=config['n_heads'], max_len=config['seq_len'] ).to(device) print("Loading state dict...") state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.eval() print("Model loaded successfully!") # ============================================================================ # 3. GRADIO INFERENCE FUNCTION # ============================================================================ def predict_mask(text): if not text: return "Please enter text." # Ensure the user provided a token if "" not in text: return "Please include a token in your text to predict." # Tokenize encoded = tokenizer.encode(text) ids = encoded.ids # Truncate if necessary (keeping space for CLS and SEP) max_len = config['seq_len'] - 2 if len(ids) > max_len: ids = ids[:max_len] # Add CLS and SEP input_ids = [CLS_ID] + ids + [SEP_ID] segment_ids = [0] * len(input_ids) # Single sentence segment # Find MASK indices mask_indices = [i for i, token_id in enumerate(input_ids) if token_id == MASK_ID] if not mask_indices: return "No token found after tokenization." # Convert to Tensor input_tensor = torch.tensor([input_ids], device=device) segment_tensor = torch.tensor([segment_ids], device=device) # Inference with torch.no_grad(): logits = model(input_tensor, segment_tensor) # Process results for each mask results = [] for idx in mask_indices: mask_logits = logits[0, idx, :] top_k = torch.topk(mask_logits, 5) candidates = [] for score, token_id in zip(top_k.values, top_k.indices): word = tokenizer.decode([token_id.item()]) candidates.append(f"{word} ({score.item():.2f})") results.append(f"Mask at pos {idx}: " + ", ".join(candidates)) return "\n".join(results) # ============================================================================ # 4. LAUNCH UI # ============================================================================ with gr.Blocks() as demo: gr.Markdown("# i3-BERT: Hybrid RWKV + Attention Model") gr.Markdown("A custom 10M parameter model combining Bi-Directional RWKV and Attention layers.") gr.Markdown("Type a sentence with `` to see predictions.") with gr.Row(): inp = gr.Textbox(placeholder="The capital of France is .", label="Input Text") out = gr.Textbox(label="Predictions") btn = gr.Button("Predict") btn.click(fn=predict_mask, inputs=inp, outputs=out) examples = [ ["The quick brown fox jumps over the dog."], ["I want to eat a for lunch."], ["Python is a great programming ."] ] gr.Examples(examples, inp) demo.launch()