Avanish11 commited on
Commit
dcdc448
·
verified ·
1 Parent(s): 5b80bc7

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +514 -177
inference.py CHANGED
@@ -9,11 +9,25 @@ import numpy as np
9
  from models.anime_gan import GeneratorV1
10
  from models.anime_gan_v2 import GeneratorV2
11
  from models.anime_gan_v3 import GeneratorV3
12
- from utils.common import load_checkpoint, RELEASED_WEIGHTS
13
- from utils.image_processing import resize_image, normalize_input, denormalize_input
14
- from utils import read_image, is_image_file, is_video_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from tqdm import tqdm
16
- from color_transfer import ColorTransfer
17
 
18
  try:
19
  import matplotlib.pyplot as plt
@@ -28,363 +42,657 @@ except ImportError:
28
  VideoFileClip = None
29
 
30
 
 
 
 
 
31
  def profile(func):
 
32
  def wrap(*args, **kwargs):
 
33
  started_at = time.time()
 
34
  result = func(*args, **kwargs)
 
35
  elapsed = time.time() - started_at
 
36
  print(f"Processed in {elapsed:.3f}s")
 
37
  return result
 
38
  return wrap
39
 
40
 
41
- def auto_load_weight(weight, version=None, map_location=None):
42
- """Auto load Generator version from weight."""
 
 
 
 
 
 
 
 
43
  weight_name = os.path.basename(weight).lower()
 
44
  if version is not None:
 
45
  version = version.lower()
46
- assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist"
47
- # If version is provided, use it.
 
 
 
 
 
48
  cls = {
49
  "v1": GeneratorV1,
50
  "v2": GeneratorV2,
51
  "v3": GeneratorV3
52
  }[version]
 
53
  else:
54
- # Try to get class by name of weight file
55
- # For convenenice, weight should start with classname
56
- # e.g: Generatorv2_{anything}.pt
57
  if weight_name in RELEASED_WEIGHTS:
58
- version = RELEASED_WEIGHTS[weight_name][0]
59
- return auto_load_weight(weight, version=version, map_location=map_location)
60
 
61
- elif weight_name.startswith("generatorv2"):
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  cls = GeneratorV2
63
- elif weight_name.startswith("generatorv3"):
 
 
 
 
64
  cls = GeneratorV3
65
- elif weight_name.startswith("generator"):
 
 
 
 
66
  cls = GeneratorV1
 
67
  else:
68
- raise ValueError((f"Can not get Model from {weight_name}, "
69
- "you might need to explicitly specify version"))
 
 
 
70
  model = cls()
71
- load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location)
 
 
 
 
 
 
 
72
  model.eval()
 
73
  return model
74
 
75
 
 
 
 
 
76
  class Predictor:
77
- """
78
- Generic class for transfering Image to anime like image.
79
- """
80
  def __init__(
81
  self,
82
- weight='hayao',
83
  device='cuda',
84
  amp=True,
85
  retain_color=False,
86
  imgsz=None,
87
  ):
 
88
  if not torch.cuda.is_available():
 
89
  device = 'cpu'
90
- # Amp not working on cpu
91
  amp = False
92
- print("Use CPU device")
 
 
93
  else:
94
- print(f"Use GPU {torch.cuda.get_device_name()}")
95
-
 
 
 
 
96
  self.imgsz = imgsz
 
97
  self.retain_color = retain_color
98
- self.amp = amp # Automatic Mixed Precision
99
- self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
 
 
 
 
 
 
 
100
  self.device = torch.device(device)
101
- self.G = auto_load_weight(weight, map_location=device)
 
 
 
 
 
102
  self.G.to(self.device)
103
 
 
 
 
 
104
  def transform_and_show(
105
  self,
106
  image_path,
107
  figsize=(18, 10),
108
  save_path=None
109
  ):
110
- image = resize_image(read_image(image_path))
 
 
 
 
111
  anime_img = self.transform(image)
 
112
  anime_img = anime_img.astype('uint8')
113
 
114
  fig = plt.figure(figsize=figsize)
 
