SoSolaris commited on
Commit
afe3147
·
verified ·
1 Parent(s): f7c80df

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -9
handler.py CHANGED
@@ -40,9 +40,12 @@ class EndpointHandler:
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
- # forced_decoder_ids pour français (comme fastapi)
44
- self.french_decoder_ids = self.processor.get_decoder_prompt_ids(
45
- language="french", task="transcribe"
 
 
 
46
  )
47
 
48
  print("Model loaded and optimized successfully!")
@@ -55,7 +58,6 @@ class EndpointHandler:
55
  inputs = data.get("inputs", "")
56
  parameters = data.get("parameters", {})
57
 
58
- # decode audio (base64 string or bytes)
59
  if isinstance(inputs, str):
60
  try:
61
  audio_bytes = base64.b64decode(inputs)
@@ -66,11 +68,9 @@ class EndpointHandler:
66
  else:
67
  return {"error": "Invalid input format. Expected base64 string or bytes"}
68
 
69
- # check size
70
  if len(audio_bytes) > 25 * 1024 * 1024:
71
  return {"error": "File too large (max 25MB)"}
72
 
73
- # load audio
74
  audio_array, _ = librosa.load(
75
  io.BytesIO(audio_bytes),
76
  sr=16000,
@@ -91,15 +91,15 @@ class EndpointHandler:
91
  for k, v in model_inputs.items()
92
  }
93
 
94
- # params
95
  max_length = parameters.get("max_length", 256)
96
  num_beams = parameters.get("num_beams", 6)
97
  temperature = parameters.get("temperature", 0.0)
98
 
99
- # generate
100
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
101
  predicted_ids = self.model.generate(
102
  **model_inputs,
 
 
103
  max_length=max_length,
104
  num_beams=num_beams,
105
  temperature=temperature,
@@ -110,7 +110,6 @@ class EndpointHandler:
110
  length_penalty=1.0,
111
  use_cache=True,
112
  pad_token_id=self.processor.tokenizer.eos_token_id,
113
- forced_decoder_ids=self.french_decoder_ids, # ✅ identique à fastapi
114
  suppress_tokens=[],
115
  begin_suppress_tokens=[]
116
  )
 
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
+ # compute decoder_input_ids for french
44
+ forced_ids = self.processor.get_decoder_prompt_ids(language="french", task="transcribe")
45
+ # convert to tensor [ [id1,id2,...] ]
46
+ self.french_decoder_input_ids = torch.tensor(
47
+ [[tok_id for _, tok_id in forced_ids]],
48
+ device="cuda" if torch.cuda.is_available() else "cpu"
49
  )
50
 
51
  print("Model loaded and optimized successfully!")
 
58
  inputs = data.get("inputs", "")
59
  parameters = data.get("parameters", {})
60
 
 
61
  if isinstance(inputs, str):
62
  try:
63
  audio_bytes = base64.b64decode(inputs)
 
68
  else:
69
  return {"error": "Invalid input format. Expected base64 string or bytes"}
70
 
 
71
  if len(audio_bytes) > 25 * 1024 * 1024:
72
  return {"error": "File too large (max 25MB)"}
73
 
 
74
  audio_array, _ = librosa.load(
75
  io.BytesIO(audio_bytes),
76
  sr=16000,
 
91
  for k, v in model_inputs.items()
92
  }
93
 
 
94
  max_length = parameters.get("max_length", 256)
95
  num_beams = parameters.get("num_beams", 6)
96
  temperature = parameters.get("temperature", 0.0)
97
 
 
98
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
99
  predicted_ids = self.model.generate(
100
  **model_inputs,
101
+ decoder_input_ids=self.french_decoder_input_ids, # ✅ remplace forced_decoder_ids
102
+ forced_decoder_ids=None, # ✅ évite le conflit
103
  max_length=max_length,
104
  num_beams=num_beams,
105
  temperature=temperature,
 
110
  length_penalty=1.0,
111
  use_cache=True,
112
  pad_token_id=self.processor.tokenizer.eos_token_id,
 
113
  suppress_tokens=[],
114
  begin_suppress_tokens=[]
115
  )