CRS-Diff / pipeline.py
BiliSakura's picture
Add files using upload-large-folder tool
b6acc0a verified
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from diffusers import DDIMScheduler, DiffusionPipeline
from diffusers.utils import BaseOutput
from PIL import Image
_ROOT = Path(__file__).resolve().parent
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
# Register alias for cached custom-pipeline imports.
sys.modules["pipeline"] = sys.modules[__name__]
from modular_pipeline import load_components, resolve_model_root # noqa: E402
@dataclass
class CRSDiffPipelineOutput(BaseOutput):
images: List[Image.Image]
class CRSDiffPipeline(DiffusionPipeline):
def register_modules(self, **kwargs):
for name, module in kwargs.items():
if module is None or (
isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None
):
self.register_to_config(**{name: (None, None)})
setattr(self, name, module)
elif _is_component_list(module):
self.register_to_config(**{name: (module[0], module[1])})
setattr(self, name, module)
else:
from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple
library, class_name = _fetch_class_library_tuple(module)
self.register_to_config(**{name: (library, class_name)})
setattr(self, name, module)
def __init__(
self,
crs_model=None,
scheduler=None,
scale_factor: float = 0.18215,
model_path: Optional[Union[str, Path]] = None,
_name_or_path: Optional[Union[str, Path]] = None,
):
super().__init__()
if _is_component_list(crs_model) or _is_component_list(scheduler):
model_root = (
resolve_model_root(model_path)
or resolve_model_root(_name_or_path)
or resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None))
)
if model_root is None:
raise ValueError(
"CRS-Diff received config placeholders but could not resolve model path. "
"Pass `model_path` or load via DiffusionPipeline.from_pretrained(<path>, custom_pipeline=...)."
)
loaded = load_components(model_root)
crs_model = loaded["crs_model"]
scheduler = loaded["scheduler"]
scale_factor = loaded["scale_factor"]
self.register_modules(crs_model=crs_model, scheduler=scheduler)
self.vae_scale_factor = scale_factor
@property
def device(self) -> torch.device:
params = list(self.crs_model.parameters())
if params:
return params[0].device
return torch.device("cpu")
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, Path],
device: Optional[Union[str, torch.device]] = None,
subfolder: Optional[str] = None,
**kwargs,
) -> "CRSDiffPipeline":
path = resolve_model_root(pretrained_model_name_or_path)
if path is None:
raise ValueError(f"Could not resolve CRS-Diff model root from: {pretrained_model_name_or_path}")
subfolder = kwargs.pop("subfolder", subfolder)
if subfolder == "scheduler":
return DDIMScheduler.from_pretrained(path, subfolder="scheduler")
loaded = load_components(path)
pipe = cls(crs_model=loaded["crs_model"], scheduler=loaded["scheduler"], scale_factor=loaded["scale_factor"])
if device is not None:
pipe = pipe.to(device)
return pipe
def _to_tensor(self, x, device: torch.device, dtype=torch.float32) -> torch.Tensor:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if not isinstance(x, torch.Tensor):
raise TypeError("Expected torch.Tensor or np.ndarray for conditioning inputs.")
return x.to(device=device, dtype=dtype)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
local_control,
global_control,
metadata,
negative_prompt: Union[str, List[str]] = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
eta: float = 0.0,
strength: float = 1.0,
global_strength: float = 1.0,
generator: Optional[torch.Generator] = None,
output_type: str = "pil",
) -> CRSDiffPipelineOutput:
device = self.device
local_control = self._to_tensor(local_control, device=device)
global_control = self._to_tensor(global_control, device=device)
metadata = self._to_tensor(metadata, device=device)
batch_size = local_control.shape[0]
if isinstance(prompt, str):
prompt = [prompt] * batch_size
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if metadata.dim() == 1:
metadata = metadata.unsqueeze(0).repeat(batch_size, 1)
cond = {
"local_control": [local_control],
"c_crossattn": [self.crs_model.get_learned_conditioning(prompt)],
"global_control": [global_control],
}
un_cond = {
"local_control": [local_control],
"c_crossattn": [self.crs_model.get_learned_conditioning(negative_prompt)],
"global_control": [torch.zeros_like(global_control)],
}
if hasattr(self.crs_model, "local_control_scales"):
self.crs_model.local_control_scales = [strength] * 13
_, _, h, w = local_control.shape
latents = torch.randn(
(batch_size, self.crs_model.channels, h // 8, w // 8),
generator=generator,
device=device,
)
latents = latents * self.scheduler.init_noise_sigma
self.scheduler.set_timesteps(num_inference_steps, device=device)
for t in self.scheduler.timesteps:
ts = torch.full((batch_size,), int(t), device=device, dtype=torch.long)
if guidance_scale > 1.0:
noise_text = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength)
noise_uncond = self.crs_model.apply_model(latents, ts, un_cond, metadata, global_strength)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
else:
noise_pred = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength)
latents = self.scheduler.step(
model_output=noise_pred,
timestep=t,
sample=latents,
eta=eta,
generator=generator,
return_dict=True,
).prev_sample
images = self.crs_model.decode_first_stage(latents)
images = images.clamp(-1, 1)
images = ((images + 1.0) / 2.0).permute(0, 2, 3, 1).cpu().numpy()
images = (images * 255.0).clip(0, 255).astype(np.uint8)
if output_type == "pil":
images = [Image.fromarray(img) for img in images]
elif output_type != "numpy":
raise ValueError("output_type must be 'pil' or 'numpy'")
return CRSDiffPipelineOutput(images=images)
def _is_component_list(v):
return (
isinstance(v, (list, tuple))
and len(v) == 2
and isinstance(v[0], str)
and isinstance(v[1], str)
)