File size: 5,122 Bytes
5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 42ab4b8 afe3147 5cf4223 2bcb732 5cf4223 2bcb732 42ab4b8 5cf4223 2bcb732 5cf4223 2bcb732 42ab4b8 2bcb732 5cf4223 2bcb732 5cf4223 2bcb732 42ab4b8 5cf4223 2bcb732 5cf4223 42ab4b8 bb23501 42ab4b8 5cf4223 2bcb732 5cf4223 2bcb732 42ab4b8 5cf4223 2bcb732 42ab4b8 5cf4223 42ab4b8 5cf4223 42ab4b8 5cf4223 2bcb732 5cf4223 42ab4b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import librosa
import io
import base64
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
print("Loading Whisper model...")
try:
try:
self.model = WhisperForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map={"": 0},
attn_implementation="flash_attention_2"
)
print("✅ Flash Attention 2 activated!")
except ImportError:
print("⚠️ Flash Attention not available, fallback to eager")
self.model = WhisperForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto"
)
self.processor = WhisperProcessor.from_pretrained(path)
self.model.eval()
if hasattr(torch, 'compile'):
try:
self.model = torch.compile(self.model, mode="max-autotune")
print("Model compiled with max-autotune!")
except Exception as e:
print(f"Max-autotune compilation failed: {e}")
try:
self.model = torch.compile(self.model, mode="reduce-overhead")
print("Model compiled with reduce-overhead!")
except Exception as e2:
print(f"Compilation failed: {e2}")
# Precompute decoder_input_ids for French transcription
forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe")
self.french_decoder_input_ids = torch.tensor(
[[tok_id for _, tok_id in forced_ids]],
device="cuda" if torch.cuda.is_available() else "cpu"
)
print("Model loaded and optimized successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
try:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Decode audio
if isinstance(inputs, str):
try:
audio_bytes = base64.b64decode(inputs)
except Exception:
return {"error": "Invalid base64 encoded audio"}
elif isinstance(inputs, bytes):
audio_bytes = inputs
else:
return {"error": "Invalid input format. Expected base64 string or bytes"}
if len(audio_bytes) > 25 * 1024 * 1024:
return {"error": "File too large (max 25MB)"}
# Load audio
audio_array, _ = librosa.load(
io.BytesIO(audio_bytes),
sr=16000,
mono=True,
duration=30
)
if len(audio_array) == 0:
return {"error": "Invalid or empty audio file"}
# Process audio WITHOUT language/task specification to avoid forced_decoder_ids
model_inputs = self.processor(
audio_array,
sampling_rate=16000,
return_tensors="pt"
)
# Remove any forced_decoder_ids that might have been added
if "forced_decoder_ids" in model_inputs:
del model_inputs["forced_decoder_ids"]
# Move to device and convert dtype
model_inputs = {
k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
for k, v in model_inputs.items()
}
# Parameters
max_length = parameters.get("max_length", 256)
num_beams = parameters.get("num_beams", 6)
temperature = parameters.get("temperature", 0.0)
# Generate with explicit decoder_input_ids
with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
predicted_ids = self.model.generate(
**model_inputs,
decoder_input_ids=self.french_decoder_input_ids,
max_length=max_length,
num_beams=num_beams,
temperature=temperature,
do_sample=False,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=1.1,
length_penalty=1.0,
use_cache=True,
pad_token_id=self.processor.tokenizer.eos_token_id
)
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
return {"transcription": transcription[0]}
except Exception as e:
return {"error": f"Transcription error: {str(e)}"} |