WZT006 commited on
Commit
95ec8d7
·
1 Parent(s): cc4ab88

add application file

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.utils import data
7
+ from torchvision import transforms, utils
8
+ from tqdm import tqdm
9
+ torch.backends.cudnn.benchmark = True
10
+ import copy
11
+ from util import *
12
+ from PIL import Image
13
+
14
+ from model import *
15
+ import moviepy.video.io.ImageSequenceClip
16
+ import scipy
17
+ import kornia.augmentation as K
18
+
19
+ from base64 import b64encode
20
+ import gradio as gr
21
+ from torchvision import transforms
22
+
23
+ # torch.hub.download_url_to_file('https://i.imgur.com/HiOTPNg.png', 'mona.png')
24
+ # torch.hub.download_url_to_file('https://i.imgur.com/Cw8HcTN.png', 'painting.png')
25
+
26
+ device = 'cpu'
27
+ latent_dim = 8
28
+ n_mlp = 5
29
+ num_down = 3
30
+
31
+ G_A2B = Generator(256, 4, latent_dim, n_mlp, channel_multiplier=1, lr_mlp=.01,n_res=1).to(device).eval()
32
+
33
+ ensure_checkpoint_exists('GNR_checkpoint_full.pt')
34
+ ckpt = torch.load('GNR_checkpoint_full.pt', map_location=device)
35
+
36
+ G_A2B.load_state_dict(ckpt['G_A2B_ema'])
37
+
38
+ # mean latent
39
+ truncation = 1
40
+ with torch.no_grad():
41
+ mean_style = G_A2B.mapping(torch.randn([1000, latent_dim]).to(device)).mean(0, keepdim=True)
42
+
43
+
44
+ test_transform = transforms.Compose([
45
+ transforms.Resize((256, 256)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
48
+ ])
49
+ plt.rcParams['figure.dpi'] = 200
50
+
51
+ # torch.manual_seed(84986)
52
+
53
+ num_styles = 1
54
+ style = torch.randn([num_styles, latent_dim]).to(device)
55
+
56
+
57
+ def inference(input_im):
58
+ real_A = test_transform(input_im).unsqueeze(0).to(device)
59
+
60
+ with torch.no_grad():
61
+ A2B_content, _ = G_A2B.encode(real_A)
62
+ #fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles,1,1,1), style)
63
+ fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles,1,1,1), torch.randn([num_styles, latent_dim]).to(device))
64
+ std=(0.5, 0.5, 0.5)
65
+ mean=(0.5, 0.5, 0.5)
66
+ z = fake_A2B * torch.tensor(std).view(3, 1, 1)
67
+ z = z + torch.tensor(mean).view(3, 1, 1)
68
+ tensor_to_pil = transforms.ToPILImage(mode='RGB')(z.squeeze())
69
+ return tensor_to_pil
70
+
71
+ title = "GANsNRoses"
72
+ article = "<p style='text-align: center'>GANs N' Roses: Image to Iamge Translation | Obtained from :<a href='https://github.com/mchong6/GANsNRoses'>Github Repo</a></p>"
73
+ demo = gr.Interface(
74
+ inference,
75
+ [gr.inputs.Image(type="pil", label="Input")],
76
+ gr.outputs.Image(type="pil"),
77
+ title=title,
78
+ # description=description,
79
+ article=article,
80
+ allow_flagging = "never",
81
+ )
82
+
83
+ demo.launch(share = True)