mazesmazes commited on
Commit
d9e53e1
·
verified ·
1 Parent(s): cf359ac

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +10 -2
asr_modeling.py CHANGED
@@ -710,6 +710,7 @@ class ASRModel(PreTrainedModel):
710
 
711
  # Generate with or without streamer
712
  if streamer is not None:
 
713
  generated_ids = self.decoder.generate(
714
  input_ids=expanded_prompt_ids,
715
  inputs_embeds=inputs_embeds,
@@ -717,13 +718,20 @@ class ASRModel(PreTrainedModel):
717
  streamer=streamer,
718
  **generate_kwargs,
719
  )
 
 
 
720
  else:
 
721
  generated_ids = self.decoder.generate(
722
  input_ids=expanded_prompt_ids,
723
  inputs_embeds=inputs_embeds,
724
  attention_mask=attention_mask,
725
  **generate_kwargs,
726
  )
 
 
 
727
 
728
  return generated_ids[:, prompt_length:]
729
 
@@ -740,11 +748,11 @@ class ASRModel(PreTrainedModel):
740
  """
741
  Stream generation by using the working generate() method with a TextIteratorStreamer.
742
  """
743
- # Set up the streamer
744
  streamer = TextIteratorStreamer(
745
  self.tokenizer,
746
  skip_prompt=True,
747
- skip_special_tokens=True
748
  )
749
 
750
  # Count prompt length for stats
 
710
 
711
  # Generate with or without streamer
712
  if streamer is not None:
713
+ print(f"DEBUG generate: Using streamer", file=sys.stderr)
714
  generated_ids = self.decoder.generate(
715
  input_ids=expanded_prompt_ids,
716
  inputs_embeds=inputs_embeds,
 
718
  streamer=streamer,
719
  **generate_kwargs,
720
  )
721
+ # Debug what was generated
722
+ generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
723
+ print(f"DEBUG generate with streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
724
  else:
725
+ print(f"DEBUG generate: No streamer", file=sys.stderr)
726
  generated_ids = self.decoder.generate(
727
  input_ids=expanded_prompt_ids,
728
  inputs_embeds=inputs_embeds,
729
  attention_mask=attention_mask,
730
  **generate_kwargs,
731
  )
732
+ # Debug what was generated
733
+ generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
734
+ print(f"DEBUG generate without streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
735
 
736
  return generated_ids[:, prompt_length:]
737
 
 
748
  """
749
  Stream generation by using the working generate() method with a TextIteratorStreamer.
750
  """
751
+ # Set up the streamer - don't skip special tokens as it might affect audio token processing
752
  streamer = TextIteratorStreamer(
753
  self.tokenizer,
754
  skip_prompt=True,
755
+ skip_special_tokens=False # Changed from True - audio token is special
756
  )
757
 
758
  # Count prompt length for stats