| """ |
| Inference script for HowFar-Caarma: distance estimation from speech using HuBERT. |
| |
| Usage: |
| python inference.py --ckpt epoch18_val_acc7997.ckpt --audio path/to/audio.wav |
| """ |
| import argparse |
| import torch |
| import torch.nn as nn |
| import torchaudio |
| import pytorch_lightning as pl |
| from transformers import HubertModel, AutoFeatureExtractor |
|
|
|
|
| class Hubert_Model(nn.Module): |
| def __init__(self, hubert_model_name="facebook/hubert-large-ls960-ft", cache_dir=""): |
| super().__init__() |
| self.encoder = HubertModel.from_pretrained(hubert_model_name, cache_dir=cache_dir) |
| hidden_size = self.encoder.config.hidden_size |
| self.layer_norm = nn.LayerNorm(hidden_size) |
| self.dropout = nn.Dropout(0.1) |
| self.bn = nn.BatchNorm1d(hidden_size) |
|
|
| def forward(self, input_values, labels=None): |
| outputs = self.encoder(input_values) |
| hidden_states = outputs.last_hidden_state |
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| pooled = torch.mean(hidden_states, dim=1) |
| pooled = self.bn(pooled) |
| return pooled |
|
|
|
|
| class CassavaPLModule(pl.LightningModule): |
| def __init__(self, hparams, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| def load_model(ckpt_path, device="cuda"): |
| base = Hubert_Model() |
| model = CassavaPLModule.load_from_checkpoint( |
| ckpt_path, |
| hparams={"lr": 0.001, "batch_size": 1}, |
| model=base, |
| strict=False, |
| map_location=device, |
| ) |
| model.eval() |
| model.to(device) |
| model.freeze() |
| return model |
|
|
|
|
| def extract_embedding(model, audio_path, device="cuda"): |
| waveform, sr = torchaudio.load(audio_path) |
| if sr != 16000: |
| waveform = torchaudio.functional.resample(waveform, sr, 16000) |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| processor = AutoFeatureExtractor.from_pretrained("facebook/hubert-large-ls960-ft") |
| inputs = processor( |
| waveform.squeeze(0).numpy(), |
| sampling_rate=16000, |
| return_tensors="pt", |
| ).input_values.to(device) |
|
|
| with torch.no_grad(): |
| embedding = model(inputs) |
| return embedding |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--ckpt", required=True) |
| parser.add_argument("--audio", required=True) |
| args = parser.parse_args() |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = load_model(args.ckpt, device=device) |
| emb = extract_embedding(model, args.audio, device=device) |
| print(f"Embedding shape: {emb.shape}") |