dmartu commited on
Commit
d0e224d
·
verified ·
1 Parent(s): 3134440

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -0
handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ import soundfile as sf
4
+ import io
5
+ import base64
6
+ import numpy as np
7
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
13
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
14
+ path,
15
+ trust_remote_code=True,
16
+ torch_dtype=torch.float16,
17
+ device_map="auto"
18
+ )
19
+ self.model.eval()
20
+
21
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
+ audio_input = data.get("inputs")
23
+
24
+ if isinstance(audio_input, str):
25
+ audio_bytes = base64.b64decode(audio_input)
26
+ else:
27
+ audio_bytes = audio_input
28
+
29
+ audio_array, sample_rate = sf.read(io.BytesIO(audio_bytes))
30
+
31
+ if audio_array.ndim > 1:
32
+ audio_array = audio_array.mean(axis=1)
33
+
34
+ inputs = self.processor(
35
+ audio_array,
36
+ sampling_rate=sample_rate,
37
+ return_tensors="pt"
38
+ ).to(self.model.device)
39
+
40
+ with torch.no_grad():
41
+ generated_ids = self.model.generate(**inputs)
42
+
43
+ transcription = self.processor.batch_decode(
44
+ generated_ids,
45
+ skip_special_tokens=True
46
+ )[0]
47
+
48
+ return {"text": transcription}