iitolstykh commited on
Commit
32a5ef3
·
verified ·
1 Parent(s): 7d37585

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +42 -37
inferencer.py CHANGED
@@ -51,7 +51,8 @@ class InterleaveInferencer:
51
  new_token_ids=self.new_token_ids,
52
  )
53
 
54
- past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
 
55
  gen_context['kv_lens'] = kv_lens
56
  gen_context['ropes'] = ropes
57
  gen_context['past_key_values'] = past_key_values
@@ -76,7 +77,8 @@ class InterleaveInferencer:
76
  transforms=self.vae_transform,
77
  new_token_ids=self.new_token_ids,
78
  )
79
- past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
 
80
 
81
  if vit:
82
  ## update vit
@@ -87,7 +89,8 @@ class InterleaveInferencer:
87
  transforms=self.vit_transform,
88
  new_token_ids=self.new_token_ids,
89
  )
90
- past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
 
91
 
92
  gen_context['kv_lens'] = kv_lens
93
  gen_context['ropes'] = ropes
@@ -143,27 +146,28 @@ class InterleaveInferencer:
143
  image_sizes=[image_shape],
144
  )
145
 
146
- unpacked_latent = self.model.generate_image(
147
- past_key_values=past_key_values,
148
- cfg_text_past_key_values=cfg_text_past_key_values,
149
- cfg_img_past_key_values=cfg_img_past_key_values,
150
- num_timesteps=num_timesteps,
151
- cfg_text_scale=cfg_text_scale,
152
- cfg_img_scale=cfg_img_scale,
153
- cfg_interval=cfg_interval,
154
- cfg_renorm_min=cfg_renorm_min,
155
- cfg_renorm_type=cfg_renorm_type,
156
- timestep_shift=timestep_shift,
157
- **generation_input,
158
- cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
159
- cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
160
- cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
161
- cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
162
- cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
163
- cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
164
- cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
165
- cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
166
- )
 
167
 
168
  image = self.decode_image(unpacked_latent[0], image_shape)
169
  return image
@@ -189,19 +193,20 @@ class InterleaveInferencer:
189
  kv_lens = gen_context['kv_lens']
190
  ropes = gen_context['ropes']
191
 
192
- generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
193
- for unpacked_latent in self.model.generate_text(
194
- past_key_values=past_key_values,
195
- max_length=max_length,
196
- do_sample=do_sample,
197
- temperature=temperature,
198
- end_token_id=self.new_token_ids['eos_token_id'],
199
- **generation_input,
200
- ):
201
- output = self.tokenizer.decode(unpacked_latent)
202
- if output != "<|im_end|>":
203
- yield output
204
-
 
205
  @torch.no_grad()
206
  def interleave_inference(
207
  self,
 
51
  new_token_ids=self.new_token_ids,
52
  )
53
 
54
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
55
+ past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
56
  gen_context['kv_lens'] = kv_lens
57
  gen_context['ropes'] = ropes
58
  gen_context['past_key_values'] = past_key_values
 
77
  transforms=self.vae_transform,
78
  new_token_ids=self.new_token_ids,
79
  )
80
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
81
+ past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
82
 
83
  if vit:
84
  ## update vit
 
89
  transforms=self.vit_transform,
90
  new_token_ids=self.new_token_ids,
91
  )
92
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
93
+ past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
94
 
95
  gen_context['kv_lens'] = kv_lens
96
  gen_context['ropes'] = ropes
 
146
  image_sizes=[image_shape],
147
  )
148
 
149
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
150
+ unpacked_latent = self.model.generate_image(
151
+ past_key_values=past_key_values,
152
+ cfg_text_past_key_values=cfg_text_past_key_values,
153
+ cfg_img_past_key_values=cfg_img_past_key_values,
154
+ num_timesteps=num_timesteps,
155
+ cfg_text_scale=cfg_text_scale,
156
+ cfg_img_scale=cfg_img_scale,
157
+ cfg_interval=cfg_interval,
158
+ cfg_renorm_min=cfg_renorm_min,
159
+ cfg_renorm_type=cfg_renorm_type,
160
+ timestep_shift=timestep_shift,
161
+ **generation_input,
162
+ cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
163
+ cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
164
+ cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
165
+ cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
166
+ cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
167
+ cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
168
+ cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
169
+ cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
170
+ )
171
 
172
  image = self.decode_image(unpacked_latent[0], image_shape)
173
  return image
 
193
  kv_lens = gen_context['kv_lens']
194
  ropes = gen_context['ropes']
195
 
196
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
197
+ generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
198
+ for unpacked_latent in self.model.generate_text(
199
+ past_key_values=past_key_values,
200
+ max_length=max_length,
201
+ do_sample=do_sample,
202
+ temperature=temperature,
203
+ end_token_id=self.new_token_ids['eos_token_id'],
204
+ **generation_input,
205
+ ):
206
+ output = self.tokenizer.decode(unpacked_latent)
207
+ if output != "<|im_end|>":
208
+ yield output
209
+
210
  @torch.no_grad()
211
  def interleave_inference(
212
  self,