mazesmazes commited on
Commit
7265045
·
verified ·
1 Parent(s): dda3d8c

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +10 -12
asr_modeling.py CHANGED
@@ -733,30 +733,28 @@ class ASRModel(PreTrainedModel):
733
  """
734
  Stream generation by using the working generate() method with a TextIteratorStreamer.
735
  """
736
- # Set up the streamer
737
- # Note: skip_prompt=True means it won't output the prompt tokens
738
- # This should start streaming from the first NEW generated token
739
  streamer = TextIteratorStreamer(
740
  self.tokenizer,
741
- skip_prompt=True,
742
  skip_special_tokens=True,
743
- timeout=30.0 # Add timeout to prevent hanging
744
  )
745
 
746
- # Count prompt length for stats
747
- # We need to encode just to get the prompt length
748
  audio_inputs = input_values if input_values is not None else input_features
749
  if audio_inputs is None:
750
  raise ValueError("input_values or input_features must be provided")
751
 
752
- # Simple way to get prompt length - just count audio tokens
753
  import threading
754
  from concurrent import futures
755
 
756
  # Run generation in a thread with streamer
757
  def generation_thread(future: futures.Future):
758
  try:
759
- # Just call the working generate method with the streamer
 
760
  result = self.generate(
761
  input_values=input_values,
762
  input_features=input_features,
@@ -774,17 +772,17 @@ class ASRModel(PreTrainedModel):
774
  thread = threading.Thread(target=generation_thread, args=(future,))
775
  thread.start()
776
 
777
- # Stream the output
778
  output_token_count = 0
779
  try:
780
  for chunk in streamer:
781
- if chunk:
782
  output_token_count += 1
783
  yield StreamChunk(chunk)
784
  except Exception as e:
785
  # Check if it's the Empty exception from queue
786
  if e.__class__.__name__ == "Empty":
787
- # This is expected when streaming completes quickly
788
  pass
789
  else:
790
  # Re-raise other exceptions
 
733
  """
734
  Stream generation by using the working generate() method with a TextIteratorStreamer.
735
  """
736
+ # Set up the streamer - use skip_prompt=True like Ultravox
737
+ # The key is that when we return the full sequence from generate(),
738
+ # the streamer can properly identify and skip the prompt
739
  streamer = TextIteratorStreamer(
740
  self.tokenizer,
741
+ skip_prompt=True, # Skip the prompt tokens
742
  skip_special_tokens=True,
743
+ timeout=30.0
744
  )
745
 
 
 
746
  audio_inputs = input_values if input_values is not None else input_features
747
  if audio_inputs is None:
748
  raise ValueError("input_values or input_features must be provided")
749
 
 
750
  import threading
751
  from concurrent import futures
752
 
753
  # Run generation in a thread with streamer
754
  def generation_thread(future: futures.Future):
755
  try:
756
+ # Call generate with the streamer
757
+ # Important: This now returns the FULL sequence when streaming
758
  result = self.generate(
759
  input_values=input_values,
760
  input_features=input_features,
 
772
  thread = threading.Thread(target=generation_thread, args=(future,))
773
  thread.start()
774
 
775
+ # Stream the output - like Ultravox, just yield chunks as they come
776
  output_token_count = 0
777
  try:
778
  for chunk in streamer:
779
+ if chunk: # Only yield non-empty chunks
780
  output_token_count += 1
781
  yield StreamChunk(chunk)
782
  except Exception as e:
783
  # Check if it's the Empty exception from queue
784
  if e.__class__.__name__ == "Empty":
785
+ # This happens when generation completes before we start iterating
786
  pass
787
  else:
788
  # Re-raise other exceptions