Spaces:
Runtime error
Runtime error
| import torch | |
| import src.models.mar.misc as misc | |
| import torch_fidelity | |
| import shutil | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import time | |
| def torch_evaluate(model, args): | |
| model.eval() | |
| num_steps = args.num_images // (args.batch_size * misc.get_world_size()) + 1 | |
| save_folder = os.path.join(args.output_dir, "ariter{}-temp{}-{}cfg{}-image{}".format( | |
| args.num_iter, args.temperature, args.cfg_schedule, args.cfg, args.num_images)) | |
| print("Save to:", save_folder) | |
| if misc.get_rank() == 0: | |
| if not os.path.exists(save_folder): | |
| os.makedirs(save_folder) | |
| class_num = args.class_num | |
| assert args.num_images % class_num == 0 # number of images per class must be the same | |
| class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num) | |
| class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)]) | |
| world_size = misc.get_world_size() | |
| local_rank = misc.get_rank() | |
| used_time = 0 | |
| gen_img_cnt = 0 | |
| for i in range(num_steps): | |
| print("Generation step {}/{}".format(i, num_steps)) | |
| labels_gen = class_label_gen_world[world_size * args.batch_size * i + local_rank * args.batch_size: | |
| world_size * args.batch_size * i + (local_rank + 1) * args.batch_size] | |
| labels_gen = torch.Tensor(labels_gen).long().cuda() | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| # generation | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(): | |
| # sampled_images = model.sample_official(bsz=args.batch_size, num_iter=args.num_iter, cfg=args.cfg, | |
| # cfg_schedule=args.cfg_schedule, labels=labels_gen, | |
| # temperature=args.temperature) | |
| import pdb; pdb.set_trace() | |
| if args.cfg != 1.0: | |
| labels_gen = torch.cat([ | |
| labels_gen, torch.full_like(labels_gen, fill_value=-1)]) | |
| sampled_images = model.sample(labels_gen, | |
| num_iter=args.num_iter, cfg=args.cfg, cfg_schedule=args.cfg_schedule, | |
| temperature=args.temperature, progress=False) | |
| # measure speed after the first generation batch | |
| if i >= 1: | |
| torch.cuda.synchronize() | |
| used_time += time.time() - start_time | |
| gen_img_cnt += args.batch_size | |
| print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt)) | |
| torch.distributed.barrier() | |
| sampled_images = sampled_images.detach().cpu() | |
| sampled_images = (sampled_images + 1) / 2 | |
| # distributed save | |
| for b_id in range(sampled_images.size(0)): | |
| img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id | |
| if img_id >= args.num_images: | |
| break | |
| gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255)) | |
| gen_img = gen_img.astype(np.uint8)[:, :, ::-1] | |
| cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img) | |
| torch.distributed.barrier() | |
| time.sleep(10) | |
| if misc.get_rank() == 0: | |
| input2 = None | |
| fid_statistics_file = 'fid_stats/adm_in256_stats.npz' | |
| metrics_dict = torch_fidelity.calculate_metrics( | |
| input1=save_folder, | |
| input2=input2, | |
| fid_statistics_file=fid_statistics_file, | |
| cuda=True, | |
| isc=True, | |
| fid=True, | |
| kid=False, | |
| prc=False, | |
| verbose=True, | |
| ) | |
| fid = metrics_dict['frechet_inception_distance'] | |
| inception_score = metrics_dict['inception_score_mean'] | |
| print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score)) | |
| # remove temporal saving folder | |
| shutil.rmtree(save_folder) | |
| torch.distributed.barrier() | |
| time.sleep(10) | |