from typing import Dict, List, Any import torch from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor import logging import base64 import tempfile import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): """ Initialize the handler with the Wolof Whisper model and fix the forced_decoder_ids issue """ logger.info(f"Loading Wolof Whisper model from {path}") try: # Load the model and processor self.model = AutoModelForSpeechSeq2Seq.from_pretrained( path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, use_safetensors=True ) self.processor = AutoProcessor.from_pretrained(path) # Fix the deprecated forced_decoder_ids parameter if hasattr(self.model, 'generation_config'): logger.info("Fixing deprecated forced_decoder_ids parameter for Wolof model...") # Remove deprecated parameters that cause 400 errors self.model.generation_config.forced_decoder_ids = None # Clear suppress tokens that might cause issues if hasattr(self.model.generation_config, 'suppress_tokens'): self.model.generation_config.suppress_tokens = [] # Set correct parameters for Wolof transcription self.model.generation_config.language = "wo" # Wolof language code self.model.generation_config.task = "transcribe" # Ensure we don't have conflicting parameters if hasattr(self.model.generation_config, 'decoder_input_ids'): self.model.generation_config.decoder_input_ids = None if hasattr(self.model.generation_config, 'input_ids'): self.model.generation_config.input_ids = None logger.info("Successfully fixed model configuration for Wolof transcription") # Create pipeline with fixed model self.pipe = pipeline( "automatic-speech-recognition", model=self.model, tokenizer=self.processor.tokenizer, feature_extractor=self.processor.feature_extractor, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device=0 if torch.cuda.is_available() else -1 ) logger.info("Wolof Whisper model loaded successfully with fixed configuration") except Exception as e: logger.error(f"Error loading Wolof model: {e}") raise e def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the audio input and return Wolof transcription Args: data: Input data containing audio (binary or base64) Returns: Transcription result in the expected format """ try: logger.info("Processing Wolof audio transcription request") # Get the audio input inputs = data.get("inputs", data) # Handle different input types if isinstance(inputs, str): logger.info("Processing base64 encoded audio") # Base64 encoded audio - decode and save to temp file try: audio_bytes = base64.b64decode(inputs) except Exception as e: logger.error(f"Failed to decode base64 audio: {e}") return [{"error": f"Invalid base64 audio data: {str(e)}"}] # Save to temporary file for processing with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: temp_file.write(audio_bytes) temp_path = temp_file.name try: result = self._transcribe_audio(temp_path) finally: # Clean up temp file if os.path.exists(temp_path): os.unlink(temp_path) elif isinstance(inputs, bytes): logger.info("Processing binary audio data") # Direct binary audio data with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: temp_file.write(inputs) temp_path = temp_file.name try: result = self._transcribe_audio(temp_path) finally: # Clean up temp file if os.path.exists(temp_path): os.unlink(temp_path) else: logger.info("Processing direct audio path/data") # Direct audio path or numpy array result = self._transcribe_audio(inputs) logger.info(f"Wolof transcription completed successfully") return [result] if not isinstance(result, list) else result except Exception as e: logger.error(f"Error during Wolof transcription: {e}") return [{"error": f"Wolof transcription failed: {str(e)}"}] def _transcribe_audio(self, audio_input): """ Internal method to transcribe audio using the fixed pipeline """ try: # Use the pipeline with explicit parameters to avoid forced_decoder_ids result = self.pipe( audio_input, generate_kwargs={ "language": "wo", # Wolof language code "task": "transcribe", # Explicitly avoid deprecated parameters "forced_decoder_ids": None, "suppress_tokens": [], # Use modern parameters "max_length": 448, "num_beams": 1, "do_sample": False, } ) # Extract text from result if isinstance(result, dict): text = result.get("text", "") elif isinstance(result, list) and len(result) > 0: text = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0]) else: text = str(result) # Return in expected format return { "text": text.strip(), "language": "wo", "model": "Alwaly/whisper-medium-wolof" } except Exception as e: logger.error(f"Pipeline transcription error: {e}") # If we get the forced_decoder_ids error, provide helpful message if "forced_decoder_ids" in str(e): raise Exception( "forced_decoder_ids parameter is deprecated. " "This handler.py file should fix this issue. " "Please redeploy the endpoint." ) else: raise e