Update custom model files, README, and requirements
Browse files- asr_modeling.py +10 -12
asr_modeling.py
CHANGED
|
@@ -733,30 +733,28 @@ class ASRModel(PreTrainedModel):
|
|
| 733 |
"""
|
| 734 |
Stream generation by using the working generate() method with a TextIteratorStreamer.
|
| 735 |
"""
|
| 736 |
-
# Set up the streamer
|
| 737 |
-
#
|
| 738 |
-
#
|
| 739 |
streamer = TextIteratorStreamer(
|
| 740 |
self.tokenizer,
|
| 741 |
-
skip_prompt=True,
|
| 742 |
skip_special_tokens=True,
|
| 743 |
-
timeout=30.0
|
| 744 |
)
|
| 745 |
|
| 746 |
-
# Count prompt length for stats
|
| 747 |
-
# We need to encode just to get the prompt length
|
| 748 |
audio_inputs = input_values if input_values is not None else input_features
|
| 749 |
if audio_inputs is None:
|
| 750 |
raise ValueError("input_values or input_features must be provided")
|
| 751 |
|
| 752 |
-
# Simple way to get prompt length - just count audio tokens
|
| 753 |
import threading
|
| 754 |
from concurrent import futures
|
| 755 |
|
| 756 |
# Run generation in a thread with streamer
|
| 757 |
def generation_thread(future: futures.Future):
|
| 758 |
try:
|
| 759 |
-
#
|
|
|
|
| 760 |
result = self.generate(
|
| 761 |
input_values=input_values,
|
| 762 |
input_features=input_features,
|
|
@@ -774,17 +772,17 @@ class ASRModel(PreTrainedModel):
|
|
| 774 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 775 |
thread.start()
|
| 776 |
|
| 777 |
-
# Stream the output
|
| 778 |
output_token_count = 0
|
| 779 |
try:
|
| 780 |
for chunk in streamer:
|
| 781 |
-
if chunk:
|
| 782 |
output_token_count += 1
|
| 783 |
yield StreamChunk(chunk)
|
| 784 |
except Exception as e:
|
| 785 |
# Check if it's the Empty exception from queue
|
| 786 |
if e.__class__.__name__ == "Empty":
|
| 787 |
-
# This
|
| 788 |
pass
|
| 789 |
else:
|
| 790 |
# Re-raise other exceptions
|
|
|
|
| 733 |
"""
|
| 734 |
Stream generation by using the working generate() method with a TextIteratorStreamer.
|
| 735 |
"""
|
| 736 |
+
# Set up the streamer - use skip_prompt=True like Ultravox
|
| 737 |
+
# The key is that when we return the full sequence from generate(),
|
| 738 |
+
# the streamer can properly identify and skip the prompt
|
| 739 |
streamer = TextIteratorStreamer(
|
| 740 |
self.tokenizer,
|
| 741 |
+
skip_prompt=True, # Skip the prompt tokens
|
| 742 |
skip_special_tokens=True,
|
| 743 |
+
timeout=30.0
|
| 744 |
)
|
| 745 |
|
|
|
|
|
|
|
| 746 |
audio_inputs = input_values if input_values is not None else input_features
|
| 747 |
if audio_inputs is None:
|
| 748 |
raise ValueError("input_values or input_features must be provided")
|
| 749 |
|
|
|
|
| 750 |
import threading
|
| 751 |
from concurrent import futures
|
| 752 |
|
| 753 |
# Run generation in a thread with streamer
|
| 754 |
def generation_thread(future: futures.Future):
|
| 755 |
try:
|
| 756 |
+
# Call generate with the streamer
|
| 757 |
+
# Important: This now returns the FULL sequence when streaming
|
| 758 |
result = self.generate(
|
| 759 |
input_values=input_values,
|
| 760 |
input_features=input_features,
|
|
|
|
| 772 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 773 |
thread.start()
|
| 774 |
|
| 775 |
+
# Stream the output - like Ultravox, just yield chunks as they come
|
| 776 |
output_token_count = 0
|
| 777 |
try:
|
| 778 |
for chunk in streamer:
|
| 779 |
+
if chunk: # Only yield non-empty chunks
|
| 780 |
output_token_count += 1
|
| 781 |
yield StreamChunk(chunk)
|
| 782 |
except Exception as e:
|
| 783 |
# Check if it's the Empty exception from queue
|
| 784 |
if e.__class__.__name__ == "Empty":
|
| 785 |
+
# This happens when generation completes before we start iterating
|
| 786 |
pass
|
| 787 |
else:
|
| 788 |
# Re-raise other exceptions
|