| | import sys |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import List, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers import DDIMScheduler, DiffusionPipeline |
| | from diffusers.utils import BaseOutput |
| | from PIL import Image |
| |
|
| | _ROOT = Path(__file__).resolve().parent |
| | if str(_ROOT) not in sys.path: |
| | sys.path.insert(0, str(_ROOT)) |
| |
|
| | |
| | sys.modules["pipeline"] = sys.modules[__name__] |
| |
|
| | from modular_pipeline import load_components, resolve_model_root |
| |
|
| |
|
| | @dataclass |
| | class CRSDiffPipelineOutput(BaseOutput): |
| | images: List[Image.Image] |
| |
|
| |
|
| | class CRSDiffPipeline(DiffusionPipeline): |
| | def register_modules(self, **kwargs): |
| | for name, module in kwargs.items(): |
| | if module is None or ( |
| | isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None |
| | ): |
| | self.register_to_config(**{name: (None, None)}) |
| | setattr(self, name, module) |
| | elif _is_component_list(module): |
| | self.register_to_config(**{name: (module[0], module[1])}) |
| | setattr(self, name, module) |
| | else: |
| | from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple |
| |
|
| | library, class_name = _fetch_class_library_tuple(module) |
| | self.register_to_config(**{name: (library, class_name)}) |
| | setattr(self, name, module) |
| |
|
| | def __init__( |
| | self, |
| | crs_model=None, |
| | scheduler=None, |
| | scale_factor: float = 0.18215, |
| | model_path: Optional[Union[str, Path]] = None, |
| | _name_or_path: Optional[Union[str, Path]] = None, |
| | ): |
| | super().__init__() |
| | if _is_component_list(crs_model) or _is_component_list(scheduler): |
| | model_root = ( |
| | resolve_model_root(model_path) |
| | or resolve_model_root(_name_or_path) |
| | or resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None)) |
| | ) |
| | if model_root is None: |
| | raise ValueError( |
| | "CRS-Diff received config placeholders but could not resolve model path. " |
| | "Pass `model_path` or load via DiffusionPipeline.from_pretrained(<path>, custom_pipeline=...)." |
| | ) |
| | loaded = load_components(model_root) |
| | crs_model = loaded["crs_model"] |
| | scheduler = loaded["scheduler"] |
| | scale_factor = loaded["scale_factor"] |
| |
|
| | self.register_modules(crs_model=crs_model, scheduler=scheduler) |
| | self.vae_scale_factor = scale_factor |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | params = list(self.crs_model.parameters()) |
| | if params: |
| | return params[0].device |
| | return torch.device("cpu") |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: Union[str, Path], |
| | device: Optional[Union[str, torch.device]] = None, |
| | subfolder: Optional[str] = None, |
| | **kwargs, |
| | ) -> "CRSDiffPipeline": |
| | path = resolve_model_root(pretrained_model_name_or_path) |
| | if path is None: |
| | raise ValueError(f"Could not resolve CRS-Diff model root from: {pretrained_model_name_or_path}") |
| |
|
| | subfolder = kwargs.pop("subfolder", subfolder) |
| | if subfolder == "scheduler": |
| | return DDIMScheduler.from_pretrained(path, subfolder="scheduler") |
| |
|
| | loaded = load_components(path) |
| | pipe = cls(crs_model=loaded["crs_model"], scheduler=loaded["scheduler"], scale_factor=loaded["scale_factor"]) |
| | if device is not None: |
| | pipe = pipe.to(device) |
| | return pipe |
| |
|
| | def _to_tensor(self, x, device: torch.device, dtype=torch.float32) -> torch.Tensor: |
| | if isinstance(x, np.ndarray): |
| | x = torch.from_numpy(x) |
| | if not isinstance(x, torch.Tensor): |
| | raise TypeError("Expected torch.Tensor or np.ndarray for conditioning inputs.") |
| | return x.to(device=device, dtype=dtype) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]], |
| | local_control, |
| | global_control, |
| | metadata, |
| | negative_prompt: Union[str, List[str]] = "", |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | eta: float = 0.0, |
| | strength: float = 1.0, |
| | global_strength: float = 1.0, |
| | generator: Optional[torch.Generator] = None, |
| | output_type: str = "pil", |
| | ) -> CRSDiffPipelineOutput: |
| | device = self.device |
| | local_control = self._to_tensor(local_control, device=device) |
| | global_control = self._to_tensor(global_control, device=device) |
| | metadata = self._to_tensor(metadata, device=device) |
| |
|
| | batch_size = local_control.shape[0] |
| | if isinstance(prompt, str): |
| | prompt = [prompt] * batch_size |
| | if isinstance(negative_prompt, str): |
| | negative_prompt = [negative_prompt] * batch_size |
| |
|
| | if metadata.dim() == 1: |
| | metadata = metadata.unsqueeze(0).repeat(batch_size, 1) |
| |
|
| | cond = { |
| | "local_control": [local_control], |
| | "c_crossattn": [self.crs_model.get_learned_conditioning(prompt)], |
| | "global_control": [global_control], |
| | } |
| | un_cond = { |
| | "local_control": [local_control], |
| | "c_crossattn": [self.crs_model.get_learned_conditioning(negative_prompt)], |
| | "global_control": [torch.zeros_like(global_control)], |
| | } |
| |
|
| | if hasattr(self.crs_model, "local_control_scales"): |
| | self.crs_model.local_control_scales = [strength] * 13 |
| |
|
| | _, _, h, w = local_control.shape |
| | latents = torch.randn( |
| | (batch_size, self.crs_model.channels, h // 8, w // 8), |
| | generator=generator, |
| | device=device, |
| | ) |
| | latents = latents * self.scheduler.init_noise_sigma |
| |
|
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | for t in self.scheduler.timesteps: |
| | ts = torch.full((batch_size,), int(t), device=device, dtype=torch.long) |
| | if guidance_scale > 1.0: |
| | noise_text = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength) |
| | noise_uncond = self.crs_model.apply_model(latents, ts, un_cond, metadata, global_strength) |
| | noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) |
| | else: |
| | noise_pred = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength) |
| |
|
| | latents = self.scheduler.step( |
| | model_output=noise_pred, |
| | timestep=t, |
| | sample=latents, |
| | eta=eta, |
| | generator=generator, |
| | return_dict=True, |
| | ).prev_sample |
| |
|
| | images = self.crs_model.decode_first_stage(latents) |
| | images = images.clamp(-1, 1) |
| | images = ((images + 1.0) / 2.0).permute(0, 2, 3, 1).cpu().numpy() |
| | images = (images * 255.0).clip(0, 255).astype(np.uint8) |
| |
|
| | if output_type == "pil": |
| | images = [Image.fromarray(img) for img in images] |
| | elif output_type != "numpy": |
| | raise ValueError("output_type must be 'pil' or 'numpy'") |
| |
|
| | return CRSDiffPipelineOutput(images=images) |
| |
|
| |
|
| | def _is_component_list(v): |
| | return ( |
| | isinstance(v, (list, tuple)) |
| | and len(v) == 2 |
| | and isinstance(v[0], str) |
| | and isinstance(v[1], str) |
| | ) |
| |
|