Update custom model files, README, and requirements
Browse files- asr_modeling.py +10 -2
asr_modeling.py
CHANGED
|
@@ -710,6 +710,7 @@ class ASRModel(PreTrainedModel):
|
|
| 710 |
|
| 711 |
# Generate with or without streamer
|
| 712 |
if streamer is not None:
|
|
|
|
| 713 |
generated_ids = self.decoder.generate(
|
| 714 |
input_ids=expanded_prompt_ids,
|
| 715 |
inputs_embeds=inputs_embeds,
|
|
@@ -717,13 +718,20 @@ class ASRModel(PreTrainedModel):
|
|
| 717 |
streamer=streamer,
|
| 718 |
**generate_kwargs,
|
| 719 |
)
|
|
|
|
|
|
|
|
|
|
| 720 |
else:
|
|
|
|
| 721 |
generated_ids = self.decoder.generate(
|
| 722 |
input_ids=expanded_prompt_ids,
|
| 723 |
inputs_embeds=inputs_embeds,
|
| 724 |
attention_mask=attention_mask,
|
| 725 |
**generate_kwargs,
|
| 726 |
)
|
|
|
|
|
|
|
|
|
|
| 727 |
|
| 728 |
return generated_ids[:, prompt_length:]
|
| 729 |
|
|
@@ -740,11 +748,11 @@ class ASRModel(PreTrainedModel):
|
|
| 740 |
"""
|
| 741 |
Stream generation by using the working generate() method with a TextIteratorStreamer.
|
| 742 |
"""
|
| 743 |
-
# Set up the streamer
|
| 744 |
streamer = TextIteratorStreamer(
|
| 745 |
self.tokenizer,
|
| 746 |
skip_prompt=True,
|
| 747 |
-
skip_special_tokens=True
|
| 748 |
)
|
| 749 |
|
| 750 |
# Count prompt length for stats
|
|
|
|
| 710 |
|
| 711 |
# Generate with or without streamer
|
| 712 |
if streamer is not None:
|
| 713 |
+
print(f"DEBUG generate: Using streamer", file=sys.stderr)
|
| 714 |
generated_ids = self.decoder.generate(
|
| 715 |
input_ids=expanded_prompt_ids,
|
| 716 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 718 |
streamer=streamer,
|
| 719 |
**generate_kwargs,
|
| 720 |
)
|
| 721 |
+
# Debug what was generated
|
| 722 |
+
generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
|
| 723 |
+
print(f"DEBUG generate with streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
|
| 724 |
else:
|
| 725 |
+
print(f"DEBUG generate: No streamer", file=sys.stderr)
|
| 726 |
generated_ids = self.decoder.generate(
|
| 727 |
input_ids=expanded_prompt_ids,
|
| 728 |
inputs_embeds=inputs_embeds,
|
| 729 |
attention_mask=attention_mask,
|
| 730 |
**generate_kwargs,
|
| 731 |
)
|
| 732 |
+
# Debug what was generated
|
| 733 |
+
generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
|
| 734 |
+
print(f"DEBUG generate without streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
|
| 735 |
|
| 736 |
return generated_ids[:, prompt_length:]
|
| 737 |
|
|
|
|
| 748 |
"""
|
| 749 |
Stream generation by using the working generate() method with a TextIteratorStreamer.
|
| 750 |
"""
|
| 751 |
+
# Set up the streamer - don't skip special tokens as it might affect audio token processing
|
| 752 |
streamer = TextIteratorStreamer(
|
| 753 |
self.tokenizer,
|
| 754 |
skip_prompt=True,
|
| 755 |
+
skip_special_tokens=False # Changed from True - audio token is special
|
| 756 |
)
|
| 757 |
|
| 758 |
# Count prompt length for stats
|