File size: 11,036 Bytes
0839907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Image generation inference script.

Supports:
- Text-conditional models: SD15, SDXL, Flux
- Class-conditional models: EDM, SiT, DiT (ImageNet)
- Unconditional generation

Examples:

    # Text-conditional: eval teacher only (SDXL)
    PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
        scripts/inference/image_model_inference.py --do_student_sampling False \
        --config fastgen/configs/experiments/SDXL/config_sft.py \
        - trainer.seed=1 trainer.ddp=True model.guidance_scale=5.0 log_config.name=sdxl_inference

    # Class-conditional: eval teacher (SiT on ImageNet)
    PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
        scripts/inference/image_model_inference.py --do_student_sampling False \
        --prompt_file scripts/inference/prompts/classes.txt --classes 1000 \
        --config fastgen/configs/experiments/SiT/config_sft.py \
        - trainer.seed=1 trainer.ddp=True log_config.name=sit_inference

    # Unconditional generation (EDM CIFAR-10)
    PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
        scripts/inference/image_model_inference.py --do_student_sampling False \
        --unconditional --num_samples 16 \
        --config fastgen/configs/experiments/EDM/config_sft_edm_cifar10.py \
        - trainer.seed=1 trainer.ddp=True log_config.name=edm_cifar10_inference

    # Eval both student and teacher
    PYTHONPATH=$(pwd) FASTGEN_OUTPUT_ROOT='FASTGEN_OUTPUT' torchrun --nproc_per_node=1 --standalone \
        scripts/inference/image_model_inference.py --ckpt_path /path/to/checkpoints/0003000.pth \
        --do_student_sampling True --do_teacher_sampling True \
        --config fastgen/configs/experiments/SD15/config_dmd2.py \
        - trainer.seed=1 trainer.ddp=True log_config.name=sd15_inference
