| | import gradio as gr
|
| | import torch
|
| | import torchvision.transforms as transforms
|
| | import os
|
| | from PIL import Image
|
| |
|
| | import mynet
|
| |
|
| | transform = transform = transforms.Compose([
|
| | transforms.ToTensor()
|
| | ])
|
| |
|
| | models = {
|
| | "kodak_gold_200.pth", "fuji_color_200.pth",
|
| | "kodak_E100.pth"
|
| | }
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| |
|
| | MAX_WIDTH = 2457
|
| | MAX_HEIGHT = 1843
|
| |
|
| | def resize_to_nearest_multiple_of_32(image):
|
| | width, height = image.size
|
| |
|
| | if width > MAX_WIDTH and height > MAX_HEIGHT:
|
| | gr.Error("Image too large!")
|
| | return image
|
| |
|
| |
|
| | new_width = (width + 31) // 32 * 32
|
| | new_height = (height + 31) // 32 * 32
|
| |
|
| |
|
| | resized_image = image.resize((new_width, new_height), Image.BILINEAR)
|
| |
|
| | return resized_image
|
| |
|
| | def load_model(model_path):
|
| | if model_path[-4:] == ".pth":
|
| | model = mynet.ResNetGenerator(3,3)
|
| | model.load_state_dict(torch.load(model_path, map_location=device))
|
| | else:
|
| | model = mynet.FilmStyleTransfer()
|
| | checkpoint = torch.load(model_path, map_location=device)
|
| | model.load_state_dict(checkpoint['model_state_dict'])
|
| | model = model.to(device)
|
| | print("[INFO] Open model successfully!")
|
| | return model
|
| |
|
| | def process_images(image, model_choice):
|
| | width, height = image.size
|
| |
|
| | if width > MAX_WIDTH and height > MAX_HEIGHT:
|
| | raise gr.Error("Image too large!")
|
| |
|
| | image = resize_to_nearest_multiple_of_32(image)
|
| | image = transform(image)
|
| | print(os.path.join("models", model_choice))
|
| | model = load_model(os.path.join("models", model_choice))
|
| | model.eval()
|
| | with torch.no_grad():
|
| | image = image.unsqueeze(0)
|
| | image = image.to(device)
|
| | output, _ = model(image)
|
| | output = output.squeeze().cpu().clamp(0, 1)
|
| | output = transforms.ToPILImage()(output)
|
| | return output
|
| |
|
| | def main():
|
| | with gr.Blocks(title="DigitalFilm App") as demo:
|
| | image_input = gr.Image(type="pil", label="Upload Image(w <= 2457 & h <= 1843)")
|
| | model_choice = gr.Dropdown(models, label="Select Model", allow_custom_value=False)
|
| | image_output = gr.Image(type="pil", label="Generated Image")
|
| | run_button = gr.Button("Run Model")
|
| | run_button.click(process_images, inputs=[image_input, model_choice], outputs=image_output)
|
| |
|
| | demo.launch()
|
| |
|
| | if __name__ == "__main__":
|
| | main() |