| |
| |
|
|
| """Shared utilities for image and video model inference scripts.""" |
|
|
| import argparse |
| import os |
| from pathlib import Path |
| from typing import Optional, Tuple, List, Any |
|
|
| import torch |
|
|
| from fastgen.configs.config import BaseConfig |
| from fastgen.utils import instantiate |
| from fastgen.utils.checkpointer import FSDPCheckpointer, Checkpointer |
| import fastgen.utils.logging_utils as logger |
|
|
|
|
| def expand_path(path: str | Path, relative_to: str = "cwd") -> Path: |
| """Resolve path - absolute paths stay as-is, relative paths are resolved. |
| |
| Args: |
| path: Path to resolve |
| relative_to: How to resolve relative paths: |
| - "cwd": relative to current working directory |
| - "script": relative to the calling script's directory (for prompt files) |
| |
| Returns: |
| Resolved absolute path |
| """ |
| path = Path(path) |
| if path.is_absolute(): |
| return path |
|
|
| if relative_to == "cwd": |
| return Path.cwd() / path |
| elif relative_to == "script": |
| |
| import inspect |
|
|
| frame = inspect.currentframe() |
| if frame and frame.f_back: |
| caller_file = frame.f_back.f_globals.get("__file__") |
| if caller_file: |
| return Path(caller_file).parent / path |
| return Path.cwd() / path |
| else: |
| return Path.cwd() / path |
|
|
|
|
| def load_prompts(prompt_file: str | Path, relative_to: str = "cwd") -> List[str]: |
| """Load prompts from a file. |
| |
| Args: |
| prompt_file: Path to prompt file (one prompt per line) |
| relative_to: How to resolve relative paths ("cwd" or "script") |
| |
| Returns: |
| List of prompts |
| |
| Raises: |
| FileNotFoundError: If prompt file doesn't exist |
| """ |
| prompt_path = expand_path(prompt_file, relative_to) |
| if not prompt_path.is_file(): |
| raise FileNotFoundError(f"Prompt file not found: {prompt_path}") |
|
|
| with prompt_path.open("r") as f: |
| prompts = [line.strip() for line in f.readlines() if line.strip()] |
|
|
| logger.info(f"Loaded {len(prompts)} prompts from {prompt_path}") |
| return prompts |
|
|
|
|
| def init_model(config: BaseConfig) -> Any: |
| """Initialize the model from config. |
| |
| Args: |
| config: Base configuration object |
| |
| Returns: |
| Instantiated model |
| """ |
| config.model_class.config = config.model |
| model = instantiate(config.model_class) |
| config.model_class.config = None |
| return model |
|
|
|
|
| def init_checkpointer(config: BaseConfig) -> Checkpointer | FSDPCheckpointer: |
| """Initialize the appropriate checkpointer based on config. |
| |
| Args: |
| config: Base configuration object |
| |
| Returns: |
| Checkpointer or FSDPCheckpointer instance |
| """ |
| if config.trainer.fsdp: |
| return FSDPCheckpointer(config.trainer.checkpointer) |
| else: |
| return Checkpointer(config.trainer.checkpointer) |
|
|
|
|
| def load_checkpoint( |
| checkpointer: Checkpointer | FSDPCheckpointer, |
| model: Any, |
| ckpt_path: str, |
| config: BaseConfig, |
| ) -> Tuple[Optional[int], str]: |
| """Load checkpoint if valid path provided. |
| |
| Args: |
| checkpointer: Checkpointer instance |
| model: Model to load weights into |
| ckpt_path: Path to checkpoint |
| config: Base configuration object |
| |
| Returns: |
| Tuple of (checkpoint_iteration or None, save_directory) |
| """ |
| ckpt_iter = None |
|
|
| if ckpt_path and (os.path.isdir(ckpt_path + ".net_model") or os.path.isfile(ckpt_path)): |
| |
| save_dir = f"{config.log_config.save_path}/{ckpt_path.split('/')[-3]}/{ckpt_path.split('/')[-1].split('.')[0]}" |
| logger.info(f"ckpt_path: {ckpt_path}, save_dir: {save_dir}") |
|
|
| |
| model_dict_infer = torch.nn.ModuleDict({"net": model.net, **model.ema_dict}) |
|
|
| ckpt_iter = checkpointer.load(model_dict_infer, path=ckpt_path) |
| logger.success(f"Loading successfully checkpoint {ckpt_iter}") |
| else: |
| save_dir = f"{config.log_config.save_path}/inference_validation" |
| logger.warning(f"No valid ckpt path, save_dir: {save_dir}") |
|
|
| return ckpt_iter, save_dir |
|
|
|
|
| def cleanup_unused_modules(model: Any, do_teacher_sampling: bool) -> None: |
| """Remove unused modules to free memory. |
| |
| Args: |
| model: Model to clean up |
| do_teacher_sampling: Whether teacher sampling will be performed |
| """ |
| if hasattr(model, "fake_score"): |
| del model.fake_score |
| if hasattr(model, "discriminator"): |
| del model.discriminator |
| if (not do_teacher_sampling) and hasattr(model, "teacher"): |
| del model.teacher |
|
|
|
|
| def setup_inference_modules( |
| model: Any, |
| config: BaseConfig, |
| do_teacher_sampling: bool, |
| do_student_sampling: bool, |
| precision: torch.dtype, |
| ) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]: |
| """Set up model modules for inference. |
| |
| Args: |
| model: The model instance |
| config: Base configuration object |
| do_teacher_sampling: Whether to set up teacher for sampling |
| do_student_sampling: Whether to set up student for sampling |
| precision: Inference precision dtype |
| |
| Returns: |
| Tuple of (teacher, student, vae) - any may be None |
| """ |
| teacher, student, vae = None, None, None |
|
|
| if do_teacher_sampling: |
| |
| if getattr(model, "teacher", None) is not None: |
| teacher = model.teacher |
| else: |
| teacher = model.net |
| teacher.eval().to(dtype=precision, device=model.device) |
|
|
| if do_student_sampling: |
| student = getattr(model, model.use_ema[0]) if model.use_ema else model.net |
| student.eval().to(dtype=precision, device=model.device) |
| logger.info(f"Evaluating student sample steps: {model.config.student_sample_steps}") |
|
|
| if hasattr(model.net, "init_preprocessors") and config.model.enable_preprocessors: |
| model.net.init_preprocessors() |
| vae = model.net.vae |
| vae.to(device=model.device, dtype=precision) |
|
|
| return teacher, student, vae |
|
|
|
|
| def add_common_args(parser: argparse.ArgumentParser) -> None: |
| """Add common arguments shared between image and video inference. |
| |
| Args: |
| parser: ArgumentParser to add arguments to |
| """ |
| parser.add_argument( |
| "--ckpt_path", |
| default="", |
| type=str, |
| help="Path to the checkpoint (optional, uses pretrained if not provided)", |
| ) |
| parser.add_argument( |
| "--do_student_sampling", |
| default=True, |
| type=lambda x: x.lower() in ("true", "1", "yes"), |
| help="Whether to perform student sampling", |
| ) |
| parser.add_argument( |
| "--do_teacher_sampling", |
| default=True, |
| type=lambda x: x.lower() in ("true", "1", "yes"), |
| help="Whether to perform teacher sampling", |
| ) |
|
|