SoSolaris commited on
Commit
bb23501
·
verified ·
1 Parent(s): 1a35a23

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -3
handler.py CHANGED
@@ -40,7 +40,7 @@ class EndpointHandler:
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
- # compute decoder_input_ids for french
44
  forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe")
45
  self.french_decoder_input_ids = torch.tensor(
46
  [[tok_id for _, tok_id in forced_ids]],
@@ -57,6 +57,7 @@ class EndpointHandler:
57
  inputs = data.get("inputs", "")
58
  parameters = data.get("parameters", {})
59
 
 
60
  if isinstance(inputs, str):
61
  try:
62
  audio_bytes = base64.b64decode(inputs)
@@ -70,34 +71,40 @@ class EndpointHandler:
70
  if len(audio_bytes) > 25 * 1024 * 1024:
71
  return {"error": "File too large (max 25MB)"}
72
 
 
73
  audio_array, _ = librosa.load(
74
  io.BytesIO(audio_bytes),
75
  sr=16000,
76
  mono=True,
77
  duration=30
78
  )
79
-
80
  if len(audio_array) == 0:
81
  return {"error": "Invalid or empty audio file"}
82
 
 
83
  model_inputs = self.processor(
84
  audio_array,
85
  sampling_rate=16000,
86
  return_tensors="pt"
87
  )
 
 
 
88
  model_inputs = {
89
  k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
90
  for k, v in model_inputs.items()
91
  }
92
 
 
93
  max_length = parameters.get("max_length", 256)
94
  num_beams = parameters.get("num_beams", 6)
95
  temperature = parameters.get("temperature", 0.0)
96
 
 
97
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
98
  predicted_ids = self.model.generate(
99
  **model_inputs,
100
- decoder_input_ids=self.french_decoder_input_ids, # ✅ seul forçage
101
  max_length=max_length,
102
  num_beams=num_beams,
103
  temperature=temperature,
 
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
+ # precompute decoder_input_ids for French transcription
44
  forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe")
45
  self.french_decoder_input_ids = torch.tensor(
46
  [[tok_id for _, tok_id in forced_ids]],
 
57
  inputs = data.get("inputs", "")
58
  parameters = data.get("parameters", {})
59
 
60
+ # decode audio
61
  if isinstance(inputs, str):
62
  try:
63
  audio_bytes = base64.b64decode(inputs)
 
71
  if len(audio_bytes) > 25 * 1024 * 1024:
72
  return {"error": "File too large (max 25MB)"}
73
 
74
+ # load audio
75
  audio_array, _ = librosa.load(
76
  io.BytesIO(audio_bytes),
77
  sr=16000,
78
  mono=True,
79
  duration=30
80
  )
 
81
  if len(audio_array) == 0:
82
  return {"error": "Invalid or empty audio file"}
83
 
84
+ # processor injecte forced_decoder_ids -> on les enlève
85
  model_inputs = self.processor(
86
  audio_array,
87
  sampling_rate=16000,
88
  return_tensors="pt"
89
  )
90
+ if "forced_decoder_ids" in model_inputs:
91
+ del model_inputs["forced_decoder_ids"]
92
+
93
  model_inputs = {
94
  k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
95
  for k, v in model_inputs.items()
96
  }
97
 
98
+ # params
99
  max_length = parameters.get("max_length", 256)
100
  num_beams = parameters.get("num_beams", 6)
101
  temperature = parameters.get("temperature", 0.0)
102
 
103
+ # generate
104
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
105
  predicted_ids = self.model.generate(
106
  **model_inputs,
107
+ decoder_input_ids=self.french_decoder_input_ids, # ✅ seul forçage langue
108
  max_length=max_length,
109
  num_beams=num_beams,
110
  temperature=temperature,