File size: 11,036 Bytes
0839907 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Image generation inference script.
Supports:
- Text-conditional models: SD15, SDXL, Flux
- Class-conditional models: EDM, SiT, DiT (ImageNet)
- Unconditional generation
Examples:
# Text-conditional: eval teacher only (SDXL)
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --do_student_sampling False \
--config fastgen/configs/experiments/SDXL/config_sft.py \
- trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=sdxl_inference
# Class-conditional: eval teacher (SiT on ImageNet)
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --do_student_sampling False \
--prompt_file scripts/inference/prompts/classes.txt --classes 1000 \
--config fastgen/configs/experiments/SiT/config_sft.py \
- trainer.seed=1 trainer.ddp=True log_config.name=sit_inference
# Unconditional generation (EDM CIFAR-10)
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --do_student_sampling False \
--unconditional --num_samples 16 \
--config fastgen/configs/experiments/EDM/config_sft_edm_cifar10.py \
- trainer.seed=1 trainer.ddp=True log_config.name=edm_cifar10_inference
# Eval both student and teacher
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --ckpt_path /path/to/checkpoints/0003000.pth \
--do_student_sampling True --do_teacher_sampling True \
--config fastgen/configs/experiments/SD15/config_dmd2.py \
- trainer.seed=1 trainer.ddp=True log_config.name=sd15_inference
"""
import argparse
import time
from pathlib import Path
import torch
from fastgen.configs.config import BaseConfig
import fastgen.utils.logging_utils as logger
from fastgen.utils import basic_utils
from fastgen.utils.distributed import clean_up
from fastgen.utils.scripts import parse_args, setup
from scripts.inference.inference_utils import (
load_prompts,
init_model,
init_checkpointer,
load_checkpoint,
cleanup_unused_modules,
setup_inference_modules,
add_common_args,
)
def _prepare_condition(args, prompt, model, ctx):
"""Prepare conditioning based on generation mode.
Args:
args: Command line arguments
prompt: Text prompt or class label (None for unconditional)
model: The model instance
ctx: Device/dtype context
Returns:
Encoded condition tensor or None
"""
if args.unconditional:
# Unconditional: use zeros for class-conditional, None for text-conditional
if args.classes is not None:
return torch.zeros(1, args.classes, **ctx)
return None
if args.classes is not None:
# Class-conditional: one-hot encode the class label
assert prompt.isdigit(), f"Each prompt must be an integer class label, got: {prompt}"
condition = torch.zeros(1, args.classes, **ctx)
condition[0, int(prompt)] = 1
return condition
# Text-conditional: encode the prompt
condition = [prompt]
if hasattr(model.net, "text_encoder"):
with basic_utils.inference_mode(
model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type
):
condition = basic_utils.to(model.net.text_encoder.encode(condition), **ctx)
return condition
def main(args, config: BaseConfig):
# Load prompts or set up unconditional generation
if args.unconditional:
pos_prompt_set = [None] * args.num_samples
prompt_name = "unconditional"
else:
pos_prompt_set = load_prompts(args.prompt_file, relative_to="cwd")
prompt_name = Path(args.prompt_file).stem
# Fix sampling seeds
seed = basic_utils.set_random_seed(config.trainer.seed, by_rank=True)
# Initialize model and checkpointer
model = init_model(config)
checkpointer = init_checkpointer(config)
# Load checkpoint
ckpt_iter, save_dir = load_checkpoint(checkpointer, model, args.ckpt_path, config)
if ckpt_iter is None and args.do_student_sampling:
logger.warning(f"Performing {model.config.student_sample_steps}-step generation on the non-distilled model")
# Set up save directory
if args.image_save_dir:
save_dir = args.image_save_dir
logger.info(f"image_save_dir: {save_dir}")
save_dir = Path(save_dir) / prompt_name
# Remove unused modules to free memory
cleanup_unused_modules(model, args.do_teacher_sampling)
# Set up inference modules
teacher, student, vae = setup_inference_modules(
model, config, args.do_teacher_sampling, args.do_student_sampling, model.precision
)
ctx = {"dtype": model.precision, "device": model.device}
# Validate sampling configuration
has_teacher_sampling = teacher is not None and hasattr(teacher, "sample")
has_student_sampling = student is not None and hasattr(model, "generator_fn")
assert (
has_teacher_sampling or has_student_sampling
), "At least one of teacher or student (with generator_fn) must be provided for sampling"
# Prepare negative condition for CFG
neg_condition = None
if args.classes is not None:
# Class-conditional: use zero vector as negative
neg_condition = torch.zeros(1, args.classes, **ctx)
elif args.neg_prompt_file is not None:
neg_prompts = load_prompts(args.neg_prompt_file, relative_to="cwd")
if len(neg_prompts) > 1:
logger.warning(f"Found {len(neg_prompts)} negative prompts, only using the first one.")
neg_condition = neg_prompts[:1]
logger.debug(f"Loaded negative prompt: {neg_condition[0]}")
if hasattr(model.net, "text_encoder"):
with basic_utils.inference_mode(
model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type
):
neg_condition = basic_utils.to(model.net.text_encoder.encode(neg_condition), **ctx)
# Build skip-layer guidance tag for filenames
slg_tag = ""
if config.model.skip_layers is not None:
slg_tag = f"_slg{'_'.join([str(x) for x in config.model.skip_layers])}"
# Initialize noise (regenerated per sample for unconditional mode)
noise = torch.randn([1, *config.model.input_shape], **ctx)
# Main generation loop
for i, prompt in enumerate(pos_prompt_set):
# Log progress
if args.unconditional:
logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating unconditional sample...")
# Generate different noise for each unconditional sample (diversity)
noise = torch.randn([1, *config.model.input_shape], **ctx)
else:
logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating: {prompt[:80]}...")
# Prepare condition based on model type
condition = _prepare_condition(args, prompt, model, ctx)
# Student sampling
if has_student_sampling:
start_time = time.time()
image_student = model.generator_fn(
student,
noise,
condition=condition,
student_sample_steps=model.config.student_sample_steps,
student_sample_type=model.config.student_sample_type,
t_list=model.config.sample_t_cfg.t_list,
precision_amp=model.precision_amp_infer,
)
logger.info(f"Student sampling time: {time.time() - start_time:.2f}s")
save_path = save_dir / f"student_step{model.config.student_sample_steps}_{i:04d}_seed{seed}.png"
basic_utils.save_media(image_student, str(save_path), vae=vae, precision_amp=model.precision_amp_infer)
# Teacher sampling
if has_teacher_sampling:
start_time = time.time()
teacher_kwargs = {
"num_steps": args.num_steps,
"second_order": False,
"precision_amp": model.precision_amp_infer,
}
if config.model.skip_layers is not None:
teacher_kwargs["skip_layers"] = config.model.skip_layers
image_teacher = model.sample(
teacher, noise, condition=condition, neg_condition=neg_condition, **teacher_kwargs
)
logger.info(f"Teacher sampling time: {time.time() - start_time:.2f}s")
save_path = (
save_dir
/ f"teacher_cfg{config.model.guidance_scale}_steps{args.num_steps}{slg_tag}_{i:04d}_seed{seed}.png"
)
basic_utils.save_media(image_teacher, str(save_path), vae=vae, precision_amp=model.precision_amp_infer)
# ----------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Image generation inference for text-conditional, class-conditional, and unconditional models",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Add common args
add_common_args(parser)
# Prompt/condition arguments
parser.add_argument(
"--prompt_file",
default="scripts/inference/prompts/image_prompts.txt",
type=str,
help="File containing prompts (one per line). For class-conditional models, use integer class labels.",
)
parser.add_argument(
"--neg_prompt_file",
default=None,
type=str,
help="File containing negative prompt for CFG (only first line used).",
)
parser.add_argument(
"--classes",
default=None,
type=int,
help="Number of classes for class-conditional generation (e.g., 1000 for ImageNet). "
"Prompts should be integer class labels.",
)
parser.add_argument(
"--unconditional",
action="store_true",
help="Generate unconditional samples (no class or text conditioning).",
)
parser.add_argument(
"--num_samples",
default=10,
type=int,
help="Number of samples for unconditional generation (default: 10).",
)
# Output arguments
parser.add_argument(
"--image_save_dir",
default=None,
type=str,
help="Directory to save generated images (overrides default).",
)
# Sampling arguments
parser.add_argument(
"--num_steps",
default=50,
type=int,
help="Number of sampling steps for teacher (default: 50).",
)
args = parse_args(parser)
config = setup(args, evaluation=True)
main(args, config)
clean_up()
|