File size: 4,756 Bytes
5cf4223
 
 
 
 
2bcb732
5cf4223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bcb732
5cf4223
 
2bcb732
5cf4223
 
 
 
 
2bcb732
5cf4223
 
 
 
 
2bcb732
f7c80df
 
 
5cf4223
2bcb732
5cf4223
 
 
 
 
 
 
 
 
2bcb732
f7c80df
5cf4223
 
 
 
 
 
 
 
 
2bcb732
f7c80df
5cf4223
 
2bcb732
f7c80df
2bcb732
 
5cf4223
 
2bcb732
5cf4223
2bcb732
5cf4223
 
2bcb732
5cf4223
2bcb732
 
5cf4223
 
 
2bcb732
5cf4223
 
2bcb732
f7c80df
5cf4223
 
 
2bcb732
f7c80df
5cf4223
 
 
 
 
 
 
 
 
 
 
 
 
f7c80df
5cf4223
 
 
2bcb732
5cf4223
 
 
2bcb732
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import librosa
import io
import base64
from typing import Dict, Any

class EndpointHandler:
    def __init__(self, path=""):
        print("Loading Whisper model...")
        try:
            try:
                self.model = WhisperForConditionalGeneration.from_pretrained(
                    path,
                    torch_dtype=torch.bfloat16,
                    device_map={"": 0},
                    attn_implementation="flash_attention_2"
                )
                print("✅ Flash Attention 2 activated!")
            except ImportError:
                print("⚠️ Flash Attention not available, fallback to eager")
                self.model = WhisperForConditionalGeneration.from_pretrained(
                    path,
                    torch_dtype=torch.float16,
                    device_map="auto"
                )

            self.processor = WhisperProcessor.from_pretrained(path)
            self.model.eval()

            if hasattr(torch, 'compile'):
                try:
                    self.model = torch.compile(self.model, mode="max-autotune")
                    print("Model compiled with max-autotune!")
                except Exception as e:
                    print(f"Max-autotune compilation failed: {e}")
                    try:
                        self.model = torch.compile(self.model, mode="reduce-overhead")
                        print("Model compiled with reduce-overhead!")
                    except Exception as e2:
                        print(f"Compilation failed: {e2}")

            # forced_decoder_ids pour français (comme fastapi)
            self.french_decoder_ids = self.processor.get_decoder_prompt_ids(
                language="french", task="transcribe"
            )

            print("Model loaded and optimized successfully!")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise e

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        try:
            inputs = data.get("inputs", "")
            parameters = data.get("parameters", {})

            # decode audio (base64 string or bytes)
            if isinstance(inputs, str):
                try:
                    audio_bytes = base64.b64decode(inputs)
                except Exception:
                    return {"error": "Invalid base64 encoded audio"}
            elif isinstance(inputs, bytes):
                audio_bytes = inputs
            else:
                return {"error": "Invalid input format. Expected base64 string or bytes"}

            # check size
            if len(audio_bytes) > 25 * 1024 * 1024:
                return {"error": "File too large (max 25MB)"}

            # load audio
            audio_array, _ = librosa.load(
                io.BytesIO(audio_bytes),
                sr=16000,
                mono=True,
                duration=30
            )

            if len(audio_array) == 0:
                return {"error": "Invalid or empty audio file"}

            model_inputs = self.processor(
                audio_array,
                sampling_rate=16000,
                return_tensors="pt"
            )
            model_inputs = {
                k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
                for k, v in model_inputs.items()
            }

            # params
            max_length = parameters.get("max_length", 256)
            num_beams = parameters.get("num_beams", 6)
            temperature = parameters.get("temperature", 0.0)

            # generate
            with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
                predicted_ids = self.model.generate(
                    **model_inputs,
                    max_length=max_length,
                    num_beams=num_beams,
                    temperature=temperature,
                    do_sample=False,
                    early_stopping=True,
                    no_repeat_ngram_size=3,
                    repetition_penalty=1.1,
                    length_penalty=1.0,
                    use_cache=True,
                    pad_token_id=self.processor.tokenizer.eos_token_id,
                    forced_decoder_ids=self.french_decoder_ids,  # ✅ identique à fastapi
                    suppress_tokens=[],
                    begin_suppress_tokens=[]
                )

            transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
            return {"transcription": transcription[0]}
        except Exception as e:
            return {"error": f"Transcription error: {str(e)}"}