| import torch |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from torchvision.datasets import ImageFolder |
| import argparse |
| import os, yaml |
| from safetensors.torch import save_file |
| from datetime import datetime |
| from datasets.img_latent_dataset import ImgLatentDataset |
| from tokenizer import models_mae |
| from tokenizer.sdvae import Diffusers_AutoencoderKL |
|
|
| def load_config(config_path): |
| with open(config_path, "r") as file: |
| config = yaml.safe_load(file) |
| return config |
|
|
| def main(args, train_config): |
| """ |
| Run a tokenizer on full dataset and save the features. |
| """ |
| assert torch.cuda.is_available(), "Extract features currently requires at least one GPU." |
|
|
| |
| try: |
| dist.init_process_group("nccl") |
| rank = dist.get_rank() |
| device = rank % torch.cuda.device_count() |
| world_size = dist.get_world_size() |
| seed = args.seed + rank |
| if rank == 0: |
| print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.") |
| except: |
| print("Failed to initialize DDP. Running in local mode.") |
| rank = 0 |
| device = 0 |
| world_size = 1 |
| seed = args.seed |
| torch.manual_seed(seed) |
| torch.cuda.set_device(device) |
| model_name = train_config['vae']['model_name'].split("_")[0] |
| output_path = os.path.dirname(train_config['data']['origin_path']) |
| dataset_name = train_config['data']['name'] |
| |
| |
| output_dir = os.path.join(output_path, f'{model_name}_feature_{dataset_name}_{args.data_split}_{args.image_size}') |
| if 'sample' in train_config['data']: |
| output_dir += '_sample' |
| if rank == 0: |
| os.makedirs(output_dir, exist_ok=True) |
| print(model_name) |
| |
|
|
| if model_name == 'vmae': |
| arch = 'mae_for_ldmae_f8d16_prev' |
| |
| chkpt = train_config['vae']['weight_path'] |
| tokenizer = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=args.image_size) |
| checkpoint = torch.load(chkpt, map_location='cpu') |
| tokenizer = tokenizer.to(device).eval() |
| msg = tokenizer.load_state_dict(checkpoint['model'], strict=False) |
| if rank == 0: |
| print(model_name, msg) |
| elif model_name in ['ae','dae', 'vae','sdv3']: |
| tokenizer = Diffusers_AutoencoderKL( |
| img_size=args.image_size, |
| sample_size=128, |
| in_channels=3, |
| out_channels=3, |
| layers_per_block=2, |
| latent_channels=16, |
| norm_num_groups=32, |
| act_fn="silu", |
| block_out_channels=(128, 256, 512, 512), |
| force_upcast=False, |
| use_quant_conv=False, |
| use_post_quant_conv=False, |
| down_block_types=( |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| ), |
| up_block_types=( |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| ), |
| ).to(device).eval() |
| |
| chkpt = train_config['vae']['weight_path'] |
| checkpoint = torch.load(chkpt, map_location='cpu') |
| msg = tokenizer.load_state_dict(checkpoint['model'], strict=False) |
| if rank == 0: |
| print(model_name, msg) |
| else: |
| raise("") |
| |
|
|
| print(f"{device} GPU - Model loaded") |
| |
| data_path = train_config['data']['origin_path'] |
| datasets = [ |
| ImageFolder(os.path.join(data_path, args.data_split), transform=tokenizer.img_transform(p_hflip=0.0, img_size=args.image_size)), |
| ImageFolder(os.path.join(data_path, args.data_split), transform=tokenizer.img_transform(p_hflip=1.0, img_size=args.image_size)) |
| ] |
| samplers = [ |
| DistributedSampler( |
| dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=False, |
| seed=args.seed |
| ) for dataset in datasets |
| ] |
| loaders = [ |
| DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| sampler=sampler, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| drop_last=False |
| ) for dataset, sampler in zip(datasets, samplers) |
| ] |
| total_data_in_loop = len(loaders[0].dataset) |
| if rank == 0: |
| print(f"Total data in one loop: {total_data_in_loop}") |
|
|
| run_images = 0 |
| saved_files = 0 |
| latents = [] |
| latents_flip = [] |
| labels = [] |
| for batch_idx, batch_data in enumerate(zip(*loaders)): |
| run_images += batch_data[0][0].shape[0] |
| if run_images % 100 == 0 and rank == 0: |
| print(f'{datetime.now()} processing {run_images} of {total_data_in_loop} images') |
| |
| for loader_idx, data in enumerate(batch_data): |
| x = data[0].to(device) |
| y = data[1] |
| with torch.no_grad(): |
| if 'sample' in train_config['data']: |
| z = tokenizer._encode(x) |
| else: |
| z = tokenizer.encode(x).latent_dist.mode().detach().cpu() |
|
|
| if batch_idx == 0 and rank == 0: |
| print('latent shape', z.shape, 'dtype', z.dtype) |
| |
| if loader_idx == 0: |
| latents.append(z) |
| labels.append(y) |
| else: |
| latents_flip.append(z) |
|
|
| if len(latents) == 10000 // args.batch_size: |
| latents = torch.cat(latents, dim=0) |
| latents_flip = torch.cat(latents_flip, dim=0) |
| labels = torch.cat(labels, dim=0) |
| save_dict = { |
| 'latents': latents, |
| 'latents_flip': latents_flip, |
| 'labels': labels |
| } |
| for key in save_dict: |
| if rank == 0: |
| print(key, save_dict[key].shape) |
| save_dict = {key: tensor.contiguous().cpu() for key, tensor in save_dict.items()} |
| save_filename = os.path.join(output_dir, f'latents_rank{rank:02d}_shard{saved_files:03d}.safetensors') |
| save_file( |
| save_dict, |
| save_filename, |
| metadata={'total_size': f'{latents.shape[0]}', 'dtype': f'{latents.dtype}', 'device': f'{latents.device}'} |
| ) |
| if rank == 0: |
| print(f'Saved {save_filename}') |
| |
| latents = [] |
| latents_flip = [] |
| labels = [] |
| saved_files += 1 |
|
|
| |
| if len(latents) > 0: |
| latents = torch.cat(latents, dim=0) |
| latents_flip = torch.cat(latents_flip, dim=0) |
| labels = torch.cat(labels, dim=0) |
| save_dict = { |
| 'latents': latents, |
| 'latents_flip': latents_flip, |
| 'labels': labels |
| } |
| for key in save_dict: |
| if rank == 0: |
| print(key, save_dict[key].shape) |
| |
| save_dict = {key: tensor.contiguous().cpu() for key, tensor in save_dict.items()} |
| save_filename = os.path.join(output_dir, f'latents_rank{rank:02d}_shard{saved_files:03d}.safetensors') |
| save_file( |
| save_dict, |
| save_filename, |
| metadata={'total_size': f'{latents.shape[0]}', 'dtype': f'{latents.dtype}', 'device': f'{latents.device}'} |
| ) |
| if rank == 0: |
| print(f'Saved {save_filename}') |
|
|
| |
| dist.barrier() |
| if rank == 0: |
| dataset = ImgLatentDataset(output_dir, latent_norm=True, sample=train_config['data']['sample'] if 'sample' in train_config['data'] else False,) |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--data_split", type=str, default='train') |
| parser.add_argument("--output_path", type=str, default="/data/dataset/imagenet/") |
| parser.add_argument("--image_size", type=int, default=256) |
| parser.add_argument("--batch_size", type=int, default=64) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--num_workers", type=int, default=8) |
| parser.add_argument('--config', type=str, default='configs/debug.yaml') |
| args = parser.parse_args() |
|
|
| train_config = load_config(args.config) |
| main(args, train_config) |
|
|