cthleen commited on
Commit
fdfadb4
·
1 Parent(s): fb3c7bc

add generator

Browse files
Files changed (5) hide show
  1. app.py +2 -3
  2. dcgan_generator.py +1 -1
  3. generator.py +53 -30
  4. progan_generator.py +32 -0
  5. stylegan2_generator.py +33 -0
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
- from generator import generate_stylegan2
3
- from stylegan_generator import generate_stylegan
4
- from dcgan_generator import generate_dcgan
5
 
6
  def generate_image(model_name):
7
  if model_name == "GAN / VanillaGAN":
 
1
  import gradio as gr
2
+ from stylegan2_generator import generate_stylegan2
3
+ from generator import generate_gan, generate_dcgan, generate_progan, generate_stylegan
 
4
 
5
  def generate_image(model_name):
6
  if model_name == "GAN / VanillaGAN":
dcgan_generator.py CHANGED
@@ -4,7 +4,7 @@ import onnxruntime as ort
4
  from PIL import Image
5
 
6
  LATENT_FEATURES = 512
7
- MODEL_PATH = os.path.join("model", "dcgan.onnx")
8
 
9
  model = ort.InferenceSession(MODEL_PATH)
10
  input_name = model.get_inputs()[0].name
 
4
  from PIL import Image
5
 
6
  LATENT_FEATURES = 512
7
+ MODEL_PATH = os.path.join("model", "batik_dcgan.onnx")
8
 
9
  model = ort.InferenceSession(MODEL_PATH)
10
  input_name = model.get_inputs()[0].name
generator.py CHANGED
@@ -1,33 +1,56 @@
1
- import os
2
- import sys
3
-
4
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
5
- STYLEGAN2_DIR = os.path.join(BASE_DIR, "stylegan2")
6
- MODEL_PATH = os.path.join(BASE_DIR, "model", "network-snapshot-000560.pkl")
7
-
8
- sys.path.append(STYLEGAN2_DIR)
9
-
10
- import torch
11
- import legacy
12
- import dnnlib
13
  import numpy as np
 
14
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- torch.autograd.set_grad_enabled(False)
17
- torch.backends.cudnn.benchmark = True
18
-
19
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
-
21
- with dnnlib.util.open_url(MODEL_PATH) as f:
22
- G = legacy.load_network_pkl(f)['G_ema'].to(device)
23
-
24
- def generate_stylegan2():
25
- seed = np.random.randint(0, 2**32)
26
- z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
27
- label = torch.zeros([1, G.c_dim], device=device)
28
- img = G(z, label, truncation_psi=1.0, noise_mode='const')
29
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
30
-
31
- pil_img = Image.fromarray(img[0].cpu().numpy(), 'RGB')
32
- resized = pil_img.resize((512, 512), Image.LANCZOS)
33
- return resized
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ import onnxruntime as ort
3
  from PIL import Image
4
+ import os
5
+ import math
6
+
7
+ def out(output):
8
+ image = output.squeeze()
9
+ image = (image * 0.5 + 0.5) * 255
10
+ image = image.astype(np.uint8)
11
+ image = np.transpose(image, (1, 2, 0))
12
+ return Image.fromarray(image, "RGB").resize((512, 512), Image.LANCZOS)
13
+
14
+ def generate_gan():
15
+ return Image.new("RGB", (512, 512), color="gray")
16
+
17
+ def generate_progan():
18
+ model_path = os.path.join("model", "batik_progan.onnx")
19
+ session = ort.InferenceSession(model_path)
20
+
21
+ noise = np.random.randn(1, 512, 1, 1).astype(np.float32)
22
+ alpha = np.array([1.0], dtype=np.float32)
23
+
24
+ output = session.run(None, {
25
+ 'z': noise,
26
+ 'alpha': alpha,
27
+ })
28
+ return out(output[0])
29
+
30
+ def generate_dcgan():
31
+ model_path = os.path.join("model", "batik_dcgan.onnx")
32
+ session = ort.InferenceSession(model_path)
33
+ noise = np.random.randn(1, 512, 1, 1).astype(np.float32)
34
+ input_name = session.get_inputs()[0].name
35
+ output = session.run(None, {input_name: noise})
36
+ return out(output[0])
37
+
38
+ def generate_stylegan():
39
+ model_path = os.path.join("model", "batik_stylegan.onnx")
40
+ session = ort.InferenceSession(model_path)
41
+
42
+ LATENT_FEATURES = 512
43
+ RESOLUTION = 256
44
+ LAST_INDEX = math.log2(RESOLUTION) - 2
45
+
46
+ z = np.random.randn(1, LATENT_FEATURES).astype(np.float32)
47
+ alpha = np.array([1.0], dtype=np.float32)
48
+ steps = np.array([LAST_INDEX], dtype=np.int64)
49
+
50
+ output = session.run(None, {
51
+ 'z': z,
52
+ 'alpha': alpha,
53
+ 'steps': steps
54
+ })
55
+ return out(output[0])
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
progan_generator.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from PIL import Image
6
+
7
+ LATENT_FEATURES = 512
8
+ RESOLUTION = 256
9
+ LAST_INDEX = math.log2(RESOLUTION) - 2
10
+
11
+ MODEL_PATH = os.path.join("model", "batik_stylegan.onnx")
12
+ model = ort.InferenceSession(MODEL_PATH)
13
+
14
+ alpha = np.array([1.0], dtype=np.float32)
15
+ steps = np.array([LAST_INDEX], dtype=np.int64)
16
+
17
+ def generate_stylegan():
18
+ z = np.random.randn(1, LATENT_FEATURES).astype(np.float32)
19
+
20
+ output = model.run(None, {
21
+ 'z': z,
22
+ 'alpha': alpha,
23
+ 'steps': steps
24
+ })[0]
25
+
26
+ image = output.squeeze(0)
27
+ image = (image * 0.5 + 0.5) * 255
28
+ image = image.astype(np.uint8)
29
+ image = np.transpose(image, (1, 2, 0))
30
+ pil_img = Image.fromarray(image, 'RGB')
31
+
32
+ return pil_img.resize((512, 512), Image.LANCZOS)
stylegan2_generator.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
5
+ STYLEGAN2_DIR = os.path.join(BASE_DIR, "stylegan2")
6
+ MODEL_PATH = os.path.join(BASE_DIR, "model", "network-snapshot-000560.pkl")
7
+
8
+ sys.path.append(STYLEGAN2_DIR)
9
+
10
+ import torch
11
+ import legacy
12
+ import dnnlib
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
+ torch.autograd.set_grad_enabled(False)
17
+ torch.backends.cudnn.benchmark = True
18
+
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ with dnnlib.util.open_url(MODEL_PATH) as f:
22
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
23
+
24
+ def generate_stylegan2():
25
+ seed = np.random.randint(0, 2**32)
26
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
27
+ label = torch.zeros([1, G.c_dim], device=device)
28
+ img = G(z, label, truncation_psi=1.0, noise_mode='const')
29
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
30
+
31
+ pil_img = Image.fromarray(img[0].cpu().numpy(), 'RGB')
32
+ resized = pil_img.resize((512, 512), Image.LANCZOS)
33
+ return resized