File size: 1,307 Bytes
d8c5831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json
import torch
import torch.nn as nn

class SmallAudioClassifierMLP(nn.Module):
    def __init__(self, input_dim=6144, num_labels=10, hidden_dims=(512, 256), dropout=0.2, activation="gelu"):
        super().__init__()
        act = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}[activation]
        dims = [input_dim] + list(hidden_dims)
        layers = []
        for i in range(len(dims) - 1):
            layers += [
                nn.Linear(dims[i], dims[i+1]),
                nn.LayerNorm(dims[i+1]),
                act(),
                nn.Dropout(dropout),
            ]
        self.mlp = nn.Sequential(*layers)
        self.classifier = nn.Linear(dims[-1], num_labels)

    def forward(self, x):
        h = self.mlp(x)
        return self.classifier(h)

def load_pretrained(model_dir: str, map_location="cpu"):
    with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
        cfg = json.load(f)
    m = SmallAudioClassifierMLP(
        input_dim=cfg["input_dim"],
        num_labels=cfg["num_labels"],
        hidden_dims=tuple(cfg["hidden_dims"]),
        dropout=cfg["dropout"],
        activation=cfg["activation"],
    )
    sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location)
    m.load_state_dict(sd)
    m.eval()
    return m, cfg