|
|
|
|
| """
|
| Pipeline-based MMS Model using the official MMS library.
|
| This implementation uses Wav2Vec2LlamaInferencePipeline to avoid Seq2SeqBatch complexity.
|
| """
|
|
|
| import logging
|
| import os
|
| import torch
|
| from typing import List, Dict, Any, Optional
|
|
|
| from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline
|
|
|
| from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs
|
|
|
| from inference.audio_reading_tools import wav_to_bytes
|
| from env_vars import MODEL_NAME
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| class MMSModel:
|
| """Pipeline-based MMS model wrapper using the official inference pipeline."""
|
| _instance = None
|
| _initialized = False
|
|
|
| def __new__(cls, *args, **kwargs):
|
| if cls._instance is None:
|
| logger.info("Creating new MMSModel singleton instance")
|
| cls._instance = super().__new__(cls)
|
| else:
|
| logger.info("Using existing MMSModel singleton instance")
|
| return cls._instance
|
|
|
| def __init__(self, model_card: str = None, device = None):
|
| """
|
| Initialize the MMS model with the official pipeline.
|
|
|
| Args:
|
| model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
|
| If None, uses MODEL_NAME from environment variables
|
| device: Device to use (torch.device object, "cuda", "cpu", etc.)
|
| """
|
|
|
| if self._initialized:
|
| return
|
|
|
|
|
| self.model_card = model_card or MODEL_NAME
|
| self.device = device
|
|
|
|
|
| self._load_pipeline()
|
|
|
|
|
| self._initialized = True
|
|
|
| def _load_pipeline(self):
|
| """Load the MMS pipeline during initialization."""
|
| logger.info(f"Loading MMS pipeline: {self.model_card}")
|
| logger.info(f"Target device: {self.device}")
|
|
|
|
|
|
|
| fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR',"./models")
|
| logger.info(f"DEBUG: FAIRSEQ2_CACHE_DIR = {fairseq2_cache_dir}")
|
|
|
| try:
|
|
|
| device_str = str(self.device) if hasattr(self.device, 'type') else str(self.device)
|
|
|
|
|
|
|
|
|
| self.pipeline = ASRInferencePipeline(
|
| model_card=self.model_card,
|
| device=device_str
|
| )
|
| logger.info("✓ MMS pipeline loaded successfully")
|
| except Exception as e:
|
| logger.error(f"Failed to load MMS pipeline: {e}")
|
| raise
|
|
|
| def transcribe_audio(self, audio_tensor: torch.Tensor, batch_size: int = 1, language_with_scripts: List[str] = None) -> List[Dict[str, Any]]:
|
| """
|
| Transcribe audio tensor using the MMS pipeline.
|
|
|
| Args:
|
| audio_tensor: Audio tensor (1D waveform) to transcribe
|
| batch_size: Batch size for processing
|
| language_with_scripts: List of language_with_scripts codes for transcription (3-letter ISO codes with script)
|
| If None, uses auto-detection
|
|
|
| Returns:
|
| List of transcription results
|
| """
|
|
|
|
|
|
|
| logger.info(f"Converting tensor (shape: {audio_tensor.shape}) to bytes")
|
|
|
| tensor_cpu = audio_tensor.cpu() if audio_tensor.is_cuda else audio_tensor
|
|
|
| audio_bytes = wav_to_bytes(tensor_cpu, sample_rate=16000, format="wav")
|
|
|
| logger.info(f"Transcribing audio tensor with batch_size={batch_size}, language_with_scripts={language_with_scripts}")
|
|
|
| try:
|
|
|
| if language_with_scripts is not None:
|
| transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size, lang=language_with_scripts)
|
| else:
|
| transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size)
|
|
|
| logger.info(f"✓ Successfully transcribed audio tensor")
|
| return transcriptions
|
|
|
| except Exception as e:
|
| logger.error(f"Transcription failed: {e}")
|
| raise
|
|
|
| @classmethod
|
| def get_instance(cls, model_card: str = None, device = None):
|
| """
|
| Get the singleton instance of MMSModel.
|
|
|
| Args:
|
| model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
|
| If None, uses MODEL_NAME from environment variables
|
| device: Device to use (torch.device object, "cuda", "cpu", etc.)
|
|
|
| Returns:
|
| MMSModel: The singleton instance
|
| """
|
| if cls._instance is None:
|
| cls._instance = cls(model_card=model_card, device=device)
|
| return cls._instance
|
|
|