""" Inference script for HowFar-Caarma: distance estimation from speech using HuBERT. Usage: python inference.py --ckpt epoch18_val_acc7997.ckpt --audio path/to/audio.wav """ import argparse import torch import torch.nn as nn import torchaudio import pytorch_lightning as pl from transformers import HubertModel, AutoFeatureExtractor class Hubert_Model(nn.Module): def __init__(self, hubert_model_name="facebook/hubert-large-ls960-ft", cache_dir=""): super().__init__() self.encoder = HubertModel.from_pretrained(hubert_model_name, cache_dir=cache_dir) hidden_size = self.encoder.config.hidden_size self.layer_norm = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(0.1) self.bn = nn.BatchNorm1d(hidden_size) def forward(self, input_values, labels=None): outputs = self.encoder(input_values) hidden_states = outputs.last_hidden_state hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) pooled = torch.mean(hidden_states, dim=1) pooled = self.bn(pooled) return pooled class CassavaPLModule(pl.LightningModule): def __init__(self, hparams, model): super().__init__() self.model = model def forward(self, x): return self.model(x) def load_model(ckpt_path, device="cuda"): base = Hubert_Model() model = CassavaPLModule.load_from_checkpoint( ckpt_path, hparams={"lr": 0.001, "batch_size": 1}, model=base, strict=False, map_location=device, ) model.eval() model.to(device) model.freeze() return model def extract_embedding(model, audio_path, device="cuda"): waveform, sr = torchaudio.load(audio_path) if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) processor = AutoFeatureExtractor.from_pretrained("facebook/hubert-large-ls960-ft") inputs = processor( waveform.squeeze(0).numpy(), sampling_rate=16000, return_tensors="pt", ).input_values.to(device) with torch.no_grad(): embedding = model(inputs) return embedding if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt", required=True) parser.add_argument("--audio", required=True) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model(args.ckpt, device=device) emb = extract_embedding(model, args.audio, device=device) print(f"Embedding shape: {emb.shape}")