sztanki commited on
Commit
169a997
Β·
1 Parent(s): 6b608d5

Add zombie model

Browse files
app.py CHANGED
@@ -15,10 +15,6 @@ from torch.nn import functional as F
15
  from tqdm import tqdm
16
  import lpips
17
  from model import *
18
-
19
-
20
- #from e4e_projection import projection as e4e_projection
21
-
22
  from copy import deepcopy
23
  import imageio
24
 
@@ -44,7 +40,6 @@ net = pSp(opts, device).eval().to(device)
44
  @ torch.no_grad()
45
  def projection(img, name, device='cuda'):
46
 
47
-
48
  transform = transforms.Compose(
49
  [
50
  transforms.Resize(256),
@@ -60,12 +55,8 @@ def projection(img, name, device='cuda'):
60
  torch.save(result_file, name)
61
  return w_plus[0]
62
 
63
-
64
-
65
-
66
  device = 'cpu'
67
 
68
-
69
  latent_dim = 512
70
 
71
  model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
@@ -74,22 +65,17 @@ ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
74
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
  mean_latent = original_generator.mean_latent(10000)
76
 
77
- generatorjojo = deepcopy(original_generator)
78
 
 
 
 
79
  generatordisney = deepcopy(original_generator)
80
-
81
  generatorjinx = deepcopy(original_generator)
82
-
83
  generatorcaitlyn = deepcopy(original_generator)
84
-
85
  generatoryasuho = deepcopy(original_generator)
86
-
87
  generatorarcanemulti = deepcopy(original_generator)
88
-
89
  generatorart = deepcopy(original_generator)
90
-
91
  generatorspider = deepcopy(original_generator)
92
-
93
  generatorsketch = deepcopy(original_generator)
94
 
95
 
@@ -102,69 +88,68 @@ transform = transforms.Compose(
102
  )
103
 
104
 
 
 
 
 
105
 
106
-
107
  modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
108
-
109
-
110
  ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
111
  generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
112
 
113
-
114
  modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
115
-
116
  ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
117
  generatordisney.load_state_dict(ckptdisney["g"], strict=False)
118
 
119
-
120
  modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
121
-
122
  ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
123
  generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
124
 
125
-
126
  modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
127
-
128
  ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
129
  generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
130
 
131
-
132
  modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
133
-
134
  ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
135
  generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
136
 
137
-
138
  model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
139
-
140
  ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
141
  generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
142
 
143
-
144
  modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
145
-
146
  ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
147
  generatorart.load_state_dict(ckptart["g"], strict=False)
148
 
149
-
150
  modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
151
-
152
  ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
153
  generatorspider.load_state_dict(ckptspider["g"], strict=False)
154
 
 
155
  modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
156
-
157
  ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
158
  generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
159
 
 
160
  def inference(img, model):
161
  img.save('out.jpg')
162
  aligned_face = align_face('out.jpg')
163
 
164
  my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
165
- if model == 'JoJo':
166
  with torch.no_grad():
167
- my_sample = generatorjojo(my_w, input_is_latent=True)
 
 
 
168
  elif model == 'Disney':
169
  with torch.no_grad():
170
  my_sample = generatordisney(my_w, input_is_latent=True)
@@ -196,5 +181,5 @@ def inference(img, model):
196
  return 'filename.jpeg'
197
 
198
  title = "JoJoGAN Test πŸ€–"
199
- examples=[['assets/image01.jpg','JoJo'],['assets/image02.jpg','Disney'],['assets/image03.jpg','Jinx'],['assets/image04.jpg','Sketch']]
200
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
 
15
  from tqdm import tqdm
16
  import lpips
17
  from model import *
 
 
 
 
18
  from copy import deepcopy
19
  import imageio
20
 
 
40
  @ torch.no_grad()
41
  def projection(img, name, device='cuda'):
42
 
 
43
  transform = transforms.Compose(
44
  [
45
  transforms.Resize(256),
 
55
  torch.save(result_file, name)
56
  return w_plus[0]
57
 
 
 
 
58
  device = 'cpu'
59
 
 
60
  latent_dim = 512
61
 
62
  model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
 
65
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
66
  mean_latent = original_generator.mean_latent(10000)
67
 
 
68
 
69
+ #MODELS
70
+ generatorzombie = deepcopy(original_generator)
71
+ generatorjojo = deepcopy(original_generator)
72
  generatordisney = deepcopy(original_generator)
 
73
  generatorjinx = deepcopy(original_generator)
 
74
  generatorcaitlyn = deepcopy(original_generator)
 
75
  generatoryasuho = deepcopy(original_generator)
 
76
  generatorarcanemulti = deepcopy(original_generator)
 
77
  generatorart = deepcopy(original_generator)
 
78
  generatorspider = deepcopy(original_generator)
 
79
  generatorsketch = deepcopy(original_generator)
80
 
81
 
 
88
  )
89
 
90
 
91
+ #ZOMBIE
92
+ modelzombie = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
93
+ ckptzombie = torch.load(modelzombie, map_location=lambda storage, loc: storage)
94
+ generatorzombie.load_state_dict(ckptzombie["g"], strict=False)
95
 
96
+ #JOJO
97
  modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
 
 
98
  ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
99
  generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
100
 
101
+ #DISNEY
102
  modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
 
103
  ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
104
  generatordisney.load_state_dict(ckptdisney["g"], strict=False)
105
 
106
+ #JINX
107
  modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
 
108
  ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
109
  generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
110
 
111
+ #CAITLYN
112
  modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
 
113
  ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
114
  generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
115
 
116
+ #YASHUO
117
  modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
 
118
  ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
119
  generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
120
 
121
+ #ARCANE
122
  model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
 
123
  ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
124
  generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
125
 
126
+ #ART
127
  modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
 
128
  ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
129
  generatorart.load_state_dict(ckptart["g"], strict=False)
130
 
131
+ #SPIDER
132
  modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
 
133
  ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
134
  generatorspider.load_state_dict(ckptspider["g"], strict=False)
135
 
136
+ #SKETCH
137
  modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
 
138
  ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
139
  generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
140
 
141
+
142
  def inference(img, model):
143
  img.save('out.jpg')
144
  aligned_face = align_face('out.jpg')
145
 
146
  my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
147
+ if model == 'Zombie':
148
  with torch.no_grad():
149
+ my_sample = generatorzombie(my_w, input_is_latent=True)
150
+ elif model == 'JoJo':
151
+ with torch.no_grad():
152
+ my_sample = generatordisney(my_w, input_is_latent=True)
153
  elif model == 'Disney':
154
  with torch.no_grad():
155
  my_sample = generatordisney(my_w, input_is_latent=True)
 
181
  return 'filename.jpeg'
182
 
183
  title = "JoJoGAN Test πŸ€–"
184
+ examples=[['assets/samples/image01.jpg','JoJo'],['assets/samples/image02.jpg','Disney'],['assets/samples/image03.jpg','Jinx'],['assets/samples/image04.jpg','Sketch']]
185
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='Zombie', label="Model")], gr.outputs.Image(type="file"),title=title,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
assets/references/zombie/image01.jpg ADDED
assets/references/zombie/image02.jpg ADDED
assets/references/zombie/image03.jpg ADDED
assets/references/zombie/image04.jpg ADDED
assets/references/zombie/image05.jpg ADDED
assets/references/zombie/image06.jpg ADDED
assets/{image01.jpg β†’ samples/image01.jpg} RENAMED
File without changes
assets/{image02.jpg β†’ samples/image02.jpg} RENAMED
File without changes
assets/{image03.jpg β†’ samples/image03.jpg} RENAMED
File without changes
assets/{image04.jpg β†’ samples/image04.jpg} RENAMED
File without changes