audio_detect / test.py
NikiPshg's picture
Upload folder using huggingface_hub
59171ac verified
from model import WavLMForEndpointing
import torchaudio
import transformers
import numpy as np
from safetensors import safe_open
import torch
MODEL_NAME = 'microsoft/wavlm-base-plus'
processor = transformers.AutoFeatureExtractor.from_pretrained(
MODEL_NAME
)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
model = WavLMForEndpointing(config)
checkpoint_path = "/home/nikita/wavlm-endpointing-model/checkpoint-29000/model.safetensors"
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
model.load_state_dict(state_dict)
model.eval()
while True:
print('1234')
audio_path = str(input())
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
inputs = processor(
waveform.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt",
padding=False,
truncation=False
)
with torch.no_grad():
result = model(**inputs)
print(result)