Spaces:
Runtime error
Runtime error
| import argparse | |
| import torch | |
| import os | |
| from Project.configs import data_configs | |
| from Project.datasets.inference_dataset import InferenceDataset | |
| from torch.utils.data import DataLoader | |
| from Project.utils.model_utils import setup_model | |
| def main(args,device): | |
| net, opts ,latent_avg= setup_model(args.ckpt, device) | |
| is_cars = 'cars_' in opts.dataset_type | |
| args, data_loader = setup_data_loader(args, opts) | |
| # Check if latents exist | |
| latents_file_path = os.path.join(args.save_dir, 'latents.pt') | |
| latent_codes = get_all_latents(net, device,data_loader,latent_avg, args.n_sample, is_cars=is_cars) | |
| torch.save(latent_codes, latents_file_path) | |
| def setup_data_loader(args, opts): | |
| dataset_args = data_configs.DATASETS[opts.dataset_type] | |
| transforms_dict = dataset_args['transforms'](opts).get_transforms() | |
| images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root'] | |
| print(f"images path: {images_path}") | |
| align_function = None | |
| test_dataset = InferenceDataset(root=images_path, | |
| transform=transforms_dict['transform_test'], | |
| preprocess=align_function, | |
| opts=opts) | |
| data_loader = DataLoader(test_dataset, | |
| batch_size=args.batch, | |
| shuffle=False, | |
| num_workers=0, | |
| drop_last=True) | |
| print(f'dataset length: {len(test_dataset)}') | |
| if args.n_sample is None: | |
| args.n_sample = len(test_dataset) | |
| return args, data_loader | |
| def get_latents(net, x,latent_avg, is_cars=False): | |
| input = {net.get_inputs()[0].name: to_numpy(x)} | |
| codes = net.run(None,input) | |
| codes=torch.from_numpy(codes[0]) | |
| if codes.ndim == 2: | |
| codes = codes + latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] | |
| else: | |
| codes = codes + latent_avg.repeat(codes.shape[0], 1, 1) | |
| return codes | |
| def get_all_latents(net, device ,data_loader, latent_avg,n_images=None, is_cars=False): | |
| all_latents = [] | |
| with torch.no_grad(): | |
| for batch in data_loader: | |
| x = batch | |
| inputs = x.float() | |
| print(inputs.shape) | |
| latents = get_latents(net, inputs,latent_avg, is_cars) | |
| all_latents.append(latents) | |
| return torch.cat(all_latents) | |
| #@torch.no_grad() | |
| #def generate_inversions(args, g, latent_codes, is_cars): | |
| # print('Saving inversion images') | |
| # inversions_directory_path = os.path.join(args.save_dir, 'inversions') | |
| # os.makedirs(inversions_directory_path, exist_ok=True) | |
| # for i in range(min(args.n_sample, len(latent_codes))): | |
| # imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True) | |
| # if is_cars: | |
| # imgs = imgs[:, :, 64:448, :] | |
| # save_image(imgs[0], inversions_directory_path, i + 1) | |
| #def run_alignment(image_path): | |
| # predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor']) | |
| # aligned_image = align_face(filepath=image_path, predictor=predictor) | |
| # print("Aligned image has shape: {}".format(aligned_image.size)) | |
| # return aligned_image | |
| def to_numpy(tensor): | |
| return tensor.cpu().numpy() | |
| def inference(): | |
| device = "cpu" | |
| parser = argparse.ArgumentParser(description="Inference") | |
| parser.add_argument("--images_dir", type=str, default='static/img_aligned', | |
| help="The directory of the images to be inverted") | |
| parser.add_argument("--save_dir", type=str, default='static/latents', | |
| help="The directory to save the latent codes and inversion images. (default: images_dir") | |
| parser.add_argument("--batch", type=int, default=1, help="batch size for the generator") | |
| parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.") | |
| parser.add_argument("--latents_only", action="store_true",default=True, help="infer only the latent codes of the directory") | |
| parser.add_argument("--align", action="store_true",default=False,help="align face images before inference") | |
| parser.add_argument("--ckpt", default='Project/pretrained_models/e4e_ffhq_encode.pt',help="path to generator checkpoint") | |
| args = parser.parse_args() | |
| main(args,device) | |
| if __name__=="__main__": | |
| inference() |