| import torch |
| import src.models.mar.misc as misc |
| import torch_fidelity |
| import shutil |
| import cv2 |
| import numpy as np |
| import os |
| import time |
|
|
|
|
| def torch_evaluate(model, args): |
| model.eval() |
| num_steps = args.num_images // (args.batch_size * misc.get_world_size()) + 1 |
| save_folder = os.path.join(args.output_dir, "ariter{}-temp{}-{}cfg{}-image{}".format( |
| args.num_iter, args.temperature, args.cfg_schedule, args.cfg, args.num_images)) |
|
|
| print("Save to:", save_folder) |
| if misc.get_rank() == 0: |
| if not os.path.exists(save_folder): |
| os.makedirs(save_folder) |
|
|
| class_num = args.class_num |
| assert args.num_images % class_num == 0 |
| class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num) |
| class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)]) |
| world_size = misc.get_world_size() |
| local_rank = misc.get_rank() |
| used_time = 0 |
| gen_img_cnt = 0 |
|
|
| for i in range(num_steps): |
| print("Generation step {}/{}".format(i, num_steps)) |
|
|
| labels_gen = class_label_gen_world[world_size * args.batch_size * i + local_rank * args.batch_size: |
| world_size * args.batch_size * i + (local_rank + 1) * args.batch_size] |
| labels_gen = torch.Tensor(labels_gen).long().cuda() |
|
|
| torch.cuda.synchronize() |
| start_time = time.time() |
|
|
| |
| with torch.no_grad(): |
| with torch.cuda.amp.autocast(): |
| |
| |
| |
|
|
| import pdb; pdb.set_trace() |
| if args.cfg != 1.0: |
| labels_gen = torch.cat([ |
| labels_gen, torch.full_like(labels_gen, fill_value=-1)]) |
| sampled_images = model.sample(labels_gen, |
| num_iter=args.num_iter, cfg=args.cfg, cfg_schedule=args.cfg_schedule, |
| temperature=args.temperature, progress=False) |
|
|
| |
| if i >= 1: |
| torch.cuda.synchronize() |
| used_time += time.time() - start_time |
| gen_img_cnt += args.batch_size |
| print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt)) |
|
|
| torch.distributed.barrier() |
| sampled_images = sampled_images.detach().cpu() |
| sampled_images = (sampled_images + 1) / 2 |
|
|
| |
| for b_id in range(sampled_images.size(0)): |
| img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id |
| if img_id >= args.num_images: |
| break |
| gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255)) |
| gen_img = gen_img.astype(np.uint8)[:, :, ::-1] |
| cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img) |
|
|
| torch.distributed.barrier() |
| time.sleep(10) |
| if misc.get_rank() == 0: |
| input2 = None |
| fid_statistics_file = 'fid_stats/adm_in256_stats.npz' |
| metrics_dict = torch_fidelity.calculate_metrics( |
| input1=save_folder, |
| input2=input2, |
| fid_statistics_file=fid_statistics_file, |
| cuda=True, |
| isc=True, |
| fid=True, |
| kid=False, |
| prc=False, |
| verbose=True, |
| ) |
| fid = metrics_dict['frechet_inception_distance'] |
| inception_score = metrics_dict['inception_score_mean'] |
| print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score)) |
| |
| shutil.rmtree(save_folder) |
|
|
| torch.distributed.barrier() |
| time.sleep(10) |
|
|