| | """
|
| | MatFuse Pipeline for diffusers.
|
| |
|
| | A custom diffusers pipeline for generating PBR material maps using the MatFuse model.
|
| |
|
| | Note: This pipeline uses:
|
| | - Standard UNet2DConditionModel from diffusers (with custom in/out channels config)
|
| | - Custom MatFuseVQModel (required because MatFuse uses 4 separate encoders/quantizers)
|
| | """
|
| |
|
| | import os
|
| | import inspect
|
| | from typing import Optional, Union, List, Callable, Dict, Any, Tuple
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| | from PIL import Image
|
| | import numpy as np
|
| |
|
| | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| | from diffusers.models import UNet2DConditionModel
|
| | from diffusers.schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler
|
| |
|
| | try:
|
| | from vae_matfuse import MatFuseVQModel
|
| | except ImportError:
|
| | from diffusers.models.modeling_utils import ModelMixin as MatFuseVQModel
|
| | try:
|
| | from condition_encoders import MultiConditionEncoder
|
| | except ImportError:
|
| | from diffusers.models.modeling_utils import ModelMixin as MultiConditionEncoder
|
| |
|
| |
|
| | class MatFusePipeline(DiffusionPipeline):
|
| | """
|
| | Pipeline for generating PBR material maps using MatFuse.
|
| |
|
| | This pipeline generates 4 material maps (diffuse, normal, roughness, specular)
|
| | from various conditioning inputs like reference images, text, sketches, and color palettes.
|
| |
|
| | Args:
|
| | vae: MatFuseVQModel for encoding/decoding material maps (custom, required).
|
| | unet: UNet2DConditionModel for denoising (standard diffusers model).
|
| | scheduler: Diffusion scheduler.
|
| | condition_encoder: Multi-condition encoder for processing inputs.
|
| |
|
| | Note:
|
| | The VQ-VAE must be the custom MatFuseVQModel because MatFuse uses 4 separate
|
| | encoders and quantizers (one per material map type). The UNet can be the
|
| | standard diffusers UNet2DConditionModel configured with:
|
| | - in_channels=16 (12 latent + 4 sketch concat)
|
| | - out_channels=12 (4 maps × 3 channels)
|
| | - cross_attention_dim=512
|
| | """
|
| |
|
| | model_cpu_offload_seq = "condition_encoder->unet->vae"
|
| | _optional_components = ["condition_encoder"]
|
| |
|
| | def __init__(
|
| | self,
|
| | vae: MatFuseVQModel,
|
| | unet: UNet2DConditionModel,
|
| | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler],
|
| | condition_encoder: Optional[MultiConditionEncoder] = None,
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.register_modules(
|
| | vae=vae,
|
| | unet=unet,
|
| | scheduler=scheduler,
|
| | condition_encoder=condition_encoder,
|
| | )
|
| |
|
| | self.vae_scale_factor = 8
|
| |
|
| | @classmethod
|
| | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| | """
|
| | Load the MatFuse pipeline from a local directory.
|
| |
|
| | Loads each component (UNet, VAE, scheduler, condition_encoder) individually
|
| | from their respective subdirectories.
|
| |
|
| | Args:
|
| | pretrained_model_name_or_path: Path to the directory containing the model components.
|
| | **kwargs: Additional keyword arguments (e.g., torch_dtype).
|
| | """
|
| | model_dir = pretrained_model_name_or_path
|
| | torch_dtype = kwargs.get("torch_dtype", None)
|
| |
|
| |
|
| | unet = UNet2DConditionModel.from_pretrained(
|
| | os.path.join(model_dir, "unet"),
|
| | torch_dtype=torch_dtype,
|
| | )
|
| |
|
| |
|
| | vae = MatFuseVQModel.from_pretrained(
|
| | os.path.join(model_dir, "vae"),
|
| | torch_dtype=torch_dtype,
|
| | )
|
| |
|
| |
|
| | scheduler = DDIMScheduler.from_pretrained(
|
| | os.path.join(model_dir, "scheduler"),
|
| | )
|
| |
|
| |
|
| | cond_dir = os.path.join(model_dir, "condition_encoder")
|
| | condition_encoder = None
|
| | if os.path.isdir(cond_dir):
|
| | condition_encoder = MultiConditionEncoder.from_pretrained(
|
| | cond_dir,
|
| | torch_dtype=torch_dtype,
|
| | )
|
| |
|
| | return cls(
|
| | vae=vae,
|
| | unet=unet,
|
| | scheduler=scheduler,
|
| | condition_encoder=condition_encoder,
|
| | )
|
| |
|
| | @property
|
| | def _execution_device(self):
|
| | if self.device != torch.device("meta"):
|
| | return self.device
|
| | for name, model in self.components.items():
|
| | if isinstance(model, torch.nn.Module):
|
| | return next(model.parameters()).device
|
| |
|
| | if self.condition_encoder is not None:
|
| | return next(self.condition_encoder.parameters()).device
|
| | return torch.device("cpu")
|
| |
|
| | def to(self, *args, **kwargs):
|
| | """Override to() to also move condition_encoder (not auto-tracked by diffusers)."""
|
| | result = super().to(*args, **kwargs)
|
| | if self.condition_encoder is not None:
|
| | self.condition_encoder = self.condition_encoder.to(*args, **kwargs)
|
| | return result
|
| |
|
| | def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| | """Decode latents to material maps."""
|
| |
|
| | latents = F.pad(latents, (7, 7, 7, 7), mode="circular")
|
| |
|
| |
|
| | needs_upcast = latents.dtype == torch.float16
|
| | if needs_upcast:
|
| | self.vae.to(dtype=torch.float32)
|
| | latents = latents.float()
|
| |
|
| |
|
| | materials = self.vae.decode(latents)
|
| |
|
| | if needs_upcast:
|
| | self.vae.to(dtype=torch.float16)
|
| | materials = materials.half()
|
| |
|
| |
|
| | _, _, h, w = materials.shape
|
| | target_h = (h - 14 * self.vae_scale_factor)
|
| | target_w = (w - 14 * self.vae_scale_factor)
|
| | start_h = (h - target_h) // 2
|
| | start_w = (w - target_w) // 2
|
| | materials = materials[:, :, start_h:start_h + target_h, start_w:start_w + target_w]
|
| |
|
| | return materials
|
| |
|
| | def prepare_latents(
|
| | self,
|
| | batch_size: int,
|
| | num_channels_latents: int,
|
| | height: int,
|
| | width: int,
|
| | dtype: torch.dtype,
|
| | device: torch.device,
|
| | generator: Optional[torch.Generator] = None,
|
| | latents: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """Prepare initial noise latents."""
|
| | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| |
|
| | if latents is None:
|
| | latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| | else:
|
| | if latents.shape != shape:
|
| | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| | latents = latents.to(device)
|
| |
|
| |
|
| | latents = latents * self.scheduler.init_noise_sigma
|
| |
|
| | return latents
|
| |
|
| | def prepare_extra_step_kwargs(self, generator: Optional[torch.Generator], eta: float) -> Dict[str, Any]:
|
| | """Prepare extra kwargs for the scheduler step."""
|
| | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| | extra_step_kwargs = {}
|
| | if accepts_eta:
|
| | extra_step_kwargs["eta"] = eta
|
| |
|
| | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| | if accepts_generator:
|
| | extra_step_kwargs["generator"] = generator
|
| |
|
| | return extra_step_kwargs
|
| |
|
| | def _encode_conditions(
|
| | self,
|
| | image: Optional[torch.Tensor] = None,
|
| | text: Optional[Union[str, List[str]]] = None,
|
| | sketch: Optional[torch.Tensor] = None,
|
| | palette: Optional[torch.Tensor] = None,
|
| | batch_size: int = 1,
|
| | image_size: int = 256,
|
| | device: torch.device = None,
|
| | dtype: torch.dtype = None,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """Encode all condition inputs through their respective encoders.
|
| |
|
| | When a condition is not provided, the encoder creates a placeholder
|
| | and encodes it (matching training behavior), rather than using zero tensors.
|
| | """
|
| | device = device or self._execution_device
|
| |
|
| | if self.condition_encoder is not None:
|
| | cond = self.condition_encoder(
|
| | image_embed=image,
|
| | text=text,
|
| | sketch=sketch,
|
| | palette=palette,
|
| | batch_size=batch_size,
|
| | image_size=image_size,
|
| | device=device,
|
| | )
|
| | c_crossattn = cond["c_crossattn"]
|
| | c_concat = cond["c_concat"]
|
| | else:
|
| | c_crossattn = None
|
| | c_concat = None
|
| |
|
| |
|
| | if c_crossattn is not None:
|
| | c_crossattn = c_crossattn.to(dtype=dtype, device=device)
|
| | if c_concat is not None:
|
| | c_concat = c_concat.to(dtype=dtype, device=device)
|
| |
|
| | return c_crossattn, c_concat
|
| |
|
| | def _get_uncond_embeddings(
|
| | self,
|
| | batch_size: int,
|
| | image_size: int,
|
| | device: torch.device,
|
| | dtype: torch.dtype,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """Get unconditional embeddings for classifier-free guidance.
|
| |
|
| | Creates proper unconditional embeddings by encoding placeholder inputs
|
| | through the actual encoders (gray image → CLIP, empty string → SentenceTransformer,
|
| | zero palette → PaletteEncoder, zero sketch → SketchEncoder).
|
| |
|
| | This matches the original training behavior where ucg_training drops conditions
|
| | by setting them to val=0.0 (images/palette/sketch) or val="" (text), and then
|
| | encoding those placeholder values through the encoders.
|
| | """
|
| | if self.condition_encoder is not None:
|
| | uc = self.condition_encoder.get_unconditional_conditioning(
|
| | batch_size=batch_size,
|
| | image_size=image_size,
|
| | device=device,
|
| | )
|
| | uc_crossattn = uc["c_crossattn"].to(dtype=dtype, device=device)
|
| | uc_concat = uc["c_concat"].to(dtype=dtype, device=device)
|
| | else:
|
| | uc_crossattn = None
|
| | uc_concat = None
|
| |
|
| | return uc_crossattn, uc_concat
|
| |
|
| | @torch.no_grad()
|
| | def __call__(
|
| | self,
|
| | image: Optional[Union[torch.Tensor, Image.Image]] = None,
|
| | text: Optional[Union[str, List[str]]] = None,
|
| | sketch: Optional[Union[torch.Tensor, Image.Image]] = None,
|
| | palette: Optional[Union[torch.Tensor, np.ndarray, List[Tuple[int, int, int]]]] = None,
|
| | height: int = 256,
|
| | width: int = 256,
|
| | num_inference_steps: int = 50,
|
| | guidance_scale: float = 7.5,
|
| | num_images_per_prompt: int = 1,
|
| | eta: float = 0.0,
|
| | generator: Optional[torch.Generator] = None,
|
| | latents: Optional[torch.Tensor] = None,
|
| | output_type: str = "pil",
|
| | return_dict: bool = True,
|
| | callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| | callback_steps: int = 1,
|
| | ) -> Dict[str, Any]:
|
| | """
|
| | Generate PBR material maps.
|
| |
|
| | Args:
|
| | image: Reference image for style/appearance guidance.
|
| | text: Text description of the material.
|
| | sketch: Binary edge/sketch map for structure guidance.
|
| | palette: Color palette (5 colors) for color guidance.
|
| | height: Output image height.
|
| | width: Output image width.
|
| | num_inference_steps: Number of denoising steps.
|
| | guidance_scale: Classifier-free guidance scale.
|
| | num_images_per_prompt: Number of images to generate per prompt.
|
| | eta: DDIM eta parameter.
|
| | generator: Random number generator for reproducibility.
|
| | latents: Pre-generated noise latents.
|
| | output_type: Output format ("pil", "tensor", "np").
|
| | return_dict: Whether to return a dict.
|
| | callback: Callback function called every `callback_steps` steps.
|
| | callback_steps: Frequency of callback calls.
|
| |
|
| | Returns:
|
| | Dictionary containing:
|
| | - images: List of generated images (4 maps per generation).
|
| | - diffuse: Diffuse/albedo maps.
|
| | - normal: Normal maps.
|
| | - roughness: Roughness maps.
|
| | - specular: Specular maps.
|
| | """
|
| | device = self._execution_device
|
| | dtype = self.unet.dtype if hasattr(self.unet, 'dtype') else torch.float32
|
| |
|
| |
|
| | if text is not None and isinstance(text, str):
|
| | batch_size = 1
|
| | elif text is not None:
|
| | batch_size = len(text)
|
| | else:
|
| | batch_size = 1
|
| |
|
| | batch_size = batch_size * num_images_per_prompt
|
| |
|
| |
|
| | if image is not None and isinstance(image, Image.Image):
|
| | image = self._preprocess_image(image, device, dtype)
|
| |
|
| | if sketch is not None and isinstance(sketch, Image.Image):
|
| | sketch = self._preprocess_sketch(sketch, height, width, device, dtype)
|
| |
|
| | if palette is not None and not isinstance(palette, torch.Tensor):
|
| | palette = self._preprocess_palette(palette, device, dtype)
|
| |
|
| |
|
| |
|
| |
|
| | c_crossattn, c_concat = self._encode_conditions(
|
| | image=image,
|
| | text=text,
|
| | sketch=sketch,
|
| | palette=palette,
|
| | batch_size=batch_size,
|
| | image_size=height,
|
| | device=device,
|
| | dtype=dtype,
|
| | )
|
| |
|
| |
|
| |
|
| | do_classifier_free_guidance = guidance_scale > 1.0
|
| | if do_classifier_free_guidance:
|
| | uc_crossattn, uc_concat = self._get_uncond_embeddings(
|
| | batch_size, height, device, dtype
|
| | )
|
| |
|
| |
|
| | self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| | timesteps = self.scheduler.timesteps
|
| |
|
| |
|
| | num_channels_latents = 12
|
| | latents = self.prepare_latents(
|
| | batch_size,
|
| | num_channels_latents,
|
| | height,
|
| | width,
|
| | dtype,
|
| | device,
|
| | generator,
|
| | latents,
|
| | )
|
| |
|
| |
|
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| |
|
| |
|
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| |
|
| | with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| | for i, t in enumerate(timesteps):
|
| |
|
| | if do_classifier_free_guidance:
|
| |
|
| | latent_uncond = torch.cat([latents, uc_concat], dim=1)
|
| | latent_cond = torch.cat([latents, c_concat], dim=1)
|
| | latent_model_input = torch.cat([latent_uncond, latent_cond])
|
| | if c_crossattn is not None:
|
| | encoder_hidden_states = torch.cat([uc_crossattn, c_crossattn])
|
| | else:
|
| | encoder_hidden_states = None
|
| | else:
|
| | latent_model_input = torch.cat([latents, c_concat], dim=1)
|
| | encoder_hidden_states = c_crossattn
|
| |
|
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| |
|
| |
|
| | noise_pred = self.unet(
|
| | latent_model_input,
|
| | t,
|
| | encoder_hidden_states=encoder_hidden_states,
|
| | return_dict=False,
|
| | )
|
| |
|
| | if isinstance(noise_pred, tuple):
|
| | noise_pred = noise_pred[0]
|
| | elif isinstance(noise_pred, dict):
|
| | noise_pred = noise_pred["sample"]
|
| |
|
| |
|
| | if do_classifier_free_guidance:
|
| | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| |
|
| |
|
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| |
|
| |
|
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| | progress_bar.update()
|
| | if callback is not None and i % callback_steps == 0:
|
| | callback(i, t, latents)
|
| |
|
| |
|
| | materials = self.decode_latents(latents)
|
| |
|
| |
|
| | diffuse = materials[:, 0:3]
|
| | normal = materials[:, 3:6]
|
| | roughness = materials[:, 6:9]
|
| | specular = materials[:, 9:12]
|
| |
|
| |
|
| | if output_type == "pil":
|
| | diffuse = self._tensor_to_pil(diffuse)
|
| | normal = self._tensor_to_pil(normal)
|
| | roughness = self._tensor_to_pil(roughness)
|
| | specular = self._tensor_to_pil(specular)
|
| | elif output_type == "np":
|
| | diffuse = self._tensor_to_numpy(diffuse)
|
| | normal = self._tensor_to_numpy(normal)
|
| | roughness = self._tensor_to_numpy(roughness)
|
| | specular = self._tensor_to_numpy(specular)
|
| |
|
| | if return_dict:
|
| | return {
|
| | "diffuse": diffuse,
|
| | "normal": normal,
|
| | "roughness": roughness,
|
| | "specular": specular,
|
| | }
|
| |
|
| | return (diffuse, normal, roughness, specular)
|
| |
|
| | def _preprocess_image(self, image: Image.Image, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| | """Preprocess PIL image to tensor."""
|
| | image = image.convert("RGB")
|
| | image = np.array(image).astype(np.float32) / 255.0
|
| | image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
| | image = image * 2.0 - 1.0
|
| | return image.to(device=device, dtype=dtype)
|
| |
|
| | def _preprocess_sketch(
|
| | self,
|
| | sketch: Image.Image,
|
| | height: int,
|
| | width: int,
|
| | device: torch.device,
|
| | dtype: torch.dtype,
|
| | ) -> torch.Tensor:
|
| | """Preprocess sketch image to tensor."""
|
| | sketch = sketch.convert("L")
|
| | sketch = sketch.resize((width, height), Image.BILINEAR)
|
| | sketch = np.array(sketch).astype(np.float32) / 255.0
|
| | sketch = torch.from_numpy(sketch).unsqueeze(0).unsqueeze(0)
|
| | return sketch.to(device=device, dtype=dtype)
|
| |
|
| | def _preprocess_palette(
|
| | self,
|
| | palette: Union[np.ndarray, List[Tuple[int, int, int]]],
|
| | device: torch.device,
|
| | dtype: torch.dtype,
|
| | ) -> torch.Tensor:
|
| | """Preprocess color palette to tensor."""
|
| | if isinstance(palette, list):
|
| | palette = np.array(palette, dtype=np.float32) / 255.0
|
| | elif isinstance(palette, np.ndarray):
|
| | if palette.max() > 1.0:
|
| | palette = palette.astype(np.float32) / 255.0
|
| | else:
|
| | palette = palette.astype(np.float32)
|
| |
|
| |
|
| | while len(palette) < 5:
|
| | palette = np.concatenate([palette, palette[-1:]], axis=0)
|
| | palette = palette[:5]
|
| |
|
| | palette = torch.from_numpy(palette).unsqueeze(0)
|
| | return palette.to(device=device, dtype=dtype)
|
| |
|
| | def _tensor_to_pil(self, tensor: torch.Tensor) -> List[Image.Image]:
|
| | """Convert tensor to list of PIL images."""
|
| | tensor = (tensor + 1.0) / 2.0
|
| | tensor = tensor.clamp(0, 1)
|
| | tensor = tensor.cpu().permute(0, 2, 3, 1).numpy()
|
| | tensor = (tensor * 255).astype(np.uint8)
|
| | return [Image.fromarray(img) for img in tensor]
|
| |
|
| | def _tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
|
| | """Convert tensor to numpy array."""
|
| | tensor = (tensor + 1.0) / 2.0
|
| | tensor = tensor.clamp(0, 1)
|
| | return tensor.cpu().permute(0, 2, 3, 1).numpy()
|
| |
|