""" 实时风格迁移算法 - Fast Style Transfer 作者: [您的名字或者用户名] 描述: 这个脚本实现了实时风格迁移,支持训练模型、处理一批图像以及处理视频文件。它基于PyTorch实现,并使用VGG网络提取风格特征。 使用方法: - 训练模式: python fast_style_transfer.py --mode train --style_image ./data/udnie.jpg --content_dataset data/train2017 --model_save_path ./models/new_model.pth --epochs 10 - 图像处理模式: python fast_style_transfer.py --mode image --input_images_dir ./data/train2014/default_class --output_images_dir ./output/images_generated --model_path ./models/udnie.pth - 视频处理模式: python fast_style_transfer.py --mode video --video_input data/maigua.mp4 --video_output output/videos/maigua_udnie.mp4 --model_path ./models/udnie.pth """ import argparse import os from datetime import datetime, timedelta from typing import List, Iterable from PIL import Image import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision import transforms from tqdm import tqdm import cv2 import numpy as np from models import VGG, TransNet from datasets import COCODataset from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \ denormalize def train(model, vgg, lr, epochs, batch_size, style_weight, content_weight, style_layers, content_layers, device, transform, image_style, content_dataset_root, save_path, output_dir, log_dir='./runs/fast_style_transfer', save_interval=timedelta(seconds=120)): """ :param model: 内容生成模型,输出与输出图像尺寸相同 :param vgg: 特征提取网络,以vgg19为例 :param lr: :param epochs: :param batch_size: :param style_weight: 风格损失权重 :param content_weight: 内容损失权重 :param style_layers: 选取的风格层及其权重 :param content_layers: 选取的内容层及其权重 :param device: 计算设备 :param transform: 图像变换 :param image_style: 风格图片 :param content_dataset_root: 内容图片文件夹路径 :param save_path: 模型保存路径 :param output_dir: 中间结果输出路径 :param log_dir: tensorboard日志的保存路径 :param save_interval: 保存时间间隔 :return: """ writer = SummaryWriter(log_dir) # autodl平台tensorboard默认日志路径 optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 对目标网络进行优化 dataset = COCODataset(content_dataset_root, transform=transform) dataloader = DataLoader(dataset, batch_size, shuffle=True) _, style_features = vgg(image_style) p_bar = tqdm(range(epochs)) last_save_time = datetime.now() - save_interval for epoch in p_bar: running_content_loss, running_style_loss = 0.0, 0.0 for i, content_img in enumerate(dataloader): content_img = content_img.to(device) image_generated = model(content_img) # 只使用内容图像进行风格迁移 generated_content, generated_style = vgg(image_generated) style_loss = sum( style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style) for name, gen_style in generated_style.items()) content_features, _ = vgg(content_img) # 计算内容图的内容特征 content_loss = sum( content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content) for name, gen_content in generated_content.items()) total_loss = style_loss + content_loss optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # 梯度裁剪 optimizer.step() running_content_loss += content_loss.item() running_style_loss += style_loss.item() p_bar.set_postfix(progress=f'{(i + 1) / len(dataloader) * 100:.3f}%', style_loss=f"{style_loss.item():.3f}", content_loss=f"{content_loss.item():.3f}", last_save_time=last_save_time) writer.add_scalar('Loss/content', running_content_loss / (i + 1), epoch * len(dataloader) + i) writer.add_scalar('Loss/style', running_style_loss / (i + 1), epoch * len(dataloader) + i) if datetime.now() - last_save_time > save_interval: last_save_time = datetime.now() writer.add_images('image_generated', denormalize(image_generated), epoch * len(dataloader) + i) save_model(model, save_path) # 'fast_style_transfer.pth' save_image(torch.cat((image_generated, content_img), 3), output_dir, f'{epoch}_{i}.jpg') writer.close() def process_images(images: Iterable[Image.Image], transform, model, device) -> List[Image.Image]: images = torch.stack([transform(image) for image in images]).to(device) model.to(device) batch_generated = model(images) batch_generated = denormalize(batch_generated).detach().cpu() batch_generated = [transforms.ToPILImage()(image) for image in batch_generated] return batch_generated def process_video(video_path, output_path, transform, model, device, batch_size=4): # 打开视频文件 cap = cv2.VideoCapture(video_path) output_dir, filename = os.path.split(output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) # 获取视频属性 frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 定义视频编码器和创建 VideoWriter 对象 fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) # 初始化 tqdm 进度条 pbar = tqdm(total=total_frames, desc="Processing Video") # 读取视频并批量处理 frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) if len(frames) == batch_size: batch_generated = process_images(frames, transform, model, device) for gen_frame in batch_generated: gen = cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR) cv2.imshow("output", gen) gen = cv2.resize(gen, (frame_width, frame_height)) out.write(gen) if cv2.waitKey(1) & 0xFF == ord('s'): break frames.clear() pbar.update(batch_size) if frames: batch_generated = process_images(frames, transform, model, device) for gen_frame in batch_generated: out.write(cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR)) pbar.update(len(frames)) print(f'video successfully saved to: {output_path}') pbar.close() cap.release() out.release() def parse_args(): parser = argparse.ArgumentParser(description='实时风格迁移算法 - Fast Style Transfer') parser.add_argument('--mode', type=str, choices=['train', 'image', 'video'], default='video', help='运行模式: train, image, video') parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='图像尺寸 (高度, 宽度)') parser.add_argument('--style_image', type=str, default='./data/udnie.jpg', help='风格图像路径 (仅训练模式)') # data/train2017/default_class下包含了若干个图片,由于ImageFolder的格式需要,我们需要用类别文件夹包含图片,尽管类别标签没有使用到 parser.add_argument('--content_dataset', type=str, default='data/train2014', help='内容图像数据集路径 (仅训练模式)') parser.add_argument('--content_weight', type=float, default=1., help='内容权重(仅训练模式)') parser.add_argument('--style_weight', type=float, default=15., help='风格权重(仅训练模式)') parser.add_argument('--model_save_path', type=str, default=f'./models/{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.pth', help='模型保存路径 (训练模式)') parser.add_argument('--pretrained_model_path', type=str, help='预训练模型加载路径(仅训练模式)') parser.add_argument('--epochs', type=int, default=20, help='训练周期数(训练模式)') parser.add_argument('--save_interval', type=int, default=120, help='保存时间间隔(秒)(训练模式)') parser.add_argument('--learning_rate', type=float, default=0.001, help='学习率(训练模式)') parser.add_argument('--output_dir', type=str, default='./output/realtime_transfer', help='输出目录 (训练模式)') parser.add_argument('--log_dir', type=str, default='./runs/fast_style_transfer', help='tensorboard日志路径 (训练模式)') parser.add_argument('--model_path', type=str, default='./models/udnie.pth', help='模型加载路径(图像和视频模式)') parser.add_argument('--input_images_dir', type=str, help='输入图像根目录(仅图像模式)') parser.add_argument('--output_images_dir', type=str, default='./output/fast_style_transfer/image_generated.jpg', help='输出图像根目录 (仅图像模式)') parser.add_argument('--video_input', type=str, default='data/maigua.mp4', help='输入视频路径 (仅视频模式)') parser.add_argument('--video_output', type=str, default='output/videos/maigua.mp4', help='输出视频路径 (仅视频模式)') parser.add_argument('--batch_size', type=int, default=4, help='批次大小') return parser.parse_args() if __name__ == '__main__': args = parse_args() # print(args) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ----------------路径参数---------------- # 内容特征层及loss加权系数 content_layers = {'5': 0.5, '10': 0.5} # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性 # 风格特征层及loss加权系数 style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2} # 使用vgg不同深度的风格特征,生成风格更加层次丰富 transform = make_transform(size=args.image_size, normalize=True) # 图像变换 image_style = load_image(args.style_image, transform=transform).to(device) # 风格图像 vgg = VGG(content_layers, style_layers).to(device) # 特征提取网络,只用来提取特征,不进行训练 model = TransNet(input_size=args.image_size).to(device) # 内容生成网络,用于生成风格图片,进行训练 if args.mode != 'train' and getattr(args, 'model_path'): if not os.path.exists(args.model_path): raise FileNotFoundError(f'{args.model_path}不存在!') model.load_state_dict(torch.load(args.model_path)) elif args.mode == 'train' and getattr(args, 'pretrained_model_path') and os.path.exists(args.pretrained_model_path): if not os.path.exists(args.pretrained_model_path): raise FileNotFoundError(f'{args.pretrained_model_path}不存在!') model.load_state_dict(torch.load(args.pretrained_model_path)) if args.mode == 'train': # 训练模式 # 使用大规模内容图像数据训练快速图像风格迁移网络,比如COCO2017数据集 train(model, vgg, args.learning_rate, args.epochs, args.batch_size, args.style_weight, args.content_weight, style_layers, content_layers, device, transform, image_style, args.content_dataset, args.model_save_path, args.output_dir, log_dir=args.log_dir, save_interval=timedelta(seconds=args.save_interval)) elif args.mode == 'image': # 使用训练好的风格迁移模型演示批量处理图片 if not os.path.exists(args.output_images_dir): os.makedirs(args.output_images_dir) for filename in tqdm(os.listdir(args.input_images_dir), desc='Processing Images'): try: filepath = os.path.join(args.input_images_dir, filename) images_generated = process_images([Image.open(filepath)], transform, model, device) images_generated[0].save(os.path.join(args.output_images_dir, filename)) except Exception as e: pass elif args.mode == 'video': # 视频处理模式 process_video(args.video_input, args.video_output, transform, model, device, batch_size=args.batch_size) else: raise ValueError("未知的运行模式")