bigeco commited on
Commit
77e0eca
Β·
verified Β·
1 Parent(s): 9c54070

Update model/wav2vec2.py

Browse files
Files changed (1) hide show
  1. model/wav2vec2.py +15 -5
model/wav2vec2.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
- import librosa
3
  import numpy as np
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
  import warnings
 
6
 
7
  warnings.filterwarnings("ignore")
8
 
@@ -25,13 +26,22 @@ class Wav2Vec2:
25
 
26
  self.model.eval()
27
 
28
- def preprocess_audio(self, audio_data: np.ndarray, original_sr: int) -> np.ndarray:
29
  """μ˜€λ””μ˜€ 데이터 μ „μ²˜λ¦¬"""
30
  # μƒ˜ν”Œλ§ 레이트 λ³€ν™˜
31
  if original_sr != self.sampling_rate:
32
- audio_data = librosa.resample(audio_data, orig_sr=original_sr, target_sr=self.sampling_rate)
 
33
 
34
- # μ •κ·œν™”
 
 
 
 
 
 
 
 
35
  if audio_data.dtype != np.float32:
36
  audio_data = audio_data.astype(np.float32)
37
 
@@ -45,7 +55,7 @@ class Wav2Vec2:
45
  """μ˜€λ””μ˜€ νŒŒμΌμ„ ν…μŠ€νŠΈλ‘œ λ³€ν™˜"""
46
  try:
47
  # μ˜€λ””μ˜€ 파일 λ‘œλ“œ
48
- audio_data, sample_rate = librosa.load(audio_file_path, sr=None)
49
 
50
  # μ „μ²˜λ¦¬
51
  audio_data = self.preprocess_audio(audio_data, sample_rate)
 
1
  import torch
2
+ import torchaudio
3
  import numpy as np
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
  import warnings
6
+ import io
7
 
8
  warnings.filterwarnings("ignore")
9
 
 
26
 
27
  self.model.eval()
28
 
29
+ def preprocess_audio(self, audio_data: torch.Tensor, original_sr: int) -> np.ndarray:
30
  """μ˜€λ””μ˜€ 데이터 μ „μ²˜λ¦¬"""
31
  # μƒ˜ν”Œλ§ 레이트 λ³€ν™˜
32
  if original_sr != self.sampling_rate:
33
+ resampler = torchaudio.transforms.Resample(original_sr, self.sampling_rate)
34
+ audio_data = resampler(audio_data)
35
 
36
+ # numpy둜 λ³€ν™˜
37
+ if isinstance(audio_data, torch.Tensor):
38
+ audio_data = audio_data.numpy()
39
+
40
+ # μŠ€ν…Œλ ˆμ˜€λ₯Ό λͺ¨λ…Έλ‘œ λ³€ν™˜ (ν•„μš”ν•œ 경우)
41
+ if len(audio_data.shape) > 1:
42
+ audio_data = np.mean(audio_data, axis=0)
43
+
44
+ # float32둜 λ³€ν™˜
45
  if audio_data.dtype != np.float32:
46
  audio_data = audio_data.astype(np.float32)
47
 
 
55
  """μ˜€λ””μ˜€ νŒŒμΌμ„ ν…μŠ€νŠΈλ‘œ λ³€ν™˜"""
56
  try:
57
  # μ˜€λ””μ˜€ 파일 λ‘œλ“œ
58
+ audio_data, sample_rate = torchaudio.load(audio_file_path)
59
 
60
  # μ „μ²˜λ¦¬
61
  audio_data = self.preprocess_audio(audio_data, sample_rate)