| | """
|
| | MatFuse Condition Encoders for diffusers.
|
| |
|
| | These encoders handle the multi-modal conditioning:
|
| | - Image embedding (CLIP image encoder)
|
| | - Text embedding (CLIP text encoder)
|
| | - Sketch encoder (CNN)
|
| | - Palette encoder (MLP)
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Dict, Union, List
|
| | from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| | from diffusers.models.modeling_utils import ModelMixin
|
| |
|
| |
|
| | class SketchEncoder(ModelMixin, ConfigMixin):
|
| | """
|
| | CNN encoder for binary sketch/edge maps.
|
| |
|
| | Takes a single-channel binary image and encodes it to a spatial feature map
|
| | that will be concatenated with the latent for hybrid conditioning.
|
| | """
|
| |
|
| | @register_to_config
|
| | def __init__(
|
| | self,
|
| | in_channels: int = 1,
|
| | out_channels: int = 4,
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.net = nn.Sequential(
|
| | nn.Conv2d(in_channels, 32, 7, 1, 1),
|
| | nn.BatchNorm2d(32),
|
| | nn.GELU(),
|
| | nn.Conv2d(32, 64, 3, 2, 1),
|
| | nn.BatchNorm2d(64),
|
| | nn.GELU(),
|
| | nn.Conv2d(64, 128, 3, 2, 1),
|
| | nn.BatchNorm2d(128),
|
| | nn.GELU(),
|
| | nn.Conv2d(128, 256, 3, 2, 1),
|
| | nn.BatchNorm2d(256),
|
| | nn.GELU(),
|
| | nn.Conv2d(256, out_channels, 1, 1, 0),
|
| | nn.BatchNorm2d(out_channels),
|
| | nn.GELU(),
|
| | )
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Encode sketch input.
|
| |
|
| | Args:
|
| | x: Input tensor of shape (B, 1, H, W) with values in [0, 1].
|
| |
|
| | Returns:
|
| | Encoded features of shape (B, out_channels, H/8, W/8).
|
| | """
|
| | return self.net(x)
|
| |
|
| |
|
| | class PaletteEncoder(ModelMixin, ConfigMixin):
|
| | """
|
| | MLP encoder for color palettes.
|
| |
|
| | Takes a color palette (N colors, RGB) and encodes it to a single embedding
|
| | for cross-attention conditioning.
|
| | """
|
| |
|
| | @register_to_config
|
| | def __init__(
|
| | self,
|
| | in_channels: int = 3,
|
| | hidden_channels: int = 64,
|
| | out_channels: int = 512,
|
| | n_colors: int = 5,
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.net = nn.Sequential(
|
| | nn.Linear(in_channels, hidden_channels),
|
| | nn.GELU(),
|
| | nn.Flatten(),
|
| | nn.Linear(hidden_channels * n_colors, out_channels),
|
| | nn.GELU(),
|
| | )
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Encode color palette.
|
| |
|
| | Args:
|
| | x: Input tensor of shape (B, n_colors, 3) with RGB values in [0, 1].
|
| |
|
| | Returns:
|
| | Encoded embedding of shape (B, out_channels).
|
| | """
|
| | return self.net(x)
|
| |
|
| |
|
| | class CLIPImageEncoder(ModelMixin, ConfigMixin):
|
| | """
|
| | Wrapper for CLIP image encoder using the OpenAI CLIP library.
|
| |
|
| | Generates image embeddings for cross-attention conditioning.
|
| | """
|
| |
|
| | @register_to_config
|
| | def __init__(
|
| | self,
|
| | model_name: str = "ViT-B/16",
|
| | normalize: bool = True,
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.model_name = model_name
|
| | self.normalize = normalize
|
| | self.model = None
|
| |
|
| |
|
| | self.register_buffer(
|
| | "mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
| | )
|
| | self.register_buffer(
|
| | "std", torch.tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
| | )
|
| |
|
| | def _load_model(self):
|
| | """Lazy load the CLIP model."""
|
| | if self.model is None:
|
| | import clip
|
| |
|
| | self.model, _ = clip.load(self.model_name, device="cpu", jit=False)
|
| | self.model = self.model.visual
|
| |
|
| | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
| | """Preprocess images for CLIP."""
|
| |
|
| | x = F.interpolate(
|
| | x, size=(224, 224), mode="bicubic", align_corners=True, antialias=True
|
| | )
|
| |
|
| | x = (x + 1.0) / 2.0
|
| |
|
| | mean = self.mean.to(x.device).view(1, 3, 1, 1)
|
| | std = self.std.to(x.device).view(1, 3, 1, 1)
|
| | x = (x - mean) / std
|
| | return x
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Encode image using CLIP.
|
| |
|
| | Args:
|
| | x: Input tensor of shape (B, 3, H, W) with values in [-1, 1].
|
| |
|
| | Returns:
|
| | Image embedding of shape (B, 1, 512).
|
| | """
|
| | self._load_model()
|
| |
|
| |
|
| | device = x.device
|
| | self.model = self.model.to(device)
|
| |
|
| | x = self.preprocess(x)
|
| | z = self.model(x).float().unsqueeze(1)
|
| |
|
| | if self.normalize:
|
| | z = z / torch.linalg.norm(z, dim=2, keepdim=True)
|
| |
|
| | return z
|
| |
|
| |
|
| | class CLIPTextEncoder(ModelMixin, ConfigMixin):
|
| | """
|
| | Wrapper for CLIP sentence encoder using sentence-transformers.
|
| |
|
| | Generates text embeddings for cross-attention conditioning.
|
| | """
|
| |
|
| | @register_to_config
|
| | def __init__(
|
| | self,
|
| | model_name: str = "sentence-transformers/clip-ViT-B-16",
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.model_name = model_name
|
| | self.model = None
|
| |
|
| | def _load_model(self):
|
| | """Lazy load the sentence transformer model."""
|
| | if self.model is None:
|
| | from sentence_transformers import SentenceTransformer
|
| |
|
| | self.model = SentenceTransformer(self.model_name)
|
| | self.model.eval()
|
| |
|
| | def forward(self, text: Union[str, List[str]]) -> torch.Tensor:
|
| | """
|
| | Encode text using CLIP sentence transformer.
|
| |
|
| | Args:
|
| | text: Input text or list of texts.
|
| |
|
| | Returns:
|
| | Text embedding of shape (B, 512).
|
| | """
|
| | self._load_model()
|
| |
|
| | if isinstance(text, str):
|
| | text = [text]
|
| |
|
| | embeddings = self.model.encode(text, convert_to_tensor=True)
|
| | return embeddings
|
| |
|
| |
|
| | class MultiConditionEncoder(ModelMixin, ConfigMixin):
|
| | """
|
| | Multi-condition encoder that combines all conditioning modalities.
|
| |
|
| | This encoder takes multiple condition inputs and produces:
|
| | - c_crossattn: Features for cross-attention (image, text, palette embeddings)
|
| | - c_concat: Features for concatenation (sketch encoding)
|
| | """
|
| |
|
| | @register_to_config
|
| | def __init__(
|
| | self,
|
| | sketch_in_channels: int = 1,
|
| | sketch_out_channels: int = 4,
|
| | palette_in_channels: int = 3,
|
| | palette_hidden_channels: int = 64,
|
| | palette_out_channels: int = 512,
|
| | n_colors: int = 5,
|
| | clip_image_model: str = "ViT-B/16",
|
| | clip_text_model: str = "sentence-transformers/clip-ViT-B-16",
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.sketch_encoder = SketchEncoder(
|
| | in_channels=sketch_in_channels,
|
| | out_channels=sketch_out_channels,
|
| | )
|
| |
|
| | self.palette_encoder = PaletteEncoder(
|
| | in_channels=palette_in_channels,
|
| | hidden_channels=palette_hidden_channels,
|
| | out_channels=palette_out_channels,
|
| | n_colors=n_colors,
|
| | )
|
| |
|
| |
|
| | self.clip_image_encoder = None
|
| | self.clip_text_encoder = None
|
| | self._clip_image_model = clip_image_model
|
| | self._clip_text_model = clip_text_model
|
| |
|
| | def _load_clip_encoders(self):
|
| | """Lazy load CLIP encoders."""
|
| | if self.clip_image_encoder is None:
|
| | self.clip_image_encoder = CLIPImageEncoder(
|
| | model_name=self._clip_image_model
|
| | )
|
| | if self.clip_text_encoder is None:
|
| | self.clip_text_encoder = CLIPTextEncoder(model_name=self._clip_text_model)
|
| |
|
| | def encode_image(self, image: torch.Tensor) -> torch.Tensor:
|
| | """Encode image using CLIP."""
|
| | self._load_clip_encoders()
|
| | return self.clip_image_encoder(image)
|
| |
|
| | def encode_text(self, text: Union[str, List[str]]) -> torch.Tensor:
|
| | """Encode text using CLIP."""
|
| | self._load_clip_encoders()
|
| | return self.clip_text_encoder(text)
|
| |
|
| | def encode_sketch(self, sketch: torch.Tensor) -> torch.Tensor:
|
| | """Encode sketch/edge map."""
|
| | return self.sketch_encoder(sketch)
|
| |
|
| | def encode_palette(self, palette: torch.Tensor) -> torch.Tensor:
|
| | """Encode color palette."""
|
| | return self.palette_encoder(palette)
|
| |
|
| | def get_unconditional_conditioning(
|
| | self,
|
| | batch_size: int = 1,
|
| | image_size: int = 256,
|
| | device: Optional[torch.device] = None,
|
| | ) -> Dict[str, torch.Tensor]:
|
| | """
|
| | Get unconditional conditioning for classifier-free guidance.
|
| |
|
| | IMPORTANT: The original model was trained to drop conditions by replacing them
|
| | with encoded placeholders (zero/gray image through CLIP, empty string through
|
| | sentence-transformers, zero palette through PaletteEncoder, zero sketch through
|
| | SketchEncoder) — NOT with zero tensors. This method produces the correct
|
| | unconditional embeddings.
|
| |
|
| | Args:
|
| | batch_size: Batch size.
|
| | image_size: Image resolution (for sketch spatial dims).
|
| | device: Device to place tensors on.
|
| |
|
| | Returns:
|
| | Dictionary with c_crossattn and c_concat for unconditional guidance.
|
| | """
|
| | return self.forward(
|
| | image_embed=None,
|
| | text=None,
|
| | sketch=None,
|
| | palette=None,
|
| | batch_size=batch_size,
|
| | image_size=image_size,
|
| | device=device,
|
| | )
|
| |
|
| | def forward(
|
| | self,
|
| | image_embed: Optional[torch.Tensor] = None,
|
| | text: Optional[Union[str, List[str]]] = None,
|
| | sketch: Optional[torch.Tensor] = None,
|
| | palette: Optional[torch.Tensor] = None,
|
| | batch_size: int = 1,
|
| | image_size: int = 256,
|
| | device: Optional[torch.device] = None,
|
| | ) -> Dict[str, torch.Tensor]:
|
| | """
|
| | Encode all conditions.
|
| |
|
| | When a condition is not provided, the model encodes a placeholder input
|
| | through the actual encoder (matching training behavior) rather than using
|
| | zero tensors. This is critical because the model was trained with:
|
| | - Image drop → CLIP encoding of a gray/zero image (0.0 in [-1,1])
|
| | - Text drop → sentence-transformer encoding of ""
|
| | - Palette drop → PaletteEncoder(zeros)
|
| | - Sketch drop → SketchEncoder(zeros)
|
| |
|
| | Args:
|
| | image_embed: Reference image of shape (B, 3, H, W) in [-1, 1].
|
| | text: Text description(s).
|
| | sketch: Binary sketch of shape (B, 1, H, W) in [0, 1].
|
| | palette: Color palette of shape (B, n_colors, 3) in [0, 1].
|
| | batch_size: Batch size (used when no inputs are provided).
|
| | image_size: Image resolution (used to create placeholder sketch).
|
| | device: Device to place tensors on.
|
| |
|
| | Returns:
|
| | Dictionary with:
|
| | - c_crossattn: Cross-attention context of shape (B, 3, 512) - always 3 tokens.
|
| | - c_concat: Concatenation features of shape (B, 4, H/8, W/8).
|
| | """
|
| | self._load_clip_encoders()
|
| |
|
| |
|
| | if image_embed is not None:
|
| | batch_size = image_embed.shape[0]
|
| | device = device or image_embed.device
|
| | image_size = image_embed.shape[-1]
|
| | elif sketch is not None:
|
| | batch_size = sketch.shape[0]
|
| | device = device or sketch.device
|
| | image_size = sketch.shape[-1]
|
| | elif palette is not None:
|
| | batch_size = palette.shape[0]
|
| | device = device or palette.device
|
| |
|
| | device = device or torch.device("cpu")
|
| |
|
| | dtype = next(self.palette_encoder.parameters()).dtype
|
| |
|
| |
|
| |
|
| | if image_embed is not None:
|
| | img_emb = self.clip_image_encoder(image_embed)
|
| | else:
|
| | placeholder_img = torch.zeros(
|
| | batch_size, 3, image_size, image_size, device=device, dtype=dtype
|
| | )
|
| | img_emb = self.clip_image_encoder(placeholder_img)
|
| |
|
| |
|
| |
|
| | if text is not None:
|
| | text_emb = self.clip_text_encoder(text)
|
| | if device is not None:
|
| | text_emb = text_emb.to(device)
|
| | text_emb = text_emb.unsqueeze(1)
|
| | else:
|
| | text_emb = self.clip_text_encoder([""] * batch_size)
|
| | text_emb = text_emb.to(device).unsqueeze(1)
|
| |
|
| |
|
| |
|
| | if palette is not None:
|
| | palette_emb = self.palette_encoder(palette)
|
| | palette_emb = palette_emb.unsqueeze(1)
|
| | else:
|
| | n_colors = self.config.get("n_colors", 5)
|
| | placeholder_palette = torch.zeros(batch_size, n_colors, 3, device=device, dtype=dtype)
|
| | palette_emb = self.palette_encoder(placeholder_palette)
|
| | palette_emb = palette_emb.unsqueeze(1)
|
| |
|
| |
|
| | c_crossattn = torch.cat([img_emb, text_emb, palette_emb], dim=1)
|
| |
|
| |
|
| |
|
| | if sketch is not None:
|
| | c_concat = self.sketch_encoder(sketch)
|
| | else:
|
| | placeholder_sketch = torch.zeros(
|
| | batch_size, 1, image_size, image_size, device=device, dtype=dtype
|
| | )
|
| | c_concat = self.sketch_encoder(placeholder_sketch)
|
| |
|
| | return {
|
| | "c_crossattn": c_crossattn,
|
| | "c_concat": c_concat,
|
| | }
|
| |
|