File size: 4,756 Bytes
5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 f7c80df 5cf4223 2bcb732 5cf4223 2bcb732 f7c80df 5cf4223 2bcb732 f7c80df 5cf4223 2bcb732 f7c80df 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 f7c80df 5cf4223 2bcb732 f7c80df 5cf4223 f7c80df 5cf4223 2bcb732 5cf4223 2bcb732 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import librosa
import io
import base64
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
print("Loading Whisper model...")
try:
try:
self.model = WhisperForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map={"": 0},
attn_implementation="flash_attention_2"
)
print("✅ Flash Attention 2 activated!")
except ImportError:
print("⚠️ Flash Attention not available, fallback to eager")
self.model = WhisperForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto"
)
self.processor = WhisperProcessor.from_pretrained(path)
self.model.eval()
if hasattr(torch, 'compile'):
try:
self.model = torch.compile(self.model, mode="max-autotune")
print("Model compiled with max-autotune!")
except Exception as e:
print(f"Max-autotune compilation failed: {e}")
try:
self.model = torch.compile(self.model, mode="reduce-overhead")
print("Model compiled with reduce-overhead!")
except Exception as e2:
print(f"Compilation failed: {e2}")
# forced_decoder_ids pour français (comme fastapi)
self.french_decoder_ids = self.processor.get_decoder_prompt_ids(
language="french", task="transcribe"
)
print("Model loaded and optimized successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
try:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# decode audio (base64 string or bytes)
if isinstance(inputs, str):
try:
audio_bytes = base64.b64decode(inputs)
except Exception:
return {"error": "Invalid base64 encoded audio"}
elif isinstance(inputs, bytes):
audio_bytes = inputs
else:
return {"error": "Invalid input format. Expected base64 string or bytes"}
# check size
if len(audio_bytes) > 25 * 1024 * 1024:
return {"error": "File too large (max 25MB)"}
# load audio
audio_array, _ = librosa.load(
io.BytesIO(audio_bytes),
sr=16000,
mono=True,
duration=30
)
if len(audio_array) == 0:
return {"error": "Invalid or empty audio file"}
model_inputs = self.processor(
audio_array,
sampling_rate=16000,
return_tensors="pt"
)
model_inputs = {
k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
for k, v in model_inputs.items()
}
# params
max_length = parameters.get("max_length", 256)
num_beams = parameters.get("num_beams", 6)
temperature = parameters.get("temperature", 0.0)
# generate
with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
predicted_ids = self.model.generate(
**model_inputs,
max_length=max_length,
num_beams=num_beams,
temperature=temperature,
do_sample=False,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=1.1,
length_penalty=1.0,
use_cache=True,
pad_token_id=self.processor.tokenizer.eos_token_id,
forced_decoder_ids=self.french_decoder_ids, # ✅ identique à fastapi
suppress_tokens=[],
begin_suppress_tokens=[]
)
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
return {"transcription": transcription[0]}
except Exception as e:
return {"error": f"Transcription error: {str(e)}"}
|