SoSolaris commited on
Commit
42ab4b8
·
verified ·
1 Parent(s): bb23501

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -11
handler.py CHANGED
@@ -40,7 +40,7 @@ class EndpointHandler:
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
- # precompute decoder_input_ids for French transcription
44
  forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe")
45
  self.french_decoder_input_ids = torch.tensor(
46
  [[tok_id for _, tok_id in forced_ids]],
@@ -57,7 +57,7 @@ class EndpointHandler:
57
  inputs = data.get("inputs", "")
58
  parameters = data.get("parameters", {})
59
 
60
- # decode audio
61
  if isinstance(inputs, str):
62
  try:
63
  audio_bytes = base64.b64decode(inputs)
@@ -71,7 +71,7 @@ class EndpointHandler:
71
  if len(audio_bytes) > 25 * 1024 * 1024:
72
  return {"error": "File too large (max 25MB)"}
73
 
74
- # load audio
75
  audio_array, _ = librosa.load(
76
  io.BytesIO(audio_bytes),
77
  sr=16000,
@@ -81,30 +81,33 @@ class EndpointHandler:
81
  if len(audio_array) == 0:
82
  return {"error": "Invalid or empty audio file"}
83
 
84
- # processor injecte forced_decoder_ids -> on les enlève
85
  model_inputs = self.processor(
86
  audio_array,
87
  sampling_rate=16000,
88
  return_tensors="pt"
89
  )
 
 
90
  if "forced_decoder_ids" in model_inputs:
91
  del model_inputs["forced_decoder_ids"]
92
 
 
93
  model_inputs = {
94
  k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
95
  for k, v in model_inputs.items()
96
  }
97
 
98
- # params
99
  max_length = parameters.get("max_length", 256)
100
  num_beams = parameters.get("num_beams", 6)
101
  temperature = parameters.get("temperature", 0.0)
102
 
103
- # generate
104
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
105
  predicted_ids = self.model.generate(
106
  **model_inputs,
107
- decoder_input_ids=self.french_decoder_input_ids, # ✅ seul forçage langue
108
  max_length=max_length,
109
  num_beams=num_beams,
110
  temperature=temperature,
@@ -114,12 +117,10 @@ class EndpointHandler:
114
  repetition_penalty=1.1,
115
  length_penalty=1.0,
116
  use_cache=True,
117
- pad_token_id=self.processor.tokenizer.eos_token_id,
118
- suppress_tokens=[],
119
- begin_suppress_tokens=[]
120
  )
121
 
122
  transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
123
  return {"transcription": transcription[0]}
124
  except Exception as e:
125
- return {"error": f"Transcription error: {str(e)}"}
 
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
+ # Precompute decoder_input_ids for French transcription
44
  forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe")
45
  self.french_decoder_input_ids = torch.tensor(
46
  [[tok_id for _, tok_id in forced_ids]],
 
57
  inputs = data.get("inputs", "")
58
  parameters = data.get("parameters", {})
59
 
60
+ # Decode audio
61
  if isinstance(inputs, str):
62
  try:
63
  audio_bytes = base64.b64decode(inputs)
 
71
  if len(audio_bytes) > 25 * 1024 * 1024:
72
  return {"error": "File too large (max 25MB)"}
73
 
74
+ # Load audio
75
  audio_array, _ = librosa.load(
76
  io.BytesIO(audio_bytes),
77
  sr=16000,
 
81
  if len(audio_array) == 0:
82
  return {"error": "Invalid or empty audio file"}
83
 
84
+ # Process audio WITHOUT language/task specification to avoid forced_decoder_ids
85
  model_inputs = self.processor(
86
  audio_array,
87
  sampling_rate=16000,
88
  return_tensors="pt"
89
  )
90
+
91
+ # Remove any forced_decoder_ids that might have been added
92
  if "forced_decoder_ids" in model_inputs:
93
  del model_inputs["forced_decoder_ids"]
94
 
95
+ # Move to device and convert dtype
96
  model_inputs = {
97
  k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
98
  for k, v in model_inputs.items()
99
  }
100
 
101
+ # Parameters
102
  max_length = parameters.get("max_length", 256)
103
  num_beams = parameters.get("num_beams", 6)
104
  temperature = parameters.get("temperature", 0.0)
105
 
106
+ # Generate with explicit decoder_input_ids
107
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
108
  predicted_ids = self.model.generate(
109
  **model_inputs,
110
+ decoder_input_ids=self.french_decoder_input_ids,
111
  max_length=max_length,
112
  num_beams=num_beams,
113
  temperature=temperature,
 
117
  repetition_penalty=1.1,
118
  length_penalty=1.0,
119
  use_cache=True,
120
+ pad_token_id=self.processor.tokenizer.eos_token_id
 
 
121
  )
122
 
123
  transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
124
  return {"transcription": transcription[0]}
125
  except Exception as e:
126
+ return {"error": f"Transcription error: {str(e)}"}