115
  fig.add_subplot(1, 2, 1)
116
- # plt.title("Input")
117
  plt.imshow(image)
 
118
  plt.axis('off')
 
119
  fig.add_subplot(1, 2, 2)
120
- # plt.title("Anime style")
121
  plt.imshow(anime_img[0])
 
122
  plt.axis('off')
 
123
  plt.tight_layout()
 
124
  plt.show()
 
125
  if save_path is not None:
 
126
  plt.savefig(save_path)
127
 
128
- def transform(self, image, denorm=True):
129
- '''
130
- Transform a image to animation
131
 
132
- @Arguments:
133
- - image: np.array, shape = (Batch, width, height, channels)
 
 
 
134
 
135
- @Returns:
136
- - anime version of image: np.array
137
- '''
138
  with torch.no_grad():
 
139
  image = self.preprocess_images(image)
140
- # image = image.to(self.device)
141
- # with autocast(self.device_type, enabled=self.amp):
142
- # print(image.dtype, self.G)
143
  fake = self.G(image)
144
- # Transfer color of fake image look similiar color as image
145
- if self.retain_color:
146
- fake = color_transfer_pytorch(fake, image)
147
- fake = (fake / 0.5) - 1.0 # remap to [-1. 1]
148
  fake = fake.detach().cpu().numpy()
149
- # Channel last
150
- fake = fake.transpose(0, 2, 3, 1)
 
 
 
 
 
151
 
152
  if denorm:
153
- fake = denormalize_input(fake, dtype=np.uint8)
 
 
 
 
 
154
  return fake
155
 
156
- def read_and_resize(self, path, max_size=1536):
 
 
 
 
 
 
 
 
 
157
  image = read_image(path)
 
158
  _, ext = os.path.splitext(path)
 
159
  h, w = image.shape[:2]
 
160
  if self.imgsz is not None:
161
- image = resize_image(image, width=self.imgsz)
 
 
 
 
 
162
  elif max(h, w) > max_size:
163
- print(f"Image {os.path.basename(path)} is too big ({h}x{w}), resize to max size {max_size}")
 
 
 
 
 
164
  image = resize_image(
165
  image,
166
  width=max_size if w > h else None,
167
  height=max_size if w < h else None,
168
  )
169
- cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
 
 
 
 
 
170
  else:
 
171
  image = resize_image(image)
172
- # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
173
- # image = np.stack([image, image, image], -1)
174
- # cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
175
  return image
176
 
 
 
 
 
177
  @profile
178
- def transform_file(self, file_path, save_path):
 
 
 
 
 
179
  if not is_image_file(save_path):
180
- raise ValueError(f"{save_path} is not valid")
181
 
182
- image = self.read_and_resize(file_path)
 
 
 
 
 
 
 
183
  anime_img = self.transform(image)[0]
184
- cv2.imwrite(save_path, anime_img[..., ::-1])
185
- print(f"Anime image saved to {save_path}")
 
 
 
 
 
 
 
 
186
  return anime_img
187
 
 
 
 
 
188
  @profile
189
- def transform_gif(self, file_path, save_path, batch_size=4):
 
 
 
 
 
 
190
  import imageio
191
 
192
  def _preprocess_gif(img):
 
193
  if img.shape[-1] == 4:
194
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
 
 
 
 
 
195
  return resize_image(img)
196
 
197
  images = imageio.mimread(file_path)
 
198
  images = np.stack([
199
  _preprocess_gif(img)
200
  for img in images
201
  ])
202
 
203
- print(images.shape)
204
-
205
  anime_gif = np.zeros_like(images)
206
 
207
- for i in tqdm(range(0, len(images), batch_size)):
 
 
 
208
  end = i + batch_size
209
- anime_gif[i: end] = self.transform(
210
- images[i: end]
211
- )
212
 
213
- if end < len(images) - 1:
214
- # transform last frame
215
- print("LAST", images[end: ].shape)
216
- anime_gif[end:] = self.transform(images[end:])
217
 
