|
|
|
|
|
import json, numpy as np, torch |
|
|
import torch.nn as nn |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, dim, heads, ff_dim, dropout): |
|
|
super().__init__() |
|
|
self.mha = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True) |
|
|
self.ln1 = nn.LayerNorm(dim); self.ln2 = nn.LayerNorm(dim) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
self.ffn = nn.Sequential(nn.Linear(dim, ff_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim)) |
|
|
def forward(self, x, key_padding_mask=None): |
|
|
attn_out, _ = self.mha(x, x, x, key_padding_mask=key_padding_mask, need_weights=False) |
|
|
x = self.ln1(x + self.drop(attn_out)) |
|
|
ff_out = self.ffn(x) |
|
|
return self.ln2(x + self.drop(ff_out)) |
|
|
|
|
|
class BiGRUResidualBlock(nn.Module): |
|
|
def __init__(self, dim, dropout): |
|
|
super().__init__() |
|
|
self.gru = nn.GRU(dim, dim//2, num_layers=1, batch_first=True, bidirectional=True) |
|
|
self.ln = nn.LayerNorm(dim); self.drop = nn.Dropout(dropout) |
|
|
def forward(self, x): |
|
|
out, _ = self.gru(x) |
|
|
return self.ln(x + self.drop(out)) |
|
|
|
|
|
class MeterModel(nn.Module): |
|
|
def __init__(self, vocab_size, num_classes, T, emb_dim=64, num_heads=4, ff_dim=256, gru_blocks=3, dropout=0.1): |
|
|
super().__init__() |
|
|
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0) |
|
|
self.pos = nn.Embedding(T, emb_dim) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
self.tr = TransformerBlock(emb_dim, num_heads, ff_dim, dropout) |
|
|
self.gru_blocks = nn.ModuleList([BiGRUResidualBlock(emb_dim, dropout) for _ in range(gru_blocks)]) |
|
|
self.head = nn.Sequential(nn.Linear(emb_dim, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, num_classes)) |
|
|
def forward(self, x): |
|
|
B, T = x.shape |
|
|
pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T) |
|
|
h = self.drop(self.emb(x) + self.pos(pos)) |
|
|
pad_mask = (x == 0) |
|
|
h = self.tr(h, key_padding_mask=pad_mask) |
|
|
for blk in self.gru_blocks: |
|
|
h = blk(h) |
|
|
mask = (~pad_mask).float().unsqueeze(-1) |
|
|
pooled = (h * mask).sum(1) / mask.sum(1).clamp_min(1.0) |
|
|
return self.head(pooled) |
|
|
|
|
|
def load_local(repo_dir="."): |
|
|
cfg = json.load(open(f"{repo_dir}/config.json", "r", encoding="utf-8")) |
|
|
stoi = json.load(open(f"{repo_dir}/stoi.json", "r", encoding="utf-8")) |
|
|
classes = cfg["classes"] |
|
|
model = MeterModel( |
|
|
vocab_size=cfg["vocab_size"], num_classes=cfg["num_classes"], T=cfg["max_length"], |
|
|
emb_dim=cfg["emb_dim"], num_heads=cfg["num_heads"], ff_dim=cfg["ff_dim"], |
|
|
gru_blocks=cfg["gru_blocks"], dropout=cfg["dropout"] |
|
|
) |
|
|
sd = load_file(f"{repo_dir}/model.safetensors") |
|
|
model.load_state_dict(sd) |
|
|
model.eval() |
|
|
return model, stoi, classes, cfg["max_length"] |
|
|
|
|
|
def encode(text, stoi, max_len): |
|
|
ids = np.zeros((max_len,), dtype=np.int64) |
|
|
for i, ch in enumerate(text[:max_len]): |
|
|
ids[i] = stoi.get(ch, 0) |
|
|
return torch.from_numpy(ids).unsqueeze(0) |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(text, model, stoi, classes, max_len, topk=5): |
|
|
x = encode(text, stoi, max_len) |
|
|
logits = model(x)[0] |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
topv, topi = torch.topk(probs, k=min(topk, probs.numel())) |
|
|
pred = classes[int(topi[0])] |
|
|
top = [(classes[int(i)], float(v)) for v, i in zip(topv, topi)] |
|
|
return pred, top |
|
|
|