Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -388,7 +388,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
|
|
| 388 |
guidance_scale = guidance_scale # # Scale for classifier-free guidance
|
| 389 |
generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
|
| 390 |
batch_size = 1
|
| 391 |
-
|
| 392 |
|
| 393 |
# Prep text
|
| 394 |
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
|
@@ -463,17 +463,22 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
|
|
| 463 |
# Calculate loss
|
| 464 |
# "contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"
|
| 465 |
if loss_function == "contrast":
|
| 466 |
-
|
|
|
|
| 467 |
elif loss_function == "blue_original":
|
| 468 |
-
|
|
|
|
| 469 |
elif loss_function == "blue_modified":
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
loss =
|
|
|
|
|
|
|
|
|
|
| 475 |
else :
|
| 476 |
-
loss = ymca_loss(denoised_images) *
|
| 477 |
|
| 478 |
# # Occasionally print it out
|
| 479 |
# if i%10==0:
|
|
@@ -537,7 +542,7 @@ demo = gr.Interface(inference,
|
|
| 537 |
step=8,
|
| 538 |
label="Select Guidance Scale",
|
| 539 |
interactive=True,
|
| 540 |
-
),gr.Radio(["contrast", "blue_original", "blue_modified","
|
| 541 |
],
|
| 542 |
outputs = [
|
| 543 |
gr.Image(label="Stable Diffusion Output"),
|
|
|
|
| 388 |
guidance_scale = guidance_scale # # Scale for classifier-free guidance
|
| 389 |
generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
|
| 390 |
batch_size = 1
|
| 391 |
+
|
| 392 |
|
| 393 |
# Prep text
|
| 394 |
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
|
|
|
| 463 |
# Calculate loss
|
| 464 |
# "contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"
|
| 465 |
if loss_function == "contrast":
|
| 466 |
+
loss_scale = 200 #
|
| 467 |
+
loss = contrast_loss(denoised_images) * loss_scale
|
| 468 |
elif loss_function == "blue_original":
|
| 469 |
+
loss_scale = 200 #
|
| 470 |
+
loss = blue_loss(denoised_images) * loss_scale
|
| 471 |
elif loss_function == "blue_modified":
|
| 472 |
+
loss_scale = 200 #
|
| 473 |
+
loss = blue_loss_variant(denoised_images) * loss_scale
|
| 474 |
+
elif loss_function == "ymca":
|
| 475 |
+
loss_scale = 200 #
|
| 476 |
+
loss = ymca_loss(denoised_images) * loss_scale
|
| 477 |
+
elif loss_function == "cmyk":
|
| 478 |
+
loss_scale = 10 #
|
| 479 |
+
loss = cymk_loss(denoised_images) * loss_scale
|
| 480 |
else :
|
| 481 |
+
loss = ymca_loss(denoised_images) * loss_scale
|
| 482 |
|
| 483 |
# # Occasionally print it out
|
| 484 |
# if i%10==0:
|
|
|
|
| 542 |
step=8,
|
| 543 |
label="Select Guidance Scale",
|
| 544 |
interactive=True,
|
| 545 |
+
),gr.Radio(["contrast", "blue_original", "blue_modified","ymca","cmyk"], label="loss-function", info="loss-function" , value="ymca_loss"),
|
| 546 |
],
|
| 547 |
outputs = [
|
| 548 |
gr.Image(label="Stable Diffusion Output"),
|