Update custom model files, README, and requirements
Browse files- asr_modeling.py +7 -2
asr_modeling.py
CHANGED
|
@@ -707,6 +707,9 @@ class ASRModel(PreTrainedModel):
|
|
| 707 |
streamer=streamer,
|
| 708 |
**generate_kwargs,
|
| 709 |
)
|
|
|
|
|
|
|
|
|
|
| 710 |
else:
|
| 711 |
generated_ids = self.decoder.generate(
|
| 712 |
input_ids=expanded_prompt_ids,
|
|
@@ -714,8 +717,8 @@ class ASRModel(PreTrainedModel):
|
|
| 714 |
attention_mask=attention_mask,
|
| 715 |
**generate_kwargs,
|
| 716 |
)
|
| 717 |
-
|
| 718 |
-
|
| 719 |
|
| 720 |
@torch.no_grad()
|
| 721 |
def generate_stream(
|
|
@@ -789,6 +792,8 @@ class ASRModel(PreTrainedModel):
|
|
| 789 |
import sys
|
| 790 |
result = future.result()
|
| 791 |
if result is not None:
|
|
|
|
|
|
|
| 792 |
decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
|
| 793 |
print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
|
| 794 |
|
|
|
|
| 707 |
streamer=streamer,
|
| 708 |
**generate_kwargs,
|
| 709 |
)
|
| 710 |
+
# When using a streamer, return the full output (streamer will handle skipping prompt)
|
| 711 |
+
# The streamer needs the full sequence to properly identify what to skip
|
| 712 |
+
return generated_ids
|
| 713 |
else:
|
| 714 |
generated_ids = self.decoder.generate(
|
| 715 |
input_ids=expanded_prompt_ids,
|
|
|
|
| 717 |
attention_mask=attention_mask,
|
| 718 |
**generate_kwargs,
|
| 719 |
)
|
| 720 |
+
# When not streaming, return only the new tokens (without prompt)
|
| 721 |
+
return generated_ids[:, prompt_length:]
|
| 722 |
|
| 723 |
@torch.no_grad()
|
| 724 |
def generate_stream(
|
|
|
|
| 792 |
import sys
|
| 793 |
result = future.result()
|
| 794 |
if result is not None:
|
| 795 |
+
# Note: result now includes the full sequence (including prompt)
|
| 796 |
+
# when streaming, so decode the full thing
|
| 797 |
decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
|
| 798 |
print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
|
| 799 |
|