import argparse import os from datetime import datetime, timedelta from typing import List, Iterable from PIL import Image import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision import transforms from tqdm import tqdm import cv2 import numpy as np from models import VGG, TransNet from datasets import COCODataset from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \ denormalize import gradio as gr from typing import List, Iterable from PIL import Image import torch from torchvision import transforms device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ----------------路径参数---------------- # 内容特征层及loss加权系数 content_layers = {'5': 0.5, '10': 0.5} # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性 # 风格特征层及loss加权系数 style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2} # 使用vgg不同深度的风格特征,生成风格更加层次丰富 transform = make_transform(size=(300, 450), normalize=True) # 图像变换 image_style = load_image('./data/udnie.jpg', transform=transform).to(device) # 风格图像 vgg = VGG(content_layers, style_layers).to(device) # 特征提取网络,只用来提取特征,不进行训练 model = TransNet(input_size=(300, 450)).to(device) # 内容生成网络,用于生成风格图片,进行训练 model.load_state_dict(torch.load('./models/udnie.pth', map_location=device)) def process_images(image) : image = transform(image).to(device) model.to(device) batch_generated = model(image) batch_generated = denormalize(batch_generated).detach().cpu() batch_generated = transforms.ToPILImage()(batch_generated) return batch_generated # 创建 Gradio 接口 demo = gr.Interface( fn=process_images, inputs=gr.Image(type="pil", label="输入图像"), outputs=gr.Image(type="pil", label="生成图像") ) # 启动 Gradio 界面 demo.launch(share=True)