line2image / app.py
gdo's picture
Create app.py
33c4967 verified
import numpy as np
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import gradio as gr
# Load a pre-trained GAN model (e.g., DCGAN)
# You can replace this with any other GAN or VAE model
generator = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)
def generate_image(input_image):
# Convert the input image to a PIL Image
input_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
# Preprocess the image (resize, normalize, etc.)
preprocess = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = preprocess(input_image).unsqueeze(0)
# Generate an image using the GAN model
with torch.no_grad():
generated_tensor = generator(input_tensor)
# Convert the generated tensor back to an image
generated_image = transforms.ToPILImage()(generated_tensor.squeeze(0))
return generated_image
def draw_line(image):
# Convert the image to a numpy array
image_np = np.array(image)
# Generate a similar image using the GAN model
generated_image = generate_image(image_np)
return generated_image
# Create the Gradio interface
iface = gr.Interface(
fn=draw_line,
inputs="sketchpad",
outputs="image",
live=True,
title="Draw a Line and Generate a Similar Image",
description="Draw a line on the sketchpad, and the app will generate a similar image using a GAN model."
)
# Launch the app
iface.launch()