sztanki commited on
Commit
271f7dc
·
1 Parent(s): 57ab410

Remove unused models

Browse files
Files changed (1) hide show
  1. app.py +5 -76
app.py CHANGED
@@ -69,15 +69,6 @@ mean_latent = original_generator.mean_latent(10000)
69
  #MODELS
70
  generatorzombie = deepcopy(original_generator)
71
  generatorjojo = deepcopy(original_generator)
72
- generatordisney = deepcopy(original_generator)
73
- generatorjinx = deepcopy(original_generator)
74
- generatorcaitlyn = deepcopy(original_generator)
75
- generatoryasuho = deepcopy(original_generator)
76
- generatorarcanemulti = deepcopy(original_generator)
77
- generatorart = deepcopy(original_generator)
78
- generatorspider = deepcopy(original_generator)
79
- generatorsketch = deepcopy(original_generator)
80
-
81
 
82
  transform = transforms.Compose(
83
  [
@@ -89,7 +80,7 @@ transform = transforms.Compose(
89
 
90
 
91
  #ZOMBIE
92
- modelzombie = hf_hub_download(repo_id="Awesimo/jojogan-zombie", filename="jojo_preserve_color.pt")
93
  ckptzombie = torch.load(modelzombie, map_location=lambda storage, loc: storage)
94
  generatorzombie.load_state_dict(ckptzombie["g"], strict=False)
95
 
@@ -98,47 +89,6 @@ modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_prese
98
  ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
99
  generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
100
 
101
- #DISNEY
102
- modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
103
- ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
104
- generatordisney.load_state_dict(ckptdisney["g"], strict=False)
105
-
106
- #JINX
107
- modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
108
- ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
109
- generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
110
-
111
- #CAITLYN
112
- modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
113
- ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
114
- generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
115
-
116
- #YASHUO
117
- modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
118
- ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
119
- generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
120
-
121
- #ARCANE
122
- model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
123
- ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
124
- generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
125
-
126
- #ART
127
- modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
128
- ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
129
- generatorart.load_state_dict(ckptart["g"], strict=False)
130
-
131
- #SPIDER
132
- modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
133
- ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
134
- generatorspider.load_state_dict(ckptspider["g"], strict=False)
135
-
136
- #SKETCH
137
- modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
138
- ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
139
- generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
140
-
141
-
142
  def inference(img, model):
143
  img.save('out.jpg')
144
  aligned_face = align_face('out.jpg')
@@ -149,31 +99,10 @@ def inference(img, model):
149
  my_sample = generatorzombie(my_w, input_is_latent=True)
150
  elif model == 'JoJo':
151
  with torch.no_grad():
152
- my_sample = generatordisney(my_w, input_is_latent=True)
153
- elif model == 'Disney':
154
- with torch.no_grad():
155
- my_sample = generatordisney(my_w, input_is_latent=True)
156
- elif model == 'Jinx':
157
- with torch.no_grad():
158
- my_sample = generatorjinx(my_w, input_is_latent=True)
159
- elif model == 'Caitlyn':
160
- with torch.no_grad():
161
- my_sample = generatorcaitlyn(my_w, input_is_latent=True)
162
- elif model == 'Yasuho':
163
- with torch.no_grad():
164
- my_sample = generatoryasuho(my_w, input_is_latent=True)
165
- elif model == 'Arcane Multi':
166
- with torch.no_grad():
167
- my_sample = generatorarcanemulti(my_w, input_is_latent=True)
168
- elif model == 'Art':
169
- with torch.no_grad():
170
- my_sample = generatorart(my_w, input_is_latent=True)
171
- elif model == 'Spider-Verse':
172
- with torch.no_grad():
173
- my_sample = generatorspider(my_w, input_is_latent=True)
174
  else:
175
  with torch.no_grad():
176
- my_sample = generatorsketch(my_w, input_is_latent=True)
177
 
178
 
179
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
@@ -181,5 +110,5 @@ def inference(img, model):
181
  return 'filename.jpeg'
182
 
183
  title = "JoJoGAN Test 🤖"
184
- examples=[['assets/samples/image01.jpg','JoJo'],['assets/samples/image02.jpg','Disney'],['assets/samples/image03.jpg','Jinx'],['assets/samples/image04.jpg','Sketch']]
185
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['Zombie', 'JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='Zombie', label="Model")], gr.outputs.Image(type="file"),title=title,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
 
69
  #MODELS
70
  generatorzombie = deepcopy(original_generator)
71
  generatorjojo = deepcopy(original_generator)
 
 
 
 
 
 
 
 
 
72
 
73
  transform = transforms.Compose(
74
  [
 
80
 
81
 
82
  #ZOMBIE
83
+ modelzombie = hf_hub_download(repo_id="Awesimo/jojogan-zombie", filename="jojo_zombie_preserve_color.pt")
84
  ckptzombie = torch.load(modelzombie, map_location=lambda storage, loc: storage)
85
  generatorzombie.load_state_dict(ckptzombie["g"], strict=False)
86
 
 
89
  ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
90
  generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def inference(img, model):
93
  img.save('out.jpg')
94
  aligned_face = align_face('out.jpg')
 
99
  my_sample = generatorzombie(my_w, input_is_latent=True)
100
  elif model == 'JoJo':
101
  with torch.no_grad():
102
+ my_sample = generatorjojo(my_w, input_is_latent=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  else:
104
  with torch.no_grad():
105
+ my_sample = generatorzombie(my_w, input_is_latent=True)
106
 
107
 
108
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
 
110
  return 'filename.jpeg'
111
 
112
  title = "JoJoGAN Test 🤖"
113
+ examples=[['assets/samples/image01.jpg','Zombie'],['assets/samples/image02.jpg','JoJo'],['assets/samples/image03.jpg','Zombie'],['assets/samples/image04.jpg','JoJo']]
114
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['Zombie', 'JoJo'], type="value", default='Zombie', label="Model")], gr.outputs.Image(type="file"),title=title,allow_flagging=False,examples=examples,allow_screenshot=False).launch()