mazesmazes commited on
Commit
a8b0a76
·
verified ·
1 Parent(s): 8be8b54

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +7 -2
asr_modeling.py CHANGED
@@ -707,6 +707,9 @@ class ASRModel(PreTrainedModel):
707
  streamer=streamer,
708
  **generate_kwargs,
709
  )
 
 
 
710
  else:
711
  generated_ids = self.decoder.generate(
712
  input_ids=expanded_prompt_ids,
@@ -714,8 +717,8 @@ class ASRModel(PreTrainedModel):
714
  attention_mask=attention_mask,
715
  **generate_kwargs,
716
  )
717
-
718
- return generated_ids[:, prompt_length:]
719
 
720
  @torch.no_grad()
721
  def generate_stream(
@@ -789,6 +792,8 @@ class ASRModel(PreTrainedModel):
789
  import sys
790
  result = future.result()
791
  if result is not None:
 
 
792
  decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
793
  print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
794
 
 
707
  streamer=streamer,
708
  **generate_kwargs,
709
  )
710
+ # When using a streamer, return the full output (streamer will handle skipping prompt)
711
+ # The streamer needs the full sequence to properly identify what to skip
712
+ return generated_ids
713
  else:
714
  generated_ids = self.decoder.generate(
715
  input_ids=expanded_prompt_ids,
 
717
  attention_mask=attention_mask,
718
  **generate_kwargs,
719
  )
720
+ # When not streaming, return only the new tokens (without prompt)
721
+ return generated_ids[:, prompt_length:]
722
 
723
  @torch.no_grad()
724
  def generate_stream(
 
792
  import sys
793
  result = future.result()
794
  if result is not None:
795
+ # Note: result now includes the full sequence (including prompt)
796
+ # when streaming, so decode the full thing
797
  decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
798
  print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
799