fastgen-offline / FastGen /scripts /inference /image_model_inference.py
taohu's picture
Upload folder using huggingface_hub
0839907 verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Image generation inference script.
Supports:
- Text-conditional models: SD15, SDXL, Flux
- Class-conditional models: EDM, SiT, DiT (ImageNet)
- Unconditional generation
Examples:
# Text-conditional: eval teacher only (SDXL)
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --do_student_sampling False \
--config fastgen/configs/experiments/SDXL/config_sft.py \
- trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=sdxl_inference
# Class-conditional: eval teacher (SiT on ImageNet)
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --do_student_sampling False \
--prompt_file scripts/inference/prompts/classes.txt --classes 1000 \
--config fastgen/configs/experiments/SiT/config_sft.py \
- trainer.seed=1 trainer.ddp=True log_config.name=sit_inference
# Unconditional generation (EDM CIFAR-10)
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --do_student_sampling False \
--unconditional --num_samples 16 \
--config fastgen/configs/experiments/EDM/config_sft_edm_cifar10.py \
- trainer.seed=1 trainer.ddp=True log_config.name=edm_cifar10_inference
# Eval both student and teacher
PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
scripts/inference/image_model_inference.py --ckpt_path /path/to/checkpoints/0003000.pth \
--do_student_sampling True --do_teacher_sampling True \
--config fastgen/configs/experiments/SD15/config_dmd2.py \
- trainer.seed=1 trainer.ddp=True log_config.name=sd15_inference
"""
import argparse
import time
from pathlib import Path
import torch
from fastgen.configs.config import BaseConfig
import fastgen.utils.logging_utils as logger
from fastgen.utils import basic_utils
from fastgen.utils.distributed import clean_up
from fastgen.utils.scripts import parse_args, setup
from scripts.inference.inference_utils import (
load_prompts,
init_model,
init_checkpointer,
load_checkpoint,
cleanup_unused_modules,
setup_inference_modules,
add_common_args,
)
def _prepare_condition(args, prompt, model, ctx):
"""Prepare conditioning based on generation mode.
Args:
args: Command line arguments
prompt: Text prompt or class label (None for unconditional)
model: The model instance
ctx: Device/dtype context
Returns:
Encoded condition tensor or None
"""
if args.unconditional:
# Unconditional: use zeros for class-conditional, None for text-conditional
if args.classes is not None:
return torch.zeros(1, args.classes, **ctx)
return None
if args.classes is not None:
# Class-conditional: one-hot encode the class label
assert prompt.isdigit(), f"Each prompt must be an integer class label, got: {prompt}"
condition = torch.zeros(1, args.classes, **ctx)
condition[0, int(prompt)] = 1
return condition
# Text-conditional: encode the prompt
condition = [prompt]
if hasattr(model.net, "text_encoder"):
with basic_utils.inference_mode(
model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type
):
condition = basic_utils.to(model.net.text_encoder.encode(condition), **ctx)
return condition
def main(args, config: BaseConfig):
# Load prompts or set up unconditional generation
if args.unconditional:
pos_prompt_set = [None] * args.num_samples
prompt_name = "unconditional"
else:
pos_prompt_set = load_prompts(args.prompt_file, relative_to="cwd")
prompt_name = Path(args.prompt_file).stem
# Fix sampling seeds
seed = basic_utils.set_random_seed(config.trainer.seed, by_rank=True)
# Initialize model and checkpointer
model = init_model(config)
checkpointer = init_checkpointer(config)
# Load checkpoint
ckpt_iter, save_dir = load_checkpoint(checkpointer, model, args.ckpt_path, config)
if ckpt_iter is None and args.do_student_sampling:
logger.warning(f"Performing {model.config.student_sample_steps}-step generation on the non-distilled model")
# Set up save directory
if args.image_save_dir:
save_dir = args.image_save_dir
logger.info(f"image_save_dir: {save_dir}")
save_dir = Path(save_dir) / prompt_name
# Remove unused modules to free memory
cleanup_unused_modules(model, args.do_teacher_sampling)
# Set up inference modules
teacher, student, vae = setup_inference_modules(
model, config, args.do_teacher_sampling, args.do_student_sampling, model.precision
)
ctx = {"dtype": model.precision, "device": model.device}
# Validate sampling configuration
has_teacher_sampling = teacher is not None and hasattr(teacher, "sample")
has_student_sampling = student is not None and hasattr(model, "generator_fn")
assert (
has_teacher_sampling or has_student_sampling
), "At least one of teacher or student (with generator_fn) must be provided for sampling"
# Prepare negative condition for CFG
neg_condition = None
if args.classes is not None:
# Class-conditional: use zero vector as negative
neg_condition = torch.zeros(1, args.classes, **ctx)
elif args.neg_prompt_file is not None:
neg_prompts = load_prompts(args.neg_prompt_file, relative_to="cwd")
if len(neg_prompts) > 1:
logger.warning(f"Found {len(neg_prompts)} negative prompts, only using the first one.")
neg_condition = neg_prompts[:1]
logger.debug(f"Loaded negative prompt: {neg_condition[0]}")
if hasattr(model.net, "text_encoder"):
with basic_utils.inference_mode(
model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type
):
neg_condition = basic_utils.to(model.net.text_encoder.encode(neg_condition), **ctx)
# Build skip-layer guidance tag for filenames
slg_tag = ""
if config.model.skip_layers is not None:
slg_tag = f"_slg{'_'.join([str(x) for x in config.model.skip_layers])}"
# Initialize noise (regenerated per sample for unconditional mode)
noise = torch.randn([1, *config.model.input_shape], **ctx)
# Main generation loop
for i, prompt in enumerate(pos_prompt_set):
# Log progress
if args.unconditional:
logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating unconditional sample...")
# Generate different noise for each unconditional sample (diversity)
noise = torch.randn([1, *config.model.input_shape], **ctx)
else:
logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating: {prompt[:80]}...")
# Prepare condition based on model type
condition = _prepare_condition(args, prompt, model, ctx)
# Student sampling
if has_student_sampling:
start_time = time.time()
image_student = model.generator_fn(
student,
noise,
condition=condition,
student_sample_steps=model.config.student_sample_steps,
student_sample_type=model.config.student_sample_type,
t_list=model.config.sample_t_cfg.t_list,
precision_amp=model.precision_amp_infer,
)
logger.info(f"Student sampling time: {time.time() - start_time:.2f}s")
save_path = save_dir / f"student_step{model.config.student_sample_steps}_{i:04d}_seed{seed}.png"
basic_utils.save_media(image_student, str(save_path), vae=vae, precision_amp=model.precision_amp_infer)
# Teacher sampling
if has_teacher_sampling:
start_time = time.time()
teacher_kwargs = {
"num_steps": args.num_steps,
"second_order": False,
"precision_amp": model.precision_amp_infer,
}
if config.model.skip_layers is not None:
teacher_kwargs["skip_layers"] = config.model.skip_layers
image_teacher = model.sample(
teacher, noise, condition=condition, neg_condition=neg_condition, **teacher_kwargs
)
logger.info(f"Teacher sampling time: {time.time() - start_time:.2f}s")
save_path = (
save_dir
/ f"teacher_cfg{config.model.guidance_scale}_steps{args.num_steps}{slg_tag}_{i:04d}_seed{seed}.png"
)
basic_utils.save_media(image_teacher, str(save_path), vae=vae, precision_amp=model.precision_amp_infer)
# ----------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Image generation inference for text-conditional, class-conditional, and unconditional models",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Add common args
add_common_args(parser)
# Prompt/condition arguments
parser.add_argument(
"--prompt_file",
default="scripts/inference/prompts/image_prompts.txt",
type=str,
help="File containing prompts (one per line). For class-conditional models, use integer class labels.",
)
parser.add_argument(
"--neg_prompt_file",
default=None,
type=str,
help="File containing negative prompt for CFG (only first line used).",
)
parser.add_argument(
"--classes",
default=None,
type=int,
help="Number of classes for class-conditional generation (e.g., 1000 for ImageNet). "
"Prompts should be integer class labels.",
)
parser.add_argument(
"--unconditional",
action="store_true",
help="Generate unconditional samples (no class or text conditioning).",
)
parser.add_argument(
"--num_samples",
default=10,
type=int,
help="Number of samples for unconditional generation (default: 10).",
)
# Output arguments
parser.add_argument(
"--image_save_dir",
default=None,
type=str,
help="Directory to save generated images (overrides default).",
)
# Sampling arguments
parser.add_argument(
"--num_steps",
default=50,
type=int,
help="Number of sampling steps for teacher (default: 50).",
)
args = parse_args(parser)
config = setup(args, evaluation=True)
main(args, config)
clean_up()