Spaces:
Sleeping
Sleeping
fix(app): align input tensor dtypes with model dtypes during inference
Browse files
app.py
CHANGED
|
@@ -28,11 +28,21 @@ from gdf.schedulers import CosineSchedule
|
|
| 28 |
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
|
| 29 |
from gdf.targets import EpsilonTarget
|
| 30 |
import PIL
|
|
|
|
| 31 |
|
| 32 |
# Device configuration
|
| 33 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 34 |
print(device)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Flag for low VRAM usage
|
| 37 |
# low_vram = False
|
| 38 |
|
|
@@ -53,10 +63,11 @@ def models_to(model, device="cpu", excepts=None):
|
|
| 53 |
continue
|
| 54 |
print(f"Change device of '{attr_name}' to {device}")
|
| 55 |
attr_value.to(device)
|
| 56 |
-
|
| 57 |
torch.cuda.empty_cache()
|
| 58 |
gc.collect()
|
| 59 |
|
|
|
|
| 60 |
# Stage C model configuration
|
| 61 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
| 62 |
with open(config_file, "r", encoding="utf-8") as file:
|
|
@@ -68,7 +79,7 @@ core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
|
|
| 68 |
config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
|
| 69 |
with open(config_file_b, "r", encoding="utf-8") as file:
|
| 70 |
config_file_b = yaml.safe_load(file)
|
| 71 |
-
|
| 72 |
core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
|
| 73 |
|
| 74 |
# Setup extras and models for Stage C
|
|
@@ -129,20 +140,20 @@ models_rbm = core.Models(
|
|
| 129 |
models_rbm.generator.eval().requires_grad_(False)
|
| 130 |
|
| 131 |
|
| 132 |
-
|
| 133 |
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
| 134 |
global models_rbm, models_b, device
|
| 135 |
-
|
| 136 |
models_to(models_rbm, device=device)
|
| 137 |
-
|
| 138 |
try:
|
| 139 |
-
|
| 140 |
caption = f"{caption} in {style_description}"
|
| 141 |
-
height=1024
|
| 142 |
-
width=1024
|
| 143 |
-
batch_size=1
|
| 144 |
-
|
| 145 |
-
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(
|
|
|
|
|
|
|
| 146 |
|
| 147 |
extras.sampling_configs['cfg'] = 4
|
| 148 |
extras.sampling_configs['shift'] = 2
|
|
@@ -155,26 +166,46 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
|
| 155 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 156 |
|
| 157 |
progress(0.1, "Loading style reference image")
|
| 158 |
-
ref_style = resize_image(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
batch = {'captions': [caption] * batch_size}
|
| 161 |
-
batch['style'] =
|
| 162 |
|
| 163 |
progress(0.2, "Processing style reference image")
|
| 164 |
-
x0_style_forward = models_rbm.effnet(
|
|
|
|
|
|
|
| 165 |
|
| 166 |
progress(0.3, "Generating conditions")
|
| 167 |
-
conditions = core.get_conditions(
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
if use_low_vram:
|
| 173 |
# The sampling process uses more vram, so we offload everything except two modules to the cpu.
|
| 174 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 175 |
|
| 176 |
progress(0.4, "Starting Stage C reverse process")
|
| 177 |
-
# Stage C reverse process.
|
| 178 |
sampling_c = extras.gdf.sample(
|
| 179 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
| 180 |
unconditions, device=device,
|
|
@@ -186,74 +217,73 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
|
| 186 |
lam_style=1, lam_txt_alignment=1.0,
|
| 187 |
use_ddim_sampler=True,
|
| 188 |
)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
| 193 |
sampled_c = sampled_c
|
| 194 |
|
| 195 |
progress(0.7, "Starting Stage B reverse process")
|
| 196 |
-
|
| 197 |
-
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 198 |
conditions_b['effnet'] = sampled_c
|
| 199 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
| 200 |
-
|
| 201 |
sampling_b = extras_b.gdf.sample(
|
| 202 |
models_b.generator, conditions_b, stage_b_latent_shape,
|
| 203 |
unconditions_b, device=device, **extras_b.sampling_configs,
|
| 204 |
)
|
| 205 |
-
for sampled_b, _, _ in progress.tqdm(
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
sampled_b = sampled_b
|
| 210 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 211 |
|
| 212 |
torch.cuda.empty_cache()
|
| 213 |
gc.collect()
|
| 214 |
-
|
| 215 |
progress(0.9, "Finalizing the output image")
|
| 216 |
sampled = torch.cat([
|
| 217 |
-
torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
|
| 218 |
-
sampled.cpu(),
|
| 219 |
], dim=0)
|
| 220 |
|
| 221 |
-
|
| 222 |
-
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
| 223 |
-
|
| 224 |
-
# Ensure the tensor values are in the correct range
|
| 225 |
sampled = torch.clamp(sampled, 0, 1)
|
| 226 |
|
| 227 |
-
# Ensure the tensor is in [C, H, W] format
|
| 228 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 229 |
-
sampled_image = T.ToPILImage()(sampled)
|
| 230 |
else:
|
| 231 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 232 |
|
| 233 |
progress(1.0, "Inference complete")
|
| 234 |
-
return sampled_image
|
| 235 |
|
| 236 |
finally:
|
| 237 |
if use_low_vram:
|
| 238 |
models_to(models_rbm, device=device)
|
| 239 |
-
# Clear CUDA cache
|
| 240 |
torch.cuda.empty_cache()
|
| 241 |
gc.collect()
|
| 242 |
|
|
|
|
| 243 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
| 244 |
global models_rbm, models_b, device
|
| 245 |
sam_model = LangSAM()
|
| 246 |
models_to(models_rbm, device=device)
|
| 247 |
models_to(sam_model, device=device)
|
| 248 |
models_to(sam_model.sam, device=device)
|
|
|
|
| 249 |
try:
|
| 250 |
caption = f"{caption} in {style_description}"
|
| 251 |
sam_prompt = f"{caption}"
|
| 252 |
use_sam_mask = False
|
| 253 |
-
|
| 254 |
batch_size = 1
|
| 255 |
height, width = 1024, 1024
|
| 256 |
-
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(
|
|
|
|
|
|
|
| 257 |
|
| 258 |
extras.sampling_configs['cfg'] = 4
|
| 259 |
extras.sampling_configs['shift'] = 2
|
|
@@ -265,31 +295,58 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
|
|
| 265 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 266 |
|
| 267 |
progress(0.1, "Loading style and subject reference images")
|
| 268 |
-
ref_style = resize_image(
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
batch = {'captions': [caption] * batch_size}
|
| 272 |
-
batch['style'] =
|
| 273 |
-
batch['images'] =
|
| 274 |
|
| 275 |
progress(0.2, "Processing reference images")
|
| 276 |
-
x0_forward = models_rbm.effnet(
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
| 280 |
use_sam_mask = False
|
| 281 |
x0_preview = models_rbm.previewer(x0_forward)
|
| 282 |
-
|
| 283 |
-
x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
|
| 284 |
sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
|
| 285 |
-
# sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)
|
| 286 |
sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
|
| 287 |
|
| 288 |
progress(0.3, "Generating conditions")
|
| 289 |
-
conditions = core.get_conditions(
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
if use_low_vram:
|
| 295 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
|
@@ -297,15 +354,14 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
|
|
| 297 |
models_to(sam_model.sam, device="cpu")
|
| 298 |
|
| 299 |
progress(0.4, "Starting Stage C reverse process")
|
| 300 |
-
# Stage C reverse process.
|
| 301 |
sampling_c = extras.gdf.sample(
|
| 302 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
| 303 |
unconditions, device=device,
|
| 304 |
**extras.sampling_configs,
|
| 305 |
x0_style_forward=x0_style_forward, x0_forward=x0_forward,
|
| 306 |
-
apply_pushforward=False, tau_pushforward=5, tau_pushforward_csd=10,
|
| 307 |
num_iter=3, eta=1e-1, tau=20, eval_sub_csd=True,
|
| 308 |
-
extras=extras, models=models_rbm,
|
| 309 |
use_attn_mask=use_sam_mask,
|
| 310 |
save_attn_mask=False,
|
| 311 |
lam_content=1, lam_style=1,
|
|
@@ -313,63 +369,58 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
|
|
| 313 |
sam_prompt=sam_prompt
|
| 314 |
)
|
| 315 |
|
| 316 |
-
for sampled_c, _, _ in progress.tqdm(
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
sampled_c = sampled_c
|
| 321 |
|
| 322 |
progress(0.7, "Starting Stage B reverse process")
|
| 323 |
-
|
| 324 |
-
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 325 |
conditions_b['effnet'] = sampled_c
|
| 326 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
| 327 |
-
|
| 328 |
sampling_b = extras_b.gdf.sample(
|
| 329 |
models_b.generator, conditions_b, stage_b_latent_shape,
|
| 330 |
unconditions_b, device=device, **extras_b.sampling_configs,
|
| 331 |
)
|
| 332 |
-
for sampled_b, _, _ in progress.tqdm(
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
sampled_b = sampled_b
|
| 337 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 338 |
|
| 339 |
torch.cuda.empty_cache()
|
| 340 |
gc.collect()
|
| 341 |
-
|
| 342 |
progress(0.9, "Finalizing the output image")
|
| 343 |
sampled = torch.cat([
|
| 344 |
-
torch.nn.functional.interpolate(ref_images.cpu(), size=(height, width)),
|
| 345 |
-
torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
|
| 346 |
-
sampled.cpu(),
|
| 347 |
], dim=0)
|
| 348 |
|
| 349 |
-
|
| 350 |
-
sampled = sampled[2] # This selects the generated image, discarding the reference images
|
| 351 |
-
|
| 352 |
-
# Ensure the tensor values are in the correct range
|
| 353 |
sampled = torch.clamp(sampled, 0, 1)
|
| 354 |
|
| 355 |
-
# Ensure the tensor is in [C, H, W] format
|
| 356 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 357 |
-
sampled_image = T.ToPILImage()(sampled)
|
| 358 |
else:
|
| 359 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 360 |
|
| 361 |
progress(1.0, "Inference complete")
|
| 362 |
-
return sampled_image
|
| 363 |
|
| 364 |
finally:
|
| 365 |
if use_low_vram:
|
| 366 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 367 |
models_to(sam_model, device=device)
|
| 368 |
models_to(sam_model.sam, device=device)
|
| 369 |
-
# Clear CUDA cache
|
| 370 |
torch.cuda.empty_cache()
|
| 371 |
gc.collect()
|
| 372 |
|
|
|
|
| 373 |
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram):
|
| 374 |
result = None
|
| 375 |
progress = gr.Progress(track_tqdm=True)
|
|
@@ -379,13 +430,13 @@ def run(style_reference_image, style_description, subject_prompt, subject_refere
|
|
| 379 |
result = infer(style_reference_image, style_description, subject_prompt, use_low_vram, progress)
|
| 380 |
return result
|
| 381 |
|
|
|
|
| 382 |
def show_hide_subject_image_component(use_subject_ref):
|
| 383 |
if use_subject_ref is True:
|
| 384 |
return gr.update(open=True)
|
| 385 |
else:
|
| 386 |
return gr.update(open=False)
|
| 387 |
|
| 388 |
-
import gradio as gr
|
| 389 |
|
| 390 |
with gr.Blocks(analytics_enabled=False) as demo:
|
| 391 |
with gr.Column():
|
|
@@ -404,29 +455,28 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
| 404 |
with gr.Row():
|
| 405 |
with gr.Column():
|
| 406 |
style_reference_image = gr.Image(
|
| 407 |
-
label
|
| 408 |
-
type
|
| 409 |
)
|
| 410 |
style_description = gr.Textbox(
|
| 411 |
-
label
|
| 412 |
)
|
| 413 |
subject_prompt = gr.Textbox(
|
| 414 |
-
label
|
| 415 |
)
|
| 416 |
with gr.Row():
|
| 417 |
use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
|
| 418 |
use_low_vram = gr.Checkbox(label="Use Low-VRAM", value=False)
|
| 419 |
-
|
| 420 |
with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
|
| 421 |
subject_reference = gr.Image(label="Subject Reference", type="filepath")
|
| 422 |
-
|
| 423 |
submit_btn = gr.Button("Submit")
|
| 424 |
|
| 425 |
-
|
| 426 |
with gr.Column():
|
| 427 |
output_image = gr.Image(label="Output Image")
|
| 428 |
gr.Examples(
|
| 429 |
-
examples
|
| 430 |
["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False, False],
|
| 431 |
["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False, False],
|
| 432 |
["./data/glowing.png", "glowing style", "a dwarf", None, False, False],
|
|
@@ -436,21 +486,20 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
| 436 |
inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
|
| 437 |
outputs=[output_image],
|
| 438 |
cache_examples=False
|
| 439 |
-
|
| 440 |
)
|
| 441 |
|
| 442 |
use_subject_ref.input(
|
| 443 |
-
fn
|
| 444 |
-
inputs
|
| 445 |
-
outputs
|
| 446 |
-
queue
|
| 447 |
api_visibility="private"
|
| 448 |
)
|
| 449 |
-
|
| 450 |
submit_btn.click(
|
| 451 |
-
fn
|
| 452 |
-
inputs
|
| 453 |
-
outputs
|
| 454 |
api_visibility="private"
|
| 455 |
)
|
| 456 |
|
|
|
|
| 28 |
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
|
| 29 |
from gdf.targets import EpsilonTarget
|
| 30 |
import PIL
|
| 31 |
+
import gradio as gr
|
| 32 |
|
| 33 |
# Device configuration
|
| 34 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 35 |
print(device)
|
| 36 |
|
| 37 |
+
|
| 38 |
+
def module_dtype(module):
|
| 39 |
+
return next(module.parameters()).dtype
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def to_module_device_dtype(tensor, module, device=device):
|
| 43 |
+
return tensor.to(device=device, dtype=module_dtype(module))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
# Flag for low VRAM usage
|
| 47 |
# low_vram = False
|
| 48 |
|
|
|
|
| 63 |
continue
|
| 64 |
print(f"Change device of '{attr_name}' to {device}")
|
| 65 |
attr_value.to(device)
|
| 66 |
+
|
| 67 |
torch.cuda.empty_cache()
|
| 68 |
gc.collect()
|
| 69 |
|
| 70 |
+
|
| 71 |
# Stage C model configuration
|
| 72 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
| 73 |
with open(config_file, "r", encoding="utf-8") as file:
|
|
|
|
| 79 |
config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
|
| 80 |
with open(config_file_b, "r", encoding="utf-8") as file:
|
| 81 |
config_file_b = yaml.safe_load(file)
|
| 82 |
+
|
| 83 |
core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
|
| 84 |
|
| 85 |
# Setup extras and models for Stage C
|
|
|
|
| 140 |
models_rbm.generator.eval().requires_grad_(False)
|
| 141 |
|
| 142 |
|
|
|
|
| 143 |
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
| 144 |
global models_rbm, models_b, device
|
| 145 |
+
|
| 146 |
models_to(models_rbm, device=device)
|
| 147 |
+
|
| 148 |
try:
|
|
|
|
| 149 |
caption = f"{caption} in {style_description}"
|
| 150 |
+
height = 1024
|
| 151 |
+
width = 1024
|
| 152 |
+
batch_size = 1
|
| 153 |
+
|
| 154 |
+
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(
|
| 155 |
+
height, width, batch_size=batch_size
|
| 156 |
+
)
|
| 157 |
|
| 158 |
extras.sampling_configs['cfg'] = 4
|
| 159 |
extras.sampling_configs['shift'] = 2
|
|
|
|
| 166 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 167 |
|
| 168 |
progress(0.1, "Loading style reference image")
|
| 169 |
+
ref_style = resize_image(
|
| 170 |
+
PIL.Image.open(ref_style_file).convert("RGB")
|
| 171 |
+
).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 172 |
+
|
| 173 |
+
ref_style_for_clip = to_module_device_dtype(ref_style, models_rbm.image_model)
|
| 174 |
+
ref_style_for_effnet = to_module_device_dtype(ref_style, models_rbm.effnet)
|
| 175 |
|
| 176 |
batch = {'captions': [caption] * batch_size}
|
| 177 |
+
batch['style'] = ref_style_for_clip
|
| 178 |
|
| 179 |
progress(0.2, "Processing style reference image")
|
| 180 |
+
x0_style_forward = models_rbm.effnet(
|
| 181 |
+
extras.effnet_preprocess(ref_style_for_effnet)
|
| 182 |
+
)
|
| 183 |
|
| 184 |
progress(0.3, "Generating conditions")
|
| 185 |
+
conditions = core.get_conditions(
|
| 186 |
+
batch, models_rbm, extras,
|
| 187 |
+
is_eval=True, is_unconditional=False,
|
| 188 |
+
eval_image_embeds=True, eval_style=True, eval_csd=False
|
| 189 |
+
)
|
| 190 |
+
unconditions = core.get_conditions(
|
| 191 |
+
batch, models_rbm, extras,
|
| 192 |
+
is_eval=True, is_unconditional=True,
|
| 193 |
+
eval_image_embeds=False
|
| 194 |
+
)
|
| 195 |
+
conditions_b = core_b.get_conditions(
|
| 196 |
+
batch, models_b, extras_b,
|
| 197 |
+
is_eval=True, is_unconditional=False
|
| 198 |
+
)
|
| 199 |
+
unconditions_b = core_b.get_conditions(
|
| 200 |
+
batch, models_b, extras_b,
|
| 201 |
+
is_eval=True, is_unconditional=True
|
| 202 |
+
)
|
| 203 |
|
| 204 |
if use_low_vram:
|
| 205 |
# The sampling process uses more vram, so we offload everything except two modules to the cpu.
|
| 206 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 207 |
|
| 208 |
progress(0.4, "Starting Stage C reverse process")
|
|
|
|
| 209 |
sampling_c = extras.gdf.sample(
|
| 210 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
| 211 |
unconditions, device=device,
|
|
|
|
| 217 |
lam_style=1, lam_txt_alignment=1.0,
|
| 218 |
use_ddim_sampler=True,
|
| 219 |
)
|
| 220 |
+
|
| 221 |
+
for (sampled_c, _, _) in progress.tqdm(
|
| 222 |
+
tqdm(sampling_c, total=extras.sampling_configs['timesteps']),
|
| 223 |
+
desc="Stage C reverse process"
|
| 224 |
+
):
|
| 225 |
sampled_c = sampled_c
|
| 226 |
|
| 227 |
progress(0.7, "Starting Stage B reverse process")
|
| 228 |
+
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
|
|
| 229 |
conditions_b['effnet'] = sampled_c
|
| 230 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
| 231 |
+
|
| 232 |
sampling_b = extras_b.gdf.sample(
|
| 233 |
models_b.generator, conditions_b, stage_b_latent_shape,
|
| 234 |
unconditions_b, device=device, **extras_b.sampling_configs,
|
| 235 |
)
|
| 236 |
+
for sampled_b, _, _ in progress.tqdm(
|
| 237 |
+
tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']),
|
| 238 |
+
desc="Stage B reverse process"
|
| 239 |
+
):
|
| 240 |
sampled_b = sampled_b
|
| 241 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 242 |
|
| 243 |
torch.cuda.empty_cache()
|
| 244 |
gc.collect()
|
| 245 |
+
|
| 246 |
progress(0.9, "Finalizing the output image")
|
| 247 |
sampled = torch.cat([
|
| 248 |
+
torch.nn.functional.interpolate(ref_style.float().cpu(), size=(height, width)),
|
| 249 |
+
sampled.float().cpu(),
|
| 250 |
], dim=0)
|
| 251 |
|
| 252 |
+
sampled = sampled[1]
|
|
|
|
|
|
|
|
|
|
| 253 |
sampled = torch.clamp(sampled, 0, 1)
|
| 254 |
|
|
|
|
| 255 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 256 |
+
sampled_image = T.ToPILImage()(sampled)
|
| 257 |
else:
|
| 258 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 259 |
|
| 260 |
progress(1.0, "Inference complete")
|
| 261 |
+
return sampled_image
|
| 262 |
|
| 263 |
finally:
|
| 264 |
if use_low_vram:
|
| 265 |
models_to(models_rbm, device=device)
|
|
|
|
| 266 |
torch.cuda.empty_cache()
|
| 267 |
gc.collect()
|
| 268 |
|
| 269 |
+
|
| 270 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
| 271 |
global models_rbm, models_b, device
|
| 272 |
sam_model = LangSAM()
|
| 273 |
models_to(models_rbm, device=device)
|
| 274 |
models_to(sam_model, device=device)
|
| 275 |
models_to(sam_model.sam, device=device)
|
| 276 |
+
|
| 277 |
try:
|
| 278 |
caption = f"{caption} in {style_description}"
|
| 279 |
sam_prompt = f"{caption}"
|
| 280 |
use_sam_mask = False
|
| 281 |
+
|
| 282 |
batch_size = 1
|
| 283 |
height, width = 1024, 1024
|
| 284 |
+
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(
|
| 285 |
+
height, width, batch_size=batch_size
|
| 286 |
+
)
|
| 287 |
|
| 288 |
extras.sampling_configs['cfg'] = 4
|
| 289 |
extras.sampling_configs['shift'] = 2
|
|
|
|
| 295 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 296 |
|
| 297 |
progress(0.1, "Loading style and subject reference images")
|
| 298 |
+
ref_style = resize_image(
|
| 299 |
+
PIL.Image.open(ref_style_file).convert("RGB")
|
| 300 |
+
).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 301 |
+
|
| 302 |
+
ref_images = resize_image(
|
| 303 |
+
PIL.Image.open(ref_sub_file).convert("RGB")
|
| 304 |
+
).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 305 |
+
|
| 306 |
+
ref_style_for_clip = to_module_device_dtype(ref_style, models_rbm.image_model)
|
| 307 |
+
ref_images_for_clip = to_module_device_dtype(ref_images, models_rbm.image_model)
|
| 308 |
+
|
| 309 |
+
ref_style_for_effnet = to_module_device_dtype(ref_style, models_rbm.effnet)
|
| 310 |
+
ref_images_for_effnet = to_module_device_dtype(ref_images, models_rbm.effnet)
|
| 311 |
+
|
| 312 |
batch = {'captions': [caption] * batch_size}
|
| 313 |
+
batch['style'] = ref_style_for_clip
|
| 314 |
+
batch['images'] = ref_images_for_clip
|
| 315 |
|
| 316 |
progress(0.2, "Processing reference images")
|
| 317 |
+
x0_forward = models_rbm.effnet(
|
| 318 |
+
extras.effnet_preprocess(ref_images_for_effnet)
|
| 319 |
+
)
|
| 320 |
+
x0_style_forward = models_rbm.effnet(
|
| 321 |
+
extras.effnet_preprocess(ref_style_for_effnet)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
use_sam_mask = False
|
| 325 |
x0_preview = models_rbm.previewer(x0_forward)
|
| 326 |
+
|
| 327 |
+
x0_preview_pil = T.ToPILImage()(x0_preview[0].float().cpu())
|
| 328 |
sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
|
|
|
|
| 329 |
sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
|
| 330 |
|
| 331 |
progress(0.3, "Generating conditions")
|
| 332 |
+
conditions = core.get_conditions(
|
| 333 |
+
batch, models_rbm, extras,
|
| 334 |
+
is_eval=True, is_unconditional=False,
|
| 335 |
+
eval_image_embeds=True, eval_subject_style=True, eval_csd=False
|
| 336 |
+
)
|
| 337 |
+
unconditions = core.get_conditions(
|
| 338 |
+
batch, models_rbm, extras,
|
| 339 |
+
is_eval=True, is_unconditional=True,
|
| 340 |
+
eval_image_embeds=False, eval_subject_style=True
|
| 341 |
+
)
|
| 342 |
+
conditions_b = core_b.get_conditions(
|
| 343 |
+
batch, models_b, extras_b,
|
| 344 |
+
is_eval=True, is_unconditional=False
|
| 345 |
+
)
|
| 346 |
+
unconditions_b = core_b.get_conditions(
|
| 347 |
+
batch, models_b, extras_b,
|
| 348 |
+
is_eval=True, is_unconditional=True
|
| 349 |
+
)
|
| 350 |
|
| 351 |
if use_low_vram:
|
| 352 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
|
|
|
| 354 |
models_to(sam_model.sam, device="cpu")
|
| 355 |
|
| 356 |
progress(0.4, "Starting Stage C reverse process")
|
|
|
|
| 357 |
sampling_c = extras.gdf.sample(
|
| 358 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
| 359 |
unconditions, device=device,
|
| 360 |
**extras.sampling_configs,
|
| 361 |
x0_style_forward=x0_style_forward, x0_forward=x0_forward,
|
| 362 |
+
apply_pushforward=False, tau_pushforward=5, tau_pushforward_csd=10,
|
| 363 |
num_iter=3, eta=1e-1, tau=20, eval_sub_csd=True,
|
| 364 |
+
extras=extras, models=models_rbm,
|
| 365 |
use_attn_mask=use_sam_mask,
|
| 366 |
save_attn_mask=False,
|
| 367 |
lam_content=1, lam_style=1,
|
|
|
|
| 369 |
sam_prompt=sam_prompt
|
| 370 |
)
|
| 371 |
|
| 372 |
+
for sampled_c, _, _ in progress.tqdm(
|
| 373 |
+
tqdm(sampling_c, total=extras.sampling_configs['timesteps']),
|
| 374 |
+
desc="Stage C reverse process"
|
| 375 |
+
):
|
| 376 |
sampled_c = sampled_c
|
| 377 |
|
| 378 |
progress(0.7, "Starting Stage B reverse process")
|
| 379 |
+
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
|
|
| 380 |
conditions_b['effnet'] = sampled_c
|
| 381 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
| 382 |
+
|
| 383 |
sampling_b = extras_b.gdf.sample(
|
| 384 |
models_b.generator, conditions_b, stage_b_latent_shape,
|
| 385 |
unconditions_b, device=device, **extras_b.sampling_configs,
|
| 386 |
)
|
| 387 |
+
for sampled_b, _, _ in progress.tqdm(
|
| 388 |
+
tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']),
|
| 389 |
+
desc="Stage B reverse process"
|
| 390 |
+
):
|
| 391 |
sampled_b = sampled_b
|
| 392 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 393 |
|
| 394 |
torch.cuda.empty_cache()
|
| 395 |
gc.collect()
|
| 396 |
+
|
| 397 |
progress(0.9, "Finalizing the output image")
|
| 398 |
sampled = torch.cat([
|
| 399 |
+
torch.nn.functional.interpolate(ref_images.float().cpu(), size=(height, width)),
|
| 400 |
+
torch.nn.functional.interpolate(ref_style.float().cpu(), size=(height, width)),
|
| 401 |
+
sampled.float().cpu(),
|
| 402 |
], dim=0)
|
| 403 |
|
| 404 |
+
sampled = sampled[2]
|
|
|
|
|
|
|
|
|
|
| 405 |
sampled = torch.clamp(sampled, 0, 1)
|
| 406 |
|
|
|
|
| 407 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 408 |
+
sampled_image = T.ToPILImage()(sampled)
|
| 409 |
else:
|
| 410 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 411 |
|
| 412 |
progress(1.0, "Inference complete")
|
| 413 |
+
return sampled_image
|
| 414 |
|
| 415 |
finally:
|
| 416 |
if use_low_vram:
|
| 417 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 418 |
models_to(sam_model, device=device)
|
| 419 |
models_to(sam_model.sam, device=device)
|
|
|
|
| 420 |
torch.cuda.empty_cache()
|
| 421 |
gc.collect()
|
| 422 |
|
| 423 |
+
|
| 424 |
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram):
|
| 425 |
result = None
|
| 426 |
progress = gr.Progress(track_tqdm=True)
|
|
|
|
| 430 |
result = infer(style_reference_image, style_description, subject_prompt, use_low_vram, progress)
|
| 431 |
return result
|
| 432 |
|
| 433 |
+
|
| 434 |
def show_hide_subject_image_component(use_subject_ref):
|
| 435 |
if use_subject_ref is True:
|
| 436 |
return gr.update(open=True)
|
| 437 |
else:
|
| 438 |
return gr.update(open=False)
|
| 439 |
|
|
|
|
| 440 |
|
| 441 |
with gr.Blocks(analytics_enabled=False) as demo:
|
| 442 |
with gr.Column():
|
|
|
|
| 455 |
with gr.Row():
|
| 456 |
with gr.Column():
|
| 457 |
style_reference_image = gr.Image(
|
| 458 |
+
label="Style Reference Image",
|
| 459 |
+
type="filepath"
|
| 460 |
)
|
| 461 |
style_description = gr.Textbox(
|
| 462 |
+
label="Style Description"
|
| 463 |
)
|
| 464 |
subject_prompt = gr.Textbox(
|
| 465 |
+
label="Subject Prompt"
|
| 466 |
)
|
| 467 |
with gr.Row():
|
| 468 |
use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
|
| 469 |
use_low_vram = gr.Checkbox(label="Use Low-VRAM", value=False)
|
| 470 |
+
|
| 471 |
with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
|
| 472 |
subject_reference = gr.Image(label="Subject Reference", type="filepath")
|
| 473 |
+
|
| 474 |
submit_btn = gr.Button("Submit")
|
| 475 |
|
|
|
|
| 476 |
with gr.Column():
|
| 477 |
output_image = gr.Image(label="Output Image")
|
| 478 |
gr.Examples(
|
| 479 |
+
examples=[
|
| 480 |
["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False, False],
|
| 481 |
["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False, False],
|
| 482 |
["./data/glowing.png", "glowing style", "a dwarf", None, False, False],
|
|
|
|
| 486 |
inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
|
| 487 |
outputs=[output_image],
|
| 488 |
cache_examples=False
|
|
|
|
| 489 |
)
|
| 490 |
|
| 491 |
use_subject_ref.input(
|
| 492 |
+
fn=show_hide_subject_image_component,
|
| 493 |
+
inputs=[use_subject_ref],
|
| 494 |
+
outputs=[sub_img_panel],
|
| 495 |
+
queue=False,
|
| 496 |
api_visibility="private"
|
| 497 |
)
|
| 498 |
+
|
| 499 |
submit_btn.click(
|
| 500 |
+
fn=run,
|
| 501 |
+
inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
|
| 502 |
+
outputs=[output_image],
|
| 503 |
api_visibility="private"
|
| 504 |
)
|
| 505 |
|