luismidv's picture
new
930f9e5
from torch.nn.functional import mse_loss
import torch
import dataprepare
from model import VGG16
def white_noise_img(model, original_image, content_image, epochs = 5):
#FIRST OF ALL REMOVE THE BATCH_SIZE FROM THE IMAGE
original_image = original_image.squeeze(0)
#GENERATE A NEW IMAGE AND TRACK THE PROCESS SO WE CAN PARAMETRIZE COMPUTING GRADIENTS
generated_image = torch.rand_like(original_image,requires_grad=True)
optimizer = torch.optim.Adam([generated_image], lr= 0.01)
for i in range(epochs):
optimizer.zero_grad()
content_result = model(content_image)
#style_result = model(original_image)
generated_result = model(generated_image)
content_loss = sum(mse_loss(generated_result["content"][l], content_result["content"][l]) for l in content_result["content"])
#style_loss = sum(#mse(gram(generated_image[style][l]), gram(style_image[l])) #for l in style_layers)
result_list = model(generated_image)
print(f"Loss at content type: {content_loss}")
image = dataprepare.image_process('/data/16546923557574.jpg')
content_image = dataprepare.image_process('/data/imagen2.png')
print(image.shape)
print(content_image.shape)
image = image.view(1, image.shape[0],image.shape[1],image.shape[2])
model = VGG16(num_features=5,num_classes=5)
model = model.to("cpu")
#result_list = model(image)
#style_computing(result_list, model, image)
#view_activations()
white_noise_img(model,image,content_image)