|
|
import argparse |
|
|
from typing import List, Union |
|
|
|
|
|
import hydra |
|
|
import soundfile |
|
|
import torch |
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
|
|
|
|
|
|
class SpecScaler(torch.nn.Module): |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return torch.log(x.clamp_(1e-9, 1e9)) |
|
|
|
|
|
|
|
|
class GigaAMEmo(torch.nn.Module): |
|
|
def __init__(self, conf: Union[DictConfig, ListConfig]): |
|
|
super().__init__() |
|
|
self.id2name = conf.id2name |
|
|
self.feature_extractor = hydra.utils.instantiate(conf.feature_extractor) |
|
|
self.conformer = hydra.utils.instantiate(conf.encoder) |
|
|
self.linear_head = hydra.utils.instantiate(conf.classification_head) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return next(self.parameters()).device |
|
|
|
|
|
def forward(self, features, features_length=None): |
|
|
if features.dim() == 2: |
|
|
features = features.unsqueeze(0) |
|
|
if not features_length: |
|
|
features_length = torch.ones(features.shape[0], device=self.device) * features.shape[-1] |
|
|
encoded, _ = self.conformer(audio_signal=features, length=features_length) |
|
|
encoded_pooled = torch.nn.functional.avg_pool1d( |
|
|
encoded, kernel_size=encoded.shape[-1] |
|
|
).squeeze(-1) |
|
|
|
|
|
logits = self.linear_head(encoded_pooled) |
|
|
return logits |
|
|
|
|
|
def get_probs(self, audio_path: str) -> List[List[float]]: |
|
|
audio_signal, _ = soundfile.read(audio_path, dtype="float32") |
|
|
audio_tensor = torch.tensor(audio_signal).float().to(self.device) |
|
|
features = self.feature_extractor(audio_tensor) |
|
|
logits = self.forward(features) |
|
|
probs = torch.nn.functional.softmax(logits, dim=1).detach().tolist() |
|
|
return probs |
|
|
|
|
|
|
|
|
def _parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run inference using GigaAM-Emo checkpoint" |
|
|
) |
|
|
parser.add_argument("--model_config", help="Path to GigaAM-Emo config file (.yaml)") |
|
|
parser.add_argument( |
|
|
"--model_weights", help="Path to GigaAM-Emo checkpoint file (.ckpt)" |
|
|
) |
|
|
parser.add_argument("--audio_path", help="Path to audio signal") |
|
|
parser.add_argument("--device", help="Device: cpu / cuda") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(model_config: str, model_weights: str, device: str, audio_path: str): |
|
|
conf = OmegaConf.load(model_config) |
|
|
model = GigaAMEmo(conf) |
|
|
ckpt = torch.load(model_weights, map_location="cpu") |
|
|
model.load_state_dict(ckpt, strict=False) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
probs = model.get_probs(audio_path)[0] |
|
|
print(", ".join([f"{model.id2name[i]}: {p:.3f}" for i, p in enumerate(probs)])) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = _parse_args() |
|
|
main( |
|
|
model_config=args.model_config, |
|
|
model_weights=args.model_weights, |
|
|
device=args.device, |
|
|
audio_path=args.audio_path, |
|
|
) |
|
|
|