|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
from huggingface_hub import hf_hub_download |
|
|
from tokenizers import Tokenizer |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def rwkv_linear_attention(B: int, T: int, C: int, |
|
|
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
|
|
w: torch.Tensor, u: torch.Tensor, |
|
|
state_init: torch.Tensor): |
|
|
y = torch.zeros_like(v) |
|
|
state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device) |
|
|
state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device) |
|
|
state_pp = state_init.clone() |
|
|
|
|
|
for t in range(T): |
|
|
rt, kt, vt = r[:, t], k[:, t], v[:, t] |
|
|
ww = u + state_pp |
|
|
p = torch.maximum(ww, kt) |
|
|
e1 = torch.exp(ww - p) |
|
|
e2 = torch.exp(kt - p) |
|
|
wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6) |
|
|
y[:, t] = wkv |
|
|
|
|
|
ww = w + state_pp |
|
|
p = torch.maximum(ww, kt) |
|
|
e1 = torch.exp(ww - p) |
|
|
e2 = torch.exp(kt - p) |
|
|
state_aa = state_aa * e1 + vt * e2 |
|
|
state_bb = state_bb * e1 + e2 |
|
|
state_pp = p |
|
|
|
|
|
return y |
|
|
|
|
|
class RWKVTimeMix(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.time_decay = nn.Parameter(torch.ones(d_model)) |
|
|
self.time_first = nn.Parameter(torch.ones(d_model)) |
|
|
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.key = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(d_model, d_model, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.output = nn.Linear(d_model, d_model, bias=False) |
|
|
self.time_decay.data.uniform_(-6, -3) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) |
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) |
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) |
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) |
|
|
|
|
|
k = self.key(xk) |
|
|
v = self.value(xv) |
|
|
r = torch.sigmoid(self.receptance(xr)) |
|
|
|
|
|
w = -torch.exp(self.time_decay) |
|
|
u = self.time_first |
|
|
state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device) |
|
|
|
|
|
rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init) |
|
|
return self.output(r * rwkv) |
|
|
|
|
|
class RWKVChannelMix(nn.Module): |
|
|
def __init__(self, d_model, ffn_mult=4): |
|
|
super().__init__() |
|
|
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) |
|
|
hidden_sz = d_model * ffn_mult |
|
|
self.key = nn.Linear(d_model, hidden_sz, bias=False) |
|
|
self.receptance = nn.Linear(d_model, d_model, bias=False) |
|
|
self.value = nn.Linear(hidden_sz, d_model, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) |
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) |
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) |
|
|
|
|
|
k = torch.square(torch.relu(self.key(xk))) |
|
|
kv = self.value(k) |
|
|
r = torch.sigmoid(self.receptance(xr)) |
|
|
return r * kv |
|
|
|
|
|
class BiRWKVBlock(nn.Module): |
|
|
def __init__(self, d_model, ffn_mult=4): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.fwd_time_mix = RWKVTimeMix(d_model) |
|
|
self.bwd_time_mix = RWKVTimeMix(d_model) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.channel_mix = RWKVChannelMix(d_model, ffn_mult) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
x_norm = self.ln1(x) |
|
|
x_fwd = self.fwd_time_mix(x_norm) |
|
|
x_rev = torch.flip(x_norm, [1]) |
|
|
x_bwd_rev = self.bwd_time_mix(x_rev) |
|
|
x_bwd = torch.flip(x_bwd_rev, [1]) |
|
|
x = x + x_fwd + x_bwd |
|
|
x = x + self.channel_mix(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class FullAttention(nn.Module): |
|
|
def __init__(self, d_model, n_heads=16): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.head_dim = d_model // n_heads |
|
|
self.qkv = nn.Linear(d_model, d_model * 3) |
|
|
self.out_proj = nn.Linear(d_model, d_model) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, T, C = x.shape |
|
|
qkv = self.qkv(x) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
if mask is not None: |
|
|
attn = attn.masked_fill(mask == 0, float('-inf')) |
|
|
attn = F.softmax(attn, dim=-1) |
|
|
out = attn @ v |
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
return self.out_proj(out) |
|
|
|
|
|
class StandardAttentionBlock(nn.Module): |
|
|
def __init__(self, d_model, n_heads=16, ffn_mult=4): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.attn = FullAttention(d_model, n_heads) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(d_model, d_model * ffn_mult), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model * ffn_mult, d_model) |
|
|
) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
x = x + self.attn(self.ln1(x), mask) |
|
|
x = x + self.ffn(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class HybridBertEmbeddings(nn.Module): |
|
|
def __init__(self, vocab_size, d_model, max_len=512): |
|
|
super().__init__() |
|
|
self.word_embeddings = nn.Embedding(vocab_size, d_model) |
|
|
self.position_embeddings = nn.Embedding(max_len, d_model) |
|
|
self.token_type_embeddings = nn.Embedding(2, d_model) |
|
|
self.ln = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
def forward(self, input_ids, token_type_ids): |
|
|
seq_len = input_ids.size(1) |
|
|
pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
|
|
embeddings = (self.word_embeddings(input_ids) + |
|
|
self.position_embeddings(pos_ids) + |
|
|
self.token_type_embeddings(token_type_ids)) |
|
|
return self.dropout(self.ln(embeddings)) |
|
|
|
|
|
class HybridBertModel(nn.Module): |
|
|
def __init__(self, vocab_size, d_model=768, n_rwkv_layers=6, n_attn_layers=6, n_heads=12, max_len=512): |
|
|
super().__init__() |
|
|
self.embeddings = HybridBertEmbeddings(vocab_size, d_model, max_len) |
|
|
self.layers = nn.ModuleList() |
|
|
for _ in range(n_rwkv_layers): |
|
|
self.layers.append(BiRWKVBlock(d_model, ffn_mult=4)) |
|
|
for _ in range(n_attn_layers): |
|
|
self.layers.append(StandardAttentionBlock(d_model, n_heads=n_heads)) |
|
|
|
|
|
self.mlm_head = nn.Sequential( |
|
|
nn.Linear(d_model, d_model), |
|
|
nn.GELU(), |
|
|
nn.LayerNorm(d_model), |
|
|
nn.Linear(d_model, vocab_size) |
|
|
) |
|
|
self.pooler_dense = nn.Linear(d_model, d_model) |
|
|
self.nsp_head = nn.Linear(d_model, 2) |
|
|
|
|
|
def forward(self, input_ids, segment_ids): |
|
|
mask = (input_ids != 1).unsqueeze(1).unsqueeze(2) |
|
|
x = self.embeddings(input_ids, segment_ids) |
|
|
for layer in self.layers: |
|
|
x = layer(x, mask) |
|
|
prediction_scores = self.mlm_head(x) |
|
|
return prediction_scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "i3-lab/i3-BERT-v2" |
|
|
MODEL_FILENAME = "i3-bert.pt" |
|
|
TOKENIZER_FILENAME = "tokenizer_bert.json" |
|
|
|
|
|
print("Downloading model and tokenizer from Hugging Face Hub...") |
|
|
try: |
|
|
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) |
|
|
tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename=TOKENIZER_FILENAME) |
|
|
except Exception as e: |
|
|
print(f"Error downloading files: {e}") |
|
|
print("Ensure 'i3-bert.pt' and 'tokenizer_bert.json' exist in 'i3-lab/i3-BERT-v2'") |
|
|
raise e |
|
|
|
|
|
|
|
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
vocab_size = tokenizer.get_vocab_size() |
|
|
|
|
|
|
|
|
CLS_ID = tokenizer.token_to_id("<CLS>") |
|
|
SEP_ID = tokenizer.token_to_id("<SEP>") |
|
|
MASK_ID = tokenizer.token_to_id("<MASK>") |
|
|
PAD_ID = tokenizer.token_to_id("<PAD>") |
|
|
|
|
|
|
|
|
|
|
|
config = { |
|
|
"d_model": 768, |
|
|
"n_rwkv_layers": 4, |
|
|
"n_attn_layers": 4, |
|
|
"n_heads": 12, |
|
|
"seq_len": 128 |
|
|
} |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = HybridBertModel( |
|
|
vocab_size=vocab_size, |
|
|
d_model=config['d_model'], |
|
|
n_rwkv_layers=config['n_rwkv_layers'], |
|
|
n_attn_layers=config['n_attn_layers'], |
|
|
n_heads=config['n_heads'], |
|
|
max_len=config['seq_len'] |
|
|
).to(device) |
|
|
|
|
|
print("Loading state dict...") |
|
|
state_dict = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_mask(text): |
|
|
if not text: |
|
|
return "Please enter text." |
|
|
|
|
|
|
|
|
if "<MASK>" not in text: |
|
|
return "Please include a <MASK> token in your text to predict." |
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text) |
|
|
ids = encoded.ids |
|
|
|
|
|
|
|
|
max_len = config['seq_len'] - 2 |
|
|
if len(ids) > max_len: |
|
|
ids = ids[:max_len] |
|
|
|
|
|
|
|
|
input_ids = [CLS_ID] + ids + [SEP_ID] |
|
|
segment_ids = [0] * len(input_ids) |
|
|
|
|
|
|
|
|
mask_indices = [i for i, token_id in enumerate(input_ids) if token_id == MASK_ID] |
|
|
|
|
|
if not mask_indices: |
|
|
return "No <MASK> token found after tokenization." |
|
|
|
|
|
|
|
|
input_tensor = torch.tensor([input_ids], device=device) |
|
|
segment_tensor = torch.tensor([segment_ids], device=device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(input_tensor, segment_tensor) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for idx in mask_indices: |
|
|
mask_logits = logits[0, idx, :] |
|
|
top_k = torch.topk(mask_logits, 5) |
|
|
|
|
|
candidates = [] |
|
|
for score, token_id in zip(top_k.values, top_k.indices): |
|
|
word = tokenizer.decode([token_id.item()]) |
|
|
candidates.append(f"{word} ({score.item():.2f})") |
|
|
|
|
|
results.append(f"Mask at pos {idx}: " + ", ".join(candidates)) |
|
|
|
|
|
return "\n".join(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# i3-BERT: Hybrid RWKV + Attention Model") |
|
|
gr.Markdown("A custom 10M parameter model combining Bi-Directional RWKV and Attention layers.") |
|
|
gr.Markdown("Type a sentence with `<MASK>` to see predictions.") |
|
|
|
|
|
with gr.Row(): |
|
|
inp = gr.Textbox(placeholder="The capital of France is <MASK>.", label="Input Text") |
|
|
out = gr.Textbox(label="Predictions") |
|
|
|
|
|
btn = gr.Button("Predict") |
|
|
btn.click(fn=predict_mask, inputs=inp, outputs=out) |
|
|
|
|
|
examples = [ |
|
|
["The quick brown fox jumps over the <MASK> dog."], |
|
|
["I want to eat a <MASK> for lunch."], |
|
|
["Python is a great programming <MASK>."] |
|
|
] |
|
|
gr.Examples(examples, inp) |
|
|
|
|
|
demo.launch() |