Brain_Emotion_Decoder / Src /Processing_img.py
Ihssane123's picture
Initial commit
3b6d764
from PIL import Image
import torch
import torch.optim as optim
from torchvision import transforms
import torch.nn as nn
from Models_Class.NST_class import (
ContentLoss,
Normalization,
StyleLoss,
)
import copy
style_weight = 1e8
content_weight = 1e1
def image_loader(image_path, loader, device):
image = Image.open(image_path).convert('RGB')
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
def save_image(tensor, path="output.png"):
image = tensor.cpu().clone()
image = image.squeeze(0)
image = transforms.ToPILImage()(image)
image.save(path)
def gram_matrix(input):
a, b, c, d = input.size()
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())
return G.div(a * b * c * d)
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img, content_layers, style_layers, device):
cnn = copy.deepcopy(cnn)
normalization = Normalization(normalization_mean, normalization_std).to(device)
content_losses = []
style_losses = []
model = nn.Sequential(normalization)
i = 0
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = f'conv_{i}'
elif isinstance(layer, nn.ReLU):
name = f'relu_{i}'
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = f'pool_{i}'
elif isinstance(layer, nn.BatchNorm2d):
name = f'bn_{i}'
else:
raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
model.add_module(name, layer)
if name in content_layers:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module(f"content_loss_{i}", content_loss)
content_losses.append(content_loss)
if name in style_layers:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module(f"style_loss_{i}", style_loss)
style_losses.append(style_loss)
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], (ContentLoss, StyleLoss)):
break
model = model[:i+1]
return model, style_losses, content_losses
def run_style_transfer(cnn, normalization_mean, normalization_std,
content_img, style_img, input_img,content_layers, style_layers, device, num_steps=300):
print("Building the style transfer model..")
model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img,content_layers, style_layers, device )
optimizer = optim.LBFGS([input_img.requires_grad_()])
print("Optimizing..")
run = [0]
while run[0] <= num_steps:
def closure():
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = sum(sl.loss for sl in style_losses)
content_score = sum(cl.loss for cl in content_losses)
loss = style_weight * style_score + content_weight * content_score
loss.backward()
if run[0] % 50 == 0:
print(f"Step {run[0]}:")
print(f" Style Loss: {style_score.item():.4f}")
print(f" Content Loss: {content_score.item():.4f}")
print(f" Total Loss: {loss.item():.4f}\n")
run[0] += 1
return loss
optimizer.step(closure)
input_img.data.clamp_(0, 1)
return input_img