Upload handler.py
Browse files- handler.py +3 -1
handler.py
CHANGED
|
@@ -124,6 +124,9 @@ class EndpointHandler:
|
|
| 124 |
|
| 125 |
# Generate transcription with anti-hallucination parameters
|
| 126 |
with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
|
|
|
|
|
|
|
|
| 127 |
predicted_ids = self.model.generate(
|
| 128 |
**model_inputs,
|
| 129 |
max_length=max_length,
|
|
@@ -136,7 +139,6 @@ class EndpointHandler:
|
|
| 136 |
length_penalty=1.0,
|
| 137 |
use_cache=True,
|
| 138 |
pad_token_id=self.processor.tokenizer.eos_token_id,
|
| 139 |
-
forced_decoder_ids=self.french_decoder_ids,
|
| 140 |
suppress_tokens=[],
|
| 141 |
begin_suppress_tokens=[]
|
| 142 |
)
|
|
|
|
| 124 |
|
| 125 |
# Generate transcription with anti-hallucination parameters
|
| 126 |
with torch.no_grad(), torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 127 |
+
# Add language forcing to inputs instead of generation params
|
| 128 |
+
model_inputs.update(self.processor.get_decoder_prompt_ids(language="french", task="transcribe"))
|
| 129 |
+
|
| 130 |
predicted_ids = self.model.generate(
|
| 131 |
**model_inputs,
|
| 132 |
max_length=max_length,
|
|
|
|
| 139 |
length_penalty=1.0,
|
| 140 |
use_cache=True,
|
| 141 |
pad_token_id=self.processor.tokenizer.eos_token_id,
|
|
|
|
| 142 |
suppress_tokens=[],
|
| 143 |
begin_suppress_tokens=[]
|
| 144 |
)
|