File size: 5,122 Bytes
5cf4223
 
 
 
 
2bcb732
5cf4223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bcb732
5cf4223
 
2bcb732
5cf4223
 
 
 
 
2bcb732
5cf4223
 
 
 
 
2bcb732
42ab4b8
afe3147
 
 
 
5cf4223
2bcb732
5cf4223
 
 
 
 
 
 
 
 
2bcb732
42ab4b8
5cf4223
 
 
 
 
 
 
 
 
2bcb732
5cf4223
 
2bcb732
42ab4b8
2bcb732
 
5cf4223
 
2bcb732
5cf4223
 
 
2bcb732
42ab4b8
5cf4223
2bcb732
 
5cf4223
 
42ab4b8
 
bb23501
 
 
42ab4b8
5cf4223
2bcb732
5cf4223
 
2bcb732
42ab4b8
5cf4223
 
 
2bcb732
42ab4b8
5cf4223
 
 
42ab4b8
5cf4223
 
 
 
 
 
 
 
 
42ab4b8
5cf4223
2bcb732
5cf4223
 
 
42ab4b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import librosa
import io
import base64
from typing import Dict, Any

class EndpointHandler:
    def __init__(self, path=""):
        print("Loading Whisper model...")
        try:
            try:
                self.model = WhisperForConditionalGeneration.from_pretrained(
                    path,
                    torch_dtype=torch.bfloat16,
                    device_map={"": 0},
                    attn_implementation="flash_attention_2"
                )
                print("✅ Flash Attention 2 activated!")
            except ImportError:
                print("⚠️ Flash Attention not available, fallback to eager")
                self.model = WhisperForConditionalGeneration.from_pretrained(
                    path,
                    torch_dtype=torch.float16,
                    device_map="auto"
                )

            self.processor = WhisperProcessor.from_pretrained(path)
            self.model.eval()

            if hasattr(torch, 'compile'):
                try:
                    self.model = torch.compile(self.model, mode="max-autotune")
                    print("Model compiled with max-autotune!")
                except Exception as e:
                    print(f"Max-autotune compilation failed: {e}")
                    try:
                        self.model = torch.compile(self.model, mode="reduce-overhead")
                        print("Model compiled with reduce-overhead!")
                    except Exception as e2:
                        print(f"Compilation failed: {e2}")

            # Precompute decoder_input_ids for French transcription
            forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe")
            self.french_decoder_input_ids = torch.tensor(
                [[tok_id for _, tok_id in forced_ids]],
                device="cuda" if torch.cuda.is_available() else "cpu"
            )

            print("Model loaded and optimized successfully!")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise e

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        try:
            inputs = data.get("inputs", "")
            parameters = data.get("parameters", {})

            # Decode audio
            if isinstance(inputs, str):
                try:
                    audio_bytes = base64.b64decode(inputs)
                except Exception:
                    return {"error": "Invalid base64 encoded audio"}
            elif isinstance(inputs, bytes):
                audio_bytes = inputs
            else:
                return {"error": "Invalid input format. Expected base64 string or bytes"}

            if len(audio_bytes) > 25 * 1024 * 1024:
                return {"error": "File too large (max 25MB)"}

            # Load audio
            audio_array, _ = librosa.load(
                io.BytesIO(audio_bytes),
                sr=16000,
                mono=True,
                duration=30
            )
            if len(audio_array) == 0:
                return {"error": "Invalid or empty audio file"}

            # Process audio WITHOUT language/task specification to avoid forced_decoder_ids
            model_inputs = self.processor(
                audio_array,
                sampling_rate=16000,
                return_tensors="pt"
            )
            
            # Remove any forced_decoder_ids that might have been added
            if "forced_decoder_ids" in model_inputs:
                del model_inputs["forced_decoder_ids"]

            # Move to device and convert dtype
            model_inputs = {
                k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
                for k, v in model_inputs.items()
            }

            # Parameters
            max_length = parameters.get("max_length", 256)
            num_beams = parameters.get("num_beams", 6)
            temperature = parameters.get("temperature", 0.0)

            # Generate with explicit decoder_input_ids
            with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
                predicted_ids = self.model.generate(
                    **model_inputs,
                    decoder_input_ids=self.french_decoder_input_ids,
                    max_length=max_length,
                    num_beams=num_beams,
                    temperature=temperature,
                    do_sample=False,
                    early_stopping=True,
                    no_repeat_ngram_size=3,
                    repetition_penalty=1.1,
                    length_penalty=1.0,
                    use_cache=True,
                    pad_token_id=self.processor.tokenizer.eos_token_id
                )

            transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
            return {"transcription": transcription[0]}
        except Exception as e:
            return {"error": f"Transcription error: {str(e)}"}