stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
from __future__ import annotations
from dataclasses import dataclass
import json
from pathlib import Path
@dataclass(frozen=True)
class CheckpointLayout:
root: Path
transformer_ckpt: Path
vae_ckpt: Path
text_encoder_ckpt: Path
def _must_exist(path: Path, kind: str) -> Path:
if not (path.exists() or path.is_symlink()):
raise FileNotFoundError(f"Missing {kind}: {path}")
return path
def _find_single_entry(directory: Path, kind: str, *, expect_dir: bool) -> Path:
_must_exist(directory, f"{kind} directory")
entries = sorted(p for p in directory.iterdir() if not p.name.startswith("."))
if len(entries) != 1:
raise FileNotFoundError(f"Expected exactly one entry in {directory} for {kind}, found {len(entries)}")
entry = entries[0]
if expect_dir:
if not (entry.is_dir() or entry.is_symlink()):
raise FileNotFoundError(f"Expected directory-like entry for {kind}: {entry}")
else:
if not (entry.is_file() or entry.is_symlink()):
raise FileNotFoundError(f"Expected file-like entry for {kind}: {entry}")
return entry
def resolve_checkpoint_layout(root: str | Path) -> CheckpointLayout:
root_path = Path(root).expanduser().resolve()
transformer_ckpt = str(root_path / "transformer" / "transformer.pth")
vae_ckpt = _find_single_entry(root_path / "vae", "vae checkpoint", expect_dir=False)
text_encoder_ckpt = _must_exist(root_path / "JoyAI-Image-Und", "text encoder checkpoint directory")
if not text_encoder_ckpt.is_dir():
raise FileNotFoundError(f"Expected text encoder checkpoint directory: {text_encoder_ckpt}")
return CheckpointLayout(
root=root_path,
transformer_ckpt=transformer_ckpt,
vae_ckpt=vae_ckpt,
text_encoder_ckpt=text_encoder_ckpt,
)
def build_manifest(layout: CheckpointLayout) -> dict[str, str]:
return {
"root": str(layout.root),
"transformer_ckpt": str(layout.transformer_ckpt),
"vae_ckpt": str(layout.vae_ckpt),
"text_encoder_ckpt": str(layout.text_encoder_ckpt),
}
def write_manifest(layout: CheckpointLayout, output_path: str | Path) -> Path:
output = Path(output_path)
output.write_text(json.dumps(build_manifest(layout), indent=2) + "\n", encoding="utf-8")
return output