SoSolaris commited on
Commit
2bcb732
·
verified ·
1 Parent(s): 8c0d808

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -61
handler.py CHANGED
@@ -3,18 +3,12 @@ import torch
3
  import librosa
4
  import io
5
  import base64
6
- from typing import Dict, List, Any
7
- import json
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- """
12
- Initialize the handler for Hugging Face Inference Endpoints
13
- """
14
  print("Loading Whisper model...")
15
-
16
  try:
17
- # Try Flash Attention 2 first
18
  try:
19
  self.model = WhisperForConditionalGeneration.from_pretrained(
20
  path,
@@ -30,55 +24,41 @@ class EndpointHandler:
30
  torch_dtype=torch.float16,
31
  device_map="auto"
32
  )
33
-
34
  self.processor = WhisperProcessor.from_pretrained(path)
35
-
36
- # Set to evaluation mode
37
  self.model.eval()
38
-
39
- # Compile model for optimization
40
  if hasattr(torch, 'compile'):
41
  try:
42
  self.model = torch.compile(self.model, mode="max-autotune")
43
  print("Model compiled with max-autotune!")
44
  except Exception as e:
45
- print(f"Max-autotune compilation failed, fallback: {e}")
46
  try:
47
  self.model = torch.compile(self.model, mode="reduce-overhead")
48
  print("Model compiled with reduce-overhead!")
49
  except Exception as e2:
50
  print(f"Compilation failed: {e2}")
51
-
52
- # Pre-compute French decoder IDs
53
- self.french_decoder_ids = self.processor.get_decoder_prompt_ids(
54
- language="french",
55
- task="transcribe"
 
 
56
  )
57
-
58
  print("Model loaded and optimized successfully!")
59
-
60
  except Exception as e:
61
  print(f"Error loading model: {e}")
62
  raise e
63
 
64
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
65
- """
66
- Process the request
67
- Args:
68
- data (Dict): The request payload containing:
69
- - "inputs": base64 encoded audio file or audio bytes
70
- - "parameters": optional parameters for generation
71
- Returns:
72
- Dict: The transcription result
73
- """
74
  try:
75
- # Extract inputs
76
  inputs = data.get("inputs", "")
77
  parameters = data.get("parameters", {})
78
-
79
- # Handle different input formats
80
  if isinstance(inputs, str):
81
- # Assume base64 encoded audio
82
  try:
83
  audio_bytes = base64.b64decode(inputs)
84
  except Exception:
@@ -87,48 +67,39 @@ class EndpointHandler:
87
  audio_bytes = inputs
88
  else:
89
  return {"error": "Invalid input format. Expected base64 string or bytes"}
90
-
91
- # Validate file size (max 25MB)
92
  if len(audio_bytes) > 25 * 1024 * 1024:
93
  return {"error": "File too large (max 25MB)"}
94
-
95
- # Load audio from bytes
96
- audio_array, sampling_rate = librosa.load(
97
- io.BytesIO(audio_bytes),
98
  sr=16000,
99
  mono=True,
100
- duration=30 # Limit to 30 seconds max
101
  )
102
-
103
- # Validate audio
104
  if len(audio_array) == 0:
105
  return {"error": "Invalid or empty audio file"}
106
-
107
- # Process audio for the model
108
  model_inputs = self.processor(
109
- audio_array,
110
- sampling_rate=16000,
111
  return_tensors="pt"
112
  )
113
-
114
- # Move inputs to same device and dtype as model
115
  model_inputs = {
116
- k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
117
  for k, v in model_inputs.items()
118
  }
119
-
120
- # Extract generation parameters
121
  max_length = parameters.get("max_length", 256)
122
  num_beams = parameters.get("num_beams", 6)
123
  temperature = parameters.get("temperature", 0.0)
