Spaces:
Runtime error
Runtime error
Update inferencer.py
Browse files- inferencer.py +42 -37
inferencer.py
CHANGED
|
@@ -51,7 +51,8 @@ class InterleaveInferencer:
|
|
| 51 |
new_token_ids=self.new_token_ids,
|
| 52 |
)
|
| 53 |
|
| 54 |
-
|
|
|
|
| 55 |
gen_context['kv_lens'] = kv_lens
|
| 56 |
gen_context['ropes'] = ropes
|
| 57 |
gen_context['past_key_values'] = past_key_values
|
|
@@ -76,7 +77,8 @@ class InterleaveInferencer:
|
|
| 76 |
transforms=self.vae_transform,
|
| 77 |
new_token_ids=self.new_token_ids,
|
| 78 |
)
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
if vit:
|
| 82 |
## update vit
|
|
@@ -87,7 +89,8 @@ class InterleaveInferencer:
|
|
| 87 |
transforms=self.vit_transform,
|
| 88 |
new_token_ids=self.new_token_ids,
|
| 89 |
)
|
| 90 |
-
|
|
|
|
| 91 |
|
| 92 |
gen_context['kv_lens'] = kv_lens
|
| 93 |
gen_context['ropes'] = ropes
|
|
@@ -143,27 +146,28 @@ class InterleaveInferencer:
|
|
| 143 |
image_sizes=[image_shape],
|
| 144 |
)
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
| 167 |
|
| 168 |
image = self.decode_image(unpacked_latent[0], image_shape)
|
| 169 |
return image
|
|
@@ -189,19 +193,20 @@ class InterleaveInferencer:
|
|
| 189 |
kv_lens = gen_context['kv_lens']
|
| 190 |
ropes = gen_context['ropes']
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
| 205 |
@torch.no_grad()
|
| 206 |
def interleave_inference(
|
| 207 |
self,
|
|
|
|
| 51 |
new_token_ids=self.new_token_ids,
|
| 52 |
)
|
| 53 |
|
| 54 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 55 |
+
past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
|
| 56 |
gen_context['kv_lens'] = kv_lens
|
| 57 |
gen_context['ropes'] = ropes
|
| 58 |
gen_context['past_key_values'] = past_key_values
|
|
|
|
| 77 |
transforms=self.vae_transform,
|
| 78 |
new_token_ids=self.new_token_ids,
|
| 79 |
)
|
| 80 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 81 |
+
past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
|
| 82 |
|
| 83 |
if vit:
|
| 84 |
## update vit
|
|
|
|
| 89 |
transforms=self.vit_transform,
|
| 90 |
new_token_ids=self.new_token_ids,
|
| 91 |
)
|
| 92 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 93 |
+
past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
|
| 94 |
|
| 95 |
gen_context['kv_lens'] = kv_lens
|
| 96 |
gen_context['ropes'] = ropes
|
|
|
|
| 146 |
image_sizes=[image_shape],
|
| 147 |
)
|
| 148 |
|
| 149 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 150 |
+
unpacked_latent = self.model.generate_image(
|
| 151 |
+
past_key_values=past_key_values,
|
| 152 |
+
cfg_text_past_key_values=cfg_text_past_key_values,
|
| 153 |
+
cfg_img_past_key_values=cfg_img_past_key_values,
|
| 154 |
+
num_timesteps=num_timesteps,
|
| 155 |
+
cfg_text_scale=cfg_text_scale,
|
| 156 |
+
cfg_img_scale=cfg_img_scale,
|
| 157 |
+
cfg_interval=cfg_interval,
|
| 158 |
+
cfg_renorm_min=cfg_renorm_min,
|
| 159 |
+
cfg_renorm_type=cfg_renorm_type,
|
| 160 |
+
timestep_shift=timestep_shift,
|
| 161 |
+
**generation_input,
|
| 162 |
+
cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
|
| 163 |
+
cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
|
| 164 |
+
cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
|
| 165 |
+
cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
|
| 166 |
+
cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
|
| 167 |
+
cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
|
| 168 |
+
cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
|
| 169 |
+
cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
|
| 170 |
+
)
|
| 171 |
|
| 172 |
image = self.decode_image(unpacked_latent[0], image_shape)
|
| 173 |
return image
|
|
|
|
| 193 |
kv_lens = gen_context['kv_lens']
|
| 194 |
ropes = gen_context['ropes']
|
| 195 |
|
| 196 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 197 |
+
generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
|
| 198 |
+
for unpacked_latent in self.model.generate_text(
|
| 199 |
+
past_key_values=past_key_values,
|
| 200 |
+
max_length=max_length,
|
| 201 |
+
do_sample=do_sample,
|
| 202 |
+
temperature=temperature,
|
| 203 |
+
end_token_id=self.new_token_ids['eos_token_id'],
|
| 204 |
+
**generation_input,
|
| 205 |
+
):
|
| 206 |
+
output = self.tokenizer.decode(unpacked_latent)
|
| 207 |
+
if output != "<|im_end|>":
|
| 208 |
+
yield output
|
| 209 |
+
|
| 210 |
@torch.no_grad()
|
| 211 |
def interleave_inference(
|
| 212 |
self,
|