Richards-Sheehy-sudo's picture
Upload app.py
9421fed verified
import gradio as gr
import torch
import torchvision.transforms as transforms
import os
from PIL import Image
import mynet
transform = transform = transforms.Compose([
transforms.ToTensor()
])
models = {
"kodak_gold_200.pth", "fuji_color_200.pth",
"kodak_E100.pth"
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_WIDTH = 2457
MAX_HEIGHT = 1843
def resize_to_nearest_multiple_of_32(image):
width, height = image.size
if width > MAX_WIDTH and height > MAX_HEIGHT:
gr.Error("Image too large!")
return image
# 计算最接近的可以被32整除的宽度和高度
new_width = (width + 31) // 32 * 32 # 向上取整至最接近32的倍数
new_height = (height + 31) // 32 * 32 # 向上取整至最接近32的倍数
# 调整图像大小
resized_image = image.resize((new_width, new_height), Image.BILINEAR)
return resized_image
def load_model(model_path):
if model_path[-4:] == ".pth":
model = mynet.ResNetGenerator(3,3)
model.load_state_dict(torch.load(model_path, map_location=device))
else:
model = mynet.FilmStyleTransfer()
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
print("[INFO] Open model successfully!")
return model
def process_images(image, model_choice):
width, height = image.size
if width > MAX_WIDTH and height > MAX_HEIGHT:
raise gr.Error("Image too large!")
image = resize_to_nearest_multiple_of_32(image)
image = transform(image)
print(os.path.join("models", model_choice))
model = load_model(os.path.join("models", model_choice))
model.eval()
with torch.no_grad():
image = image.unsqueeze(0)
image = image.to(device)
output, _ = model(image)
output = output.squeeze().cpu().clamp(0, 1)
output = transforms.ToPILImage()(output)
return output
def main():
with gr.Blocks(title="DigitalFilm App") as demo:
image_input = gr.Image(type="pil", label="Upload Image(w <= 2457 & h <= 1843)")
model_choice = gr.Dropdown(models, label="Select Model", allow_custom_value=False)
image_output = gr.Image(type="pil", label="Generated Image")
run_button = gr.Button("Run Model")
run_button.click(process_images, inputs=[image_input, model_choice], outputs=image_output)
demo.launch()
if __name__ == "__main__":
main()