Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -150,15 +150,14 @@ models_rbm.generator.eval().requires_grad_(False)
|
|
| 150 |
|
| 151 |
def infer(style_description, ref_style_file, caption):
|
| 152 |
try:
|
| 153 |
-
#
|
| 154 |
-
models_rbm.effnet.to(device)
|
| 155 |
-
models_rbm.previewer.to(device)
|
| 156 |
-
models_rbm.generator.to(device)
|
| 157 |
-
models_rbm.text_model.to(device)
|
| 158 |
|
| 159 |
-
|
| 160 |
-
models_b.
|
| 161 |
-
models_b.stage_a.to(device)
|
| 162 |
|
| 163 |
clear_gpu_cache() # Clear cache before inference
|
| 164 |
|
|
@@ -179,13 +178,11 @@ def infer(style_description, ref_style_file, caption):
|
|
| 179 |
extras_b.sampling_configs['timesteps'] = 10
|
| 180 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 181 |
|
| 182 |
-
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
|
| 183 |
|
| 184 |
batch = {'captions': [caption] * batch_size}
|
| 185 |
batch['style'] = ref_style
|
| 186 |
|
| 187 |
-
# Ensure effnet is on the same device as the input
|
| 188 |
-
models_rbm.effnet.to(device)
|
| 189 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
| 190 |
|
| 191 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|
|
@@ -198,7 +195,7 @@ def infer(style_description, ref_style_file, caption):
|
|
| 198 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 199 |
|
| 200 |
# Stage C reverse process
|
| 201 |
-
with torch.cuda.amp.autocast(): # Use mixed precision
|
| 202 |
sampling_c = extras.gdf.sample(
|
| 203 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
| 204 |
unconditions, device=device,
|
|
@@ -216,7 +213,7 @@ def infer(style_description, ref_style_file, caption):
|
|
| 216 |
clear_gpu_cache() # Clear cache between stages
|
| 217 |
|
| 218 |
# Ensure all models are on the right device again
|
| 219 |
-
models_b.generator.to(device)
|
| 220 |
|
| 221 |
# Stage B reverse process
|
| 222 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
|
|
| 150 |
|
| 151 |
def infer(style_description, ref_style_file, caption):
|
| 152 |
try:
|
| 153 |
+
# Move all model components to the same device and set to the same precision
|
| 154 |
+
models_rbm.effnet.to(device).bfloat16()
|
| 155 |
+
models_rbm.previewer.to(device).bfloat16()
|
| 156 |
+
models_rbm.generator.to(device).bfloat16()
|
| 157 |
+
models_rbm.text_model.to(device).bfloat16()
|
| 158 |
|
| 159 |
+
models_b.generator.to(device).bfloat16()
|
| 160 |
+
models_b.stage_a.to(device).bfloat16()
|
|
|
|
| 161 |
|
| 162 |
clear_gpu_cache() # Clear cache before inference
|
| 163 |
|
|
|
|
| 178 |
extras_b.sampling_configs['timesteps'] = 10
|
| 179 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 180 |
|
| 181 |
+
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device).bfloat16()
|
| 182 |
|
| 183 |
batch = {'captions': [caption] * batch_size}
|
| 184 |
batch['style'] = ref_style
|
| 185 |
|
|
|
|
|
|
|
| 186 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
| 187 |
|
| 188 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|
|
|
|
| 195 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 196 |
|
| 197 |
# Stage C reverse process
|
| 198 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Use mixed precision with bfloat16
|
| 199 |
sampling_c = extras.gdf.sample(
|
| 200 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
| 201 |
unconditions, device=device,
|
|
|
|
| 213 |
clear_gpu_cache() # Clear cache between stages
|
| 214 |
|
| 215 |
# Ensure all models are on the right device again
|
| 216 |
+
models_b.generator.to(device).bfloat16()
|
| 217 |
|
| 218 |
# Stage B reverse process
|
| 219 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|