Update custom model files, README, and requirements
Browse files- asr_modeling.py +12 -39
asr_modeling.py
CHANGED
|
@@ -698,27 +698,18 @@ class ASRModel(PreTrainedModel):
|
|
| 698 |
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 699 |
prompt_length = expanded_prompt_ids.shape[1]
|
| 700 |
|
| 701 |
-
# Generate
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
else:
|
| 714 |
-
generated_ids = self.decoder.generate(
|
| 715 |
-
input_ids=expanded_prompt_ids,
|
| 716 |
-
inputs_embeds=inputs_embeds,
|
| 717 |
-
attention_mask=attention_mask,
|
| 718 |
-
**generate_kwargs,
|
| 719 |
-
)
|
| 720 |
-
# When not streaming, return only the new tokens (without prompt)
|
| 721 |
-
return generated_ids[:, prompt_length:]
|
| 722 |
|
| 723 |
@torch.no_grad()
|
| 724 |
def generate_stream(
|
|
@@ -753,10 +744,7 @@ class ASRModel(PreTrainedModel):
|
|
| 753 |
# Run generation in a thread with streamer
|
| 754 |
def generation_thread(future: futures.Future):
|
| 755 |
try:
|
| 756 |
-
import sys
|
| 757 |
-
print("DEBUG: Starting generation thread", file=sys.stderr)
|
| 758 |
# Call generate with the streamer
|
| 759 |
-
# Important: This now returns the FULL sequence when streaming
|
| 760 |
result = self.generate(
|
| 761 |
input_values=input_values,
|
| 762 |
input_features=input_features,
|
|
@@ -766,24 +754,18 @@ class ASRModel(PreTrainedModel):
|
|
| 766 |
streamer=streamer,
|
| 767 |
**generate_kwargs,
|
| 768 |
)
|
| 769 |
-
print("DEBUG: Generation complete", file=sys.stderr)
|
| 770 |
future.set_result(result)
|
| 771 |
except Exception as e:
|
| 772 |
-
print(f"DEBUG: Generation error: {e}", file=sys.stderr)
|
| 773 |
future.set_exception(e)
|
| 774 |
|
| 775 |
future: futures.Future = futures.Future()
|
| 776 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 777 |
thread.start()
|
| 778 |
-
print("DEBUG: Thread started", file=sys.stderr)
|
| 779 |
|
| 780 |
# Stream the output - like Ultravox, just yield chunks as they come
|
| 781 |
output_token_count = 0
|
| 782 |
-
import sys
|
| 783 |
-
print("DEBUG: Starting streaming iteration", file=sys.stderr)
|
| 784 |
try:
|
| 785 |
for chunk in streamer:
|
| 786 |
-
print(f"DEBUG: Got chunk: {repr(chunk)}", file=sys.stderr)
|
| 787 |
if chunk: # Only yield non-empty chunks
|
| 788 |
output_token_count += 1
|
| 789 |
yield StreamChunk(chunk)
|
|
@@ -801,15 +783,6 @@ class ASRModel(PreTrainedModel):
|
|
| 801 |
if future.exception():
|
| 802 |
raise future.exception()
|
| 803 |
|
| 804 |
-
# Debug: If no chunks were yielded, check what was generated
|
| 805 |
-
if output_token_count == 0:
|
| 806 |
-
import sys
|
| 807 |
-
result = future.result()
|
| 808 |
-
if result is not None:
|
| 809 |
-
# Note: result now includes the full sequence (including prompt)
|
| 810 |
-
# when streaming, so decode the full thing
|
| 811 |
-
decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
|
| 812 |
-
print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
|
| 813 |
|
| 814 |
# For stats, estimate input tokens (we can't easily get exact count without duplicating work)
|
| 815 |
# Rough estimate: prompt is about 20 tokens + 750 audio tokens
|
|
|
|
| 698 |
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 699 |
prompt_length = expanded_prompt_ids.shape[1]
|
| 700 |
|
| 701 |
+
# Generate (always returns full sequence, caller handles trimming)
|
| 702 |
+
generated_ids = self.decoder.generate(
|
| 703 |
+
input_ids=expanded_prompt_ids,
|
| 704 |
+
inputs_embeds=inputs_embeds,
|
| 705 |
+
attention_mask=attention_mask,
|
| 706 |
+
streamer=streamer,
|
| 707 |
+
**generate_kwargs,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# Always return only the new tokens (without prompt)
|
| 711 |
+
# The streamer already got the full sequence during generation
|
| 712 |
+
return generated_ids[:, prompt_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
@torch.no_grad()
|
| 715 |
def generate_stream(
|
|
|
|
| 744 |
# Run generation in a thread with streamer
|
| 745 |
def generation_thread(future: futures.Future):
|
| 746 |
try:
|
|
|
|
|
|
|
| 747 |
# Call generate with the streamer
|
|
|
|
| 748 |
result = self.generate(
|
| 749 |
input_values=input_values,
|
| 750 |
input_features=input_features,
|
|
|
|
| 754 |
streamer=streamer,
|
| 755 |
**generate_kwargs,
|
| 756 |
)
|
|
|
|
| 757 |
future.set_result(result)
|
| 758 |
except Exception as e:
|
|
|
|
| 759 |
future.set_exception(e)
|
| 760 |
|
| 761 |
future: futures.Future = futures.Future()
|
| 762 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 763 |
thread.start()
|
|
|
|
| 764 |
|
| 765 |
# Stream the output - like Ultravox, just yield chunks as they come
|
| 766 |
output_token_count = 0
|
|
|
|
|
|
|
| 767 |
try:
|
| 768 |
for chunk in streamer:
|
|
|
|
| 769 |
if chunk: # Only yield non-empty chunks
|
| 770 |
output_token_count += 1
|
| 771 |
yield StreamChunk(chunk)
|
|
|
|
| 783 |
if future.exception():
|
| 784 |
raise future.exception()
|
| 785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
|
| 787 |
# For stats, estimate input tokens (we can't easily get exact count without duplicating work)
|
| 788 |
# Rough estimate: prompt is about 20 tokens + 750 audio tokens
|