SoSolaris commited on
Commit
f7c80df
·
verified ·
1 Parent(s): 2bcb732

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -8
handler.py CHANGED
@@ -40,12 +40,9 @@ class EndpointHandler:
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
- # pre-compute decoder ids for french
44
- self.french_decoder_ids = torch.tensor(
45
- self.processor.get_decoder_prompt_ids(
46
- language="french", task="transcribe"
47
- ),
48
- device="cuda" if torch.cuda.is_available() else "cpu"
49
  )
50
 
51
  print("Model loaded and optimized successfully!")
@@ -58,6 +55,7 @@ class EndpointHandler:
58
  inputs = data.get("inputs", "")
59
  parameters = data.get("parameters", {})
60
 
 
61
  if isinstance(inputs, str):
62
  try:
63
  audio_bytes = base64.b64decode(inputs)
@@ -68,9 +66,11 @@ class EndpointHandler:
68
  else:
69
  return {"error": "Invalid input format. Expected base64 string or bytes"}
70
 
 
71
  if len(audio_bytes) > 25 * 1024 * 1024:
72
  return {"error": "File too large (max 25MB)"}
73
 
 
74
  audio_array, _ = librosa.load(
75
  io.BytesIO(audio_bytes),
76
  sr=16000,
@@ -86,20 +86,20 @@ class EndpointHandler:
86
  sampling_rate=16000,
87
  return_tensors="pt"
88
  )
89
-
90
  model_inputs = {
91
  k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
92
  for k, v in model_inputs.items()
93
  }
94
 
 
95
  max_length = parameters.get("max_length", 256)
96
  num_beams = parameters.get("num_beams", 6)
97
  temperature = parameters.get("temperature", 0.0)
98
 
 
99
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
100
  predicted_ids = self.model.generate(
101
  **model_inputs,
102
- decoder_input_ids=self.french_decoder_ids,
103
  max_length=max_length,
104
  num_beams=num_beams,
105
  temperature=temperature,
@@ -110,6 +110,7 @@ class EndpointHandler:
110
  length_penalty=1.0,
111
  use_cache=True,
112
  pad_token_id=self.processor.tokenizer.eos_token_id,
 
113
  suppress_tokens=[],
114
  begin_suppress_tokens=[]
115
  )
 
40
  except Exception as e2:
41
  print(f"Compilation failed: {e2}")
42
 
43
+ # forced_decoder_ids pour français (comme fastapi)
44
+ self.french_decoder_ids = self.processor.get_decoder_prompt_ids(
45
+ language="french", task="transcribe"
 
 
 
46
  )
47
 
48
  print("Model loaded and optimized successfully!")
 
55
  inputs = data.get("inputs", "")
56
  parameters = data.get("parameters", {})
57
 
58
+ # decode audio (base64 string or bytes)
59
  if isinstance(inputs, str):
60
  try:
61
  audio_bytes = base64.b64decode(inputs)
 
66
  else:
67
  return {"error": "Invalid input format. Expected base64 string or bytes"}
68
 
69
+ # check size
70
  if len(audio_bytes) > 25 * 1024 * 1024:
71
  return {"error": "File too large (max 25MB)"}
72
 
73
+ # load audio
74
  audio_array, _ = librosa.load(
75
  io.BytesIO(audio_bytes),
76
  sr=16000,
 
86
  sampling_rate=16000,
87
  return_tensors="pt"
88
  )
 
89
  model_inputs = {
90
  k: v.to(self.model.device).half() if v.dtype == torch.float32 else v.to(self.model.device)
91
  for k, v in model_inputs.items()
92
  }
93
 
94
+ # params
95
  max_length = parameters.get("max_length", 256)
96
  num_beams = parameters.get("num_beams", 6)
97
  temperature = parameters.get("temperature", 0.0)
98
 
99
+ # generate
100
  with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
101
  predicted_ids = self.model.generate(
102
  **model_inputs,
 
103
  max_length=max_length,
104
  num_beams=num_beams,
105
  temperature=temperature,
 
110
  length_penalty=1.0,
111
  use_cache=True,
112
  pad_token_id=self.processor.tokenizer.eos_token_id,
113
+ forced_decoder_ids=self.french_decoder_ids, # ✅ identique à fastapi
114
  suppress_tokens=[],
115
  begin_suppress_tokens=[]
116
  )