Spaces:
Runtime error
Runtime error
| import torch | |
| from matplotlib import pyplot as plt | |
| def get_style_embeddings(style_file): | |
| style_embed = torch.load(style_file) | |
| style_name = list(style_embed.keys())[0] | |
| return style_embed[style_name] | |
| def get_EOS_pos_in_prompt(prompt): | |
| return len(prompt.split())+1 | |
| def invert_loss(gen_image): | |
| loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1]) | |
| return loss | |
| def blue_loss(images): | |
| # How far are the blue channel values to 0.9: | |
| error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel | |
| return error | |
| def show_images(images_list): | |
| # Let's visualize the four channels of this latent representation: | |
| fig, axs = plt.subplots(1, len(images_list), figsize=(16, 4)) | |
| for c in range(len(images_list)): | |
| axs[c].imshow(images_list[c]) | |
| plt.show() |