jsbeaudry commited on
Commit
768357c
·
verified ·
1 Parent(s): 82d8904

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +60 -0
handler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastModel
2
+ from transformers import WhisperForConditionalGeneration, pipeline
3
+ import torch
4
+ import tempfile
5
+ import os
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, model_path):
9
+ # Load Unsloth Whisper model
10
+ model, tokenizer = FastModel.from_pretrained(
11
+ model_name = model_path,
12
+ dtype = None,
13
+ load_in_4bit = False,
14
+ auto_model = WhisperForConditionalGeneration,
15
+ whisper_language = "Haitian",
16
+ whisper_task = "transcribe"
17
+ )
18
+
19
+ # Prepare model for inference
20
+ FastModel.for_inference(model)
21
+ model.eval()
22
+
23
+ # Load ASR pipeline
24
+ self.pipeline = pipeline(
25
+ "automatic-speech-recognition",
26
+ model=model,
27
+ tokenizer=tokenizer.tokenizer,
28
+ feature_extractor=tokenizer.feature_extractor,
29
+ processor=tokenizer,
30
+ return_language=True,
31
+ torch_dtype=torch.float16,
32
+ )
33
+
34
+ # ⚠️ Remove forced_decoder_ids from generation config (causes runtime error)
35
+ if hasattr(self.pipeline.model.generation_config, "forced_decoder_ids"):
36
+ del self.pipeline.model.generation_config.forced_decoder_ids
37
+ if hasattr(self.pipeline.model.generation_config, "is_forced_decoder_ids_init"):
38
+ del self.pipeline.model.generation_config.is_forced_decoder_ids_init
39
+
40
+ def __call__(self, data):
41
+ audio = data.get("inputs")
42
+ if audio is None:
43
+ return "Error: No input audio provided."
44
+
45
+ try:
46
+ # Handle byte input (e.g., uploaded or streamed audio)
47
+ if isinstance(audio, bytes):
48
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
49
+ f.write(audio)
50
+ file_path = f.name
51
+ elif isinstance(audio, str) and os.path.isfile(audio):
52
+ file_path = audio
53
+ else:
54
+ return "Error: Invalid input. Expected audio bytes or file path."
55
+
56
+ result = self.pipeline(file_path)
57
+ return result["text"]
58
+
59
+ except Exception as e:
60
+ return f"Error during transcription: {str(e)}"