"""

import argparse
import time
from pathlib import Path

import torch

from fastgen.configs.config import BaseConfig
import fastgen.utils.logging_utils as logger
from fastgen.utils import basic_utils
from fastgen.utils.distributed import clean_up
from fastgen.utils.scripts import parse_args, setup
from scripts.inference.inference_utils import (
    load_prompts,
    init_model,
    init_checkpointer,
    load_checkpoint,
    cleanup_unused_modules,
    setup_inference_modules,
    add_common_args,
)


def _prepare_condition(args, prompt, model, ctx):
    """Prepare conditioning based on generation mode.

    Args:
        args: Command line arguments
        prompt: Text prompt or class label (None for unconditional)
        model: The model instance
        ctx: Device/dtype context

    Returns:
        Encoded condition tensor or None
    """
    if args.unconditional:
        # Unconditional: use zeros for class-conditional, None for text-conditional
        if args.classes is not None:
            return torch.zeros(1, args.classes, **ctx)
        return None

    if args.classes is not None:
        # Class-conditional: one-hot encode the class label
        assert prompt.isdigit(), f"Each prompt must be an integer class label, got: {prompt}"
        condition = torch.zeros(1, args.classes, **ctx)
        condition[0, int(prompt)] = 1
        return condition

    # Text-conditional: encode the prompt
    condition = [prompt]
    if hasattr(model.net, "text_encoder"):
        with basic_utils.inference_mode(
            model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type
        ):
            condition = basic_utils.to(model.net.text_encoder.encode(condition), **ctx)
    return condition


def main(args, config: BaseConfig):
    # Load prompts or set up unconditional generation
    if args.unconditional:
        pos_prompt_set = [None] * args.num_samples
        prompt_name = "unconditional"
    else:
        pos_prompt_set = load_prompts(args.prompt_file, relative_to="cwd")
        prompt_name = Path(args.prompt_file).stem

    # Fix sampling seeds
    seed = basic_utils.set_random_seed(config.trainer.seed, by_rank=True)

    # Initialize model and checkpointer
    model = init_model(config)
    checkpointer = init_checkpointer(config)

    # Load checkpoint
    ckpt_iter, save_dir = load_checkpoint(checkpointer, model, args.ckpt_path, config)

    if ckpt_iter is None and args.do_student_sampling:
        logger.warning(f"Performing {model.config.student_sample_steps}-step generation on the non-distilled model")

    # Set up save directory
    if args.image_save_dir:
        save_dir = args.image_save_dir
        logger.info(f"image_save_dir: {save_dir}")
    save_dir = Path(save_dir) / prompt_name

    # Remove unused modules to free memory
    cleanup_unused_modules(model, args.do_teacher_sampling)

    # Set up inference modules
    teacher, student, vae = setup_inference_modules(
        model, config, args.do_teacher_sampling, args.do_student_sampling, model.precision
    )
    ctx = {"dtype": model.precision, "device": model.device}

    # Validate sampling configuration
    has_teacher_sampling = teacher is not None and hasattr(teacher, "sample")
    has_student_sampling = student is not None and hasattr(model, "generator_fn")
    assert (
        has_teacher_sampling or has_student_sampling
    ), "At least one of teacher or student (with generator_fn) must be provided for sampling"

    # Prepare negative condition for CFG
    neg_condition = None
    if args.classes is not None:
        # Class-conditional: use zero vector as negative
        neg_condition = torch.zeros(1, args.classes, **ctx)
    elif args.neg_prompt_file is not None:
        neg_prompts = load_prompts(args.neg_prompt_file, relative_to="cwd")
        if len(neg_prompts) > 1:
            logger.warning(f"Found {len(neg_prompts)} negative prompts, only using the first one.")
        neg_condition = neg_prompts[:1]
        logger.debug(f"Loaded negative prompt: {neg_condition[0]}")
        if hasattr(model.net, "text_encoder"):
            with basic_utils.inference_mode(
                model.net.text_encoder, precision_amp=model.precision_amp_enc, device_type=model.device.type
            ):
                neg_condition = basic_utils.to(model.net.text_encoder.encode(neg_condition), **ctx)

    # Build skip-layer guidance tag for filenames
    slg_tag = ""
    if config.model.skip_layers is not None:
        slg_tag = f"_slg{'_'.join([str(x) for x in config.model.skip_layers])}"

    # Initialize noise (regenerated per sample for unconditional mode)
    noise = torch.randn([1, *config.model.input_shape], **ctx)

    # Main generation loop
    for i, prompt in enumerate(pos_prompt_set):
        # Log progress
        if args.unconditional:
            logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating unconditional sample...")
            # Generate different noise for each unconditional sample (diversity)
            noise = torch.randn([1, *config.model.input_shape], **ctx)
        else:
            logger.info(f"[{i+1}/{len(pos_prompt_set)}] Generating: {prompt[:80]}...")

        # Prepare condition based on model type
        condition = _prepare_condition(args, prompt, model, ctx)

        # Student sampling
        if has_student_sampling:
            start_time = time.time()
            image_student = model.generator_fn(
                student,
                noise,
                condition=condition,
                student_sample_steps=model.config.student_sample_steps,
                student_sample_type=model.config.student_sample_type,
                t_list=model.config.sample_t_cfg.t_list,
                precision_amp=model.precision_amp_infer,
            )
            logger.info(f"Student sampling time: {time.time() - start_time:.2f}s")

            save_path = save_dir / f"student_step{model.config.student_sample_steps}_{i:04d}_seed{seed}.png"
            basic_utils.save_media(image_student, str(save_path), vae=vae, precision_amp=model.precision_amp_infer)

        # Teacher sampling
        if has_teacher_sampling:
            start_time = time.time()
            teacher_kwargs = {
                "num_steps": args.num_steps,
                "second_order": False,
                "precision_amp": model.precision_amp_infer,
            }
            if config.model.skip_layers is not None:
                teacher_kwargs["skip_layers"] = config.model.skip_layers

            image_teacher = model.sample(
                teacher, noise, condition=condition, neg_condition=neg_condition, **teacher_kwargs
            )
            logger.info(f"Teacher sampling time: {time.time() - start_time:.2f}s")

            save_path = (
                save_dir
                / f"teacher_cfg{config.model.guidance_scale}_steps{args.num_steps}{slg_tag}_{i:04d}_seed{seed}.png"
            )
            basic_utils.save_media(image_teacher, str(save_path), vae=vae, precision_amp=model.precision_amp_infer)


# ----------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Image generation inference for text-conditional, class-conditional, and unconditional models",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    # Add common args
    add_common_args(parser)

    # Prompt/condition arguments
    parser.add_argument(
        "--prompt_file",
        default="scripts/inference/prompts/image_prompts.txt",
        type=str,
        help="File containing prompts (one per line). For class-conditional models, use integer class labels.",
    )
    parser.add_argument(
        "--neg_prompt_file",
        default=None,
        type=str,
        help="File containing negative prompt for CFG (only first line used).",
    )
    parser.add_argument(
        "--classes",
        default=None,
        type=int,
        help="Number of classes for class-conditional generation (e.g., 1000 for ImageNet). "
        "Prompts should be integer class labels.",
    )
    parser.add_argument(
        "--unconditional",
        action="store_true",
        help="Generate unconditional samples (no class or text conditioning).",
    )
    parser.add_argument(
        "--num_samples",
        default=10,
        type=int,
        help="Number of samples for unconditional generation (default: 10).",
    )

    # Output arguments
    parser.add_argument(
        "--image_save_dir",
        default=None,
        type=str,
        help="Directory to save generated images (overrides default).",
    )

    # Sampling arguments
    parser.add_argument(
        "--num_steps",
        default=50,
        type=int,
        help="Number of sampling steps for teacher (default: 50).",
    )

    args = parse_args(parser)
    config = setup(args, evaluation=True)
    main(args, config)

    clean_up()