import sys from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Union import numpy as np import torch from diffusers import DDIMScheduler, DiffusionPipeline from diffusers.utils import BaseOutput from PIL import Image _ROOT = Path(__file__).resolve().parent if str(_ROOT) not in sys.path: sys.path.insert(0, str(_ROOT)) # Register alias for cached custom-pipeline imports. sys.modules["pipeline"] = sys.modules[__name__] from modular_pipeline import load_components, resolve_model_root # noqa: E402 @dataclass class CRSDiffPipelineOutput(BaseOutput): images: List[Image.Image] class CRSDiffPipeline(DiffusionPipeline): def register_modules(self, **kwargs): for name, module in kwargs.items(): if module is None or ( isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None ): self.register_to_config(**{name: (None, None)}) setattr(self, name, module) elif _is_component_list(module): self.register_to_config(**{name: (module[0], module[1])}) setattr(self, name, module) else: from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple library, class_name = _fetch_class_library_tuple(module) self.register_to_config(**{name: (library, class_name)}) setattr(self, name, module) def __init__( self, crs_model=None, scheduler=None, scale_factor: float = 0.18215, model_path: Optional[Union[str, Path]] = None, _name_or_path: Optional[Union[str, Path]] = None, ): super().__init__() if _is_component_list(crs_model) or _is_component_list(scheduler): model_root = ( resolve_model_root(model_path) or resolve_model_root(_name_or_path) or resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None)) ) if model_root is None: raise ValueError( "CRS-Diff received config placeholders but could not resolve model path. " "Pass `model_path` or load via DiffusionPipeline.from_pretrained(, custom_pipeline=...)." ) loaded = load_components(model_root) crs_model = loaded["crs_model"] scheduler = loaded["scheduler"] scale_factor = loaded["scale_factor"] self.register_modules(crs_model=crs_model, scheduler=scheduler) self.vae_scale_factor = scale_factor @property def device(self) -> torch.device: params = list(self.crs_model.parameters()) if params: return params[0].device return torch.device("cpu") @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, Path], device: Optional[Union[str, torch.device]] = None, subfolder: Optional[str] = None, **kwargs, ) -> "CRSDiffPipeline": path = resolve_model_root(pretrained_model_name_or_path) if path is None: raise ValueError(f"Could not resolve CRS-Diff model root from: {pretrained_model_name_or_path}") subfolder = kwargs.pop("subfolder", subfolder) if subfolder == "scheduler": return DDIMScheduler.from_pretrained(path, subfolder="scheduler") loaded = load_components(path) pipe = cls(crs_model=loaded["crs_model"], scheduler=loaded["scheduler"], scale_factor=loaded["scale_factor"]) if device is not None: pipe = pipe.to(device) return pipe def _to_tensor(self, x, device: torch.device, dtype=torch.float32) -> torch.Tensor: if isinstance(x, np.ndarray): x = torch.from_numpy(x) if not isinstance(x, torch.Tensor): raise TypeError("Expected torch.Tensor or np.ndarray for conditioning inputs.") return x.to(device=device, dtype=dtype) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], local_control, global_control, metadata, negative_prompt: Union[str, List[str]] = "", num_inference_steps: int = 50, guidance_scale: float = 7.5, eta: float = 0.0, strength: float = 1.0, global_strength: float = 1.0, generator: Optional[torch.Generator] = None, output_type: str = "pil", ) -> CRSDiffPipelineOutput: device = self.device local_control = self._to_tensor(local_control, device=device) global_control = self._to_tensor(global_control, device=device) metadata = self._to_tensor(metadata, device=device) batch_size = local_control.shape[0] if isinstance(prompt, str): prompt = [prompt] * batch_size if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] * batch_size if metadata.dim() == 1: metadata = metadata.unsqueeze(0).repeat(batch_size, 1) cond = { "local_control": [local_control], "c_crossattn": [self.crs_model.get_learned_conditioning(prompt)], "global_control": [global_control], } un_cond = { "local_control": [local_control], "c_crossattn": [self.crs_model.get_learned_conditioning(negative_prompt)], "global_control": [torch.zeros_like(global_control)], } if hasattr(self.crs_model, "local_control_scales"): self.crs_model.local_control_scales = [strength] * 13 _, _, h, w = local_control.shape latents = torch.randn( (batch_size, self.crs_model.channels, h // 8, w // 8), generator=generator, device=device, ) latents = latents * self.scheduler.init_noise_sigma self.scheduler.set_timesteps(num_inference_steps, device=device) for t in self.scheduler.timesteps: ts = torch.full((batch_size,), int(t), device=device, dtype=torch.long) if guidance_scale > 1.0: noise_text = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength) noise_uncond = self.crs_model.apply_model(latents, ts, un_cond, metadata, global_strength) noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) else: noise_pred = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength) latents = self.scheduler.step( model_output=noise_pred, timestep=t, sample=latents, eta=eta, generator=generator, return_dict=True, ).prev_sample images = self.crs_model.decode_first_stage(latents) images = images.clamp(-1, 1) images = ((images + 1.0) / 2.0).permute(0, 2, 3, 1).cpu().numpy() images = (images * 255.0).clip(0, 255).astype(np.uint8) if output_type == "pil": images = [Image.fromarray(img) for img in images] elif output_type != "numpy": raise ValueError("output_type must be 'pil' or 'numpy'") return CRSDiffPipelineOutput(images=images) def _is_component_list(v): return ( isinstance(v, (list, tuple)) and len(v) == 2 and isinstance(v[0], str) and isinstance(v[1], str) )