|
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
import torch |
|
|
import librosa |
|
|
import io |
|
|
import base64 |
|
|
from typing import Dict, Any |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
print("Loading Whisper model...") |
|
|
try: |
|
|
try: |
|
|
self.model = WhisperForConditionalGeneration.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map={"": 0}, |
|
|
attn_implementation="flash_attention_2" |
|
|
) |
|
|
print("✅ Flash Attention 2 activated!") |
|
|
except ImportError: |
|
|
print("⚠️ Flash Attention not available, fallback to eager") |
|
|
self.model = WhisperForConditionalGeneration.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
self.processor = WhisperProcessor.from_pretrained(path) |
|
|
self.model.eval() |
|
|
|
|
|
if hasattr(torch, 'compile'): |
|
|
try: |
|
|
self.model = torch.compile(self.model, mode="max-autotune") |
|
|
print("Model compiled with max-autotune!") |
|
|
except Exception as e: |
|
|
print(f"Max-autotune compilation failed: {e}") |
|
|
try: |
|
|
self.model = torch.compile(self.model, mode="reduce-overhead") |
|
|
print("Model compiled with reduce-overhead!") |
|
|
except Exception as e2: |
|
|
print(f"Compilation failed: {e2}") |
|
|
|
|
|
|
|
|
forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe") |
|
|
self.french_decoder_input_ids = torch.tensor( |
|
|
[[tok_id for _, tok_id in forced_ids]], |
|
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
|
) |
|
|
|
|
|
print("Model loaded and optimized successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
raise e |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
|
try: |
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
try: |
|
|
audio_bytes = base64.b64decode(inputs) |
|
|
except Exception: |
|
|
return {"error": "Invalid base64 encoded audio"} |
|
|
elif isinstance(inputs, bytes): |
|
|
audio_bytes = inputs |
|
|
else: |
|
|
return {"error": "Invalid input format. Expected base64 string or bytes"} |
|
|
|
|
|
if len(audio_bytes) > 25 * 1024 * 1024: |
|
|
return {"error": "File too large (max 25MB)"} |
|
|
|
|
|
|
|
|
audio_array, _ = librosa.load( |
|
|
io.BytesIO(audio_bytes), |
|
|
sr=16000, |
|
|
mono=True, |
|
|
duration=30 |
|
|
) |
|
|
if len(audio_array) == 0: |
|
|
return {"error": "Invalid or empty audio file"} |
|
|
|
|
|
|
|
|
model_inputs = self.processor( |
|
|
audio_array, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
if "forced_decoder_ids" in model_inputs: |
|
|
del model_inputs["forced_decoder_ids"] |
|
|
|
|
|
|
|
|
model_inputs = { |
|
|
k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device) |
|
|
for k, v in model_inputs.items() |
|
|
} |
|
|
|
|
|
|
|
|
max_length = parameters.get("max_length", 256) |
|
|
num_beams = parameters.get("num_beams", 6) |
|
|
temperature = parameters.get("temperature", 0.0) |
|
|
|
|
|
|
|
|
with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16): |
|
|
predicted_ids = self.model.generate( |
|
|
**model_inputs, |
|
|
decoder_input_ids=self.french_decoder_input_ids, |
|
|
max_length=max_length, |
|
|
num_beams=num_beams, |
|
|
temperature=temperature, |
|
|
do_sample=False, |
|
|
early_stopping=True, |
|
|
no_repeat_ngram_size=3, |
|
|
repetition_penalty=1.1, |
|
|
length_penalty=1.0, |
|
|
use_cache=True, |
|
|
pad_token_id=self.processor.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
|
return {"transcription": transcription[0]} |
|
|
except Exception as e: |
|
|
return {"error": f"Transcription error: {str(e)}"} |