MatFuse / pipeline_matfuse.py
gvecchio's picture
Add model
5b8131f
"""
MatFuse Pipeline for diffusers.
A custom diffusers pipeline for generating PBR material maps using the MatFuse model.
Note: This pipeline uses:
- Standard UNet2DConditionModel from diffusers (with custom in/out channels config)
- Custom MatFuseVQModel (required because MatFuse uses 4 separate encoders/quantizers)
"""
import os
import inspect
from typing import Optional, Union, List, Callable, Dict, Any, Tuple
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.models import UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler
try:
from vae_matfuse import MatFuseVQModel
except ImportError:
from diffusers.models.modeling_utils import ModelMixin as MatFuseVQModel
try:
from condition_encoders import MultiConditionEncoder
except ImportError:
from diffusers.models.modeling_utils import ModelMixin as MultiConditionEncoder
class MatFusePipeline(DiffusionPipeline):
"""
Pipeline for generating PBR material maps using MatFuse.
This pipeline generates 4 material maps (diffuse, normal, roughness, specular)
from various conditioning inputs like reference images, text, sketches, and color palettes.
Args:
vae: MatFuseVQModel for encoding/decoding material maps (custom, required).
unet: UNet2DConditionModel for denoising (standard diffusers model).
scheduler: Diffusion scheduler.
condition_encoder: Multi-condition encoder for processing inputs.
Note:
The VQ-VAE must be the custom MatFuseVQModel because MatFuse uses 4 separate
encoders and quantizers (one per material map type). The UNet can be the
standard diffusers UNet2DConditionModel configured with:
- in_channels=16 (12 latent + 4 sketch concat)
- out_channels=12 (4 maps × 3 channels)
- cross_attention_dim=512
"""
model_cpu_offload_seq = "condition_encoder->unet->vae"
_optional_components = ["condition_encoder"]
def __init__(
self,
vae: MatFuseVQModel,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler],
condition_encoder: Optional[MultiConditionEncoder] = None,
):
super().__init__()
self.register_modules(
vae=vae,
unet=unet,
scheduler=scheduler,
condition_encoder=condition_encoder,
)
self.vae_scale_factor = 8 # Downsampling factor of VQ-VAE
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load the MatFuse pipeline from a local directory.
Loads each component (UNet, VAE, scheduler, condition_encoder) individually
from their respective subdirectories.
Args:
pretrained_model_name_or_path: Path to the directory containing the model components.
**kwargs: Additional keyword arguments (e.g., torch_dtype).
"""
model_dir = pretrained_model_name_or_path
torch_dtype = kwargs.get("torch_dtype", None)
# Load UNet (standard diffusers)
unet = UNet2DConditionModel.from_pretrained(
os.path.join(model_dir, "unet"),
torch_dtype=torch_dtype,
)
# Load VAE (custom)
vae = MatFuseVQModel.from_pretrained(
os.path.join(model_dir, "vae"),
torch_dtype=torch_dtype,
)
# Load scheduler
scheduler = DDIMScheduler.from_pretrained(
os.path.join(model_dir, "scheduler"),
)
# Load condition encoder (custom) if it exists
cond_dir = os.path.join(model_dir, "condition_encoder")
condition_encoder = None
if os.path.isdir(cond_dir):
condition_encoder = MultiConditionEncoder.from_pretrained(
cond_dir,
torch_dtype=torch_dtype,
)
return cls(
vae=vae,
unet=unet,
scheduler=scheduler,
condition_encoder=condition_encoder,
)
@property
def _execution_device(self):
if self.device != torch.device("meta"):
return self.device
for name, model in self.components.items():
if isinstance(model, torch.nn.Module):
return next(model.parameters()).device
# Also check condition_encoder (may not be in components dict)
if self.condition_encoder is not None:
return next(self.condition_encoder.parameters()).device
return torch.device("cpu")
def to(self, *args, **kwargs):
"""Override to() to also move condition_encoder (not auto-tracked by diffusers)."""
result = super().to(*args, **kwargs)
if self.condition_encoder is not None:
self.condition_encoder = self.condition_encoder.to(*args, **kwargs)
return result
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode latents to material maps."""
# Add circular padding for seamless textures
latents = F.pad(latents, (7, 7, 7, 7), mode="circular")
# Upcast to float32 for VAE decoding to avoid NaN from float16 precision
needs_upcast = latents.dtype == torch.float16
if needs_upcast:
self.vae.to(dtype=torch.float32)
latents = latents.float()
# Decode
materials = self.vae.decode(latents)
if needs_upcast:
self.vae.to(dtype=torch.float16)
materials = materials.half()
# Center crop to remove padding
_, _, h, w = materials.shape
target_h = (h - 14 * self.vae_scale_factor)
target_w = (w - 14 * self.vae_scale_factor)
start_h = (h - target_h) // 2
start_w = (w - target_w) // 2
materials = materials[:, :, start_h:start_h + target_h, start_w:start_w + target_w]
return materials
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Prepare initial noise latents."""
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# Scale by scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_extra_step_kwargs(self, generator: Optional[torch.Generator], eta: float) -> Dict[str, Any]:
"""Prepare extra kwargs for the scheduler step."""
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def _encode_conditions(
self,
image: Optional[torch.Tensor] = None,
text: Optional[Union[str, List[str]]] = None,
sketch: Optional[torch.Tensor] = None,
palette: Optional[torch.Tensor] = None,
batch_size: int = 1,
image_size: int = 256,
device: torch.device = None,
dtype: torch.dtype = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode all condition inputs through their respective encoders.
When a condition is not provided, the encoder creates a placeholder
and encodes it (matching training behavior), rather than using zero tensors.
"""
device = device or self._execution_device
if self.condition_encoder is not None:
cond = self.condition_encoder(
image_embed=image,
text=text,
sketch=sketch,
palette=palette,
batch_size=batch_size,
image_size=image_size,
device=device,
)
c_crossattn = cond["c_crossattn"]
c_concat = cond["c_concat"]
else:
c_crossattn = None
c_concat = None
# Ensure proper dtype
if c_crossattn is not None:
c_crossattn = c_crossattn.to(dtype=dtype, device=device)
if c_concat is not None:
c_concat = c_concat.to(dtype=dtype, device=device)
return c_crossattn, c_concat
def _get_uncond_embeddings(
self,
batch_size: int,
image_size: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get unconditional embeddings for classifier-free guidance.
Creates proper unconditional embeddings by encoding placeholder inputs
through the actual encoders (gray image → CLIP, empty string → SentenceTransformer,
zero palette → PaletteEncoder, zero sketch → SketchEncoder).
This matches the original training behavior where ucg_training drops conditions
by setting them to val=0.0 (images/palette/sketch) or val="" (text), and then
encoding those placeholder values through the encoders.
"""
if self.condition_encoder is not None:
uc = self.condition_encoder.get_unconditional_conditioning(
batch_size=batch_size,
image_size=image_size,
device=device,
)
uc_crossattn = uc["c_crossattn"].to(dtype=dtype, device=device)
uc_concat = uc["c_concat"].to(dtype=dtype, device=device)
else:
uc_crossattn = None
uc_concat = None
return uc_crossattn, uc_concat
@torch.no_grad()
def __call__(
self,
image: Optional[Union[torch.Tensor, Image.Image]] = None,
text: Optional[Union[str, List[str]]] = None,
sketch: Optional[Union[torch.Tensor, Image.Image]] = None,
palette: Optional[Union[torch.Tensor, np.ndarray, List[Tuple[int, int, int]]]] = None,
height: int = 256,
width: int = 256,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
) -> Dict[str, Any]:
"""
Generate PBR material maps.
Args:
image: Reference image for style/appearance guidance.
text: Text description of the material.
sketch: Binary edge/sketch map for structure guidance.
palette: Color palette (5 colors) for color guidance.
height: Output image height.
width: Output image width.
num_inference_steps: Number of denoising steps.
guidance_scale: Classifier-free guidance scale.
num_images_per_prompt: Number of images to generate per prompt.
eta: DDIM eta parameter.
generator: Random number generator for reproducibility.
latents: Pre-generated noise latents.
output_type: Output format ("pil", "tensor", "np").
return_dict: Whether to return a dict.
callback: Callback function called every `callback_steps` steps.
callback_steps: Frequency of callback calls.
Returns:
Dictionary containing:
- images: List of generated images (4 maps per generation).
- diffuse: Diffuse/albedo maps.
- normal: Normal maps.
- roughness: Roughness maps.
- specular: Specular maps.
"""
device = self._execution_device
dtype = self.unet.dtype if hasattr(self.unet, 'dtype') else torch.float32
# Determine batch size
if text is not None and isinstance(text, str):
batch_size = 1
elif text is not None:
batch_size = len(text)
else:
batch_size = 1
batch_size = batch_size * num_images_per_prompt
# Preprocess inputs
if image is not None and isinstance(image, Image.Image):
image = self._preprocess_image(image, device, dtype)
if sketch is not None and isinstance(sketch, Image.Image):
sketch = self._preprocess_sketch(sketch, height, width, device, dtype)
if palette is not None and not isinstance(palette, torch.Tensor):
palette = self._preprocess_palette(palette, device, dtype)
# Encode conditions
# The encoder handles None conditions by encoding placeholder inputs
# (matching the original model's UCG training behavior)
c_crossattn, c_concat = self._encode_conditions(
image=image,
text=text,
sketch=sketch,
palette=palette,
batch_size=batch_size,
image_size=height,
device=device,
dtype=dtype,
)
# Get unconditional embeddings for CFG
# These are encoded placeholders, NOT zero tensors
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
uc_crossattn, uc_concat = self._get_uncond_embeddings(
batch_size, height, device, dtype
)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# Prepare latent variables
num_channels_latents = 12 # 4 maps * 3 channels per quantizer
latents = self.prepare_latents(
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents,
)
# Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Prepare latent input with sketch conditioning
if do_classifier_free_guidance:
# For CFG: unconditional uses uc_concat, conditional uses c_concat
latent_uncond = torch.cat([latents, uc_concat], dim=1)
latent_cond = torch.cat([latents, c_concat], dim=1)
latent_model_input = torch.cat([latent_uncond, latent_cond])
if c_crossattn is not None:
encoder_hidden_states = torch.cat([uc_crossattn, c_crossattn])
else:
encoder_hidden_states = None
else:
latent_model_input = torch.cat([latents, c_concat], dim=1)
encoder_hidden_states = c_crossattn
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Predict noise
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=encoder_hidden_states,
return_dict=False,
)
# return_dict=False returns tuple, first element is sample
if isinstance(noise_pred, tuple):
noise_pred = noise_pred[0]
elif isinstance(noise_pred, dict):
noise_pred = noise_pred["sample"]
# Classifier-free guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# Compute previous noisy sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# Callback
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# Decode latents
materials = self.decode_latents(latents)
# Split into individual maps
diffuse = materials[:, 0:3]
normal = materials[:, 3:6]
roughness = materials[:, 6:9]
specular = materials[:, 9:12]
# Post-process outputs
if output_type == "pil":
diffuse = self._tensor_to_pil(diffuse)
normal = self._tensor_to_pil(normal)
roughness = self._tensor_to_pil(roughness)
specular = self._tensor_to_pil(specular)
elif output_type == "np":
diffuse = self._tensor_to_numpy(diffuse)
normal = self._tensor_to_numpy(normal)
roughness = self._tensor_to_numpy(roughness)
specular = self._tensor_to_numpy(specular)
if return_dict:
return {
"diffuse": diffuse,
"normal": normal,
"roughness": roughness,
"specular": specular,
}
return (diffuse, normal, roughness, specular)
def _preprocess_image(self, image: Image.Image, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Preprocess PIL image to tensor."""
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
image = image * 2.0 - 1.0 # Scale to [-1, 1]
return image.to(device=device, dtype=dtype)
def _preprocess_sketch(
self,
sketch: Image.Image,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Preprocess sketch image to tensor."""
sketch = sketch.convert("L")
sketch = sketch.resize((width, height), Image.BILINEAR)
sketch = np.array(sketch).astype(np.float32) / 255.0
sketch = torch.from_numpy(sketch).unsqueeze(0).unsqueeze(0)
return sketch.to(device=device, dtype=dtype)
def _preprocess_palette(
self,
palette: Union[np.ndarray, List[Tuple[int, int, int]]],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Preprocess color palette to tensor."""
if isinstance(palette, list):
palette = np.array(palette, dtype=np.float32) / 255.0
elif isinstance(palette, np.ndarray):
if palette.max() > 1.0:
palette = palette.astype(np.float32) / 255.0
else:
palette = palette.astype(np.float32)
# Ensure 5 colors
while len(palette) < 5:
palette = np.concatenate([palette, palette[-1:]], axis=0)
palette = palette[:5]
palette = torch.from_numpy(palette).unsqueeze(0)
return palette.to(device=device, dtype=dtype)
def _tensor_to_pil(self, tensor: torch.Tensor) -> List[Image.Image]:
"""Convert tensor to list of PIL images."""
tensor = (tensor + 1.0) / 2.0
tensor = tensor.clamp(0, 1)
tensor = tensor.cpu().permute(0, 2, 3, 1).numpy()
tensor = (tensor * 255).astype(np.uint8)
return [Image.fromarray(img) for img in tensor]
def _tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
"""Convert tensor to numpy array."""
tensor = (tensor + 1.0) / 2.0
tensor = tensor.clamp(0, 1)
return tensor.cpu().permute(0, 2, 3, 1).numpy()