| """ |
| Whisper Transcriber - транскрибация аудио с использованием Whisper |
| """ |
|
|
| from pathlib import Path |
| from typing import Dict, Optional |
| import torch |
| from transformers import pipeline, AutoProcessor |
|
|
| |
| from common import ( |
| get_logger, |
| TranscriptionException, |
| AudioFileException |
| ) |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class WhisperTranscriber: |
| """ |
| Транскрибер на базе Whisper для медицинских диктовок. |
| |
| Поддерживает: |
| - Загрузку локальной модели |
| - GPU/CPU выполнение |
| - Различные форматы аудио |
| - Медицинские промпты для улучшения точности |
| """ |
| |
| def __init__( |
| self, |
| model_path: Path, |
| device: str = "auto", |
| dtype: str = "float32", |
| language: str = "russian" |
| ): |
| """ |
| Инициализация транскрибера. |
| |
| Args: |
| model_path: Путь к папке с моделью Whisper |
| device: Устройство ('auto', 'cuda', 'cpu') |
| dtype: Тип данных ('float32', 'float16', 'bfloat16') |
| language: Язык транскрибации |
| """ |
| self.model_path = Path(model_path) |
| self.device = self._resolve_device(device) |
| self.dtype = self._resolve_dtype(dtype) |
| self.language = language |
| |
| logger.info(f"Initializing WhisperTranscriber") |
| logger.info(f"Model: {self.model_path}") |
| logger.info(f"Device: {self.device}") |
| logger.info(f"Dtype: {self.dtype}") |
| |
| |
| self.processor = AutoProcessor.from_pretrained(str(self.model_path)) |
| self.pipe = pipeline( |
| "automatic-speech-recognition", |
| model=str(self.model_path), |
| device=self.device, |
| torch_dtype=self.dtype |
| ) |
| |
| logger.info("WhisperTranscriber initialized successfully") |
| |
| def _resolve_device(self, device: str) -> str: |
| """ |
| Определить устройство для вычислений. |
| |
| Args: |
| device: Желаемое устройство |
| |
| Returns: |
| Реальное устройство |
| """ |
| if device == "auto": |
| if torch.cuda.is_available(): |
| return "cuda" |
| elif torch.backends.mps.is_available(): |
| return "mps" |
| else: |
| return "cpu" |
| return device |
| |
| def _resolve_dtype(self, dtype: str) -> torch.dtype: |
| """ |
| Преобразовать строковое представление dtype в torch.dtype. |
| |
| Args: |
| dtype: Строковое представление |
| |
| Returns: |
| torch.dtype |
| """ |
| dtype_map = { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16 |
| } |
| return dtype_map.get(dtype, torch.float32) |
| |
| def transcribe( |
| self, |
| audio: any, |
| medical_prompt: Optional[str] = None, |
| return_timestamps: bool = False |
| ) -> Dict: |
| """ |
| Транскрибировать аудио. |
| |
| Args: |
| audio: Аудио данные (массив numpy или путь к файлу) |
| medical_prompt: Медицинский промпт для улучшения точности |
| return_timestamps: Возвращать временные метки |
| |
| Returns: |
| Словарь с результатом транскрибации |
| """ |
| try: |
| logger.info("Starting transcription...") |
| |
| |
| generate_kwargs = { |
| "language": self.language, |
| "task": "transcribe" |
| } |
| |
| |
| if medical_prompt: |
| logger.info(f"Using medical prompt (length: {len(medical_prompt)} chars)") |
| generate_kwargs["prompt_ids"] = self.processor.get_prompt_ids( |
| medical_prompt, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| |
| result = self.pipe( |
| audio, |
| generate_kwargs=generate_kwargs, |
| return_timestamps=return_timestamps |
| ) |
| |
| transcription = result["text"].strip() |
| logger.info(f"Transcription completed: {len(transcription)} characters") |
| |
| return { |
| "text": transcription, |
| "language": self.language, |
| "timestamps": result.get("chunks", []) if return_timestamps else [] |
| } |
| |
| except Exception as e: |
| logger.error(f"Transcription failed: {e}") |
| raise |
| |
| def transcribe_file( |
| self, |
| audio_path: Path, |
| medical_prompt: Optional[str] = None, |
| return_timestamps: bool = False |
| ) -> Dict: |
| """ |
| Транскрибировать аудио файл. |
| |
| Args: |
| audio_path: Путь к аудио файлу |
| medical_prompt: Медицинский промпт |
| return_timestamps: Возвращать временные метки |
| |
| Returns: |
| Словарь с результатом |
| """ |
| from .audio_processor import load_audio |
| |
| logger.info(f"Loading audio from {audio_path}") |
| audio = load_audio(audio_path) |
| |
| return self.transcribe( |
| audio, |
| medical_prompt=medical_prompt, |
| return_timestamps=return_timestamps |
| ) |
| |
| def get_model_info(self) -> Dict: |
| """ |
| Получить информацию о модели. |
| |
| Returns: |
| Словарь с информацией о модели |
| """ |
| return { |
| "model_path": str(self.model_path), |
| "device": self.device, |
| "dtype": str(self.dtype), |
| "language": self.language, |
| "cuda_available": torch.cuda.is_available(), |
| "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None |
| } |
|
|