StyleFusion / app.py
escapist413's picture
Upload existing project files
b4877a2
import argparse
import os
from datetime import datetime, timedelta
from typing import List, Iterable
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
import cv2
import numpy as np
from models import VGG, TransNet
from datasets import COCODataset
from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \
denormalize
import gradio as gr
from typing import List, Iterable
from PIL import Image
import torch
from torchvision import transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ----------------路径参数----------------
# 内容特征层及loss加权系数
content_layers = {'5': 0.5, '10': 0.5} # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性
# 风格特征层及loss加权系数
style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2} # 使用vgg不同深度的风格特征,生成风格更加层次丰富
transform = make_transform(size=(300, 450), normalize=True) # 图像变换
image_style = load_image('./data/udnie.jpg', transform=transform).to(device) # 风格图像
vgg = VGG(content_layers, style_layers).to(device) # 特征提取网络,只用来提取特征,不进行训练
model = TransNet(input_size=(300, 450)).to(device) # 内容生成网络,用于生成风格图片,进行训练
model.load_state_dict(torch.load('./models/udnie.pth', map_location=device))
def process_images(image) :
image = transform(image).to(device)
model.to(device)
batch_generated = model(image)
batch_generated = denormalize(batch_generated).detach().cpu()
batch_generated = transforms.ToPILImage()(batch_generated)
return batch_generated
# 创建 Gradio 接口
demo = gr.Interface(
fn=process_images,
inputs=gr.Image(type="pil", label="输入图像"),
outputs=gr.Image(type="pil", label="生成图像")
)
# 启动 Gradio 界面
demo.launch(share=True)