wav2vec2-server / model /wav2vec2.py
bigeco's picture
Update model/wav2vec2.py
77e0eca verified
raw
history blame
3.79 kB
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import warnings
import io
warnings.filterwarnings("ignore")
class Wav2Vec2:
def __init__(self, config: dict):
self.config = config
self.model_id = config["model"]["id"]
self.device = config["model"]["device"]
self.sampling_rate = config["model"]["sampling_rate"]
# λͺ¨λΈκ³Ό ν”„λ‘œμ„Έμ„œ λ‘œλ“œ
self.processor = Wav2Vec2Processor.from_pretrained(self.model_id)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_id)
# λ””λ°”μ΄μŠ€ μ„€μ •
if self.device == "cuda" and torch.cuda.is_available():
self.model = self.model.to("cuda")
else:
self.model = self.model.to("cpu")
self.model.eval()
def preprocess_audio(self, audio_data: torch.Tensor, original_sr: int) -> np.ndarray:
"""μ˜€λ””μ˜€ 데이터 μ „μ²˜λ¦¬"""
# μƒ˜ν”Œλ§ 레이트 λ³€ν™˜
if original_sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(original_sr, self.sampling_rate)
audio_data = resampler(audio_data)
# numpy둜 λ³€ν™˜
if isinstance(audio_data, torch.Tensor):
audio_data = audio_data.numpy()
# μŠ€ν…Œλ ˆμ˜€λ₯Ό λͺ¨λ…Έλ‘œ λ³€ν™˜ (ν•„μš”ν•œ 경우)
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=0)
# float32둜 λ³€ν™˜
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# λ³Όλ₯¨ μ •κ·œν™”
if np.max(np.abs(audio_data)) > 0:
audio_data = audio_data / np.max(np.abs(audio_data))
return audio_data
def transcribe(self, audio_file_path: str) -> str:
"""μ˜€λ””μ˜€ νŒŒμΌμ„ ν…μŠ€νŠΈλ‘œ λ³€ν™˜"""
try:
# μ˜€λ””μ˜€ 파일 λ‘œλ“œ
audio_data, sample_rate = torchaudio.load(audio_file_path)
# μ „μ²˜λ¦¬
audio_data = self.preprocess_audio(audio_data, sample_rate)
# λͺ¨λΈ μž…λ ₯ μ€€λΉ„
inputs = self.processor(
audio_data,
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding=True
)
# λ””λ°”μ΄μŠ€ 이동
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# μΆ”λ‘ 
with torch.no_grad():
logits = self.model(**inputs).logits
# λ””μ½”λ”©
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription.strip()
except Exception as e:
raise Exception(f"Audio transcription failed: {str(e)}")
def transcribe_from_bytes(self, audio_bytes: bytes, filename: str = "temp.wav") -> str:
"""λ°”μ΄νŠΈ λ°μ΄ν„°μ—μ„œ 직접 μŒμ„± 인식"""
import tempfile
import os
try:
# μž„μ‹œ 파일둜 μ €μž₯
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(audio_bytes)
temp_file_path = temp_file.name
# μŒμ„± 인식 μˆ˜ν–‰
result = self.transcribe(temp_file_path)
# μž„μ‹œ 파일 μ‚­μ œ
os.unlink(temp_file_path)
return result
except Exception as e:
raise Exception(f"Audio transcription from bytes failed: {str(e)}")