File size: 3,419 Bytes
5edd0fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

import json, numpy as np, torch
import torch.nn as nn
from safetensors.torch import load_file

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, ff_dim, dropout):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(dim); self.ln2 = nn.LayerNorm(dim)
        self.drop = nn.Dropout(dropout)
        self.ffn = nn.Sequential(nn.Linear(dim, ff_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim))
    def forward(self, x, key_padding_mask=None):
        attn_out, _ = self.mha(x, x, x, key_padding_mask=key_padding_mask, need_weights=False)
        x = self.ln1(x + self.drop(attn_out))
        ff_out = self.ffn(x)
        return self.ln2(x + self.drop(ff_out))

class BiGRUResidualBlock(nn.Module):
    def __init__(self, dim, dropout):
        super().__init__()
        self.gru = nn.GRU(dim, dim//2, num_layers=1, batch_first=True, bidirectional=True)
        self.ln = nn.LayerNorm(dim); self.drop = nn.Dropout(dropout)
    def forward(self, x):
        out, _ = self.gru(x)
        return self.ln(x + self.drop(out))

class MeterModel(nn.Module):
    def __init__(self, vocab_size, num_classes, T, emb_dim=64, num_heads=4, ff_dim=256, gru_blocks=3, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.pos = nn.Embedding(T, emb_dim)
        self.drop = nn.Dropout(dropout)
        self.tr = TransformerBlock(emb_dim, num_heads, ff_dim, dropout)
        self.gru_blocks = nn.ModuleList([BiGRUResidualBlock(emb_dim, dropout) for _ in range(gru_blocks)])
        self.head = nn.Sequential(nn.Linear(emb_dim, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, num_classes))
    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        h = self.drop(self.emb(x) + self.pos(pos))
        pad_mask = (x == 0)
        h = self.tr(h, key_padding_mask=pad_mask)
        for blk in self.gru_blocks:
            h = blk(h)
        mask = (~pad_mask).float().unsqueeze(-1)
        pooled = (h * mask).sum(1) / mask.sum(1).clamp_min(1.0)
        return self.head(pooled)

def load_local(repo_dir="."):
    cfg = json.load(open(f"{repo_dir}/config.json", "r", encoding="utf-8"))
    stoi = json.load(open(f"{repo_dir}/stoi.json", "r", encoding="utf-8"))
    classes = cfg["classes"]
    model = MeterModel(
        vocab_size=cfg["vocab_size"], num_classes=cfg["num_classes"], T=cfg["max_length"],
        emb_dim=cfg["emb_dim"], num_heads=cfg["num_heads"], ff_dim=cfg["ff_dim"],
        gru_blocks=cfg["gru_blocks"], dropout=cfg["dropout"]
    )
    sd = load_file(f"{repo_dir}/model.safetensors")
    model.load_state_dict(sd)
    model.eval()
    return model, stoi, classes, cfg["max_length"]

def encode(text, stoi, max_len):
    ids = np.zeros((max_len,), dtype=np.int64)
    for i, ch in enumerate(text[:max_len]):
        ids[i] = stoi.get(ch, 0)
    return torch.from_numpy(ids).unsqueeze(0)

@torch.no_grad()
def predict(text, model, stoi, classes, max_len, topk=5):
    x = encode(text, stoi, max_len)
    logits = model(x)[0]
    probs = torch.softmax(logits, dim=-1)
    topv, topi = torch.topk(probs, k=min(topk, probs.numel()))
    pred = classes[int(topi[0])]
    top = [(classes[int(i)], float(v)) for v, i in zip(topv, topi)]
    return pred, top