Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -202,7 +202,6 @@ def blue_loss(images):
|
|
| 202 |
|
| 203 |
return -variance
|
| 204 |
|
| 205 |
-
import torch
|
| 206 |
|
| 207 |
def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
| 208 |
"""
|
|
@@ -263,6 +262,82 @@ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
|
| 263 |
return loss
|
| 264 |
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
def blue_loss_variant(images, use_mean=False, alpha=1.0):
|
| 267 |
"""
|
| 268 |
Computes the blue loss for a batch of images with an optional mean component.
|
|
@@ -301,7 +376,7 @@ def blue_loss_variant(images, use_mean=False, alpha=1.0):
|
|
| 301 |
|
| 302 |
return loss
|
| 303 |
|
| 304 |
-
def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale):
|
| 305 |
|
| 306 |
prompt = prompt + ' in style of s'
|
| 307 |
|
|
@@ -386,7 +461,19 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
|
|
| 386 |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
|
| 387 |
|
| 388 |
# Calculate loss
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
# # Occasionally print it out
|
| 392 |
# if i%10==0:
|
|
@@ -423,7 +510,7 @@ def inference(prompt, seed, style,num_inference_steps,guidance_scale,loss_functi
|
|
| 423 |
print(loss_function)
|
| 424 |
style = dict_styles[style]
|
| 425 |
torch.manual_seed(seed)
|
| 426 |
-
result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale)
|
| 427 |
return np.array(result)
|
| 428 |
else:
|
| 429 |
return None
|
|
@@ -450,7 +537,7 @@ demo = gr.Interface(inference,
|
|
| 450 |
step=8,
|
| 451 |
label="Select Guidance Scale",
|
| 452 |
interactive=True,
|
| 453 |
-
),gr.Radio(["contrast", "
|
| 454 |
],
|
| 455 |
outputs = [
|
| 456 |
gr.Image(label="Stable Diffusion Output"),
|
|
|
|
| 202 |
|
| 203 |
return -variance
|
| 204 |
|
|
|
|
| 205 |
|
| 206 |
def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
| 207 |
"""
|
|
|
|
| 262 |
return loss
|
| 263 |
|
| 264 |
|
| 265 |
+
|
| 266 |
+
def rgb_to_cmyk(images):
|
| 267 |
+
"""
|
| 268 |
+
Converts an RGB image tensor to CMYK.
|
| 269 |
+
|
| 270 |
+
Parameters:
|
| 271 |
+
images (torch.Tensor): A batch of images in RGB format. Expected shape is (N, 3, H, W).
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
torch.Tensor: A tensor containing the CMYK channels.
|
| 275 |
+
"""
|
| 276 |
+
R = images[:, 0, :, :]
|
| 277 |
+
G = images[:, 1, :, :]
|
| 278 |
+
B = images[:, 2, :, :]
|
| 279 |
+
|
| 280 |
+
# Convert RGB to CMY
|
| 281 |
+
C = 1 - R
|
| 282 |
+
M = 1 - G
|
| 283 |
+
Y = 1 - B
|
| 284 |
+
|
| 285 |
+
# Convert CMY to CMYK
|
| 286 |
+
K = torch.min(torch.min(C, M), Y)
|
| 287 |
+
C = (C - K) / (1 - K + 1e-8)
|
| 288 |
+
M = (M - K) / (1 - K + 1e-8)
|
| 289 |
+
Y = (Y - K) / (1 - K + 1e-8)
|
| 290 |
+
|
| 291 |
+
CMYK = torch.stack([C, M, Y, K], dim=1)
|
| 292 |
+
return CMYK
|
| 293 |
+
|
| 294 |
+
def cymk_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
| 295 |
+
"""
|
| 296 |
+
Computes the CYMK loss for a batch of images.
|
| 297 |
+
|
| 298 |
+
The CYMK loss is a custom loss function combining the variance of the Cyan channel,
|
| 299 |
+
the mean value of the Yellow channel, the variance of the Magenta channel, and the
|
| 300 |
+
absolute sum of the Black channel.
|
| 301 |
+
|
| 302 |
+
Parameters:
|
| 303 |
+
images (torch.Tensor): A batch of images. Expected shape is (N, 3, H, W) for RGB input.
|
| 304 |
+
weights (tuple): A tuple of four floats representing the weights for each component of the loss
|
| 305 |
+
(default is (1.0, 1.0, 1.0, 1.0)).
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
torch.Tensor: The CYMK loss, combining the specified components.
|
| 309 |
+
"""
|
| 310 |
+
# Ensure the input tensor has the correct shape
|
| 311 |
+
if images.shape[1] != 3:
|
| 312 |
+
raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
|
| 313 |
+
|
| 314 |
+
# Convert RGB to CMYK
|
| 315 |
+
cmyk_images = rgb_to_cmyk(images)
|
| 316 |
+
|
| 317 |
+
# Extract CMYK channels
|
| 318 |
+
C = cmyk_images[:, 0, :, :]
|
| 319 |
+
M = cmyk_images[:, 1, :, :]
|
| 320 |
+
Y = cmyk_images[:, 2, :, :]
|
| 321 |
+
K = cmyk_images[:, 3, :, :]
|
| 322 |
+
|
| 323 |
+
# Compute the variance of the C channel
|
| 324 |
+
variance_C = torch.var(C)
|
| 325 |
+
|
| 326 |
+
# Compute the mean of the Y channel
|
| 327 |
+
mean_Y = torch.mean(Y)
|
| 328 |
+
|
| 329 |
+
# Compute the variance of the M channel
|
| 330 |
+
variance_M = torch.var(M)
|
| 331 |
+
|
| 332 |
+
# Compute the absolute sum of the K channel
|
| 333 |
+
abs_sum_K = torch.sum(torch.abs(K))
|
| 334 |
+
|
| 335 |
+
# Combine the components with the given weights
|
| 336 |
+
loss = (weights[0] * variance_C) + (weights[1] * mean_Y) + (weights[2] * variance_M) + (weights[3] * abs_sum_K)
|
| 337 |
+
|
| 338 |
+
return loss
|
| 339 |
+
|
| 340 |
+
|
| 341 |
def blue_loss_variant(images, use_mean=False, alpha=1.0):
|
| 342 |
"""
|
| 343 |
Computes the blue loss for a batch of images with an optional mean component.
|
|
|
|
| 376 |
|
| 377 |
return loss
|
| 378 |
|
| 379 |
+
def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale,loss_function):
|
| 380 |
|
| 381 |
prompt = prompt + ' in style of s'
|
| 382 |
|
|
|
|
| 461 |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
|
| 462 |
|
| 463 |
# Calculate loss
|
| 464 |
+
# "contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"
|
| 465 |
+
if loss_function == "contrast":
|
| 466 |
+
loss = contrast_loss(denoised_images) * contrast_loss_scale
|
| 467 |
+
elif loss_function == "blue_original":
|
| 468 |
+
loss = blue_loss(denoised_images) * contrast_loss_scale
|
| 469 |
+
elif loss_function == "blue_modified":
|
| 470 |
+
loss = blue_loss_variant(denoised_images) * contrast_loss_scale
|
| 471 |
+
elif loss_function == "ymca_loss":
|
| 472 |
+
loss = ymca_loss(denoised_images) * contrast_loss_scale
|
| 473 |
+
elif loss_function == "cymk_loss":
|
| 474 |
+
loss = cymk_loss(denoised_images) * contrast_loss_scale
|
| 475 |
+
else :
|
| 476 |
+
loss = ymca_loss(denoised_images) * contrast_loss_scale
|
| 477 |
|
| 478 |
# # Occasionally print it out
|
| 479 |
# if i%10==0:
|
|
|
|
| 510 |
print(loss_function)
|
| 511 |
style = dict_styles[style]
|
| 512 |
torch.manual_seed(seed)
|
| 513 |
+
result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale,loss_function)
|
| 514 |
return np.array(result)
|
| 515 |
else:
|
| 516 |
return None
|
|
|
|
| 537 |
step=8,
|
| 538 |
label="Select Guidance Scale",
|
| 539 |
interactive=True,
|
| 540 |
+
),gr.Radio(["contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"], label="loss-function", info="loss-function" , value="ymca_loss"),
|
| 541 |
],
|
| 542 |
outputs = [
|
| 543 |
gr.Image(label="Stable Diffusion Output"),
|