omniasr-transcriptions / server /inference /mms_model_pipeline.py
jeanma's picture
Omnilingual ASR transcription demo
ae238b3 verified
"""
Pipeline-based MMS Model using the official MMS library.
This implementation uses Wav2Vec2LlamaInferencePipeline to avoid Seq2SeqBatch complexity.
"""
import logging
import os
import torch
from typing import List, Dict, Any, Optional
from omnilingual_asr.models.inference.pipeline import Wav2Vec2InferencePipeline
from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs
from inference.audio_reading_tools import wav_to_bytes
from env_vars import MODEL_NAME
logger = logging.getLogger(__name__)
class MMSModel:
"""Pipeline-based MMS model wrapper using the official inference pipeline."""
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
logger.info("Creating new MMSModel singleton instance")
cls._instance = super().__new__(cls)
else:
logger.info("Using existing MMSModel singleton instance")
return cls._instance
def __init__(self, model_card: str = None, device = None):
"""
Initialize the MMS model with the official pipeline.
Args:
model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
If None, uses MODEL_NAME from environment variables
device: Device to use (torch.device object, "cuda", "cpu", etc.)
"""
# Only initialize once
if self._initialized:
return
# Get model name from environment variable with default fallback
self.model_card = model_card or MODEL_NAME
self.device = device
# Load the pipeline immediately during initialization
self._load_pipeline()
# Mark as initialized
self._initialized = True
def _load_pipeline(self):
"""Load the MMS pipeline during initialization."""
logger.info(f"Loading MMS pipeline: {self.model_card}")
logger.info(f"Target device: {self.device}")
# Debug FAIRSEQ2_CACHE_DIR environment variable
fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR')
logger.info(f"DEBUG: FAIRSEQ2_CACHE_DIR = {fairseq2_cache_dir}")
try:
# Convert device to string if it's a torch.device object
device_str = str(self.device) if hasattr(self.device, 'type') else str(self.device)
self.pipeline = Wav2Vec2InferencePipeline(
model_card=self.model_card,
device=device_str
)
logger.info("βœ“ MMS pipeline loaded successfully")
except Exception as e:
logger.error(f"Failed to load MMS pipeline: {e}")
raise
def transcribe_audio(self, audio_tensor: torch.Tensor, batch_size: int = 1, language_with_scripts: List[str] = None) -> List[Dict[str, Any]]:
"""
Transcribe audio tensor using the MMS pipeline.
Args:
audio_tensor: Audio tensor (1D waveform) to transcribe
batch_size: Batch size for processing
language_with_scripts: List of language_with_scripts codes for transcription (3-letter ISO codes with script)
If None, uses auto-detection
Returns:
List of transcription results
"""
# Pipeline is already loaded during initialization, no need to check
# Convert tensor to bytes for the pipeline
logger.info(f"Converting tensor (shape: {audio_tensor.shape}) to bytes")
# Move to CPU first if on GPU
tensor_cpu = audio_tensor.cpu() if audio_tensor.is_cuda else audio_tensor
# Convert to bytes using wav_to_bytes with 16kHz sample rate
audio_bytes = wav_to_bytes(tensor_cpu, sample_rate=16000, format="wav")
logger.info(f"Transcribing audio tensor with batch_size={batch_size}, language_with_scripts={language_with_scripts}")
try:
# Use the official pipeline transcribe method with a list containing the single audio bytes
if language_with_scripts is not None:
transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size, lang=language_with_scripts)
else:
transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size)
logger.info(f"βœ“ Successfully transcribed audio tensor")
return transcriptions
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise
@classmethod
def get_instance(cls, model_card: str = None, device = None):
"""
Get the singleton instance of MMSModel.
Args:
model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
If None, uses MODEL_NAME from environment variables
device: Device to use (torch.device object, "cuda", "cpu", etc.)
Returns:
MMSModel: The singleton instance
"""
if cls._instance is None:
cls._instance = cls(model_card=model_card, device=device)
return cls._instance