Files changed (1) hide show
  1. handler.py +176 -0
handler.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
4
+ import logging
5
+ import base64
6
+ import tempfile
7
+ import os
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class EndpointHandler():
13
+ def __init__(self, path=""):
14
+ """
15
+ Initialize the handler with the Wolof Whisper model and fix the forced_decoder_ids issue
16
+ """
17
+ logger.info(f"Loading Wolof Whisper model from {path}")
18
+
19
+ try:
20
+ # Load the model and processor
21
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
22
+ path,
23
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
24
+ low_cpu_mem_usage=True,
25
+ use_safetensors=True
26
+ )
27
+ self.processor = AutoProcessor.from_pretrained(path)
28
+
29
+ # Fix the deprecated forced_decoder_ids parameter
30
+ if hasattr(self.model, 'generation_config'):
31
+ logger.info("Fixing deprecated forced_decoder_ids parameter for Wolof model...")
32
+
33
+ # Remove deprecated parameters that cause 400 errors
34
+ self.model.generation_config.forced_decoder_ids = None
35
+
36
+ # Clear suppress tokens that might cause issues
37
+ if hasattr(self.model.generation_config, 'suppress_tokens'):
38
+ self.model.generation_config.suppress_tokens = []
39
+
40
+ # Set correct parameters for Wolof transcription
41
+ self.model.generation_config.language = "wo" # Wolof language code
42
+ self.model.generation_config.task = "transcribe"
43
+
44
+ # Ensure we don't have conflicting parameters
45
+ if hasattr(self.model.generation_config, 'decoder_input_ids'):
46
+ self.model.generation_config.decoder_input_ids = None
47
+ if hasattr(self.model.generation_config, 'input_ids'):
48
+ self.model.generation_config.input_ids = None
49
+
50
+ logger.info("Successfully fixed model configuration for Wolof transcription")
51
+
52
+ # Create pipeline with fixed model
53
+ self.pipe = pipeline(
54
+ "automatic-speech-recognition",
55
+ model=self.model,
56
+ tokenizer=self.processor.tokenizer,
57
+ feature_extractor=self.processor.feature_extractor,
58
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
59
+ device=0 if torch.cuda.is_available() else -1
60
+ )
61
+
62
+ logger.info("Wolof Whisper model loaded successfully with fixed configuration")
63
+
64
+ except Exception as e:
65
+ logger.error(f"Error loading Wolof model: {e}")
66
+ raise e
67
+
68
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
69
+ """
70
+ Process the audio input and return Wolof transcription
71
+ Args:
72
+ data: Input data containing audio (binary or base64)
73
+ Returns:
74
+ Transcription result in the expected format
75
+ """
76
+ try:
77
+ logger.info("Processing Wolof audio transcription request")
78
+
79
+ # Get the audio input
80
+ inputs = data.get("inputs", data)
81
+
82
+ # Handle different input types
83
+ if isinstance(inputs, str):
84
+ logger.info("Processing base64 encoded audio")
85
+ # Base64 encoded audio - decode and save to temp file
86
+ try:
87
+ audio_bytes = base64.b64decode(inputs)
88
+ except Exception as e:
89
+ logger.error(f"Failed to decode base64 audio: {e}")
90
+ return [{"error": f"Invalid base64 audio data: {str(e)}"}]
91
+
92
+ # Save to temporary file for processing
93
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
94
+ temp_file.write(audio_bytes)
95
+ temp_path = temp_file.name
96
+
97
+ try:
98
+ result = self._transcribe_audio(temp_path)
99
+ finally:
100
+ # Clean up temp file
101
+ if os.path.exists(temp_path):
102
+ os.unlink(temp_path)
103
+
104
+ elif isinstance(inputs, bytes):
105
+ logger.info("Processing binary audio data")
106
+ # Direct binary audio data
107
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
108
+ temp_file.write(inputs)
109
+ temp_path = temp_file.name
110
+
111
+ try:
112
+ result = self._transcribe_audio(temp_path)
113
+ finally:
114
+ # Clean up temp file
115
+ if os.path.exists(temp_path):
116
+ os.unlink(temp_path)
117
+
118
+ else:
119
+ logger.info("Processing direct audio path/data")
120
+ # Direct audio path or numpy array
121
+ result = self._transcribe_audio(inputs)
122
+
123
+ logger.info(f"Wolof transcription completed successfully")
124
+ return [result] if not isinstance(result, list) else result
125
+
126
+ except Exception as e:
127
+ logger.error(f"Error during Wolof transcription: {e}")
128
+ return [{"error": f"Wolof transcription failed: {str(e)}"}]
129
+
130
+ def _transcribe_audio(self, audio_input):
131
+ """
132
+ Internal method to transcribe audio using the fixed pipeline
133
+ """
134
+ try:
135
+ # Use the pipeline with explicit parameters to avoid forced_decoder_ids
136
+ result = self.pipe(
137
+ audio_input,
138
+ generate_kwargs={
139
+ "language": "wo", # Wolof language code
140
+ "task": "transcribe",
141
+ # Explicitly avoid deprecated parameters
142
+ "forced_decoder_ids": None,
143
+ "suppress_tokens": [],
144
+ # Use modern parameters
145
+ "max_length": 448,
146
+ "num_beams": 1,
147
+ "do_sample": False,
148
+ }
149
+ )
150
+
151
+ # Extract text from result
152
+ if isinstance(result, dict):
153
+ text = result.get("text", "")
154
+ elif isinstance(result, list) and len(result) > 0:
155
+ text = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0])
156
+ else:
157
+ text = str(result)
158
+
159
+ # Return in expected format
160
+ return {
161
+ "text": text.strip(),
162
+ "language": "wo",
163
+ "model": "Alwaly/whisper-medium-wolof"
164
+ }
165
+
166
+ except Exception as e:
167
+ logger.error(f"Pipeline transcription error: {e}")
168
+ # If we get the forced_decoder_ids error, provide helpful message
169
+ if "forced_decoder_ids" in str(e):
170
+ raise Exception(
171
+ "forced_decoder_ids parameter is deprecated. "
172
+ "This handler.py file should fix this issue. "
173
+ "Please redeploy the endpoint."
174
+ )
175
+ else:
176
+ raise e