import json import torch import torch.nn as nn class SmallAudioClassifierMLP(nn.Module): def __init__(self, input_dim=6144, num_labels=10, hidden_dims=(512, 256), dropout=0.2, activation="gelu"): super().__init__() act = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}[activation] dims = [input_dim] + list(hidden_dims) layers = [] for i in range(len(dims) - 1): layers += [ nn.Linear(dims[i], dims[i+1]), nn.LayerNorm(dims[i+1]), act(), nn.Dropout(dropout), ] self.mlp = nn.Sequential(*layers) self.classifier = nn.Linear(dims[-1], num_labels) def forward(self, x): h = self.mlp(x) return self.classifier(h) def load_pretrained(model_dir: str, map_location="cpu"): with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f: cfg = json.load(f) m = SmallAudioClassifierMLP( input_dim=cfg["input_dim"], num_labels=cfg["num_labels"], hidden_dims=tuple(cfg["hidden_dims"]), dropout=cfg["dropout"], activation=cfg["activation"], ) sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location) m.load_state_dict(sd) m.eval() return m, cfg