Trans_for_doctors / stt /whisper_transcriber.py
Mintik24's picture
asd
b216c95
"""
Whisper Transcriber - транскрибация аудио с использованием Whisper
"""
from pathlib import Path
from typing import Dict, Optional
import torch
from transformers import pipeline, AutoProcessor
# Import common utilities
from common import (
get_logger,
TranscriptionException,
AudioFileException
)
logger = get_logger(__name__)
class WhisperTranscriber:
"""
Транскрибер на базе Whisper для медицинских диктовок.
Поддерживает:
- Загрузку локальной модели
- GPU/CPU выполнение
- Различные форматы аудио
- Медицинские промпты для улучшения точности
"""
def __init__(
self,
model_path: Path,
device: str = "auto",
dtype: str = "float32",
language: str = "russian"
):
"""
Инициализация транскрибера.
Args:
model_path: Путь к папке с моделью Whisper
device: Устройство ('auto', 'cuda', 'cpu')
dtype: Тип данных ('float32', 'float16', 'bfloat16')
language: Язык транскрибации
"""
self.model_path = Path(model_path)
self.device = self._resolve_device(device)
self.dtype = self._resolve_dtype(dtype)
self.language = language
logger.info(f"Initializing WhisperTranscriber")
logger.info(f"Model: {self.model_path}")
logger.info(f"Device: {self.device}")
logger.info(f"Dtype: {self.dtype}")
# Загрузка модели
self.processor = AutoProcessor.from_pretrained(str(self.model_path))
self.pipe = pipeline(
"automatic-speech-recognition",
model=str(self.model_path),
device=self.device,
torch_dtype=self.dtype
)
logger.info("WhisperTranscriber initialized successfully")
def _resolve_device(self, device: str) -> str:
"""
Определить устройство для вычислений.
Args:
device: Желаемое устройство
Returns:
Реальное устройство
"""
if device == "auto":
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
return device
def _resolve_dtype(self, dtype: str) -> torch.dtype:
"""
Преобразовать строковое представление dtype в torch.dtype.
Args:
dtype: Строковое представление
Returns:
torch.dtype
"""
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16
}
return dtype_map.get(dtype, torch.float32)
def transcribe(
self,
audio: any,
medical_prompt: Optional[str] = None,
return_timestamps: bool = False
) -> Dict:
"""
Транскрибировать аудио.
Args:
audio: Аудио данные (массив numpy или путь к файлу)
medical_prompt: Медицинский промпт для улучшения точности
return_timestamps: Возвращать временные метки
Returns:
Словарь с результатом транскрибации
"""
try:
logger.info("Starting transcription...")
# Подготовка параметров
generate_kwargs = {
"language": self.language,
"task": "transcribe"
}
# Добавляем медицинский промпт если есть
if medical_prompt:
logger.info(f"Using medical prompt (length: {len(medical_prompt)} chars)")
generate_kwargs["prompt_ids"] = self.processor.get_prompt_ids(
medical_prompt,
return_tensors="pt"
).to(self.device)
# Транскрибация
result = self.pipe(
audio,
generate_kwargs=generate_kwargs,
return_timestamps=return_timestamps
)
transcription = result["text"].strip()
logger.info(f"Transcription completed: {len(transcription)} characters")
return {
"text": transcription,
"language": self.language,
"timestamps": result.get("chunks", []) if return_timestamps else []
}
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise
def transcribe_file(
self,
audio_path: Path,
medical_prompt: Optional[str] = None,
return_timestamps: bool = False
) -> Dict:
"""
Транскрибировать аудио файл.
Args:
audio_path: Путь к аудио файлу
medical_prompt: Медицинский промпт
return_timestamps: Возвращать временные метки
Returns:
Словарь с результатом
"""
from .audio_processor import load_audio
logger.info(f"Loading audio from {audio_path}")
audio = load_audio(audio_path)
return self.transcribe(
audio,
medical_prompt=medical_prompt,
return_timestamps=return_timestamps
)
def get_model_info(self) -> Dict:
"""
Получить информацию о модели.
Returns:
Словарь с информацией о модели
"""
return {
"model_path": str(self.model_path),
"device": self.device,
"dtype": str(self.dtype),
"language": self.language,
"cuda_available": torch.cuda.is_available(),
"cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None
}