File size: 7,419 Bytes
fd7e75b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import Dict, List, Any
import torch
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import logging
import base64
import tempfile
import os

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler():
    def __init__(self, path=""):
        """
        Initialize the handler with the Wolof Whisper model and fix the forced_decoder_ids issue
        """
        logger.info(f"Loading Wolof Whisper model from {path}")

        try:
            # Load the model and processor
            self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
                path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                low_cpu_mem_usage=True,
                use_safetensors=True
            )
            self.processor = AutoProcessor.from_pretrained(path)

            # Fix the deprecated forced_decoder_ids parameter
            if hasattr(self.model, 'generation_config'):
                logger.info("Fixing deprecated forced_decoder_ids parameter for Wolof model...")
                
                # Remove deprecated parameters that cause 400 errors
                self.model.generation_config.forced_decoder_ids = None
                
                # Clear suppress tokens that might cause issues
                if hasattr(self.model.generation_config, 'suppress_tokens'):
                    self.model.generation_config.suppress_tokens = []
                
                # Set correct parameters for Wolof transcription
                self.model.generation_config.language = "wo"  # Wolof language code
                self.model.generation_config.task = "transcribe"
                
                # Ensure we don't have conflicting parameters
                if hasattr(self.model.generation_config, 'decoder_input_ids'):
                    self.model.generation_config.decoder_input_ids = None
                if hasattr(self.model.generation_config, 'input_ids'):
                    self.model.generation_config.input_ids = None
                    
                logger.info("Successfully fixed model configuration for Wolof transcription")

            # Create pipeline with fixed model
            self.pipe = pipeline(
                "automatic-speech-recognition",
                model=self.model,
                tokenizer=self.processor.tokenizer,
                feature_extractor=self.processor.feature_extractor,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device=0 if torch.cuda.is_available() else -1
            )
            
            logger.info("Wolof Whisper model loaded successfully with fixed configuration")
            
        except Exception as e:
            logger.error(f"Error loading Wolof model: {e}")
            raise e

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the audio input and return Wolof transcription
        Args:
            data: Input data containing audio (binary or base64)
        Returns:
            Transcription result in the expected format
        """
        try:
            logger.info("Processing Wolof audio transcription request")
            
            # Get the audio input
            inputs = data.get("inputs", data)
            
            # Handle different input types
            if isinstance(inputs, str):
                logger.info("Processing base64 encoded audio")
                # Base64 encoded audio - decode and save to temp file
                try:
                    audio_bytes = base64.b64decode(inputs)
                except Exception as e:
                    logger.error(f"Failed to decode base64 audio: {e}")
                    return [{"error": f"Invalid base64 audio data: {str(e)}"}]
                    
                # Save to temporary file for processing
                with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
                    temp_file.write(audio_bytes)
                    temp_path = temp_file.name
                    
                try:
                    result = self._transcribe_audio(temp_path)
                finally:
                    # Clean up temp file
                    if os.path.exists(temp_path):
                        os.unlink(temp_path)
                        
            elif isinstance(inputs, bytes):
                logger.info("Processing binary audio data")
                # Direct binary audio data
                with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
                    temp_file.write(inputs)
                    temp_path = temp_file.name
                    
                try:
                    result = self._transcribe_audio(temp_path)
                finally:
                    # Clean up temp file
                    if os.path.exists(temp_path):
                        os.unlink(temp_path)
                        
            else:
                logger.info("Processing direct audio path/data")
                # Direct audio path or numpy array
                result = self._transcribe_audio(inputs)

            logger.info(f"Wolof transcription completed successfully")
            return [result] if not isinstance(result, list) else result

        except Exception as e:
            logger.error(f"Error during Wolof transcription: {e}")
            return [{"error": f"Wolof transcription failed: {str(e)}"}]

    def _transcribe_audio(self, audio_input):
        """
        Internal method to transcribe audio using the fixed pipeline
        """
        try:
            # Use the pipeline with explicit parameters to avoid forced_decoder_ids
            result = self.pipe(
                audio_input,
                generate_kwargs={
                    "language": "wo",  # Wolof language code
                    "task": "transcribe",
                    # Explicitly avoid deprecated parameters
                    "forced_decoder_ids": None,
                    "suppress_tokens": [],
                    # Use modern parameters
                    "max_length": 448,
                    "num_beams": 1,
                    "do_sample": False,
                }
            )
            
            # Extract text from result
            if isinstance(result, dict):
                text = result.get("text", "")
            elif isinstance(result, list) and len(result) > 0:
                text = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0])
            else:
                text = str(result)
            
            # Return in expected format
            return {
                "text": text.strip(),
                "language": "wo",
                "model": "Alwaly/whisper-medium-wolof"
            }
            
        except Exception as e:
            logger.error(f"Pipeline transcription error: {e}")
            # If we get the forced_decoder_ids error, provide helpful message
            if "forced_decoder_ids" in str(e):
                raise Exception(
                    "forced_decoder_ids parameter is deprecated. "
                    "This handler.py file should fix this issue. "
                    "Please redeploy the endpoint."
                )
            else:
                raise e