rohithb's picture
Uploaded bins and helper scripts.
1b341bf
raw
history blame contribute delete
980 Bytes
import torch
from matplotlib import pyplot as plt
def get_style_embeddings(style_file):
style_embed = torch.load(style_file)
style_name = list(style_embed.keys())[0]
return style_embed[style_name]
def get_EOS_pos_in_prompt(prompt):
return len(prompt.split())+1
def invert_loss(gen_image):
loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1])
return loss
def blue_loss(images):
# How far are the blue channel values to 0.9:
error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
return error
def show_images(images_list):
# Let's visualize the four channels of this latent representation:
fig, axs = plt.subplots(1, len(images_list), figsize=(16, 4))
for c in range(len(images_list)):
axs[c].imshow(images_list[c])
plt.show()