mazesmazes commited on
Commit
8eec783
·
verified ·
1 Parent(s): ae56aa2

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +12 -39
asr_modeling.py CHANGED
@@ -698,27 +698,18 @@ class ASRModel(PreTrainedModel):
698
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
699
  prompt_length = expanded_prompt_ids.shape[1]
700
 
701
- # Generate with or without streamer
702
- if streamer is not None:
703
- generated_ids = self.decoder.generate(
704
- input_ids=expanded_prompt_ids,
705
- inputs_embeds=inputs_embeds,
706
- attention_mask=attention_mask,
707
- streamer=streamer,
708
- **generate_kwargs,
709
- )
710
- # When using a streamer, return the full output (streamer will handle skipping prompt)
711
- # The streamer needs the full sequence to properly identify what to skip
712
- return generated_ids
713
- else:
714
- generated_ids = self.decoder.generate(
715
- input_ids=expanded_prompt_ids,
716
- inputs_embeds=inputs_embeds,
717
- attention_mask=attention_mask,
718
- **generate_kwargs,
719
- )
720
- # When not streaming, return only the new tokens (without prompt)
721
- return generated_ids[:, prompt_length:]
722
 
723
  @torch.no_grad()
724
  def generate_stream(
@@ -753,10 +744,7 @@ class ASRModel(PreTrainedModel):
753
  # Run generation in a thread with streamer
754
  def generation_thread(future: futures.Future):
755
  try:
756
- import sys
757
- print("DEBUG: Starting generation thread", file=sys.stderr)
758
  # Call generate with the streamer
759
- # Important: This now returns the FULL sequence when streaming
760
  result = self.generate(
761
  input_values=input_values,
762
  input_features=input_features,
@@ -766,24 +754,18 @@ class ASRModel(PreTrainedModel):
766
  streamer=streamer,
767
  **generate_kwargs,
768
  )
769
- print("DEBUG: Generation complete", file=sys.stderr)
770
  future.set_result(result)
771
  except Exception as e:
772
- print(f"DEBUG: Generation error: {e}", file=sys.stderr)
773
  future.set_exception(e)
774
 
775
  future: futures.Future = futures.Future()
776
  thread = threading.Thread(target=generation_thread, args=(future,))
777
  thread.start()
778
- print("DEBUG: Thread started", file=sys.stderr)
779
 
780
  # Stream the output - like Ultravox, just yield chunks as they come
781
  output_token_count = 0
782
- import sys
783
- print("DEBUG: Starting streaming iteration", file=sys.stderr)
784
  try:
785
  for chunk in streamer:
786
- print(f"DEBUG: Got chunk: {repr(chunk)}", file=sys.stderr)
787
  if chunk: # Only yield non-empty chunks
788
  output_token_count += 1
789
  yield StreamChunk(chunk)
@@ -801,15 +783,6 @@ class ASRModel(PreTrainedModel):
801
  if future.exception():
802
  raise future.exception()
803
 
804
- # Debug: If no chunks were yielded, check what was generated
805
- if output_token_count == 0:
806
- import sys
807
- result = future.result()
808
- if result is not None:
809
- # Note: result now includes the full sequence (including prompt)
810
- # when streaming, so decode the full thing
811
- decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
812
- print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
813
 
814
  # For stats, estimate input tokens (we can't easily get exact count without duplicating work)
815
  # Rough estimate: prompt is about 20 tokens + 750 audio tokens
 
698
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
699
  prompt_length = expanded_prompt_ids.shape[1]
700
 
701
+ # Generate (always returns full sequence, caller handles trimming)
702
+ generated_ids = self.decoder.generate(
703
+ input_ids=expanded_prompt_ids,
704
+ inputs_embeds=inputs_embeds,
705
+ attention_mask=attention_mask,
706
+ streamer=streamer,
707
+ **generate_kwargs,
708
+ )
709
+
710
+ # Always return only the new tokens (without prompt)
711
+ # The streamer already got the full sequence during generation
712
+ return generated_ids[:, prompt_length:]
 
 
 
 
 
 
 
 
 
713
 
714
  @torch.no_grad()
715
  def generate_stream(
 
744
  # Run generation in a thread with streamer
745
  def generation_thread(future: futures.Future):
746
  try:
 
 
747
  # Call generate with the streamer
 
748
  result = self.generate(
749
  input_values=input_values,
750
  input_features=input_features,
 
754
  streamer=streamer,
755
  **generate_kwargs,
756
  )
 
757
  future.set_result(result)
758
  except Exception as e:
 
759
  future.set_exception(e)
760
 
761
  future: futures.Future = futures.Future()
762
  thread = threading.Thread(target=generation_thread, args=(future,))
763
  thread.start()
 
764
 
765
  # Stream the output - like Ultravox, just yield chunks as they come
766
  output_token_count = 0
 
 
767
  try:
768
  for chunk in streamer:
 
769
  if chunk: # Only yield non-empty chunks
770
  output_token_count += 1
771
  yield StreamChunk(chunk)
 
783
  if future.exception():
784
  raise future.exception()
785
 
 
 
 
 
 
 
 
 
 
786
 
787
  # For stats, estimate input tokens (we can't easily get exact count without duplicating work)
788
  # Rough estimate: prompt is about 20 tokens + 750 audio tokens