| | from __future__ import annotations |
| |
|
| | import json |
| | from pathlib import Path |
| | from typing import Any, Dict |
| |
|
| | import torch |
| | from safetensors.torch import load_file as load_safetensors |
| |
|
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| |
|
| | try: |
| | from .transformer.qae import VQModel |
| | except ImportError: |
| | from transformer.qae import VQModel |
| |
|
| |
|
| | class BitDanceImageNetAutoencoder(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__(self, ddconfig: Dict[str, Any], num_codebooks: int = 4): |
| | super().__init__() |
| | self.runtime_model = VQModel(ddconfig, num_codebooks) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): |
| | del args, kwargs |
| | model_dir = Path(pretrained_model_name_or_path) |
| | config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) |
| | model = cls(ddconfig=config["ddconfig"], num_codebooks=int(config.get("num_codebooks", 4))) |
| | state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") |
| | model.runtime_model.load_state_dict(state, strict=True) |
| | model.eval() |
| | return model |
| |
|
| | def encode(self, x: torch.Tensor): |
| | return self.runtime_model.encode(x) |
| |
|
| | def decode(self, z: torch.Tensor): |
| | return self.runtime_model.decode(z) |
| |
|
| | def forward(self, z: torch.Tensor): |
| | return self.decode(z) |
| |
|