"""SDXL CLIP model implementation for LightDiffusion. This module provides SDXL-specific CLIP tokenizers and models, adapted from ComfyUI's implementation but using local LightDiffusion modules. """ import torch from src.SD15 import SDClip, SDToken class SDXLClipG(SDClip.SDClipModel): """SDXL ClipG model - uses the larger G model with different layer settings.""" def __init__( self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options=None, ): """Initialize SDXLClipG model. Args: device: Device to load model on max_length: Maximum token length freeze: Whether to freeze weights layer: Layer type ('penultimate' maps to 'hidden') layer_idx: Specific layer index (defaults to -2 for penultimate) dtype: Data type for model weights model_options: Additional model options (optional) """ if layer == "penultimate": layer = "hidden" layer_idx = -2 # Use the bigg config for SDXL's G model (in include/clip directory) textmodel_json_config = "./include/clip/clip_config_bigg.json" super().__init__( device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, ) def load_sd(self, sd): """Load state dict into model.""" return super().load_sd(sd) class SDXLClipGTokenizer(SDToken.SDTokenizer): """Tokenizer for SDXL ClipG model.""" def __init__( self, tokenizer_path=None, embedding_directory=None, tokenizer_data=None ): """Initialize SDXLClipGTokenizer. Args: tokenizer_path: Path to tokenizer config embedding_directory: Directory containing embeddings tokenizer_data: Pre-loaded tokenizer data dict (ignored for compatibility) """ # Note: tokenizer_data is accepted for compatibility but not used super().__init__( tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key="clip_g", ) class SDXLTokenizer: """Dual tokenizer for SDXL (combines L and G models).""" def __init__(self, embedding_directory=None, tokenizer_data=None): """Initialize SDXLTokenizer with both L and G tokenizers. Args: embedding_directory: Directory containing embeddings tokenizer_data: Pre-loaded tokenizer data dict (ignored for compatibility) """ # Note: tokenizer_data is accepted for compatibility but not used self.clip_l = SDToken.SDTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): """Tokenize text with both L and G tokenizers. Args: text: Input text to tokenize return_word_ids: Whether to return word IDs **kwargs: Additional arguments Returns: Dict with 'g' and 'l' tokenization results """ out = {} out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): """Convert tokens back to text using G tokenizer.""" return self.clip_g.untokenize(token_weight_pair) def state_dict(self): """Return empty state dict (tokenizer has no trainable params).""" return {} class SDXLClipModel(torch.nn.Module): """SDXL CLIP model combining both L and G encoders.""" def __init__(self, device="cpu", dtype=None, model_options=None): """Initialize SDXL CLIP model with both L and G components. Args: device: Device to load models on dtype: Data type for model weights model_options: Additional model options (optional) """ super().__init__() if model_options is None: model_options = {} self.clip_l = SDClip.SDClipModel( layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, ) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = {dtype} if dtype else set() def set_clip_options(self, options): """Set options for both CLIP models.""" self.clip_l.set_clip_options(options) self.clip_g.set_clip_options(options) def reset_clip_options(self): """Reset options for both CLIP models.""" self.clip_g.reset_clip_options() self.clip_l.reset_clip_options() def load_state_dict(self, state_dict, strict=True): """Override load_state_dict to handle shape mismatches. Args: state_dict: State dictionary to load strict: Whether to strictly enforce key matching Returns: NamedTuple with missing_keys and unexpected_keys """ # Filter out keys with shape mismatches filtered_sd = {} for k, v in state_dict.items(): # Handle logit_scale shape mismatch (scalar vs 1D) if "logit_scale" in k and k in self.state_dict(): expected_shape = self.state_dict()[k].shape if v.shape != expected_shape and v.numel() == 1 and len(expected_shape) == 1: # Reshape scalar to 1D filtered_sd[k] = v.reshape(expected_shape) continue filtered_sd[k] = v # Call parent load_state_dict with filtered dict return super().load_state_dict(filtered_sd, strict=strict) def encode_token_weights(self, token_weight_pairs): """Encode tokens from both L and G models and concatenate. Args: token_weight_pairs: Dict with 'g' and 'l' token weight pairs Returns: Tuple of (concatenated embeddings, pooled output from G) """ token_weight_pairs_g = token_weight_pairs["g"] token_weight_pairs_l = token_weight_pairs["l"] g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) # Cut to minimum length and concatenate cut_to = min(l_out.shape[1], g_out.shape[1]) return torch.cat([l_out[:, :cut_to], g_out[:, :cut_to]], dim=-1), g_pooled def load_sd(self, sd): """Load state dict - routes to G or L model based on keys present. Args: sd: State dict to load Returns: Tuple of (missing keys, unexpected keys) """ # Filter out problematic keys that have shape mismatches # logit_scale can be a scalar in checkpoint but 1D in model filtered_sd = {} for k, v in sd.items(): # Skip logit_scale if it has the wrong shape if "logit_scale" in k: # Check expected shape from the model if k in self.state_dict(): expected_shape = self.state_dict()[k].shape if v.shape != expected_shape: # Try to reshape or skip if v.numel() == 1 and len(expected_shape) == 1: # Reshape scalar to 1D filtered_sd[k] = v.reshape(expected_shape) continue else: # Skip if can't resolve continue filtered_sd[k] = v # Check if this is a G model state dict (has layer 30) if "text_model.encoder.layers.30.mlp.fc1.weight" in filtered_sd: return self.clip_g.load_sd(filtered_sd) else: return self.clip_l.load_sd(filtered_sd) class SDXLRefinerClipModel(SDClip.SD1ClipModel): """SDXL Refiner CLIP model (G only).""" def __init__(self, device="cpu", dtype=None, model_options=None): """Initialize SDXL Refiner CLIP model. Args: device: Device to load model on dtype: Data type for model weights model_options: Additional model options (optional) """ if model_options is None: model_options = {} super().__init__( device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options ) def load_state_dict(self, state_dict, strict=True): """Override load_state_dict to handle common SDXL mismatches.""" filtered_sd = {} for k, v in state_dict.items(): # Handle logit_scale shape mismatch if "logit_scale" in k and k in self.state_dict(): expected_shape = self.state_dict()[k].shape if v.shape != expected_shape and v.numel() == 1 and len(expected_shape) == 1: filtered_sd[k] = v.reshape(expected_shape) continue # Skip position_ids if they cause mismatches (Refiner often doesn't need them if embeddings.weight is present) if "position_ids" in k: continue filtered_sd[k] = v return super().load_state_dict(filtered_sd, strict=strict) def load_sd(self, sd): """Load state dict and route to G model.""" return self.clip_g.load_sd(sd)