fix: stream in chat not return early

#16
by airlsyn - opened
Files changed (1) hide show
  1. modeling_minicpmo.py +5 -0
modeling_minicpmo.py CHANGED
@@ -1176,6 +1176,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1176
  max_length=max_inp_length,
1177
  ).to(self.device)
1178
 
 
 
1179
  generation_config = self.prepare_generation_config(
1180
  do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs
1181
  )
@@ -1194,6 +1196,9 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1194
  **generation_config,
1195
  )
1196
 
 
 
 
1197
  # spk bound and tts bound
1198
  tts_bos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
1199
  tts_eos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
 
1176
  max_length=max_inp_length,
1177
  ).to(self.device)
1178
 
1179
+ if stream:
1180
+ kwargs["num_beams"] = 1
1181
  generation_config = self.prepare_generation_config(
1182
  do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs
1183
  )
 
1196
  **generation_config,
1197
  )
1198
 
1199
+ if stream:
1200
+ return res
1201
+
1202
  # spk bound and tts bound
1203
  tts_bos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
1204
  tts_eos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>")