File size: 1,263 Bytes
b2231f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from model import WavLMForEndpointing
import torchaudio
import transformers 
import numpy as np
from safetensors import safe_open
import torch

MODEL_NAME = 'microsoft/wavlm-base-plus'

processor = transformers.AutoFeatureExtractor.from_pretrained(
    MODEL_NAME
)

config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
model = WavLMForEndpointing(config)

checkpoint_path = "/home/nikita/wavlm-endpointing-model/checkpoint-29000/model.safetensors"

with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
    state_dict = {key: f.get_tensor(key) for key in f.keys()}

model.load_state_dict(state_dict)

model.eval()

while True:
    print('1234')
    audio_path = str(input())
    waveform, sample_rate = torchaudio.load(audio_path)

    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)

    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)


    inputs = processor(
        waveform.squeeze().numpy(), 
        sampling_rate=16000, 
        return_tensors="pt", 
        padding=False,    
        truncation=False 
    )
    
    with torch.no_grad():
        result = model(**inputs)

    print(result)