124
-
125
- # Generate transcription with anti-hallucination parameters
126
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
127
- # Add language forcing to inputs instead of generation params
128
- model_inputs.update(self.processor.get_decoder_prompt_ids(language="french", task="transcribe"))
129
-
130
  predicted_ids = self.model.generate(
131
  **model_inputs,
 
132
  max_length=max_length,
133
  num_beams=num_beams,
134
  temperature=temperature,
@@ -142,11 +113,8 @@ class EndpointHandler:
142
  suppress_tokens=[],
143
  begin_suppress_tokens=[]
144
  )
145
-
146
- # Decode the transcription
147
  transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
148
-
149
  return {"transcription": transcription[0]}
150
-
151
  except Exception as e:
152
- return {"error": f"Transcription error: {str(e)}"}
 
3
  import librosa
4
  import io
5
  import base64
6
+ from typing import Dict, Any
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
 
 
 
10
  print("Loading Whisper model...")
 
11
  try:
 
12
  try:
13
  self.model = WhisperForConditionalGeneration.from_pretrained(
14
  path,
 
24
  torch_dtype=torch.float16,
25
  device_map="auto"
26
  )
27
+
28
  self.processor = WhisperProcessor.from_pretrained(path)
 
 
29
  self.model.eval()
30
+
 
31
  if hasattr(torch, 'compile'):
32
  try:
33
  self.model = torch.compile(self.model, mode="max-autotune")
34
  print("Model compiled with max-autotune!")
35
  except Exception as e:
36
+ print(f"Max-autotune compilation failed: {e}")
37
  try:
38
  self.model = torch.compile(self.model, mode="reduce-overhead")
39
  print("Model compiled with reduce-overhead!")
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
+
43
+ # pre-compute decoder ids for french
44
+ self.french_decoder_ids = torch.tensor(
45
+ self.processor.get_decoder_prompt_ids(
46
+ language="french", task="transcribe"
47
+ ),
48
+ device="cuda" if torch.cuda.is_available() else "cpu"
49
  )
50
+
51
  print("Model loaded and optimized successfully!")
 
52
  except Exception as e:
53
  print(f"Error loading model: {e}")
54
  raise e
55
 
56
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
 
 
 
 
 
 
 
 
 
57
  try:
 
58
  inputs = data.get("inputs", "")
59
  parameters = data.get("parameters", {})
60
+
 
61
  if isinstance(inputs, str):
 
62
  try:
63
  audio_bytes = base64.b64decode(inputs)
64
  except Exception:
 
67
  audio_bytes = inputs
68
  else:
69
  return {"error": "Invalid input format. Expected base64 string or bytes"}
70
+
 
71
  if len(audio_bytes) > 25 * 1024 * 1024:
72
  return {"error": "File too large (max 25MB)"}
73
+
74
+ audio_array, _ = librosa.load(
75
+ io.BytesIO(audio_bytes),
 
76
  sr=16000,
77
  mono=True,
78
+ duration=30
79
  )
80
+
 
81
  if len(audio_array) == 0:
82
  return {"error": "Invalid or empty audio file"}
83
+
 
84
  model_inputs = self.processor(
85
+ audio_array,
86
+ sampling_rate=16000,
87
  return_tensors="pt"
88
  )
89
+
 
90
  model_inputs = {
91
+ k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
92
  for k, v in model_inputs.items()
93
  }
94
+
 
95
  max_length = parameters.get("max_length", 256)
96
  num_beams = parameters.get("num_beams", 6)
97
  temperature = parameters.get("temperature", 0.0)
98
+
 
99
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
 
 
 
100
  predicted_ids = self.model.generate(
101
  **model_inputs,
102
+ decoder_input_ids=self.french_decoder_ids,
103
  max_length=max_length,
104
  num_beams=num_beams,
105
  temperature=temperature,
 
113
  suppress_tokens=[],
114
  begin_suppress_tokens=[]
115
  )
116
+
 
117
  transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
 
118
  return {"transcription": transcription[0]}
 
119
  except Exception as e:
120
+ return {"error": f"Transcription error: {str(e)}"}