218
- print(anime_gif.shape)
219
  imageio.mimsave(
220
  save_path,
221
  anime_gif,
222
-
223
  )
224
- print(f"Anime image saved to {save_path}")
 
 
 
 
 
 
 
225
 
226
  @profile
227
- def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)):
228
- '''
229
- Read all images from img_dir, transform and write the result
230
- to dest_dir
 
 
231
 
232
- '''
233
- os.makedirs(dest_dir, exist_ok=True)
 
 
234
 
235
  files = os.listdir(img_dir)
236
- files = [f for f in files if is_image_file(f)]
237
- print(f'Found {len(files)} images in {img_dir}')
 
 
 
238
 
239
  if max_images:
 
240
  files = files[:max_images]
241
 
242
  bar = tqdm(files)
 
243
  for fname in bar:
244
- path = os.path.join(img_dir, fname)
 
 
 
 
 
245
  image = self.read_and_resize(path)
 
246
  anime_img = self.transform(image)[0]
247
- # anime_img = resize_image(anime_img, width=320)
248
  ext = fname.split('.')[-1]
249
- fname = fname.replace(f'.{ext}', '')
250
- cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1])
251
- bar.set_description(f"{fname} {image.shape}")
252
-
253
- def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0):
254
- '''
255
- Transform a video to animation version
256
- https://github.com/lengstrom/fast-style-transfer/blob/master/evaluate.py#L21
257
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  if VideoFileClip is None:
259
- raise ImportError("moviepy is not installed, please install with `pip install moviepy>=1.0.3`")
260
- # Force to None
 
 
 
261
  end = end or None
262
 
263
  if not os.path.isfile(input_path):
264
- raise FileNotFoundError(f'{input_path} does not exist')
265
 
266
- output_dir = os.path.dirname(output_path)
267
- if output_dir:
268
- os.makedirs(output_dir, exist_ok=True)
 
 
 
 
269
 
270
- is_gg_drive = '/drive/' in output_path
271
- temp_file = ''
272
 
273
- if is_gg_drive:
274
- # Writing directly into google drive can be inefficient
275
- temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
 
276
 
277
- def transform_and_write(frames, count, writer):
278
- anime_images = self.transform(frames)
279
- for i in range(0, count):
280
- img = np.clip(anime_images[i], 0, 255)
281
- writer.write_frame(img)
282
 
283
- video_clip = VideoFileClip(input_path, audio=False)
284
  if start or end:
285
- video_clip = video_clip.subclip(start, end)
286
 
287
- video_writer = ffmpeg_writer.FFMPEG_VideoWriter(
288
- temp_file or output_path,
289
- video_clip.size, video_clip.fps,
290
- codec="libx264",
291
- # preset="medium", bitrate="2000k",
292
- ffmpeg_params=None)
 
 
 
 
 
 
 
 
293
 
294
- total_frames = round(video_clip.fps * video_clip.duration)
295
- print(f'Transfroming video {input_path}, {total_frames} frames, size: {video_clip.size}')
 
 
 
 
 
 
 
 
 
296
 
297
- batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3)
298
  frame_count = 0
299
- frames = np.zeros(batch_shape, dtype=np.float32)
300
- for frame in tqdm(video_clip.iter_frames(), total=total_frames):
301
- try:
302
- frames[frame_count] = frame
303
- frame_count += 1
304
- if frame_count == batch_size:
305
- transform_and_write(frames, frame_count, video_writer)
306
- frame_count = 0
307
- except Exception as e:
308
- print(e)
309
- break
310
-
311
- # The last frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  if frame_count != 0:
313
- transform_and_write(frames, frame_count, video_writer)
314
 
315
- if temp_file:
316
- # move to output path
317
- shutil.move(temp_file, output_path)
 
 
 
 
 
 
 
 
 
 
318
 
319
- print(f'Animation video saved to {output_path}')
320
  video_writer.close()
321
 
322
- def preprocess_images(self, images):
323
- '''
324
- Preprocess image for inference
325
 
326
- @Arguments:
327
- - images: np.ndarray
 
 
328
 
