Commit
·
0c1e9f5
1
Parent(s):
a67c790
Minor fix
Browse files
app.py
CHANGED
|
@@ -151,29 +151,29 @@ def generate_with_embs(num_inference_steps, guidance_scale, seed, text_input, te
|
|
| 151 |
|
| 152 |
return latents_to_pil(latents)[0]
|
| 153 |
|
| 154 |
-
def guide_loss(images, loss_type='
|
| 155 |
# grayscale loss
|
| 156 |
-
if loss_type == '
|
| 157 |
transformed_imgs = grayscale_transformer(images)
|
| 158 |
error = torch.abs(transformed_imgs - images).mean()
|
| 159 |
|
| 160 |
# brightness loss
|
| 161 |
-
elif loss_type == '
|
| 162 |
transformed_imgs = tfms.functional.adjust_brightness(images, brightness_factor=3)
|
| 163 |
error = torch.abs(transformed_imgs - images).mean()
|
| 164 |
|
| 165 |
# contrast loss
|
| 166 |
-
elif loss_type == '
|
| 167 |
transformed_imgs = tfms.functional.adjust_contrast(images, contrast_factor=10)
|
| 168 |
error = torch.abs(transformed_imgs - images).mean()
|
| 169 |
|
| 170 |
# symmetry loss - Flip the image along the width
|
| 171 |
-
elif loss_type == "
|
| 172 |
flipped_image = torch.flip(images, [3])
|
| 173 |
error = F.mse_loss(images, flipped_image)
|
| 174 |
|
| 175 |
# saturation loss
|
| 176 |
-
elif loss_type == '
|
| 177 |
transformed_imgs = tfms.functional.adjust_saturation(images,saturation_factor = 10)
|
| 178 |
error = torch.abs(transformed_imgs - images).mean()
|
| 179 |
|
|
@@ -291,7 +291,7 @@ demo = gr.Interface(inference,
|
|
| 291 |
gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
|
| 292 |
gr.Slider(0, 10000, 1, step = 1, label="Seed"),
|
| 293 |
gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
|
| 294 |
-
'Symmetry', 'Saturation'], value="
|
| 295 |
gr.Slider(100, 10000, 100, step = 1, label="Loss scale")],
|
| 296 |
outputs= [gr.Image(width=320, height=320, label="Generated art"),
|
| 297 |
gr.Image(width=320, height=320, label="Generated art with guidance")],
|
|
|
|
| 151 |
|
| 152 |
return latents_to_pil(latents)[0]
|
| 153 |
|
| 154 |
+
def guide_loss(images, loss_type='Gayscale'):
|
| 155 |
# grayscale loss
|
| 156 |
+
if loss_type == 'Grayscale':
|
| 157 |
transformed_imgs = grayscale_transformer(images)
|
| 158 |
error = torch.abs(transformed_imgs - images).mean()
|
| 159 |
|
| 160 |
# brightness loss
|
| 161 |
+
elif loss_type == 'Bright':
|
| 162 |
transformed_imgs = tfms.functional.adjust_brightness(images, brightness_factor=3)
|
| 163 |
error = torch.abs(transformed_imgs - images).mean()
|
| 164 |
|
| 165 |
# contrast loss
|
| 166 |
+
elif loss_type == 'Contrast':
|
| 167 |
transformed_imgs = tfms.functional.adjust_contrast(images, contrast_factor=10)
|
| 168 |
error = torch.abs(transformed_imgs - images).mean()
|
| 169 |
|
| 170 |
# symmetry loss - Flip the image along the width
|
| 171 |
+
elif loss_type == "Symmetry":
|
| 172 |
flipped_image = torch.flip(images, [3])
|
| 173 |
error = F.mse_loss(images, flipped_image)
|
| 174 |
|
| 175 |
# saturation loss
|
| 176 |
+
elif loss_type == 'Saturation':
|
| 177 |
transformed_imgs = tfms.functional.adjust_saturation(images,saturation_factor = 10)
|
| 178 |
error = torch.abs(transformed_imgs - images).mean()
|
| 179 |
|
|
|
|
| 291 |
gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
|
| 292 |
gr.Slider(0, 10000, 1, step = 1, label="Seed"),
|
| 293 |
gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
|
| 294 |
+
'Symmetry', 'Saturation'], value="Grayscale"),
|
| 295 |
gr.Slider(100, 10000, 100, step = 1, label="Loss scale")],
|
| 296 |
outputs= [gr.Image(width=320, height=320, label="Generated art"),
|
| 297 |
gr.Image(width=320, height=320, label="Generated art with guidance")],
|