Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| from datetime import datetime, timedelta | |
| from typing import List, Iterable | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| import cv2 | |
| import numpy as np | |
| from models import VGG, TransNet | |
| from datasets import COCODataset | |
| from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \ | |
| denormalize | |
| import gradio as gr | |
| from typing import List, Iterable | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # ----------------路径参数---------------- | |
| # 内容特征层及loss加权系数 | |
| content_layers = {'5': 0.5, '10': 0.5} # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性 | |
| # 风格特征层及loss加权系数 | |
| style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2} # 使用vgg不同深度的风格特征,生成风格更加层次丰富 | |
| transform = make_transform(size=(300, 450), normalize=True) # 图像变换 | |
| image_style = load_image('./data/udnie.jpg', transform=transform).to(device) # 风格图像 | |
| vgg = VGG(content_layers, style_layers).to(device) # 特征提取网络,只用来提取特征,不进行训练 | |
| model = TransNet(input_size=(300, 450)).to(device) # 内容生成网络,用于生成风格图片,进行训练 | |
| model.load_state_dict(torch.load('./models/udnie.pth', map_location=device)) | |
| def process_images(image) : | |
| image = transform(image).to(device) | |
| model.to(device) | |
| batch_generated = model(image) | |
| batch_generated = denormalize(batch_generated).detach().cpu() | |
| batch_generated = transforms.ToPILImage()(batch_generated) | |
| return batch_generated | |
| # 创建 Gradio 接口 | |
| demo = gr.Interface( | |
| fn=process_images, | |
| inputs=gr.Image(type="pil", label="输入图像"), | |
| outputs=gr.Image(type="pil", label="生成图像") | |
| ) | |
| # 启动 Gradio 界面 | |
| demo.launch(share=True) | |