Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
| import warnings | |
| import io | |
| warnings.filterwarnings("ignore") | |
| class Wav2Vec2: | |
| def __init__(self, config: dict): | |
| self.config = config | |
| self.model_id = config["model"]["id"] | |
| self.device = config["model"]["device"] | |
| self.sampling_rate = config["model"]["sampling_rate"] | |
| # λͺ¨λΈκ³Ό νλ‘μΈμ λ‘λ | |
| self.processor = Wav2Vec2Processor.from_pretrained(self.model_id) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(self.model_id) | |
| # λλ°μ΄μ€ μ€μ | |
| if self.device == "cuda" and torch.cuda.is_available(): | |
| self.model = self.model.to("cuda") | |
| else: | |
| self.model = self.model.to("cpu") | |
| self.model.eval() | |
| def preprocess_audio(self, audio_data: torch.Tensor, original_sr: int) -> np.ndarray: | |
| """μ€λμ€ λ°μ΄ν° μ μ²λ¦¬""" | |
| # μνλ§ λ μ΄νΈ λ³ν | |
| if original_sr != self.sampling_rate: | |
| resampler = torchaudio.transforms.Resample(original_sr, self.sampling_rate) | |
| audio_data = resampler(audio_data) | |
| # numpyλ‘ λ³ν | |
| if isinstance(audio_data, torch.Tensor): | |
| audio_data = audio_data.numpy() | |
| # μ€ν λ μ€λ₯Ό λͺ¨λ Έλ‘ λ³ν (νμν κ²½μ°) | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=0) | |
| # float32λ‘ λ³ν | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # λ³Όλ₯¨ μ κ·ν | |
| if np.max(np.abs(audio_data)) > 0: | |
| audio_data = audio_data / np.max(np.abs(audio_data)) | |
| return audio_data | |
| def transcribe(self, audio_file_path: str) -> str: | |
| """μ€λμ€ νμΌμ ν μ€νΈλ‘ λ³ν""" | |
| try: | |
| # μ€λμ€ νμΌ λ‘λ | |
| audio_data, sample_rate = torchaudio.load(audio_file_path) | |
| # μ μ²λ¦¬ | |
| audio_data = self.preprocess_audio(audio_data, sample_rate) | |
| # λͺ¨λΈ μ λ ₯ μ€λΉ | |
| inputs = self.processor( | |
| audio_data, | |
| sampling_rate=self.sampling_rate, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # λλ°μ΄μ€ μ΄λ | |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
| # μΆλ‘ | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits | |
| # λμ½λ© | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = self.processor.batch_decode(predicted_ids)[0] | |
| return transcription.strip() | |
| except Exception as e: | |
| raise Exception(f"Audio transcription failed: {str(e)}") | |
| def transcribe_from_bytes(self, audio_bytes: bytes, filename: str = "temp.wav") -> str: | |
| """λ°μ΄νΈ λ°μ΄ν°μμ μ§μ μμ± μΈμ""" | |
| import tempfile | |
| import os | |
| try: | |
| # μμ νμΌλ‘ μ μ₯ | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
| temp_file.write(audio_bytes) | |
| temp_file_path = temp_file.name | |
| # μμ± μΈμ μν | |
| result = self.transcribe(temp_file_path) | |
| # μμ νμΌ μμ | |
| os.unlink(temp_file_path) | |
| return result | |
| except Exception as e: | |
| raise Exception(f"Audio transcription from bytes failed: {str(e)}") |