Instructions to use yitongl/5090_test with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use yitongl/5090_test with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("yitongl/5090_test", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # SPDX-License-Identifier: Apache-2.0 | |
| """Convert a DCP training checkpoint to a diffusers-style model directory. | |
| Works on a single GPU regardless of how many GPUs were used for training | |
| (DCP handles resharding automatically). | |
| Usage (no torchrun needed):: | |
| python -m fastvideo.train.entrypoint.dcp_to_diffusers \ | |
| --checkpoint /path/to/checkpoint-1000 \ | |
| --output-dir /path/to/diffusers_output | |
| Or with torchrun (also fine):: | |
| torchrun --nproc_per_node=1 \ | |
| -m fastvideo.train.entrypoint.dcp_to_diffusers \ | |
| --checkpoint ... --output-dir ... | |
| The checkpoint must contain ``metadata.json`` (written by | |
| ``CheckpointManager``). If the checkpoint predates metadata | |
| support, pass ``--config`` explicitly to provide the training | |
| YAML. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| from typing import Any | |
| from fastvideo.logger import init_logger | |
| logger = init_logger(__name__) | |
| def _ensure_distributed() -> None: | |
| """Set up a single-process distributed env if needed. | |
| When running under ``torchrun`` the env vars are already set. | |
| For plain ``python`` we fill in the minimum required vars so | |
| that ``init_process_group`` succeeds with world_size=1. | |
| """ | |
| for key, default in [ | |
| ("RANK", "0"), | |
| ("LOCAL_RANK", "0"), | |
| ("WORLD_SIZE", "1"), | |
| ("MASTER_ADDR", "127.0.0.1"), | |
| ("MASTER_PORT", "29500"), | |
| ]: | |
| os.environ.setdefault(key, default) | |
| def _save_role_pretrained( | |
| *, | |
| role: str, | |
| base_model_path: str, | |
| output_dir: str, | |
| module_names: list[str] | None = None, | |
| overwrite: bool = False, | |
| model: Any, | |
| ) -> str: | |
| """Export a role's modules into a diffusers-style model dir. | |
| Produces a ``model_path`` loadable by | |
| ``PipelineComponentLoader`` (``model_index.json``, | |
| ``transformer/``, ``vae/``, etc. copied from | |
| ``base_model_path``). | |
| """ | |
| import shutil | |
| from pathlib import Path | |
| import torch | |
| import torch.distributed as dist | |
| from torch.distributed.checkpoint.state_dict import ( | |
| StateDictOptions, | |
| get_model_state_dict, | |
| ) | |
| from fastvideo.utils import maybe_download_model | |
| def _rank() -> int: | |
| if dist.is_available() and dist.is_initialized(): | |
| return int(dist.get_rank()) | |
| return 0 | |
| def _barrier() -> None: | |
| if dist.is_available() and dist.is_initialized(): | |
| dist.barrier() | |
| local_base = Path(maybe_download_model(str(base_model_path))).resolve() | |
| dst = Path(os.path.expanduser(str(output_dir))).resolve() | |
| if _rank() == 0: | |
| if dst.exists(): | |
| if overwrite: | |
| shutil.rmtree(dst, ignore_errors=True) | |
| else: | |
| raise FileExistsError(f"Refusing to overwrite existing " | |
| f"directory: {dst}. " | |
| "Pass --overwrite to replace it.") | |
| def _copy_or_link(src: str, dest: str) -> None: | |
| try: | |
| os.link(src, dest) | |
| except OSError: | |
| shutil.copy2(src, dest) | |
| logger.info( | |
| "Creating pretrained export dir at %s " | |
| "(base=%s)", | |
| dst, | |
| local_base, | |
| ) | |
| shutil.copytree( | |
| local_base, | |
| dst, | |
| symlinks=True, | |
| copy_function=_copy_or_link, | |
| ) | |
| _barrier() | |
| modules: dict[str, torch.nn.Module] = {} | |
| if model.transformer is not None: | |
| modules["transformer"] = model.transformer | |
| if module_names is None: | |
| module_names = sorted(modules.keys()) | |
| for module_name in module_names: | |
| if module_name not in modules: | |
| raise KeyError(f"Role {role!r} does not have module " | |
| f"{module_name!r}. " | |
| f"Available: {sorted(modules.keys())}") | |
| module_dir = dst / module_name | |
| if not module_dir.is_dir(): | |
| raise FileNotFoundError(f"Export directory missing component " | |
| f"dir {module_name!r}: {module_dir}") | |
| options = StateDictOptions( | |
| full_state_dict=True, | |
| cpu_offload=True, | |
| ) | |
| state_dict = get_model_state_dict( | |
| modules[module_name], | |
| options=options, | |
| ) | |
| if _rank() == 0: | |
| for path in module_dir.glob("*.safetensors"): | |
| path.unlink(missing_ok=True) | |
| # Convert internal parameter names back to HF format. | |
| # load_model_from_full_model_state_dict builds reverse_param_names_mapping | |
| # (internal_key → hf_key) and stores it on the module. Without this, | |
| # the exported safetensors would have internal keys (e.g. | |
| # "patch_embedding.proj.bias") and the next load would double-map them | |
| # (e.g. → "patch_embedding.proj.proj.bias"). | |
| reverse_mapping: dict = getattr(modules[module_name], "reverse_param_names_mapping", {}) | |
| tensor_state: dict[str, torch.Tensor] = {} | |
| for key, value in state_dict.items(): | |
| if not isinstance(value, torch.Tensor): | |
| raise TypeError(f"Expected tensor in state_dict " | |
| f"for {module_name}.{key}, " | |
| f"got {type(value).__name__}") | |
| if key in reverse_mapping: | |
| hf_key, merge_index, _ = reverse_mapping[key] | |
| if merge_index is not None: | |
| logger.warning( | |
| "Skipping reverse-mapping for merged param %s " | |
| "(merge_index=%s); saving under internal key.", | |
| key, | |
| merge_index, | |
| ) | |
| hf_key = key | |
| key = hf_key | |
| tensor_state[key] = value.detach().cpu() | |
| from safetensors.torch import save_file | |
| out_path = module_dir / "model.safetensors" | |
| logger.info( | |
| "Saving %s weights to %s (%s tensors)", | |
| module_name, | |
| out_path, | |
| len(tensor_state), | |
| ) | |
| save_file(tensor_state, str(out_path)) | |
| _barrier() | |
| return str(dst) | |
| def convert( | |
| *, | |
| checkpoint_dir: str, | |
| output_dir: str, | |
| config_path: str | None = None, | |
| role: str = "student", | |
| overwrite: bool = False, | |
| ) -> str: | |
| """Load a DCP checkpoint and export as a diffusers model. | |
| Returns the path to the exported model directory. | |
| """ | |
| _ensure_distributed() | |
| from fastvideo.distributed import ( | |
| maybe_init_distributed_environment_and_model_parallel, ) | |
| from fastvideo.train.utils.builder import build_from_config | |
| from fastvideo.train.utils.checkpoint import ( | |
| CheckpointManager, | |
| _resolve_resume_checkpoint, | |
| ) | |
| from fastvideo.train.utils.config import ( | |
| RunConfig, | |
| load_run_config, | |
| ) | |
| import torch.distributed.checkpoint as dcp | |
| # -- Resolve checkpoint directory -- | |
| resolved = _resolve_resume_checkpoint( | |
| checkpoint_dir, | |
| output_dir=checkpoint_dir, | |
| ) | |
| if resolved is None: | |
| raise FileNotFoundError(f"Could not resolve checkpoint directory from {checkpoint_dir!r}") | |
| dcp_dir = resolved / "dcp" | |
| if not dcp_dir.is_dir(): | |
| raise FileNotFoundError(f"Missing dcp/ under {resolved}") | |
| # -- Obtain config -- | |
| cfg: RunConfig | |
| if config_path is not None: | |
| cfg = load_run_config(config_path) | |
| else: | |
| metadata = CheckpointManager.load_metadata(resolved) | |
| raw_config = metadata.get("config") | |
| if raw_config is None: | |
| raise ValueError("Checkpoint metadata.json does not " | |
| "contain 'config'. Pass --config " | |
| "explicitly.") | |
| cfg = _run_config_from_raw(raw_config) | |
| tc = cfg.training | |
| # -- Init distributed (1 GPU is enough; DCP reshards) -- | |
| maybe_init_distributed_environment_and_model_parallel( | |
| tp_size=1, | |
| sp_size=1, | |
| ) | |
| # Override distributed config so model loading uses 1 GPU. | |
| tc.distributed.tp_size = 1 | |
| tc.distributed.sp_size = 1 | |
| tc.distributed.num_gpus = 1 | |
| tc.distributed.hsdp_replicate_dim = 1 | |
| tc.distributed.hsdp_shard_dim = 1 | |
| # -- Build model (loads pretrained weights + FSDP) -- | |
| _, method, _, _ = build_from_config(cfg) | |
| # -- Load DCP weights into the model -- | |
| states = method.checkpoint_state() | |
| logger.info( | |
| "Loading DCP checkpoint from %s", | |
| resolved, | |
| ) | |
| dcp.load(states, checkpoint_id=str(dcp_dir)) | |
| # -- Export to diffusers format -- | |
| model = method._role_models[role] | |
| base_model_path = str(tc.model_path) | |
| if not base_model_path: | |
| raise ValueError("Cannot determine base_model_path from " | |
| "config. Ensure models.student.init_from " | |
| "is set.") | |
| logger.info( | |
| "Exporting role=%s to %s (base=%s)", | |
| role, | |
| output_dir, | |
| base_model_path, | |
| ) | |
| result = _save_role_pretrained( | |
| role=role, | |
| base_model_path=base_model_path, | |
| output_dir=output_dir, | |
| overwrite=overwrite, | |
| model=model, | |
| ) | |
| logger.info("Export complete: %s", result) | |
| return result | |
| def _run_config_from_raw(raw: dict[str, Any], ) -> Any: | |
| """Reconstruct a RunConfig from a raw config dict. | |
| This mirrors ``load_run_config`` but operates on an | |
| already-parsed dict (from metadata.json) instead of | |
| reading from a YAML file. | |
| """ | |
| from fastvideo.train.utils.config import ( | |
| RunConfig, | |
| _build_training_config, | |
| _parse_pipeline_config, | |
| _require_mapping, | |
| _require_str, | |
| ) | |
| models_raw = _require_mapping( | |
| raw.get("models"), | |
| where="models", | |
| ) | |
| models: dict[str, dict[str, Any]] = {} | |
| for role_key, model_cfg_raw in models_raw.items(): | |
| role_str = _require_str( | |
| role_key, | |
| where="models.<role>", | |
| ) | |
| model_cfg = _require_mapping( | |
| model_cfg_raw, | |
| where=f"models.{role_str}", | |
| ) | |
| models[role_str] = dict(model_cfg) | |
| method_raw = _require_mapping( | |
| raw.get("method"), | |
| where="method", | |
| ) | |
| method = dict(method_raw) | |
| callbacks_raw = raw.get("callbacks") | |
| callbacks: dict[str, dict[str, Any]] = (_require_mapping( | |
| callbacks_raw, | |
| where="callbacks", | |
| ) if callbacks_raw is not None else {}) | |
| pipeline_config = _parse_pipeline_config( | |
| raw, | |
| models=models, | |
| ) | |
| training_raw = _require_mapping( | |
| raw.get("training"), | |
| where="training", | |
| ) | |
| t = dict(training_raw) | |
| training = _build_training_config( | |
| t, | |
| models=models, | |
| pipeline_config=pipeline_config, | |
| ) | |
| return RunConfig( | |
| models=models, | |
| method=method, | |
| training=training, | |
| callbacks=callbacks, | |
| raw=raw, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description=("Convert a DCP training checkpoint to a " | |
| "diffusers-style model directory. " | |
| "Only 1 GPU needed (DCP reshards " | |
| "automatically)."), ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| required=True, | |
| help=("Path to checkpoint-<step> dir, its dcp/ " | |
| "subdir, or an output_dir (auto-picks " | |
| "latest)."), | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| required=True, | |
| help="Destination for the diffusers model.", | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default=None, | |
| help=("Training YAML config. If omitted, read " | |
| "from checkpoint metadata.json."), | |
| ) | |
| parser.add_argument( | |
| "--role", | |
| type=str, | |
| default="student", | |
| help="Role to export (default: student).", | |
| ) | |
| parser.add_argument( | |
| "--overwrite", | |
| action="store_true", | |
| help="Overwrite output-dir if it exists.", | |
| ) | |
| args = parser.parse_args(sys.argv[1:]) | |
| convert( | |
| checkpoint_dir=args.checkpoint, | |
| output_dir=args.output_dir, | |
| config_path=args.config, | |
| role=args.role, | |
| overwrite=args.overwrite, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |