| |
| |
|
|
| """Image generation inference script. |
| |
| Supports: |
| - Text-conditional models: SD15, SDXL, Flux |
| - Class-conditional models: EDM, SiT, DiT (ImageNet) |
| - Unconditional generation |
| |
| Examples: |
| |
| # Text-conditional: eval teacher only (SDXL) |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \ |
| scripts/inference/image_model_inference.py --do_student_sampling False \ |
| --config fastgen/configs/experiments/SDXL/config_sft.py \ |
| - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=sdxl_inference |
| |
| # Class-conditional: eval teacher (SiT on ImageNet) |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \ |
| scripts/inference/image_model_inference.py --do_student_sampling False \ |
| --prompt_file scripts/inference/prompts/classes.txt --classes 1000 \ |
| --config fastgen/configs/experiments/SiT/config_sft.py \ |
| - trainer.seed=1 trainer.ddp=True log_config.name=sit_inference |
| |
| # Unconditional generation (EDM CIFAR-10) |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \ |
| scripts/inference/image_model_inference.py --do_student_sampling False \ |
| --unconditional --num_samples 16 \ |
| --config fastgen/configs/experiments/EDM/config_sft_edm_cifar10.py \ |
| - trainer.seed=1 trainer.ddp=True log_config.name=edm_cifar10_inference |
| |
| # Eval both student and teacher |
| PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \ |
| scripts/inference/image_model_inference.py --ckpt_path /path/to/checkpoints/0003000.pth \ |
| --do_student_sampling True --do_teacher_sampling True \ |
| --config fastgen/configs/experiments/SD15/config_dmd2.py \ |
| - trainer.seed=1 trainer.ddp=True log_config.name=sd15_inference |
| """ |
|
|
| import argparse |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
| from fastgen.configs.config import BaseConfig |
| import fastgen.utils.logging_utils as logger |
| from fastgen.utils import basic_utils |
| from fastgen.utils.distributed import clean_up |
| from fastgen.utils.scripts import parse_args, setup |
| from scripts.inference.inference_utils import ( |
| load_prompts, |
| init_model, |
| init_checkpointer, |
| load_checkpoint, |
| cleanup_unused_modules, |
| setup_inference_modules, |
| add_common_args, |
| ) |
|
|
|
|
| def _prepare_condition(args, prompt, model, ctx): |
| """Prepare conditioning based on generation mode. |
| |
| Args: |
| args: Command line arguments |
| prompt: Text prompt or class label (None for unconditional) |
| model: The model instance |
| ctx: Device/dtype context |
| |
| Returns: |
| Encoded condition tensor or None |
| """ |
| if args.unconditional: |
| |
| if args.classes is not None: |
| return torch.zeros(1, args.classes, **ctx) |
| return None |
|
|
| if args.classes is not None: |
| |
| assert prompt.isdigit(), f"Each prompt must be an integer class label, got: {prompt}" |
| condition = torch.zeros(1, args.classes, **ctx) |
| condition[0, int(prompt)] = 1 |
| return condition |
|
|
| |
| condition = [prompt] |
| if hasattr(model.net, "text_encoder"): |
| with basic_utils.inference_mode( |
| model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type |
| ): |
| condition = basic_utils.to(model.net.text_encoder.encode(condition), **ctx) |
| return condition |
|
|
|
|
| def main(args, config: BaseConfig): |
| |
| if args.unconditional: |
| pos_prompt_set = [None] * args.num_samples |
| prompt_name = "unconditional" |
| else: |
| pos_prompt_set = load_prompts(args.prompt_file, relative_to="cwd") |
| prompt_name = Path(args.prompt_file).stem |
|
|
| |
| seed = basic_utils.set_random_seed(config.trainer.seed, by_rank=True) |
|
|
| |
| model = init_model(config) |
| checkpointer = init_checkpointer(config) |
|
|
| |
| ckpt_iter, save_dir = load_checkpoint(checkpointer, model, args.ckpt_path, config) |
|
|
| if ckpt_iter is None and args.do_student_sampling: |
| logger.warning(f"Performing {model.config.student_sample_steps}-step generation on the non-distilled model") |
|
|
| |
| if args.image_save_dir: |
| save_dir = args.image_save_dir |
| logger.info(f"image_save_dir: {save_dir}") |
| save_dir = Path(save_dir) / prompt_name |
|
|
| |
| cleanup_unused_modules(model, args.do_teacher_sampling) |
|
|
| |
| teacher, student, vae = setup_inference_modules( |
| model, config, args.do_teacher_sampling, args.do_student_sampling, model.precision |
| ) |
| ctx = {"dtype": model.precision, "device": model.device} |
|
|
| |
| has_teacher_sampling = teacher is not None and hasattr(teacher, "sample") |
| has_student_sampling = student is not None and hasattr(model, "generator_fn") |
| assert ( |
| has_teacher_sampling or has_student_sampling |
| ), "At least one of teacher or student (with generator_fn) must be provided for sampling" |
|
|
| |
| neg_condition = None |
| if args.classes is not None: |
| |
| neg_condition = torch.zeros(1, args.classes, **ctx) |
| elif args.neg_prompt_file is not None: |
| neg_prompts = load_prompts(args.neg_prompt_file, relative_to="cwd") |
| if len(neg_prompts) > 1: |
| logger.warning(f"Found {len(neg_prompts)} negative prompts, only using the first one.") |
| neg_condition = neg_prompts[:1] |
| logger.debug(f"Loaded negative prompt: {neg_condition[0]}") |
| if hasattr(model.net, "text_encoder"): |
| with basic_utils.inference_mode( |
| model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type |
| ): |
| neg_condition = basic_utils.to(model.net.text_encoder.encode(neg_condition), **ctx) |
|
|
| |
| slg_tag = "" |
| if config.model.skip_layers is not None: |
| slg_tag = f"_slg{'_'.join([str(x) for x in config.model.skip_layers])}" |
|
|
| |
| noise = torch.randn([1, *config.model.input_shape], **ctx) |
|
|
| |
| for i, prompt in enumerate(pos_prompt_set): |
| |
| if args.unconditional: |
| logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating unconditional sample...") |
| |
| noise = torch.randn([1, *config.model.input_shape], **ctx) |
| else: |
| logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating: {prompt[:80]}...") |
|
|
| |
| condition = _prepare_condition(args, prompt, model, ctx) |
|
|
| |
| if has_student_sampling: |
| start_time = time.time() |
| image_student = model.generator_fn( |
| student, |
| noise, |
| condition=condition, |
| student_sample_steps=model.config.student_sample_steps, |
| student_sample_type=model.config.student_sample_type, |
| t_list=model.config.sample_t_cfg.t_list, |
| precision_amp=model.precision_amp_infer, |
| ) |
| logger.info(f"Student sampling time: {time.time() - start_time:.2f}s") |
|
|
| save_path = save_dir / f"student_step{model.config.student_sample_steps}_{i:04d}_seed{seed}.png" |
| basic_utils.save_media(image_student, str(save_path), vae=vae, precision_amp=model.precision_amp_infer) |
|
|
| |
| if has_teacher_sampling: |
| start_time = time.time() |
| teacher_kwargs = { |
| "num_steps": args.num_steps, |
| "second_order": False, |
| "precision_amp": model.precision_amp_infer, |
| } |
| if config.model.skip_layers is not None: |
| teacher_kwargs["skip_layers"] = config.model.skip_layers |
|
|
| image_teacher = model.sample( |
| teacher, noise, condition=condition, neg_condition=neg_condition, **teacher_kwargs |
| ) |
| logger.info(f"Teacher sampling time: {time.time() - start_time:.2f}s") |
|
|
| save_path = ( |
| save_dir |
| / f"teacher_cfg{config.model.guidance_scale}_steps{args.num_steps}{slg_tag}_{i:04d}_seed{seed}.png" |
| ) |
| basic_utils.save_media(image_teacher, str(save_path), vae=vae, precision_amp=model.precision_amp_infer) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Image generation inference for text-conditional, class-conditional, and unconditional models", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| ) |
|
|
| |
| add_common_args(parser) |
|
|
| |
| parser.add_argument( |
| "--prompt_file", |
| default="scripts/inference/prompts/image_prompts.txt", |
| type=str, |
| help="File containing prompts (one per line). For class-conditional models, use integer class labels.", |
| ) |
| parser.add_argument( |
| "--neg_prompt_file", |
| default=None, |
| type=str, |
| help="File containing negative prompt for CFG (only first line used).", |
| ) |
| parser.add_argument( |
| "--classes", |
| default=None, |
| type=int, |
| help="Number of classes for class-conditional generation (e.g., 1000 for ImageNet). " |
| "Prompts should be integer class labels.", |
| ) |
| parser.add_argument( |
| "--unconditional", |
| action="store_true", |
| help="Generate unconditional samples (no class or text conditioning).", |
| ) |
| parser.add_argument( |
| "--num_samples", |
| default=10, |
| type=int, |
| help="Number of samples for unconditional generation (default: 10).", |
| ) |
|
|
| |
| parser.add_argument( |
| "--image_save_dir", |
| default=None, |
| type=str, |
| help="Directory to save generated images (overrides default).", |
| ) |
|
|
| |
| parser.add_argument( |
| "--num_steps", |
| default=50, |
| type=int, |
| help="Number of sampling steps for teacher (default: 50).", |
| ) |
|
|
| args = parse_args(parser) |
| config = setup(args, evaluation=True) |
| main(args, config) |
|
|
| clean_up() |
|
|