| from model import WavLMForEndpointing | |
| import torchaudio | |
| import transformers | |
| import numpy as np | |
| from safetensors import safe_open | |
| import torch | |
| MODEL_NAME = 'microsoft/wavlm-base-plus' | |
| processor = transformers.AutoFeatureExtractor.from_pretrained( | |
| MODEL_NAME | |
| ) | |
| config = transformers.AutoConfig.from_pretrained(MODEL_NAME) | |
| model = WavLMForEndpointing(config) | |
| checkpoint_path = "/home/nikita/wavlm-endpointing-model/checkpoint-29000/model.safetensors" | |
| with safe_open(checkpoint_path, framework="pt", device="cpu") as f: | |
| state_dict = {key: f.get_tensor(key) for key in f.keys()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| while True: | |
| print('1234') | |
| audio_path = str(input()) | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(waveform) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| inputs = processor( | |
| waveform.squeeze().numpy(), | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| padding=False, | |
| truncation=False | |
| ) | |
| with torch.no_grad(): | |
| result = model(**inputs) | |
| print(result) | |