File size: 1,307 Bytes
d8c5831 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import json
import torch
import torch.nn as nn
class SmallAudioClassifierMLP(nn.Module):
def __init__(self, input_dim=6144, num_labels=10, hidden_dims=(512, 256), dropout=0.2, activation="gelu"):
super().__init__()
act = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}[activation]
dims = [input_dim] + list(hidden_dims)
layers = []
for i in range(len(dims) - 1):
layers += [
nn.Linear(dims[i], dims[i+1]),
nn.LayerNorm(dims[i+1]),
act(),
nn.Dropout(dropout),
]
self.mlp = nn.Sequential(*layers)
self.classifier = nn.Linear(dims[-1], num_labels)
def forward(self, x):
h = self.mlp(x)
return self.classifier(h)
def load_pretrained(model_dir: str, map_location="cpu"):
with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
m = SmallAudioClassifierMLP(
input_dim=cfg["input_dim"],
num_labels=cfg["num_labels"],
hidden_dims=tuple(cfg["hidden_dims"]),
dropout=cfg["dropout"],
activation=cfg["activation"],
)
sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location)
m.load_state_dict(sd)
m.eval()
return m, cfg
|