riciii7 commited on
Commit
d843d07
·
verified ·
1 Parent(s): bfc3e04

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +30 -20
utils.py CHANGED
@@ -1,5 +1,4 @@
1
- from stylegan_model import StyleGAN
2
- from vanillagan_model import VanillaGAN
3
  import torch
4
  from io import BytesIO
5
  from torchvision.utils import save_image
@@ -7,19 +6,16 @@ import numpy as np
7
  import legacy
8
  from PIL import Image
9
  import time
 
10
 
11
  LATENT_FEATURES = 512
12
  RESOLUTION = 128
13
 
14
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- def load_model_pt(path='model_128.pt',model_type='stylegan'):
16
- if model_type == "stylegan":
17
- model = StyleGAN(LATENT_FEATURES, RESOLUTION).to(DEVICE)
18
- last_checkpoint = torch.load(path, map_location=DEVICE)
19
- model.load_state_dict(last_checkpoint['generator'], strict=False)
20
- elif model_type == "vanillagan":
21
- model = VanillaGAN(RESOLUTION, LATENT_FEATURES).to(DEVICE)
22
- model.load_state_dict(torch.load(path, map_location=DEVICE))
23
  model.eval()
24
  return model
25
 
@@ -33,16 +29,6 @@ def generate_image_stylegan(generator, steps=5, alpha=1.0):
33
  save_image(image, buffer, format='PNG')
34
  buffer.seek(0)
35
  return buffer
36
-
37
- def generate_image_vanillagan(generator):
38
- with torch.no_grad():
39
- image = generator(torch.randn(1, LATENT_FEATURES, device=DEVICE)).view(1, 3, RESOLUTION, RESOLUTION)
40
- image = (image * 0.5 + 0.5).clamp(0, 1)
41
-
42
- buffer = BytesIO()
43
- save_image(image, buffer, format='PNG')
44
- buffer.seek(0)
45
- return buffer
46
 
47
  def load_model_pkl(path='styleganv2.pkl'):
48
  with open(path, 'rb') as f:
@@ -68,3 +54,27 @@ def generate_image_from_pkl(generator, seed=0, trunc=1):
68
  print(f"Image generation time: {end - start:.2f} seconds")
69
 
70
  return buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import StyleGAN
 
2
  import torch
3
  from io import BytesIO
4
  from torchvision.utils import save_image
 
6
  import legacy
7
  from PIL import Image
8
  import time
9
+ import onnxruntime as ort
10
 
11
  LATENT_FEATURES = 512
12
  RESOLUTION = 128
13
 
14
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ def load_model_pt(path='model_128.pt'):
16
+ model = StyleGAN(LATENT_FEATURES, RESOLUTION).to(DEVICE)
17
+ last_checkpoint = torch.load(path, map_location=DEVICE)
18
+ model.load_state_dict(last_checkpoint['generator'], strict=False)
 
 
 
 
19
  model.eval()
20
  return model
21
 
 
29
  save_image(image, buffer, format='PNG')
30
  buffer.seek(0)
31
  return buffer
 
 
 
 
 
 
 
 
 
 
32
 
33
  def load_model_pkl(path='styleganv2.pkl'):
34
  with open(path, 'rb') as f:
 
54
  print(f"Image generation time: {end - start:.2f} seconds")
55
 
56
  return buffer
57
+
58
+ def generate_image_from_onnx(path='model_128.onnx', model=None):
59
+ if model is None:
60
+ return ValueError("Model not provided.")
61
+ if model == 'progan':
62
+ z = np.random.randn(1, 512, 1, 1).astype(np.float32)
63
+ else:
64
+ z = np.random.randn(1, 512).astype(np.float32)
65
+ inference_session = ort.InferenceSession(path)
66
+ input_name = inference_session.get_inputs()[0].name
67
+
68
+ image = inference_session.run(None, {input_name: z})[0]
69
+
70
+ image = image.squeeze(0)
71
+ image = (image * 0.5 + 0.5) * 255
72
+ image = image.astype(np.uint8)
73
+ image = np.transpose(image, (1, 2, 0))
74
+ image = Image.fromarray(image, 'RGB')
75
+
76
+ buffer = BytesIO()
77
+ image.save(buffer, format='PNG')
78
+ buffer.seek(0)
79
+
80
+ return buffer