gdo commited on
Commit
33c4967
·
verified ·
1 Parent(s): 3b9f0b1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.models as models
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import gradio as gr
7
+
8
+ # Load a pre-trained GAN model (e.g., DCGAN)
9
+ # You can replace this with any other GAN or VAE model
10
+ generator = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)
11
+
12
+ def generate_image(input_image):
13
+ # Convert the input image to a PIL Image
14
+ input_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
15
+
16
+ # Preprocess the image (resize, normalize, etc.)
17
+ preprocess = transforms.Compose([
18
+ transforms.Resize((64, 64)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
21
+ ])
22
+
23
+ input_tensor = preprocess(input_image).unsqueeze(0)
24
+
25
+ # Generate an image using the GAN model
26
+ with torch.no_grad():
27
+ generated_tensor = generator(input_tensor)
28
+
29
+ # Convert the generated tensor back to an image
30
+ generated_image = transforms.ToPILImage()(generated_tensor.squeeze(0))
31
+
32
+ return generated_image
33
+
34
+ def draw_line(image):
35
+ # Convert the image to a numpy array
36
+ image_np = np.array(image)
37
+
38
+ # Generate a similar image using the GAN model
39
+ generated_image = generate_image(image_np)
40
+
41
+ return generated_image
42
+
43
+ # Create the Gradio interface
44
+ iface = gr.Interface(
45
+ fn=draw_line,
46
+ inputs="sketchpad",
47
+ outputs="image",
48
+ live=True,
49
+ title="Draw a Line and Generate a Similar Image",
50
+ description="Draw a line on the sketchpad, and the app will generate a similar image using a GAN model."
51
+ )
52
+
53
+ # Launch the app
54
+ iface.launch()