# Modified from: # DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py """Example run: python generate.py \ --gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \ --gpt-model GPT-XXL --num-output-layer 4 \ --num-codebooks 8 --codebook-size 32768 \ --image-size 256 --cfg-scale 2.0 --top-k 0 --temperature 1.0 \ --class-dropout-prob 0.1 \ --extra-layers "QKV" \ --latent-size 65 \ --config ./configs/vibetoken_ll.yaml \ --vq-ckpt ./checkpoints/VibeToken_LL.bin \ --sample-dir ./assets/ \ --skip-folder-creation \ --compile \ --decoder-patch-size 16,16 \ --target-resolution 1024,1024 \ --llamagen-target-resolution 256,256 \ --precision bf16 """ import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision('high') setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) import torch.nn.functional as F import torch.distributed as dist from tqdm import tqdm import os from PIL import Image import numpy as np import math import argparse import sys from omegaconf import OmegaConf from vibetokengen.model import GPT_models from vibetokengen.generate import generate from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple def create_npz_from_sample_folder(sample_dir, num=50_000): """ Builds a single .npz file from a folder of .png samples. """ samples = [] for i in tqdm(range(num), desc="Building .npz file from samples"): sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) samples = np.stack(samples) assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) npz_path = f"{sample_dir}.npz" np.savez(npz_path, arr_0=samples) print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") return npz_path def main(args): # Setup PyTorch: assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" torch.set_grad_enabled(False) # Set global seed for reproducibility torch.manual_seed(args.global_seed) np.random.seed(args.global_seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.global_seed) torch.cuda.manual_seed_all(args.global_seed) device = "cuda" if torch.cuda.is_available() else "cpu" precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] # Load VibeToken model vq_model = VibeTokenTokenizer.from_config( args.config, args.vq_ckpt, device=device, dtype=precision, ) print(f"VibeToken image tokenizer is loaded") # create and load gpt model gpt_model = GPT_models[args.gpt_model]( vocab_size=args.codebook_size, block_size=args.latent_size, num_classes=args.num_classes, cls_token_num=args.cls_token_num, model_type=args.gpt_type, num_codebooks=args.num_codebooks, n_output_layer=args.num_output_layer, class_dropout_prob=args.class_dropout_prob, extra_layers=args.extra_layers, capping=args.capping, ).to(device=device, dtype=precision) print(f"GPT model is loaded") checkpoint = torch.load(args.gpt_ckpt, map_location="cpu", weights_only=False) if args.from_fsdp: # fsdp model_weight = checkpoint elif "model" in checkpoint: # ddp model_weight = checkpoint["model"] elif "module" in checkpoint: # deepspeed model_weight = checkpoint["module"] elif "state_dict" in checkpoint: model_weight = checkpoint["state_dict"] else: raise Exception("please check model weight, maybe add --from-fsdp to run command") gpt_model.load_state_dict(model_weight, strict=True) gpt_model.eval() del checkpoint print(f"GPT model weights are loaded") if args.compile: print(f"compiling the model...") gpt_model = torch.compile( gpt_model, mode="reduce-overhead", fullgraph=True ) # requires PyTorch 2.0 (optional) else: print(f"no model compile") print(f"GPT model is compiled") # Create folder to save samples: model_string_name = args.gpt_model.replace("/", "-") if args.from_fsdp: ckpt_string_name = args.gpt_ckpt.split('/')[-2] else: ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "") folder_name = f"{model_string_name}-{ckpt_string_name}-target-resolution-{args.target_resolution}-llamagen-target-resolution-{args.llamagen_target_resolution}-vibetoken-" \ f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \ f"cfg-{args.cfg_scale}-seed-{args.global_seed}" if args.skip_folder_creation: sample_folder_dir = args.sample_dir else: sample_folder_dir = f"{args.sample_dir}/{folder_name}" os.makedirs(sample_folder_dir, exist_ok=True) print(f"Saving .png samples at {sample_folder_dir}") multiplier = 2 if args.cfg_scale > 1.0 else 1 # Use fixed class labels class_labels = [207, 360, 387, 974, 88, 979, 417, 279] c_indices = torch.tensor(class_labels, device=device) n = len(class_labels) nrow = 4 # 2 rows x 4 columns for 8 images index_sample = generate( gpt_model, c_indices, args.latent_size, args.num_codebooks, cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval, target_h=torch.tensor(args.llamagen_target_resolution[0]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1), target_w=torch.tensor(args.llamagen_target_resolution[1]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1), temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, sample_logits=True, ) # Use VibeToken decode_tokens method # VibeToken expects tokens in shape (batch_size, seq_len, 1) index_sample = index_sample.unsqueeze(2) samples = vq_model.decode( index_sample, height=args.target_resolution[0], width=args.target_resolution[1], patch_size=args.decoder_patch_size ) # VibeToken output is in [0, 1] range, clamp and convert to uint8 samples = torch.clamp(samples, 0, 1) # Create a grid of images (2 rows x 4 columns) from torchvision.utils import make_grid grid = make_grid(samples, nrow=nrow, padding=2, normalize=False) # Convert to PIL and save grid_np = (grid.permute(1, 2, 0).to(torch.float32).cpu().numpy() * 255).astype('uint8') Image.fromarray(grid_np).save(f"{sample_folder_dir}/generated_images.png") print(f"Saved grid of {n} images to {sample_folder_dir}/generated_images.png") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B") parser.add_argument("--gpt-ckpt", type=str, default=None) parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional") parser.add_argument("--from-fsdp", action='store_true') parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) parser.add_argument("--compile", action='store_true', default=True) # parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") parser.add_argument("--config", type=str, required=True, help="Path to VibeToken config file") parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384) parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256) parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) parser.add_argument("--num-classes", type=int, default=1000) parser.add_argument("--cfg-scale", type=float, default=4.0) parser.add_argument("--cfg-interval", type=float, default=-1) parser.add_argument("--sample-dir", type=str, default="samples") parser.add_argument("--per-proc-batch-size", type=int, default=32) parser.add_argument("--num-fid-samples", type=int, default=50000) parser.add_argument("--global-seed", type=int, default=0) # not used parser.add_argument("--top-k", type=int, default=500, help="top-k value to sample with") parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") parser.add_argument("--num-codebooks", type=int, default=1) parser.add_argument("--num-output-layer", type=int, default=1) parser.add_argument("--class-dropout-prob", type=float, default=0.1) parser.add_argument("--extra-layers", type=str, choices=['QK', 'QKV', 'FC', 'cap', 'clip', 'QK_cap', 'QKV_cap', 'QK_clip', 'QKV_clip', 'QK_FC_cap', 'QKV_FC_cap', 'QK_FC_clip', 'QKV_FC_clip'], default=None, help="Type of extra layers to add: QK (query-key), QKV (query-key-value), FC (fully connected), cap (caption), clip (clip), QK_cap (query-key-caption), QKV_cap (query-key-value-caption), QK_clip (query-key-clip), QKV_clip (query-key-value-clip), QK_FC_cap (query-key-fully-connected-caption), QKV_FC_cap (query-key-value-fully-connected-caption), QK_FC_clip (query-key-fully-connected-clip), QKV_FC_clip (query-key-value-fully-connected-clip)") parser.add_argument("--capping", type=float, default=50.0, help="Capping for attention softmax") # VibeToken dynamic parser.add_argument("--decoder-patch-size", type=str, default="8,8", help="Decoder patch size as 'width,height'") parser.add_argument("--target-resolution", type=str, default="256,256", help="Target resolution as 'width,height'") parser.add_argument("--llamagen-target-resolution", type=str, default="256,256", help="LlamaGen target resolution as 'width,height'") parser.add_argument("--latent-size", type=int, default=16, help="Latent size") parser.add_argument("--skip-folder-creation", action='store_true', default=False, help="skip folder creation") args = parser.parse_args() args.decoder_patch_size = tuple(map(int, args.decoder_patch_size.split(","))) args.target_resolution = tuple(map(int, args.target_resolution.split(","))) args.llamagen_target_resolution = tuple(map(int, args.llamagen_target_resolution.split(","))) main(args)