import json, numpy as np, torch import torch.nn as nn from safetensors.torch import load_file class TransformerBlock(nn.Module): def __init__(self, dim, heads, ff_dim, dropout): super().__init__() self.mha = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True) self.ln1 = nn.LayerNorm(dim); self.ln2 = nn.LayerNorm(dim) self.drop = nn.Dropout(dropout) self.ffn = nn.Sequential(nn.Linear(dim, ff_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim)) def forward(self, x, key_padding_mask=None): attn_out, _ = self.mha(x, x, x, key_padding_mask=key_padding_mask, need_weights=False) x = self.ln1(x + self.drop(attn_out)) ff_out = self.ffn(x) return self.ln2(x + self.drop(ff_out)) class BiGRUResidualBlock(nn.Module): def __init__(self, dim, dropout): super().__init__() self.gru = nn.GRU(dim, dim//2, num_layers=1, batch_first=True, bidirectional=True) self.ln = nn.LayerNorm(dim); self.drop = nn.Dropout(dropout) def forward(self, x): out, _ = self.gru(x) return self.ln(x + self.drop(out)) class MeterModel(nn.Module): def __init__(self, vocab_size, num_classes, T, emb_dim=64, num_heads=4, ff_dim=256, gru_blocks=3, dropout=0.1): super().__init__() self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0) self.pos = nn.Embedding(T, emb_dim) self.drop = nn.Dropout(dropout) self.tr = TransformerBlock(emb_dim, num_heads, ff_dim, dropout) self.gru_blocks = nn.ModuleList([BiGRUResidualBlock(emb_dim, dropout) for _ in range(gru_blocks)]) self.head = nn.Sequential(nn.Linear(emb_dim, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, num_classes)) def forward(self, x): B, T = x.shape pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T) h = self.drop(self.emb(x) + self.pos(pos)) pad_mask = (x == 0) h = self.tr(h, key_padding_mask=pad_mask) for blk in self.gru_blocks: h = blk(h) mask = (~pad_mask).float().unsqueeze(-1) pooled = (h * mask).sum(1) / mask.sum(1).clamp_min(1.0) return self.head(pooled) def load_local(repo_dir="."): cfg = json.load(open(f"{repo_dir}/config.json", "r", encoding="utf-8")) stoi = json.load(open(f"{repo_dir}/stoi.json", "r", encoding="utf-8")) classes = cfg["classes"] model = MeterModel( vocab_size=cfg["vocab_size"], num_classes=cfg["num_classes"], T=cfg["max_length"], emb_dim=cfg["emb_dim"], num_heads=cfg["num_heads"], ff_dim=cfg["ff_dim"], gru_blocks=cfg["gru_blocks"], dropout=cfg["dropout"] ) sd = load_file(f"{repo_dir}/model.safetensors") model.load_state_dict(sd) model.eval() return model, stoi, classes, cfg["max_length"] def encode(text, stoi, max_len): ids = np.zeros((max_len,), dtype=np.int64) for i, ch in enumerate(text[:max_len]): ids[i] = stoi.get(ch, 0) return torch.from_numpy(ids).unsqueeze(0) @torch.no_grad() def predict(text, model, stoi, classes, max_len, topk=5): x = encode(text, stoi, max_len) logits = model(x)[0] probs = torch.softmax(logits, dim=-1) topv, topi = torch.topk(probs, k=min(topk, probs.numel())) pred = classes[int(topi[0])] top = [(classes[int(i)], float(v)) for v, i in zip(topv, topi)] return pred, top