Spaces:
Runtime error
Runtime error
File size: 13,028 Bytes
5106c86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | """
实时风格迁移算法 - Fast Style Transfer
作者: [您的名字或者用户名]
描述: 这个脚本实现了实时风格迁移,支持训练模型、处理一批图像以及处理视频文件。它基于PyTorch实现,并使用VGG网络提取风格特征。
使用方法:
- 训练模式: python fast_style_transfer.py --mode train --style_image ./data/udnie.jpg --content_dataset data/train2017 --model_save_path ./models/new_model.pth --epochs 10
- 图像处理模式: python fast_style_transfer.py --mode image --input_images_dir ./data/train2014/default_class --output_images_dir ./output/images_generated --model_path ./models/udnie.pth
- 视频处理模式: python fast_style_transfer.py --mode video --video_input data/maigua.mp4 --video_output output/videos/maigua_udnie.mp4 --model_path ./models/udnie.pth
"""
import argparse
import os
from datetime import datetime, timedelta
from typing import List, Iterable
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
import cv2
import numpy as np
from models import VGG, TransNet
from datasets import COCODataset
from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \
denormalize
def train(model,
vgg,
lr, epochs, batch_size, style_weight, content_weight, style_layers, content_layers,
device, transform,
image_style,
content_dataset_root,
save_path, output_dir,
log_dir='./runs/fast_style_transfer',
save_interval=timedelta(seconds=120)):
"""
:param model: 内容生成模型,输出与输出图像尺寸相同
:param vgg: 特征提取网络,以vgg19为例
:param lr:
:param epochs:
:param batch_size:
:param style_weight: 风格损失权重
:param content_weight: 内容损失权重
:param style_layers: 选取的风格层及其权重
:param content_layers: 选取的内容层及其权重
:param device: 计算设备
:param transform: 图像变换
:param image_style: 风格图片
:param content_dataset_root: 内容图片文件夹路径
:param save_path: 模型保存路径
:param output_dir: 中间结果输出路径
:param log_dir: tensorboard日志的保存路径
:param save_interval: 保存时间间隔
:return:
"""
writer = SummaryWriter(log_dir) # autodl平台tensorboard默认日志路径
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 对目标网络进行优化
dataset = COCODataset(content_dataset_root, transform=transform)
dataloader = DataLoader(dataset, batch_size, shuffle=True)
_, style_features = vgg(image_style)
p_bar = tqdm(range(epochs))
last_save_time = datetime.now() - save_interval
for epoch in p_bar:
running_content_loss, running_style_loss = 0.0, 0.0
for i, content_img in enumerate(dataloader):
content_img = content_img.to(device)
image_generated = model(content_img) # 只使用内容图像进行风格迁移
generated_content, generated_style = vgg(image_generated)
style_loss = sum(
style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style) for
name, gen_style in generated_style.items())
content_features, _ = vgg(content_img) # 计算内容图的内容特征
content_loss = sum(
content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content) for
name, gen_content in generated_content.items())
total_loss = style_loss + content_loss
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # 梯度裁剪
optimizer.step()
running_content_loss += content_loss.item()
running_style_loss += style_loss.item()
p_bar.set_postfix(progress=f'{(i + 1) / len(dataloader) * 100:.3f}%',
style_loss=f"{style_loss.item():.3f}",
content_loss=f"{content_loss.item():.3f}",
last_save_time=last_save_time)
writer.add_scalar('Loss/content', running_content_loss / (i + 1), epoch * len(dataloader) + i)
writer.add_scalar('Loss/style', running_style_loss / (i + 1), epoch * len(dataloader) + i)
if datetime.now() - last_save_time > save_interval:
last_save_time = datetime.now()
writer.add_images('image_generated', denormalize(image_generated), epoch * len(dataloader) + i)
save_model(model, save_path) # 'fast_style_transfer.pth'
save_image(torch.cat((image_generated, content_img), 3), output_dir, f'{epoch}_{i}.jpg')
writer.close()
def process_images(images: Iterable[Image.Image], transform, model, device) -> List[Image.Image]:
images = torch.stack([transform(image) for image in images]).to(device)
model.to(device)
batch_generated = model(images)
batch_generated = denormalize(batch_generated).detach().cpu()
batch_generated = [transforms.ToPILImage()(image) for image in batch_generated]
return batch_generated
def process_video(video_path, output_path, transform, model, device, batch_size=4):
# 打开视频文件
cap = cv2.VideoCapture(video_path)
output_dir, filename = os.path.split(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# 获取视频属性
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 定义视频编码器和创建 VideoWriter 对象
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
# 初始化 tqdm 进度条
pbar = tqdm(total=total_frames, desc="Processing Video")
# 读取视频并批量处理
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
if len(frames) == batch_size:
batch_generated = process_images(frames, transform, model, device)
for gen_frame in batch_generated:
gen = cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR)
cv2.imshow("output", gen)
gen = cv2.resize(gen, (frame_width, frame_height))
out.write(gen)
if cv2.waitKey(1) & 0xFF == ord('s'):
break
frames.clear()
pbar.update(batch_size)
if frames:
batch_generated = process_images(frames, transform, model, device)
for gen_frame in batch_generated:
out.write(cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR))
pbar.update(len(frames))
print(f'video successfully saved to: {output_path}')
pbar.close()
cap.release()
out.release()
def parse_args():
parser = argparse.ArgumentParser(description='实时风格迁移算法 - Fast Style Transfer')
parser.add_argument('--mode', type=str, choices=['train', 'image', 'video'], default='video',
help='运行模式: train, image, video')
parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='图像尺寸 (高度, 宽度)')
parser.add_argument('--style_image', type=str, default='./data/udnie.jpg', help='风格图像路径 (仅训练模式)')
# data/train2017/default_class下包含了若干个图片,由于ImageFolder的格式需要,我们需要用类别文件夹包含图片,尽管类别标签没有使用到
parser.add_argument('--content_dataset', type=str, default='data/train2014', help='内容图像数据集路径 (仅训练模式)')
parser.add_argument('--content_weight', type=float, default=1., help='内容权重(仅训练模式)')
parser.add_argument('--style_weight', type=float, default=15., help='风格权重(仅训练模式)')
parser.add_argument('--model_save_path', type=str, default=f'./models/{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.pth', help='模型保存路径 (训练模式)')
parser.add_argument('--pretrained_model_path', type=str, help='预训练模型加载路径(仅训练模式)')
parser.add_argument('--epochs', type=int, default=20, help='训练周期数(训练模式)')
parser.add_argument('--save_interval', type=int, default=120, help='保存时间间隔(秒)(训练模式)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='学习率(训练模式)')
parser.add_argument('--output_dir', type=str, default='./output/realtime_transfer', help='输出目录 (训练模式)')
parser.add_argument('--log_dir', type=str, default='./runs/fast_style_transfer', help='tensorboard日志路径 (训练模式)')
parser.add_argument('--model_path', type=str, default='./models/udnie.pth', help='模型加载路径(图像和视频模式)')
parser.add_argument('--input_images_dir', type=str, help='输入图像根目录(仅图像模式)')
parser.add_argument('--output_images_dir', type=str, default='./output/fast_style_transfer/image_generated.jpg',
help='输出图像根目录 (仅图像模式)')
parser.add_argument('--video_input', type=str, default='data/maigua.mp4', help='输入视频路径 (仅视频模式)')
parser.add_argument('--video_output', type=str, default='output/videos/maigua.mp4',
help='输出视频路径 (仅视频模式)')
parser.add_argument('--batch_size', type=int, default=4, help='批次大小')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
# print(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ----------------路径参数----------------
# 内容特征层及loss加权系数
content_layers = {'5': 0.5, '10': 0.5} # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性
# 风格特征层及loss加权系数
style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2} # 使用vgg不同深度的风格特征,生成风格更加层次丰富
transform = make_transform(size=args.image_size, normalize=True) # 图像变换
image_style = load_image(args.style_image, transform=transform).to(device) # 风格图像
vgg = VGG(content_layers, style_layers).to(device) # 特征提取网络,只用来提取特征,不进行训练
model = TransNet(input_size=args.image_size).to(device) # 内容生成网络,用于生成风格图片,进行训练
if args.mode != 'train' and getattr(args, 'model_path'):
if not os.path.exists(args.model_path):
raise FileNotFoundError(f'{args.model_path}不存在!')
model.load_state_dict(torch.load(args.model_path))
elif args.mode == 'train' and getattr(args, 'pretrained_model_path') and os.path.exists(args.pretrained_model_path):
if not os.path.exists(args.pretrained_model_path):
raise FileNotFoundError(f'{args.pretrained_model_path}不存在!')
model.load_state_dict(torch.load(args.pretrained_model_path))
if args.mode == 'train':
# 训练模式
# 使用大规模内容图像数据训练快速图像风格迁移网络,比如COCO2017数据集
train(model, vgg, args.learning_rate, args.epochs, args.batch_size, args.style_weight, args.content_weight,
style_layers,
content_layers, device, transform, image_style, args.content_dataset, args.model_save_path,
args.output_dir,
log_dir=args.log_dir,
save_interval=timedelta(seconds=args.save_interval))
elif args.mode == 'image':
# 使用训练好的风格迁移模型演示批量处理图片
if not os.path.exists(args.output_images_dir):
os.makedirs(args.output_images_dir)
for filename in tqdm(os.listdir(args.input_images_dir), desc='Processing Images'):
try:
filepath = os.path.join(args.input_images_dir, filename)
images_generated = process_images([Image.open(filepath)], transform, model, device)
images_generated[0].save(os.path.join(args.output_images_dir, filename))
except Exception as e:
pass
elif args.mode == 'video':
# 视频处理模式
process_video(args.video_input, args.video_output, transform, model, device,
batch_size=args.batch_size)
else:
raise ValueError("未知的运行模式")
|