File size: 13,028 Bytes
5106c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""
实时风格迁移算法 - Fast Style Transfer
作者: [您的名字或者用户名]
描述: 这个脚本实现了实时风格迁移,支持训练模型、处理一批图像以及处理视频文件。它基于PyTorch实现,并使用VGG网络提取风格特征。

使用方法:
    - 训练模式: python fast_style_transfer.py --mode train --style_image ./data/udnie.jpg --content_dataset data/train2017 --model_save_path ./models/new_model.pth --epochs 10
    - 图像处理模式: python fast_style_transfer.py --mode image --input_images_dir ./data/train2014/default_class --output_images_dir ./output/images_generated --model_path ./models/udnie.pth
    - 视频处理模式: python fast_style_transfer.py --mode video --video_input data/maigua.mp4 --video_output output/videos/maigua_udnie.mp4 --model_path ./models/udnie.pth
"""

import argparse
import os
from datetime import datetime, timedelta
from typing import List, Iterable

from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
import cv2
import numpy as np

from models import VGG, TransNet
from datasets import COCODataset
from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \
    denormalize


def train(model,
          vgg,
          lr, epochs, batch_size, style_weight, content_weight, style_layers, content_layers,
          device, transform,
          image_style,
          content_dataset_root,
          save_path, output_dir,
          log_dir='./runs/fast_style_transfer',
          save_interval=timedelta(seconds=120)):
    """

    :param model: 内容生成模型,输出与输出图像尺寸相同
    :param vgg: 特征提取网络,以vgg19为例
    :param lr:
    :param epochs:
    :param batch_size:
    :param style_weight: 风格损失权重
    :param content_weight: 内容损失权重
    :param style_layers: 选取的风格层及其权重
    :param content_layers: 选取的内容层及其权重
    :param device: 计算设备
    :param transform: 图像变换
    :param image_style: 风格图片
    :param content_dataset_root: 内容图片文件夹路径
    :param save_path: 模型保存路径
    :param output_dir: 中间结果输出路径
    :param log_dir: tensorboard日志的保存路径
    :param save_interval: 保存时间间隔
    :return:
    """
    writer = SummaryWriter(log_dir)  # autodl平台tensorboard默认日志路径
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # 对目标网络进行优化
    dataset = COCODataset(content_dataset_root, transform=transform)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    _, style_features = vgg(image_style)
    p_bar = tqdm(range(epochs))
    last_save_time = datetime.now() - save_interval

    for epoch in p_bar:
        running_content_loss, running_style_loss = 0.0, 0.0
        for i, content_img in enumerate(dataloader):
            content_img = content_img.to(device)
            image_generated = model(content_img)  # 只使用内容图像进行风格迁移
            generated_content, generated_style = vgg(image_generated)
            style_loss = sum(
                style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style) for
                name, gen_style in generated_style.items())

            content_features, _ = vgg(content_img)  # 计算内容图的内容特征
            content_loss = sum(
                content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content) for
                name, gen_content in generated_content.items())

            total_loss = style_loss + content_loss

            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)  # 梯度裁剪
            optimizer.step()

            running_content_loss += content_loss.item()
            running_style_loss += style_loss.item()

            p_bar.set_postfix(progress=f'{(i + 1) / len(dataloader) * 100:.3f}%',
                              style_loss=f"{style_loss.item():.3f}",
                              content_loss=f"{content_loss.item():.3f}",
                              last_save_time=last_save_time)

            writer.add_scalar('Loss/content', running_content_loss / (i + 1), epoch * len(dataloader) + i)
            writer.add_scalar('Loss/style', running_style_loss / (i + 1), epoch * len(dataloader) + i)

            if datetime.now() - last_save_time > save_interval:
                last_save_time = datetime.now()
                writer.add_images('image_generated', denormalize(image_generated), epoch * len(dataloader) + i)
                save_model(model, save_path)  # 'fast_style_transfer.pth'
                save_image(torch.cat((image_generated, content_img), 3), output_dir, f'{epoch}_{i}.jpg')

    writer.close()


def process_images(images: Iterable[Image.Image], transform, model, device) -> List[Image.Image]:
    images = torch.stack([transform(image) for image in images]).to(device)
    model.to(device)
    batch_generated = model(images)
    batch_generated = denormalize(batch_generated).detach().cpu()
    batch_generated = [transforms.ToPILImage()(image) for image in batch_generated]
    return batch_generated


def process_video(video_path, output_path, transform, model, device, batch_size=4):
    # 打开视频文件
    cap = cv2.VideoCapture(video_path)
    output_dir, filename = os.path.split(output_path)

    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 获取视频属性
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 定义视频编码器和创建 VideoWriter 对象
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

    # 初始化 tqdm 进度条
    pbar = tqdm(total=total_frames, desc="Processing Video")

    # 读取视频并批量处理
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
        if len(frames) == batch_size:
            batch_generated = process_images(frames, transform, model, device)
            for gen_frame in batch_generated:
                gen = cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR)
                cv2.imshow("output", gen)
                gen = cv2.resize(gen, (frame_width, frame_height))
                out.write(gen)
                if cv2.waitKey(1) & 0xFF == ord('s'):
                    break

            frames.clear()
            pbar.update(batch_size)

    if frames:
        batch_generated = process_images(frames, transform, model, device)
        for gen_frame in batch_generated:
            out.write(cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR))
        pbar.update(len(frames))
    print(f'video successfully saved to: {output_path}')
    pbar.close()
    cap.release()
    out.release()


def parse_args():
    parser = argparse.ArgumentParser(description='实时风格迁移算法 - Fast Style Transfer')
    parser.add_argument('--mode', type=str, choices=['train', 'image', 'video'], default='video',
                        help='运行模式: train, image, video')
    parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='图像尺寸 (高度, 宽度)')
    parser.add_argument('--style_image', type=str, default='./data/udnie.jpg', help='风格图像路径 (仅训练模式)')
    # data/train2017/default_class下包含了若干个图片,由于ImageFolder的格式需要,我们需要用类别文件夹包含图片,尽管类别标签没有使用到
    parser.add_argument('--content_dataset', type=str, default='data/train2014', help='内容图像数据集路径 (仅训练模式)')
    parser.add_argument('--content_weight', type=float, default=1., help='内容权重(仅训练模式)')
    parser.add_argument('--style_weight', type=float, default=15., help='风格权重(仅训练模式)')
    parser.add_argument('--model_save_path', type=str, default=f'./models/{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.pth', help='模型保存路径 (训练模式)')

    parser.add_argument('--pretrained_model_path', type=str, help='预训练模型加载路径(仅训练模式)')
    parser.add_argument('--epochs', type=int, default=20, help='训练周期数(训练模式)')
    parser.add_argument('--save_interval', type=int, default=120, help='保存时间间隔(秒)(训练模式)')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='学习率(训练模式)')
    parser.add_argument('--output_dir', type=str, default='./output/realtime_transfer', help='输出目录 (训练模式)')
    parser.add_argument('--log_dir', type=str, default='./runs/fast_style_transfer', help='tensorboard日志路径 (训练模式)')

    parser.add_argument('--model_path', type=str, default='./models/udnie.pth', help='模型加载路径(图像和视频模式)')
    parser.add_argument('--input_images_dir', type=str, help='输入图像根目录(仅图像模式)')
    parser.add_argument('--output_images_dir', type=str, default='./output/fast_style_transfer/image_generated.jpg',
                        help='输出图像根目录 (仅图像模式)')

    parser.add_argument('--video_input', type=str, default='data/maigua.mp4', help='输入视频路径 (仅视频模式)')
    parser.add_argument('--video_output', type=str, default='output/videos/maigua.mp4',
                        help='输出视频路径 (仅视频模式)')

    parser.add_argument('--batch_size', type=int, default=4, help='批次大小')

    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    # print(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # ----------------路径参数----------------

    # 内容特征层及loss加权系数
    content_layers = {'5': 0.5, '10': 0.5}  # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性
    # 风格特征层及loss加权系数
    style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2}  # 使用vgg不同深度的风格特征,生成风格更加层次丰富

    transform = make_transform(size=args.image_size, normalize=True)  # 图像变换
    image_style = load_image(args.style_image, transform=transform).to(device)  # 风格图像
    vgg = VGG(content_layers, style_layers).to(device)  # 特征提取网络,只用来提取特征,不进行训练
    model = TransNet(input_size=args.image_size).to(device)  # 内容生成网络,用于生成风格图片,进行训练

    if args.mode != 'train' and getattr(args, 'model_path'):
        if not os.path.exists(args.model_path):
            raise FileNotFoundError(f'{args.model_path}不存在!')
        model.load_state_dict(torch.load(args.model_path))
    elif args.mode == 'train' and getattr(args, 'pretrained_model_path') and os.path.exists(args.pretrained_model_path):
        if not os.path.exists(args.pretrained_model_path):
            raise FileNotFoundError(f'{args.pretrained_model_path}不存在!')
        model.load_state_dict(torch.load(args.pretrained_model_path))

    if args.mode == 'train':
        # 训练模式
        # 使用大规模内容图像数据训练快速图像风格迁移网络,比如COCO2017数据集
        train(model, vgg, args.learning_rate, args.epochs, args.batch_size, args.style_weight, args.content_weight,
              style_layers,
              content_layers, device, transform, image_style, args.content_dataset, args.model_save_path,
              args.output_dir,
              log_dir=args.log_dir,
              save_interval=timedelta(seconds=args.save_interval))

    elif args.mode == 'image':
        # 使用训练好的风格迁移模型演示批量处理图片
        if not os.path.exists(args.output_images_dir):
            os.makedirs(args.output_images_dir)
        for filename in tqdm(os.listdir(args.input_images_dir), desc='Processing Images'):
            try:
                filepath = os.path.join(args.input_images_dir, filename)

                images_generated = process_images([Image.open(filepath)], transform, model, device)
                images_generated[0].save(os.path.join(args.output_images_dir, filename))
            except Exception as e:
                pass

    elif args.mode == 'video':
        # 视频处理模式
        process_video(args.video_input, args.video_output, transform, model, device,
                      batch_size=args.batch_size)

    else:
        raise ValueError("未知的运行模式")