mazesmazes commited on
Commit
cf359ac
·
verified ·
1 Parent(s): 02845e7

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +15 -9
asr_modeling.py CHANGED
@@ -708,16 +708,22 @@ class ASRModel(PreTrainedModel):
708
  print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
709
  print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
710
 
711
- # Add streamer if provided
712
  if streamer is not None:
713
- generate_kwargs["streamer"] = streamer
714
-
715
- generated_ids = self.decoder.generate(
716
- input_ids=expanded_prompt_ids,
717
- inputs_embeds=inputs_embeds,
718
- attention_mask=attention_mask,
719
- **generate_kwargs,
720
- )
 
 
 
 
 
 
721
 
722
  return generated_ids[:, prompt_length:]
723
 
 
708
  print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
709
  print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
710
 
711
+ # Generate with or without streamer
712
  if streamer is not None:
713
+ generated_ids = self.decoder.generate(
714
+ input_ids=expanded_prompt_ids,
715
+ inputs_embeds=inputs_embeds,
716
+ attention_mask=attention_mask,
717
+ streamer=streamer,
718
+ **generate_kwargs,
719
+ )
720
+ else:
721
+ generated_ids = self.decoder.generate(
722
+ input_ids=expanded_prompt_ids,
723
+ inputs_embeds=inputs_embeds,
724
+ attention_mask=attention_mask,
725
+ **generate_kwargs,
726
+ )
727
 
728
  return generated_ids[:, prompt_length:]
729