File size: 2,674 Bytes
8ac3327 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """
Inference script for HowFar-Caarma: distance estimation from speech using HuBERT.
Usage:
python inference.py --ckpt epoch18_val_acc7997.ckpt --audio path/to/audio.wav
"""
import argparse
import torch
import torch.nn as nn
import torchaudio
import pytorch_lightning as pl
from transformers import HubertModel, AutoFeatureExtractor
class Hubert_Model(nn.Module):
def __init__(self, hubert_model_name="facebook/hubert-large-ls960-ft", cache_dir=""):
super().__init__()
self.encoder = HubertModel.from_pretrained(hubert_model_name, cache_dir=cache_dir)
hidden_size = self.encoder.config.hidden_size
self.layer_norm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(0.1)
self.bn = nn.BatchNorm1d(hidden_size)
def forward(self, input_values, labels=None):
outputs = self.encoder(input_values)
hidden_states = outputs.last_hidden_state
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
pooled = torch.mean(hidden_states, dim=1)
pooled = self.bn(pooled)
return pooled
class CassavaPLModule(pl.LightningModule):
def __init__(self, hparams, model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
def load_model(ckpt_path, device="cuda"):
base = Hubert_Model()
model = CassavaPLModule.load_from_checkpoint(
ckpt_path,
hparams={"lr": 0.001, "batch_size": 1},
model=base,
strict=False,
map_location=device,
)
model.eval()
model.to(device)
model.freeze()
return model
def extract_embedding(model, audio_path, device="cuda"):
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
processor = AutoFeatureExtractor.from_pretrained("facebook/hubert-large-ls960-ft")
inputs = processor(
waveform.squeeze(0).numpy(),
sampling_rate=16000,
return_tensors="pt",
).input_values.to(device)
with torch.no_grad():
embedding = model(inputs)
return embedding
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", required=True)
parser.add_argument("--audio", required=True)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(args.ckpt, device=device)
emb = extract_embedding(model, args.audio, device=device)
print(f"Embedding shape: {emb.shape}") |