Update custom model files, README, and requirements
Browse files- asr_modeling.py +15 -9
asr_modeling.py
CHANGED
|
@@ -708,16 +708,22 @@ class ASRModel(PreTrainedModel):
|
|
| 708 |
print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
|
| 709 |
print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 710 |
|
| 711 |
-
#
|
| 712 |
if streamer is not None:
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
|
| 722 |
return generated_ids[:, prompt_length:]
|
| 723 |
|
|
|
|
| 708 |
print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
|
| 709 |
print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 710 |
|
| 711 |
+
# Generate with or without streamer
|
| 712 |
if streamer is not None:
|
| 713 |
+
generated_ids = self.decoder.generate(
|
| 714 |
+
input_ids=expanded_prompt_ids,
|
| 715 |
+
inputs_embeds=inputs_embeds,
|
| 716 |
+
attention_mask=attention_mask,
|
| 717 |
+
streamer=streamer,
|
| 718 |
+
**generate_kwargs,
|
| 719 |
+
)
|
| 720 |
+
else:
|
| 721 |
+
generated_ids = self.decoder.generate(
|
| 722 |
+
input_ids=expanded_prompt_ids,
|
| 723 |
+
inputs_embeds=inputs_embeds,
|
| 724 |
+
attention_mask=attention_mask,
|
| 725 |
+
**generate_kwargs,
|
| 726 |
+
)
|
| 727 |
|
| 728 |
return generated_ids[:, prompt_length:]
|
| 729 |
|