NourFakih's picture
Upload folder using huggingface_hub
5edd0fc verified
import json, numpy as np, torch
import torch.nn as nn
from safetensors.torch import load_file
class TransformerBlock(nn.Module):
def __init__(self, dim, heads, ff_dim, dropout):
super().__init__()
self.mha = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
self.ln1 = nn.LayerNorm(dim); self.ln2 = nn.LayerNorm(dim)
self.drop = nn.Dropout(dropout)
self.ffn = nn.Sequential(nn.Linear(dim, ff_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim))
def forward(self, x, key_padding_mask=None):
attn_out, _ = self.mha(x, x, x, key_padding_mask=key_padding_mask, need_weights=False)
x = self.ln1(x + self.drop(attn_out))
ff_out = self.ffn(x)
return self.ln2(x + self.drop(ff_out))
class BiGRUResidualBlock(nn.Module):
def __init__(self, dim, dropout):
super().__init__()
self.gru = nn.GRU(dim, dim//2, num_layers=1, batch_first=True, bidirectional=True)
self.ln = nn.LayerNorm(dim); self.drop = nn.Dropout(dropout)
def forward(self, x):
out, _ = self.gru(x)
return self.ln(x + self.drop(out))
class MeterModel(nn.Module):
def __init__(self, vocab_size, num_classes, T, emb_dim=64, num_heads=4, ff_dim=256, gru_blocks=3, dropout=0.1):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.pos = nn.Embedding(T, emb_dim)
self.drop = nn.Dropout(dropout)
self.tr = TransformerBlock(emb_dim, num_heads, ff_dim, dropout)
self.gru_blocks = nn.ModuleList([BiGRUResidualBlock(emb_dim, dropout) for _ in range(gru_blocks)])
self.head = nn.Sequential(nn.Linear(emb_dim, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, num_classes))
def forward(self, x):
B, T = x.shape
pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
h = self.drop(self.emb(x) + self.pos(pos))
pad_mask = (x == 0)
h = self.tr(h, key_padding_mask=pad_mask)
for blk in self.gru_blocks:
h = blk(h)
mask = (~pad_mask).float().unsqueeze(-1)
pooled = (h * mask).sum(1) / mask.sum(1).clamp_min(1.0)
return self.head(pooled)
def load_local(repo_dir="."):
cfg = json.load(open(f"{repo_dir}/config.json", "r", encoding="utf-8"))
stoi = json.load(open(f"{repo_dir}/stoi.json", "r", encoding="utf-8"))
classes = cfg["classes"]
model = MeterModel(
vocab_size=cfg["vocab_size"], num_classes=cfg["num_classes"], T=cfg["max_length"],
emb_dim=cfg["emb_dim"], num_heads=cfg["num_heads"], ff_dim=cfg["ff_dim"],
gru_blocks=cfg["gru_blocks"], dropout=cfg["dropout"]
)
sd = load_file(f"{repo_dir}/model.safetensors")
model.load_state_dict(sd)
model.eval()
return model, stoi, classes, cfg["max_length"]
def encode(text, stoi, max_len):
ids = np.zeros((max_len,), dtype=np.int64)
for i, ch in enumerate(text[:max_len]):
ids[i] = stoi.get(ch, 0)
return torch.from_numpy(ids).unsqueeze(0)
@torch.no_grad()
def predict(text, model, stoi, classes, max_len, topk=5):
x = encode(text, stoi, max_len)
logits = model(x)[0]
probs = torch.softmax(logits, dim=-1)
topv, topi = torch.topk(probs, k=min(topk, probs.numel()))
pred = classes[int(topi[0])]
top = [(classes[int(i)], float(v)) for v, i in zip(topv, topi)]
return pred, top