| import logging |
| import os |
| from pydantic import BaseModel,Field |
| from pydantic_settings import BaseSettings |
| from typing import Optional, Literal |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ModelSettings(BaseSettings): |
| asr_model: str = Field(alias='ASR_MODEL') |
| assistant_model: str = Field(alias='ASSISTANT_MODEL') |
| diarization_model: str = Field(alias='DIARIZATION_MODEL') |
| hf_token: str = Field(alias='HF_TOKEN') |
|
|
|
|
| class InferenceConfig(BaseModel): |
| task: Literal["transcribe", "translate"] = "transcribe" |
| batch_size: int = 24 |
| assisted: bool = False |
| chunk_length_s: int = 30 |
| sampling_rate: int = 16000 |
| language: Optional[str] = None |
| num_speakers: Optional[int] = None |
| min_speakers: Optional[int] = None |
| max_speakers: Optional[int] = None |
|
|
| |
| |
| model_settings_data = { |
| "DIARIZATION_MODEL": "pyannote/speaker-diarization-3.1", |
| "HF_TOKEN": os.environ.get("HF_TOKEN"), |
| "ASR_MODEL": "openai/whisper-small", |
| "ASSISTANT_MODEL": "distil-whisper/distil-large-v3" |
| } |
|
|
| |
| model_settings = ModelSettings(**model_settings_data) |
|
|
| logger.info(f"asr model: {model_settings.asr_model}") |
| logger.info(f"assist model: {model_settings.assistant_model}") |
| logger.info(f"diar model: {model_settings.diarization_model}") |