| import os |
| import argparse |
| import torch |
| from pathlib import Path |
| from AdaIN import AdaINNet |
| from PIL import Image |
| from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range |
| import cv2 |
| import imageio |
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--content_video', type=str, required=True, help='Content video file path') |
| parser.add_argument('--style_image', type=str, required=True, help='Style image file path') |
| parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path') |
| parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level') |
| parser.add_argument('--cuda', action='store_true', help='Use CUDA') |
| parser.add_argument('--color_control', action='store_true', help='Preserve content color') |
| args = parser.parse_args() |
|
|
| device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') |
|
|
|
|
| def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0): |
| """ |
| Given content image and style image, generate feature maps with encoder, apply |
| neural style transfer with adaptive instance normalization, generate output image |
| with decoder |
| |
| Args: |
| content_tensor (torch.FloatTensor): Content image |
| style_tensor (torch.FloatTensor): Style Image |
| encoder: Encoder (vgg19) network |
| decoder: Decoder network |
| alpha (float, default=1.0): Weight of style image feature |
| |
| Return: |
| output_tensor (torch.FloatTensor): Style Transfer output image |
| """ |
|
|
| content_enc = encoder(content_tensor) |
| style_enc = encoder(style_tensor) |
|
|
| transfer_enc = adaptive_instance_normalization(content_enc, style_enc) |
| |
| mix_enc = alpha * transfer_enc + (1-alpha) * content_enc |
| return decoder(mix_enc) |
|
|
|
|
| def main(): |
| |
| content_video_pth = Path(args.content_video) |
| content_video = cv2.VideoCapture(str(content_video_pth)) |
| style_image_pth = Path(args.style_image) |
| style_image = Image.open(style_image_pth) |
|
|
| |
| fps = int(content_video.get(cv2.CAP_PROP_FPS)) |
| frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT)) |
| video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
| |
| video_tqdm = tqdm(frame_count) |
|
|
| |
| out_dir = './results_video/' |
| os.makedirs(out_dir, exist_ok=True) |
| out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem |
| if args.color_control: out_pth += '_colorcontrol' |
| out_pth += content_video_pth.suffix |
| out_pth = Path(out_pth) |
| writer = imageio.get_writer(out_pth, mode='I', fps=fps) |
|
|
| |
| vgg = torch.load('vgg_normalized.pth') |
| model = AdaINNet(vgg).to(device) |
| model.decoder.load_state_dict(torch.load(args.decoder_weight)) |
| model.eval() |
| |
| t = transform(512) |
|
|
| style_tensor = t(style_image).unsqueeze(0).to(device) |
|
|
|
|
|
|
| while content_video.isOpened(): |
| ret, content_image = content_video.read() |
| |
| if not ret: |
| break |
| |
| content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device) |
| |
| |
| if args.color_control: |
| style_tensor = linear_histogram_matching(content_tensor,style_tensor) |
|
|
| with torch.no_grad(): |
| out_tensor = style_transfer(content_tensor, style_tensor, model.encoder |
| , model.decoder, args.alpha).cpu().detach().numpy() |
| |
| |
| out_tensor = np.squeeze(out_tensor, axis=0) |
| out_tensor = np.transpose(out_tensor, (1, 2, 0)) |
| out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
| out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC) |
|
|
| |
| writer.append_data(np.array(out_tensor)) |
| video_tqdm.update(1) |
|
|
| content_video.release() |
|
|
| print("\nContent: " + content_video_pth.stem + ". Style: " + style_image_pth.stem +'\n') |
|
|
| if __name__ == '__main__': |
| main() |