import gradio as gr import torch import torchvision.transforms as transforms import os from PIL import Image import mynet transform = transform = transforms.Compose([ transforms.ToTensor() ]) models = { "kodak_gold_200.pth", "fuji_color_200.pth", "kodak_E100.pth" } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_WIDTH = 2457 MAX_HEIGHT = 1843 def resize_to_nearest_multiple_of_32(image): width, height = image.size if width > MAX_WIDTH and height > MAX_HEIGHT: gr.Error("Image too large!") return image # 计算最接近的可以被32整除的宽度和高度 new_width = (width + 31) // 32 * 32 # 向上取整至最接近32的倍数 new_height = (height + 31) // 32 * 32 # 向上取整至最接近32的倍数 # 调整图像大小 resized_image = image.resize((new_width, new_height), Image.BILINEAR) return resized_image def load_model(model_path): if model_path[-4:] == ".pth": model = mynet.ResNetGenerator(3,3) model.load_state_dict(torch.load(model_path, map_location=device)) else: model = mynet.FilmStyleTransfer() checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) print("[INFO] Open model successfully!") return model def process_images(image, model_choice): width, height = image.size if width > MAX_WIDTH and height > MAX_HEIGHT: raise gr.Error("Image too large!") image = resize_to_nearest_multiple_of_32(image) image = transform(image) print(os.path.join("models", model_choice)) model = load_model(os.path.join("models", model_choice)) model.eval() with torch.no_grad(): image = image.unsqueeze(0) image = image.to(device) output, _ = model(image) output = output.squeeze().cpu().clamp(0, 1) output = transforms.ToPILImage()(output) return output def main(): with gr.Blocks(title="DigitalFilm App") as demo: image_input = gr.Image(type="pil", label="Upload Image(w <= 2457 & h <= 1843)") model_choice = gr.Dropdown(models, label="Select Model", allow_custom_value=False) image_output = gr.Image(type="pil", label="Generated Image") run_button = gr.Button("Run Model") run_button.click(process_images, inputs=[image_input, model_choice], outputs=image_output) demo.launch() if __name__ == "__main__": main()