i3-BERT / app.py
FlameF0X's picture
Update app.py
a8a99c3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
import os
# ============================================================================
# 1. MODEL ARCHITECTURE (Must match training code exactly)
# ============================================================================
@torch.jit.script
def rwkv_linear_attention(B: int, T: int, C: int,
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor, u: torch.Tensor,
state_init: torch.Tensor):
y = torch.zeros_like(v)
state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device)
state_pp = state_init.clone()
for t in range(T):
rt, kt, vt = r[:, t], k[:, t], v[:, t]
ww = u + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6)
y[:, t] = wkv
ww = w + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
state_aa = state_aa * e1 + vt * e2
state_bb = state_bb * e1 + e2
state_pp = p
return y
class RWKVTimeMix(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.time_decay = nn.Parameter(torch.ones(d_model))
self.time_first = nn.Parameter(torch.ones(d_model))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
self.time_decay.data.uniform_(-6, -3)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
v = self.value(xv)
r = torch.sigmoid(self.receptance(xr))
w = -torch.exp(self.time_decay)
u = self.time_first
state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device)
rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init)
return self.output(r * rwkv)
class RWKVChannelMix(nn.Module):
def __init__(self, d_model, ffn_mult=4):
super().__init__()
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
hidden_sz = d_model * ffn_mult
self.key = nn.Linear(d_model, hidden_sz, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(hidden_sz, d_model, bias=False)
def forward(self, x):
B, T, C = x.size()
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = torch.square(torch.relu(self.key(xk)))
kv = self.value(k)
r = torch.sigmoid(self.receptance(xr))
return r * kv
class BiRWKVBlock(nn.Module):
def __init__(self, d_model, ffn_mult=4):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.fwd_time_mix = RWKVTimeMix(d_model)
self.bwd_time_mix = RWKVTimeMix(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.channel_mix = RWKVChannelMix(d_model, ffn_mult)
def forward(self, x, mask=None):
x_norm = self.ln1(x)
x_fwd = self.fwd_time_mix(x_norm)
x_rev = torch.flip(x_norm, [1])
x_bwd_rev = self.bwd_time_mix(x_rev)
x_bwd = torch.flip(x_bwd_rev, [1])
x = x + x_fwd + x_bwd
x = x + self.channel_mix(self.ln2(x))
return x
class FullAttention(nn.Module):
def __init__(self, d_model, n_heads=16):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
class StandardAttentionBlock(nn.Module):
def __init__(self, d_model, n_heads=16, ffn_mult=4):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = FullAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * ffn_mult),
nn.GELU(),
nn.Linear(d_model * ffn_mult, d_model)
)
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
x = x + self.ffn(self.ln2(x))
return x
class HybridBertEmbeddings(nn.Module):
def __init__(self, vocab_size, d_model, max_len=512):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, d_model)
self.position_embeddings = nn.Embedding(max_len, d_model)
self.token_type_embeddings = nn.Embedding(2, d_model)
self.ln = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids, token_type_ids):
seq_len = input_ids.size(1)
pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
embeddings = (self.word_embeddings(input_ids) +
self.position_embeddings(pos_ids) +
self.token_type_embeddings(token_type_ids))
return self.dropout(self.ln(embeddings))
class HybridBertModel(nn.Module):
def __init__(self, vocab_size, d_model=768, n_rwkv_layers=6, n_attn_layers=6, n_heads=12, max_len=512):
super().__init__()
self.embeddings = HybridBertEmbeddings(vocab_size, d_model, max_len)
self.layers = nn.ModuleList()
for _ in range(n_rwkv_layers):
self.layers.append(BiRWKVBlock(d_model, ffn_mult=4))
for _ in range(n_attn_layers):
self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads))
self.mlm_head = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.LayerNorm(d_model),
nn.Linear(d_model, vocab_size)
)
self.pooler_dense = nn.Linear(d_model, d_model)
self.nsp_head = nn.Linear(d_model, 2)
def forward(self, input_ids, segment_ids):
mask = (input_ids != 1).unsqueeze(1).unsqueeze(2) # 1 is PAD_TOKEN_ID
x = self.embeddings(input_ids, segment_ids)
for layer in self.layers:
x = layer(x, mask)
prediction_scores = self.mlm_head(x)
return prediction_scores
# ============================================================================
# 2. INITIALIZATION
# ============================================================================
REPO_ID = "i3-lab/i3-BERT-v2"
MODEL_FILENAME = "i3-bert.pt"
TOKENIZER_FILENAME = "tokenizer_bert.json"
print("Downloading model and tokenizer from Hugging Face Hub...")
try:
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename=TOKENIZER_FILENAME)
except Exception as e:
print(f"Error downloading files: {e}")
print("Ensure 'i3-bert.pt' and 'tokenizer_bert.json' exist in 'i3-lab/i3-BERT-v2'")
raise e
# Load Tokenizer
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
# Special Token IDs (based on your training code)
CLS_ID = tokenizer.token_to_id("<CLS>")
SEP_ID = tokenizer.token_to_id("<SEP>")
MASK_ID = tokenizer.token_to_id("<MASK>")
PAD_ID = tokenizer.token_to_id("<PAD>")
# Load Model
# Config matching the training parameters provided
config = {
"d_model": 768,
"n_rwkv_layers": 4,
"n_attn_layers": 4,
"n_heads": 12,
"seq_len": 128
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridBertModel(
vocab_size=vocab_size,
d_model=config['d_model'],
n_rwkv_layers=config['n_rwkv_layers'],
n_attn_layers=config['n_attn_layers'],
n_heads=config['n_heads'],
max_len=config['seq_len']
).to(device)
print("Loading state dict...")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print("Model loaded successfully!")
# ============================================================================
# 3. GRADIO INFERENCE FUNCTION
# ============================================================================
def predict_mask(text):
if not text:
return "Please enter text."
# Ensure the user provided a <mask> token
if "<MASK>" not in text:
return "Please include a <MASK> token in your text to predict."
# Tokenize
encoded = tokenizer.encode(text)
ids = encoded.ids
# Truncate if necessary (keeping space for CLS and SEP)
max_len = config['seq_len'] - 2
if len(ids) > max_len:
ids = ids[:max_len]
# Add CLS and SEP
input_ids = [CLS_ID] + ids + [SEP_ID]
segment_ids = [0] * len(input_ids) # Single sentence segment
# Find MASK indices
mask_indices = [i for i, token_id in enumerate(input_ids) if token_id == MASK_ID]
if not mask_indices:
return "No <MASK> token found after tokenization."
# Convert to Tensor
input_tensor = torch.tensor([input_ids], device=device)
segment_tensor = torch.tensor([segment_ids], device=device)
# Inference
with torch.no_grad():
logits = model(input_tensor, segment_tensor)
# Process results for each mask
results = []
for idx in mask_indices:
mask_logits = logits[0, idx, :]
top_k = torch.topk(mask_logits, 5)
candidates = []
for score, token_id in zip(top_k.values, top_k.indices):
word = tokenizer.decode([token_id.item()])
candidates.append(f"{word} ({score.item():.2f})")
results.append(f"Mask at pos {idx}: " + ", ".join(candidates))
return "\n".join(results)
# ============================================================================
# 4. LAUNCH UI
# ============================================================================
with gr.Blocks() as demo:
gr.Markdown("# i3-BERT: Hybrid RWKV + Attention Model")
gr.Markdown("A custom 10M parameter model combining Bi-Directional RWKV and Attention layers.")
gr.Markdown("Type a sentence with `<MASK>` to see predictions.")
with gr.Row():
inp = gr.Textbox(placeholder="The capital of France is <MASK>.", label="Input Text")
out = gr.Textbox(label="Predictions")
btn = gr.Button("Predict")
btn.click(fn=predict_mask, inputs=inp, outputs=out)
examples = [
["The quick brown fox jumps over the <MASK> dog."],
["I want to eat a <MASK> for lunch."],
["Python is a great programming <MASK>."]
]
gr.Examples(examples, inp)
demo.launch()