File size: 7,038 Bytes
0839907 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared utilities for image and video model inference scripts."""
import argparse
import os
from pathlib import Path
from typing import Optional, Tuple, List, Any
import torch
from fastgen.configs.config import BaseConfig
from fastgen.utils import instantiate
from fastgen.utils.checkpointer import FSDPCheckpointer, Checkpointer
import fastgen.utils.logging_utils as logger
def expand_path(path: str | Path, relative_to: str = "cwd") -> Path:
"""Resolve path - absolute paths stay as-is, relative paths are resolved.
Args:
path: Path to resolve
relative_to: How to resolve relative paths:
- "cwd": relative to current working directory
- "script": relative to the calling script's directory (for prompt files)
Returns:
Resolved absolute path
"""
path = Path(path)
if path.is_absolute():
return path
if relative_to == "cwd":
return Path.cwd() / path
elif relative_to == "script":
# Get the caller's directory - useful for default prompt files
import inspect
frame = inspect.currentframe()
if frame and frame.f_back:
caller_file = frame.f_back.f_globals.get("__file__")
if caller_file:
return Path(caller_file).parent / path
return Path.cwd() / path
else:
return Path.cwd() / path
def load_prompts(prompt_file: str | Path, relative_to: str = "cwd") -> List[str]:
"""Load prompts from a file.
Args:
prompt_file: Path to prompt file (one prompt per line)
relative_to: How to resolve relative paths ("cwd" or "script")
Returns:
List of prompts
Raises:
FileNotFoundError: If prompt file doesn't exist
"""
prompt_path = expand_path(prompt_file, relative_to)
if not prompt_path.is_file():
raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
with prompt_path.open("r") as f:
prompts = [line.strip() for line in f.readlines() if line.strip()]
logger.info(f"Loaded {len(prompts)} prompts from {prompt_path}")
return prompts
def init_model(config: BaseConfig) -> Any:
"""Initialize the model from config.
Args:
config: Base configuration object
Returns:
Instantiated model
"""
config.model_class.config = config.model
model = instantiate(config.model_class)
config.model_class.config = None
return model
def init_checkpointer(config: BaseConfig) -> Checkpointer | FSDPCheckpointer:
"""Initialize the appropriate checkpointer based on config.
Args:
config: Base configuration object
Returns:
Checkpointer or FSDPCheckpointer instance
"""
if config.trainer.fsdp:
return FSDPCheckpointer(config.trainer.checkpointer)
else:
return Checkpointer(config.trainer.checkpointer)
def load_checkpoint(
checkpointer: Checkpointer | FSDPCheckpointer,
model: Any,
ckpt_path: str,
config: BaseConfig,
) -> Tuple[Optional[int], str]:
"""Load checkpoint if valid path provided.
Args:
checkpointer: Checkpointer instance
model: Model to load weights into
ckpt_path: Path to checkpoint
config: Base configuration object
Returns:
Tuple of (checkpoint_iteration or None, save_directory)
"""
ckpt_iter = None
if ckpt_path and (os.path.isdir(ckpt_path + ".net_model") or os.path.isfile(ckpt_path)):
# Construct save directory from checkpoint path
save_dir = f"{config.log_config.save_path}/{ckpt_path.split('/')[-3]}/{ckpt_path.split('/')[-1].split('.')[0]}"
logger.info(f"ckpt_path: {ckpt_path}, save_dir: {save_dir}")
# Build model dict for loading
model_dict_infer = torch.nn.ModuleDict({"net": model.net, **model.ema_dict})
ckpt_iter = checkpointer.load(model_dict_infer, path=ckpt_path)
logger.success(f"Loading successfully checkpoint {ckpt_iter}")
else:
save_dir = f"{config.log_config.save_path}/inference_validation"
logger.warning(f"No valid ckpt path, save_dir: {save_dir}")
return ckpt_iter, save_dir
def cleanup_unused_modules(model: Any, do_teacher_sampling: bool) -> None:
"""Remove unused modules to free memory.
Args:
model: Model to clean up
do_teacher_sampling: Whether teacher sampling will be performed
"""
if hasattr(model, "fake_score"):
del model.fake_score
if hasattr(model, "discriminator"):
del model.discriminator
if (not do_teacher_sampling) and hasattr(model, "teacher"):
del model.teacher
def setup_inference_modules(
model: Any,
config: BaseConfig,
do_teacher_sampling: bool,
do_student_sampling: bool,
precision: torch.dtype,
) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
"""Set up model modules for inference.
Args:
model: The model instance
config: Base configuration object
do_teacher_sampling: Whether to set up teacher for sampling
do_student_sampling: Whether to set up student for sampling
precision: Inference precision dtype
Returns:
Tuple of (teacher, student, vae) - any may be None
"""
teacher, student, vae = None, None, None
if do_teacher_sampling:
# Use model.teacher if available, otherwise use model.net for teacher-style sampling
if getattr(model, "teacher", None) is not None:
teacher = model.teacher
else:
teacher = model.net
teacher.eval().to(dtype=precision, device=model.device)
if do_student_sampling:
student = getattr(model, model.use_ema[0]) if model.use_ema else model.net
student.eval().to(dtype=precision, device=model.device)
logger.info(f"Evaluating student sample steps: {model.config.student_sample_steps}")
if hasattr(model.net, "init_preprocessors") and config.model.enable_preprocessors:
model.net.init_preprocessors()
vae = model.net.vae
vae.to(device=model.device, dtype=precision)
return teacher, student, vae
def add_common_args(parser: argparse.ArgumentParser) -> None:
"""Add common arguments shared between image and video inference.
Args:
parser: ArgumentParser to add arguments to
"""
parser.add_argument(
"--ckpt_path",
default="",
type=str,
help="Path to the checkpoint (optional, uses pretrained if not provided)",
)
parser.add_argument(
"--do_student_sampling",
default=True,
type=lambda x: x.lower() in ("true", "1", "yes"),
help="Whether to perform student sampling",
)
parser.add_argument(
"--do_teacher_sampling",
default=True,
type=lambda x: x.lower() in ("true", "1", "yes"),
help="Whether to perform teacher sampling",
)
|