para-lost commited on
Commit
85ee426
·
1 Parent(s): 3ad4866

fix bug on device

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -1
pipeline.py CHANGED
@@ -6819,8 +6819,9 @@ class InterleaveInferencer:
6819
  past_key_values = gen_context['past_key_values']
6820
  kv_lens = gen_context['kv_lens']
6821
  ropes = gen_context['ropes']
6822
-
6823
  generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
 
6824
  unpacked_latent = self.model.generate_text(
6825
  past_key_values=past_key_values,
6826
  max_length=max_length,
 
6819
  past_key_values = gen_context['past_key_values']
6820
  kv_lens = gen_context['kv_lens']
6821
  ropes = gen_context['ropes']
6822
+ device = next(self.model.parameters()).device
6823
  generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
6824
+ generation_input = self._to_device(generation_input, device)
6825
  unpacked_latent = self.model.generate_text(
6826
  past_key_values=past_key_values,
6827
  max_length=max_length,