|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor |
|
|
import logging |
|
|
import base64 |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the handler with the Wolof Whisper model and fix the forced_decoder_ids issue |
|
|
""" |
|
|
logger.info(f"Loading Wolof Whisper model from {path}") |
|
|
|
|
|
try: |
|
|
|
|
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True |
|
|
) |
|
|
self.processor = AutoProcessor.from_pretrained(path) |
|
|
|
|
|
|
|
|
if hasattr(self.model, 'generation_config'): |
|
|
logger.info("Fixing deprecated forced_decoder_ids parameter for Wolof model...") |
|
|
|
|
|
|
|
|
self.model.generation_config.forced_decoder_ids = None |
|
|
|
|
|
|
|
|
if hasattr(self.model.generation_config, 'suppress_tokens'): |
|
|
self.model.generation_config.suppress_tokens = [] |
|
|
|
|
|
|
|
|
self.model.generation_config.language = "wo" |
|
|
self.model.generation_config.task = "transcribe" |
|
|
|
|
|
|
|
|
if hasattr(self.model.generation_config, 'decoder_input_ids'): |
|
|
self.model.generation_config.decoder_input_ids = None |
|
|
if hasattr(self.model.generation_config, 'input_ids'): |
|
|
self.model.generation_config.input_ids = None |
|
|
|
|
|
logger.info("Successfully fixed model configuration for Wolof transcription") |
|
|
|
|
|
|
|
|
self.pipe = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=self.model, |
|
|
tokenizer=self.processor.tokenizer, |
|
|
feature_extractor=self.processor.feature_extractor, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
logger.info("Wolof Whisper model loaded successfully with fixed configuration") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading Wolof model: {e}") |
|
|
raise e |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process the audio input and return Wolof transcription |
|
|
Args: |
|
|
data: Input data containing audio (binary or base64) |
|
|
Returns: |
|
|
Transcription result in the expected format |
|
|
""" |
|
|
try: |
|
|
logger.info("Processing Wolof audio transcription request") |
|
|
|
|
|
|
|
|
inputs = data.get("inputs", data) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
logger.info("Processing base64 encoded audio") |
|
|
|
|
|
try: |
|
|
audio_bytes = base64.b64decode(inputs) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to decode base64 audio: {e}") |
|
|
return [{"error": f"Invalid base64 audio data: {str(e)}"}] |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: |
|
|
temp_file.write(audio_bytes) |
|
|
temp_path = temp_file.name |
|
|
|
|
|
try: |
|
|
result = self._transcribe_audio(temp_path) |
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_path): |
|
|
os.unlink(temp_path) |
|
|
|
|
|
elif isinstance(inputs, bytes): |
|
|
logger.info("Processing binary audio data") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: |
|
|
temp_file.write(inputs) |
|
|
temp_path = temp_file.name |
|
|
|
|
|
try: |
|
|
result = self._transcribe_audio(temp_path) |
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_path): |
|
|
os.unlink(temp_path) |
|
|
|
|
|
else: |
|
|
logger.info("Processing direct audio path/data") |
|
|
|
|
|
result = self._transcribe_audio(inputs) |
|
|
|
|
|
logger.info(f"Wolof transcription completed successfully") |
|
|
return [result] if not isinstance(result, list) else result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during Wolof transcription: {e}") |
|
|
return [{"error": f"Wolof transcription failed: {str(e)}"}] |
|
|
|
|
|
def _transcribe_audio(self, audio_input): |
|
|
""" |
|
|
Internal method to transcribe audio using the fixed pipeline |
|
|
""" |
|
|
try: |
|
|
|
|
|
result = self.pipe( |
|
|
audio_input, |
|
|
generate_kwargs={ |
|
|
"language": "wo", |
|
|
"task": "transcribe", |
|
|
|
|
|
"forced_decoder_ids": None, |
|
|
"suppress_tokens": [], |
|
|
|
|
|
"max_length": 448, |
|
|
"num_beams": 1, |
|
|
"do_sample": False, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(result, dict): |
|
|
text = result.get("text", "") |
|
|
elif isinstance(result, list) and len(result) > 0: |
|
|
text = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0]) |
|
|
else: |
|
|
text = str(result) |
|
|
|
|
|
|
|
|
return { |
|
|
"text": text.strip(), |
|
|
"language": "wo", |
|
|
"model": "Alwaly/whisper-medium-wolof" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Pipeline transcription error: {e}") |
|
|
|
|
|
if "forced_decoder_ids" in str(e): |
|
|
raise Exception( |
|
|
"forced_decoder_ids parameter is deprecated. " |
|
|
"This handler.py file should fix this issue. " |
|
|
"Please redeploy the endpoint." |
|
|
) |
|
|
else: |
|
|
raise e |