srd4 commited on
Commit
12879c9
·
verified ·
1 Parent(s): e59837f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -18
handler.py CHANGED
@@ -1,26 +1,33 @@
 
1
  from faster_whisper import WhisperModel
2
- import torch
3
 
4
  class EndpointHandler:
5
  def __init__(self):
6
- model_size = "large-v2" # Update model size if different
7
- device = "cpu" # Use CPU for Azure deployment
8
- self.model = WhisperModel(model_size, device=device)
9
 
10
- def __call__(self, data):
11
- # Extract audio bytes from the request data
12
- audio_bytes = data.get("inputs")
13
 
14
- # Convert audio bytes to audio samples
15
- # Note: Additional conversion might be needed depending on the format of the incoming audio bytes
16
-
17
- # Transcribe the audio using the Whisper model without writing to disk
18
  segments, info = self.model.transcribe(audio_bytes)
 
 
 
 
 
 
 
19
 
20
- # Combine the text from all segments
21
- text = " ".join(segment.text for segment in segments)
22
-
23
- # Return the transcribed text and the detected language
24
- return {"text": text, "language": info.language_code}
25
-
26
- # If applicable, write additional conversion code to get samples from bytes
 
 
 
 
1
+ from typing import Dict
2
  from faster_whisper import WhisperModel
 
3
 
4
  class EndpointHandler:
5
  def __init__(self):
6
+ # Initialize WhisperModel
7
+ self.model = WhisperModel("large-v2")
 
8
 
9
+ def __call__(self, data: Dict) -> Dict:
10
+ # Get the audio file bytes from the request data
11
+ audio_bytes = data["inputs"]
12
 
13
+ # Perform transcription
14
+ results = []
 
 
15
  segments, info = self.model.transcribe(audio_bytes)
16
+ for segment in segments:
17
+ result = {
18
+ "start": segment.start,
19
+ "end": segment.end,
20
+ "text": segment.text
21
+ }
22
+ results.append(result)
23
 
24
+ # Return the transcribed text along with language data
25
+ language_code, language_prob = info.language, info.language_probability
26
+ response = {
27
+ "transcription": results,
28
+ "language": {
29
+ "code": language_code,
30
+ "probability": language_prob
31
+ }
32
+ }
33
+ return response