whisper-medium-wolof / handler.py
Yann LE BEUX
Create Handler.py
fd7e75b verified
raw
history blame
7.42 kB
from typing import Dict, List, Any
import torch
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import logging
import base64
import tempfile
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
"""
Initialize the handler with the Wolof Whisper model and fix the forced_decoder_ids issue
"""
logger.info(f"Loading Wolof Whisper model from {path}")
try:
# Load the model and processor
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True
)
self.processor = AutoProcessor.from_pretrained(path)
# Fix the deprecated forced_decoder_ids parameter
if hasattr(self.model, 'generation_config'):
logger.info("Fixing deprecated forced_decoder_ids parameter for Wolof model...")
# Remove deprecated parameters that cause 400 errors
self.model.generation_config.forced_decoder_ids = None
# Clear suppress tokens that might cause issues
if hasattr(self.model.generation_config, 'suppress_tokens'):
self.model.generation_config.suppress_tokens = []
# Set correct parameters for Wolof transcription
self.model.generation_config.language = "wo" # Wolof language code
self.model.generation_config.task = "transcribe"
# Ensure we don't have conflicting parameters
if hasattr(self.model.generation_config, 'decoder_input_ids'):
self.model.generation_config.decoder_input_ids = None
if hasattr(self.model.generation_config, 'input_ids'):
self.model.generation_config.input_ids = None
logger.info("Successfully fixed model configuration for Wolof transcription")
# Create pipeline with fixed model
self.pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device=0 if torch.cuda.is_available() else -1
)
logger.info("Wolof Whisper model loaded successfully with fixed configuration")
except Exception as e:
logger.error(f"Error loading Wolof model: {e}")
raise e
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the audio input and return Wolof transcription
Args:
data: Input data containing audio (binary or base64)
Returns:
Transcription result in the expected format
"""
try:
logger.info("Processing Wolof audio transcription request")
# Get the audio input
inputs = data.get("inputs", data)
# Handle different input types
if isinstance(inputs, str):
logger.info("Processing base64 encoded audio")
# Base64 encoded audio - decode and save to temp file
try:
audio_bytes = base64.b64decode(inputs)
except Exception as e:
logger.error(f"Failed to decode base64 audio: {e}")
return [{"error": f"Invalid base64 audio data: {str(e)}"}]
# Save to temporary file for processing
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
temp_file.write(audio_bytes)
temp_path = temp_file.name
try:
result = self._transcribe_audio(temp_path)
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
elif isinstance(inputs, bytes):
logger.info("Processing binary audio data")
# Direct binary audio data
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
temp_file.write(inputs)
temp_path = temp_file.name
try:
result = self._transcribe_audio(temp_path)
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
else:
logger.info("Processing direct audio path/data")
# Direct audio path or numpy array
result = self._transcribe_audio(inputs)
logger.info(f"Wolof transcription completed successfully")
return [result] if not isinstance(result, list) else result
except Exception as e:
logger.error(f"Error during Wolof transcription: {e}")
return [{"error": f"Wolof transcription failed: {str(e)}"}]
def _transcribe_audio(self, audio_input):
"""
Internal method to transcribe audio using the fixed pipeline
"""
try:
# Use the pipeline with explicit parameters to avoid forced_decoder_ids
result = self.pipe(
audio_input,
generate_kwargs={
"language": "wo", # Wolof language code
"task": "transcribe",
# Explicitly avoid deprecated parameters
"forced_decoder_ids": None,
"suppress_tokens": [],
# Use modern parameters
"max_length": 448,
"num_beams": 1,
"do_sample": False,
}
)
# Extract text from result
if isinstance(result, dict):
text = result.get("text", "")
elif isinstance(result, list) and len(result) > 0:
text = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0])
else:
text = str(result)
# Return in expected format
return {
"text": text.strip(),
"language": "wo",
"model": "Alwaly/whisper-medium-wolof"
}
except Exception as e:
logger.error(f"Pipeline transcription error: {e}")
# If we get the forced_decoder_ids error, provide helpful message
if "forced_decoder_ids" in str(e):
raise Exception(
"forced_decoder_ids parameter is deprecated. "
"This handler.py file should fix this issue. "
"Please redeploy the endpoint."
)
else:
raise e