329
- @Returns
330
- - images: torch.tensor
331
- '''
332
- images = images.astype(np.float32)
333
 
334
- # Normalize to [-1, 1]
335
  images = normalize_input(images)
 
336
  images = torch.from_numpy(images)
337
 
338
  images = images.to(self.device)
339
 
340
- # Add batch dim
341
  if len(images.shape) == 3:
 
342
  images = images.unsqueeze(0)
343
 
344
- # channel first
345
- images = images.permute(0, 3, 1, 2)
 
 
 
 
346
 
347
  return images
348
 
349
 
 
 
 
 
350
  def parse_args():
 
351
  import argparse
 
352
  parser = argparse.ArgumentParser()
 
353
  parser.add_argument(
354
  '--weight',
355
  type=str,
356
- default="hayao:v2",
357
- help=f'Model weight, can be path or pretrained {tuple(RELEASED_WEIGHTS.keys())}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  )
359
- parser.add_argument('--src', type=str, help='Source, can be directory contains images, image file or video file.')
360
- parser.add_argument('--device', type=str, default='cuda', help='Device, cuda or cpu')
361
- parser.add_argument('--imgsz', type=int, default=None, help='Resize image to specified size if provided')
362
- parser.add_argument('--out', type=str, default='inference_images', help='Output, can be directory or file')
 
 
 
363
  parser.add_argument(
364
- '--retain-color',
365
- action='store_true',
366
- help='If provided the generated image will retain original color of input image')
367
- # Video params
368
- parser.add_argument('--batch-size', type=int, default=4, help='Batch size when inference video')
369
- parser.add_argument('--start', type=int, default=0, help='Start time of video (second)')
370
- parser.add_argument('--end', type=int, default=0, help='End time of video (second), 0 if not set')
371
 
372
  return parser.parse_args()
373
 
 
 
 
 
 
374
  if __name__ == '__main__':
 
375
  args = parse_args()
376
 
377
  predictor = Predictor(
378
  args.weight,
379
  args.device,
380
- retain_color=args.retain_color,
381
  imgsz=args.imgsz,
382
  )
383
 
384
  if not os.path.exists(args.src):
385
- raise FileNotFoundError(args.src)
 
 
 
386
 
387
  if is_video_file(args.src):
 
388
  predictor.transform_video(
389
  args.src,
390
  args.out,
@@ -392,18 +700,47 @@ if __name__ == '__main__':
392
  start=args.start,
393
  end=args.end
394
  )
 
395
  elif os.path.isdir(args.src):
396
- predictor.transform_in_dir(args.src, args.out)
 
 
 
 
 
397
  elif os.path.isfile(args.src):
 
398
  save_path = args.out
 
399
  if not is_image_file(args.out):
400
- os.makedirs(args.out, exist_ok=True)
401
- save_path = os.path.join(args.out, os.path.basename(args.src))
 
 
 
 
 
 
 
 
402
 
403
  if args.src.endswith('.gif'):
404
- # GIF file
405
- predictor.transform_gif(args.src, save_path, args.batch_size)
 
 
 
 
 
406
  else:
407
- predictor.transform_file(args.src, save_path)
 
 
 
 
 
408
  else:
409
- raise NotImplementedError(f"{args.src} is not supported")
 
 
 
 
9
  from models.anime_gan import GeneratorV1
10
  from models.anime_gan_v2 import GeneratorV2
11
  from models.anime_gan_v3 import GeneratorV3
12
+
13
+ from utils.common import (
14
+ load_checkpoint,
15
+ RELEASED_WEIGHTS
16
+ )
17
+
18
+ from utils.image_processing import (
19
+ resize_image,
20
+ normalize_input,
21
+ denormalize_input
22
+ )
23
+
24
+ from utils import (
25
+ read_image,
26
+ is_image_file,
27
+ is_video_file
28
+ )
29
+
30
  from tqdm import tqdm
 
31
 
32
  try:
33
  import matplotlib.pyplot as plt
 
42
  VideoFileClip = None
43
 
44
 
45
+ # =========================================================
46
+ # PROFILE
47
+ # =========================================================
48
+
49
  def profile(func):
50
+
51
  def wrap(*args, **kwargs):
52
+
53
  started_at = time.time()
54
+
55
  result = func(*args, **kwargs)
56
+
57
  elapsed = time.time() - started_at
58
+
59
  print(f"Processed in {elapsed:.3f}s")
60
+
61
  return result
62
+
63
  return wrap
64
 
65
 
66
+ # =========================================================
67
+ # AUTO LOAD WEIGHT
68
+ # =========================================================
69
+
70
+ def auto_load_weight(
71
+ weight,
72
+ version=None,
73
+ map_location=None
74
+ ):
75
+
76
  weight_name = os.path.basename(weight).lower()
77
+
78
  if version is not None:
79
+
80
  version = version.lower()
81
+
82
+ assert version in {
83
+ "v1",
84
+ "v2",
85
+ "v3"
86
+ }
87
+
88
  cls = {
89
  "v1": GeneratorV1,
90
  "v2": GeneratorV2,
91
  "v3": GeneratorV3
92
  }[version]
93
+
94
  else:
95
+
 
 
96
  if weight_name in RELEASED_WEIGHTS:
 
 
97
 
98
+ version = RELEASED_WEIGHTS[
99
+ weight_name
100
+ ][0]
101
+
102
+ return auto_load_weight(
103
+ weight,
104
+ version=version,
105
+ map_location=map_location
106
+ )
107
+
108
+ elif weight_name.startswith(
109
+ "generatorv2"
110
+ ):
111
+
112
  cls = GeneratorV2
113
+
114
+ elif weight_name.startswith(
115
+ "generatorv3"
116
+ ):
117
+
118
  cls = GeneratorV3
119
+
120
+ elif weight_name.startswith(
121
+ "generator"
122
+ ):
123
+
124
  cls = GeneratorV1
125
+
126
  else:
127
+
128
+ raise ValueError(
129
+ f"Cannot detect model version from {weight_name}"
130
+ )
131
+
132
  model = cls()
133
+
134
+ load_checkpoint(
135
+ model,
136
+ weight,
137
+ strip_optimizer=True,
138
+ map_location=map_location
139
+ )
140
+
141
  model.eval()
142
+
143
  return model
144
 
145
 
146
+ # =========================================================
147
+ # PREDICTOR
148
+ # =========================================================
149
+
150
  class Predictor:
151
+
 
 
152
  def __init__(
153
  self,
154
+ weight='hayao:v2',
155
  device='cuda',
156
  amp=True,
157
  retain_color=False,
158
  imgsz=None,
159
  ):
160
+
161
  if not torch.cuda.is_available():
162
+
163
  device = 'cpu'
 
164
  amp = False
165
+
166
+ print("Using CPU")
167
+
168
  else:
169
+
170
+ print(
171
+ f"Using GPU: "
172
+ f"{torch.cuda.get_device_name()}"
173
+ )
174
+
175
  self.imgsz = imgsz
176
+
177
  self.retain_color = retain_color
178
+
179
+ self.amp = amp
180
+
181
+ self.device_type = (
182
+ 'cuda'
183
+ if device.startswith('cuda')
184
+ else 'cpu'
185
+ )
186
+
187
  self.device = torch.device(device)
188
+
189
+ self.G = auto_load_weight(
190
+ weight,
191
+ map_location=device
192
+ )
193
+
194
  self.G.to(self.device)
195
 
196
+ # =====================================================
197
+ # SHOW IMAGE
198
+ # =====================================================
199
+
200
  def transform_and_show(
201
  self,
202
  image_path,
203
  figsize=(18, 10),
204
  save_path=None
205
  ):
206
+
207
+ image = resize_image(
208
+ read_image(image_path)
209
+ )
210
+
211
  anime_img = self.transform(image)
212
+
213
  anime_img = anime_img.astype('uint8')
214
 
215
  fig = plt.figure(figsize=figsize)
216
+
217
  fig.add_subplot(1, 2, 1)
218
+
219
  plt.imshow(image)
220
+
221
  plt.axis('off')
222
+
223
  fig.add_subplot(1, 2, 2)
224
+
225
  plt.imshow(anime_img[0])
226
+
227
  plt.axis('off')
228
+
229
  plt.tight_layout()
230
+
231
  plt.show()
232
+
233
  if save_path is not None:
234
+
235
  plt.savefig(save_path)
236
 
237
+ # =====================================================
238
+ # MAIN TRANSFORM
239
+ # =====================================================
240
 
241
+ def transform(
242
+ self,
243
+ image,
244
+ denorm=True
245
+ ):
246
 
 
 
 
247
  with torch.no_grad():
248
+
249
  image = self.preprocess_images(image)
250
+
 
 
251
  fake = self.G(image)
252
+
 
 
 
253
  fake = fake.detach().cpu().numpy()
254
+
255
+ fake = fake.transpose(
256
+ 0,
257
+ 2,
258
+ 3,
259
+ 1
260
+ )
261
 
262
  if denorm:
263
+
264
+ fake = denormalize_input(
265
+ fake,
266
+ dtype=np.uint8
267
+ )
268
+
269
  return fake
270
 
271
+ # =====================================================
272
+ # READ RESIZE
273
+ # =====================================================
274
+
275
+ def read_and_resize(
276
+ self,
277
+ path,
278
+ max_size=1536
279
+ ):
280
+
281
  image = read_image(path)
282
+
283
  _, ext = os.path.splitext(path)
284
+
285
  h, w = image.shape[:2]
286
+
287
  if self.imgsz is not None:
288
+
289
+ image = resize_image(
290
+ image,
291
+ width=self.imgsz
292
+ )
293
+
294
  elif max(h, w) > max_size:
295
+
296
+ print(
297
+ f"Image too big "
298
+ f"({h}x{w})"
299
+ )
300
+
301
  image = resize_image(
302
  image,
303
  width=max_size if w > h else None,
304
  height=max_size if w < h else None,
305
  )
306
+
307
+ cv2.imwrite(
308
+ path.replace(ext, ".jpg"),
309
+ image[:, :, ::-1]
310
+ )
311
+
312
  else:
313
+
314
  image = resize_image(image)
315
+
 
 
316
  return image
317
 
318
+ # =====================================================
319
+ # TRANSFORM FILE
320
+ # =====================================================
321
+
322
  @profile
323
+ def transform_file(
324
+ self,
325
+ file_path,
326
+ save_path
327
+ ):
328
+
329
  if not is_image_file(save_path):
 
330
 
331
+ raise ValueError(
332
+ f"{save_path} is not valid"
333
+ )
334
+
335
+ image = self.read_and_resize(
336
+ file_path
337
+ )
338
+
339
  anime_img = self.transform(image)[0]
340
+
341
+ cv2.imwrite(
342
+ save_path,
343
+ anime_img[..., ::-1]
344
+ )
345
+
346
+ print(
347
+ f"Anime image saved to {save_path}"
348
+ )
349
+
350
  return anime_img
351
 
352
+ # =====================================================
353
+ # GIF
354
+ # =====================================================
355
+
356
  @profile
357
+ def transform_gif(
358
+ self,
359
+ file_path,
360
+ save_path,
361
+ batch_size=4
362
+ ):
363
+
364
  import imageio
365
 
366
  def _preprocess_gif(img):
367
+
368
  if img.shape[-1] == 4:
369
+
370
+ img = cv2.cvtColor(
371
+ img,
372
+ cv2.COLOR_RGBA2RGB
373
+ )
374
+
375
  return resize_image(img)
376
 
377
  images = imageio.mimread(file_path)
378
+
379
  images = np.stack([
380
  _preprocess_gif(img)
381
  for img in images
382
  ])
383
 
 
 
384
  anime_gif = np.zeros_like(images)
385
 
386
+ for i in tqdm(
387
+ range(0, len(images), batch_size)
388
+ ):
389
+
390
  end = i + batch_size
 
 
 
391
 
392
+ anime_gif[i:end] = self.transform(
393
+ images[i:end]
394
+ )
 
395
 
 
396
  imageio.mimsave(
397
  save_path,
398
  anime_gif,
 
399
  )
400
+
401
+ print(
402
+ f"Anime GIF saved to {save_path}"
403
+ )
404
+
405
+ # =====================================================
406
+ # DIRECTORY
407
+ # =====================================================
408
 
409
  @profile
410
+ def transform_in_dir(
411
+ self,
412
+ img_dir,
413
+ dest_dir,
414
+ max_images=0
415
+ ):
416
 
417
+ os.makedirs(
418
+ dest_dir,
419
+ exist_ok=True
420
+ )
421
 
422
  files = os.listdir(img_dir)
423
+
424
+ files = [
425
+ f for f in files
426
+ if is_image_file(f)
427
+ ]
428
 
429
  if max_images:
430
+
431
  files = files[:max_images]
432
 
433
  bar = tqdm(files)
434
+
435
  for fname in bar:
436
+
437
+ path = os.path.join(
438
+ img_dir,
439
+ fname
440
+ )
441
+
442
  image = self.read_and_resize(path)
443
+
444
  anime_img = self.transform(image)[0]
445
+
446
  ext = fname.split('.')[-1]
447
+
448
+ fname = fname.replace(
449
+ f'.{ext}',
450
+ ''
451
+ )
452
+
453
+ cv2.imwrite(
454
+ os.path.join(
455
+ dest_dir,
456
+ f'{fname}.jpg'
457
+ ),
458
+ anime_img[..., ::-1]
459
+ )
460
+
461
+ # =====================================================
462
+ # VIDEO
463
+ # =====================================================
464
+
465
+ def transform_video(
466
+ self,
467
+ input_path,
468
+ output_path,
469
+ batch_size=4,
470
+ start=0,
471
+ end=0
472
+ ):
473
+
474
  if VideoFileClip is None:
475
+
476
+ raise ImportError(
477
+ "moviepy not installed"
478
+ )
479
+
480
  end = end or None
481
 
482
  if not os.path.isfile(input_path):
 
483
 
484
+ raise FileNotFoundError(
485
+ input_path
486
+ )
487
+
488
+ output_dir = os.path.dirname(
489
+ output_path
490
+ )
491
 
492
+ if output_dir:
 
493
 
494
+ os.makedirs(
495
+ output_dir,
496
+ exist_ok=True
497
+ )
498
 
499
+ video_clip = VideoFileClip(
500
+ input_path,
501
+ audio=False
502
+ )
 
503
 
 
504
  if start or end:
 
505
 
506
+ video_clip = video_clip.subclip(
507
+ start,
508
+ end
509
+ )
510
+
511
+ video_writer = (
512
+ ffmpeg_writer
513
+ .FFMPEG_VideoWriter(
514
+ output_path,
515
+ video_clip.size,
516
+ video_clip.fps,
517
+ codec="libx264",
518
+ )
519
+ )
520
 
521
+ total_frames = round(
522
+ video_clip.fps *
523
+ video_clip.duration
524
+ )
525
+
526
+ batch_shape = (
527
+ batch_size,
528
+ video_clip.size[1],
529
+ video_clip.size[0],
530
+ 3
531
+ )
532
 
 
533
  frame_count = 0
534
+
535
+ frames = np.zeros(
536
+ batch_shape,
537
+ dtype=np.float32
538
+ )
539
+
540
+ for frame in tqdm(
541
+ video_clip.iter_frames(),
542
+ total=total_frames
543
+ ):
544
+
545
+ frames[frame_count] = frame
546
+
547
+ frame_count += 1
548
+
549
+ if frame_count == batch_size:
550
+
551
+ anime_images = self.transform(
552
+ frames
553
+ )
554
+
555
+ for i in range(frame_count):
556
+
557
+ video_writer.write_frame(
558
+ anime_images[i]
559
+ )
560
+
561
+ frame_count = 0
562
+
563
  if frame_count != 0:
 
564
 
565
+ anime_images = self.transform(
566
+ frames[:frame_count]
567
+ )
568
+
569
+ for i in range(frame_count):
570
+
571
+ video_writer.write_frame(
572
+ anime_images[i]
573
+ )
574
+
575
+ print(
576
+ f"Anime video saved to {output_path}"
577
+ )
578
 
 
579
  video_writer.close()
580
 
581
+ # =====================================================
582
+ # PREPROCESS
583
+ # =====================================================
584
 
585
+ def preprocess_images(
586
+ self,
587
+ images
588
+ ):
589
 
590
+ images = images.astype(
591
+ np.float32
592
+ )
 
593
 
 
594
  images = normalize_input(images)
595
+
596
  images = torch.from_numpy(images)
597
 
598
  images = images.to(self.device)
599
 
 
600
  if len(images.shape) == 3:
601
+
602
  images = images.unsqueeze(0)
603
 
604
+ images = images.permute(
605
+ 0,
606
+ 3,
607
+ 1,
608
+ 2
609
+ )
610
 
611
  return images
612
 
613
 
614
+ # =========================================================
615
+ # ARGUMENTS
616
+ # =========================================================
617
+
618
  def parse_args():
619
+
620
  import argparse
621
+
622
  parser = argparse.ArgumentParser()
623
+
624
  parser.add_argument(
625
  '--weight',
626
  type=str,
627
+ default="hayao:v2"
628
+ )
629
+
630
+ parser.add_argument(
631
+ '--src',
632
+ type=str
633
+ )
634
+
635
+ parser.add_argument(
636
+ '--device',
637
+ type=str,
638
+ default='cuda'
639
+ )
640
+
641
+ parser.add_argument(
642
+ '--imgsz',
643
+ type=int,
644
+ default=None
645
+ )
646
+
647
+ parser.add_argument(
648
+ '--out',
649
+ type=str,
650
+ default='inference_images'
651
+ )
652
+
653
+ parser.add_argument(
654
+ '--batch-size',
655
+ type=int,
656
+ default=4
657
  )
658
+
659
+ parser.add_argument(
660
+ '--start',
661
+ type=int,
662
+ default=0
663
+ )
664
+
665
  parser.add_argument(
666
+ '--end',
667
+ type=int,
668
+ default=0
669
+ )
 
 
 
670
 
671
  return parser.parse_args()
672
 
673
+
674
+ # =========================================================
675
+ # MAIN
676
+ # =========================================================
677
+
678
  if __name__ == '__main__':
679
+
680
  args = parse_args()
681
 
682
  predictor = Predictor(
683
  args.weight,
684
  args.device,
 
685
  imgsz=args.imgsz,
686
  )
687
 
688
  if not os.path.exists(args.src):
689
+
690
+ raise FileNotFoundError(
691
+ args.src
692
+ )
693
 
694
  if is_video_file(args.src):
695
+
696
  predictor.transform_video(
697
  args.src,
698
  args.out,
 
700
  start=args.start,
701
  end=args.end
702
  )
703
+
704
  elif os.path.isdir(args.src):
705
+
706
+ predictor.transform_in_dir(
707
+ args.src,
708
+ args.out
709
+ )
710
+
711
  elif os.path.isfile(args.src):
712
+
713
  save_path = args.out
714
+
715
  if not is_image_file(args.out):
716
+
717
+ os.makedirs(
718
+ args.out,
719
+ exist_ok=True
720
+ )
721
+
722
+ save_path = os.path.join(
723
+ args.out,
724
+ os.path.basename(args.src)
725
+ )
726
 
727
  if args.src.endswith('.gif'):
728
+
729
+ predictor.transform_gif(
730
+ args.src,
731
+ save_path,
732
+ args.batch_size
733
+ )
734
+
735
  else:
736
+
737
+ predictor.transform_file(
738
+ args.src,
739
+ save_path
740
+ )
741
+
742
  else:
743
+
744
+ raise NotImplementedError(
745
+ f"{args.src} not supported"
746
+ )