Spaces:
Running on Zero
Running on Zero
File size: 10,014 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 | """SDXL CLIP model implementation for LightDiffusion.
This module provides SDXL-specific CLIP tokenizers and models, adapted from
ComfyUI's implementation but using local LightDiffusion modules.
"""
import torch
from src.SD15 import SDClip, SDToken
class SDXLClipG(SDClip.SDClipModel):
"""SDXL ClipG model - uses the larger G model with different layer settings."""
def __init__(
self,
device="cpu",
max_length=77,
freeze=True,
layer="penultimate",
layer_idx=None,
dtype=None,
model_options=None,
):
"""Initialize SDXLClipG model.
Args:
device: Device to load model on
max_length: Maximum token length
freeze: Whether to freeze weights
layer: Layer type ('penultimate' maps to 'hidden')
layer_idx: Specific layer index (defaults to -2 for penultimate)
dtype: Data type for model weights
model_options: Additional model options (optional)
"""
if layer == "penultimate":
layer = "hidden"
layer_idx = -2
# Use the bigg config for SDXL's G model (in include/clip directory)
textmodel_json_config = "./include/clip/clip_config_bigg.json"
super().__init__(
device=device,
freeze=freeze,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=textmodel_json_config,
dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0},
layer_norm_hidden_state=False,
return_projected_pooled=True,
)
def load_sd(self, sd):
"""Load state dict into model."""
return super().load_sd(sd)
class SDXLClipGTokenizer(SDToken.SDTokenizer):
"""Tokenizer for SDXL ClipG model."""
def __init__(
self, tokenizer_path=None, embedding_directory=None, tokenizer_data=None
):
"""Initialize SDXLClipGTokenizer.
Args:
tokenizer_path: Path to tokenizer config
embedding_directory: Directory containing embeddings
tokenizer_data: Pre-loaded tokenizer data dict (ignored for compatibility)
"""
# Note: tokenizer_data is accepted for compatibility but not used
super().__init__(
tokenizer_path,
pad_with_end=False,
embedding_directory=embedding_directory,
embedding_size=1280,
embedding_key="clip_g",
)
class SDXLTokenizer:
"""Dual tokenizer for SDXL (combines L and G models)."""
def __init__(self, embedding_directory=None, tokenizer_data=None):
"""Initialize SDXLTokenizer with both L and G tokenizers.
Args:
embedding_directory: Directory containing embeddings
tokenizer_data: Pre-loaded tokenizer data dict (ignored for compatibility)
"""
# Note: tokenizer_data is accepted for compatibility but not used
self.clip_l = SDToken.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
"""Tokenize text with both L and G tokenizers.
Args:
text: Input text to tokenize
return_word_ids: Whether to return word IDs
**kwargs: Additional arguments
Returns:
Dict with 'g' and 'l' tokenization results
"""
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
"""Convert tokens back to text using G tokenizer."""
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
"""Return empty state dict (tokenizer has no trainable params)."""
return {}
class SDXLClipModel(torch.nn.Module):
"""SDXL CLIP model combining both L and G encoders."""
def __init__(self, device="cpu", dtype=None, model_options=None):
"""Initialize SDXL CLIP model with both L and G components.
Args:
device: Device to load models on
dtype: Data type for model weights
model_options: Additional model options (optional)
"""
super().__init__()
if model_options is None:
model_options = {}
self.clip_l = SDClip.SDClipModel(
layer="hidden",
layer_idx=-2,
device=device,
dtype=dtype,
layer_norm_hidden_state=False,
)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype} if dtype else set()
def set_clip_options(self, options):
"""Set options for both CLIP models."""
self.clip_l.set_clip_options(options)
self.clip_g.set_clip_options(options)
def reset_clip_options(self):
"""Reset options for both CLIP models."""
self.clip_g.reset_clip_options()
self.clip_l.reset_clip_options()
def load_state_dict(self, state_dict, strict=True):
"""Override load_state_dict to handle shape mismatches.
Args:
state_dict: State dictionary to load
strict: Whether to strictly enforce key matching
Returns:
NamedTuple with missing_keys and unexpected_keys
"""
# Filter out keys with shape mismatches
filtered_sd = {}
for k, v in state_dict.items():
# Handle logit_scale shape mismatch (scalar vs 1D)
if "logit_scale" in k and k in self.state_dict():
expected_shape = self.state_dict()[k].shape
if v.shape != expected_shape and v.numel() == 1 and len(expected_shape) == 1:
# Reshape scalar to 1D
filtered_sd[k] = v.reshape(expected_shape)
continue
filtered_sd[k] = v
# Call parent load_state_dict with filtered dict
return super().load_state_dict(filtered_sd, strict=strict)
def encode_token_weights(self, token_weight_pairs):
"""Encode tokens from both L and G models and concatenate.
Args:
token_weight_pairs: Dict with 'g' and 'l' token weight pairs
Returns:
Tuple of (concatenated embeddings, pooled output from G)
"""
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
# Cut to minimum length and concatenate
cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:, :cut_to], g_out[:, :cut_to]], dim=-1), g_pooled
def load_sd(self, sd):
"""Load state dict - routes to G or L model based on keys present.
Args:
sd: State dict to load
Returns:
Tuple of (missing keys, unexpected keys)
"""
# Filter out problematic keys that have shape mismatches
# logit_scale can be a scalar in checkpoint but 1D in model
filtered_sd = {}
for k, v in sd.items():
# Skip logit_scale if it has the wrong shape
if "logit_scale" in k:
# Check expected shape from the model
if k in self.state_dict():
expected_shape = self.state_dict()[k].shape
if v.shape != expected_shape:
# Try to reshape or skip
if v.numel() == 1 and len(expected_shape) == 1:
# Reshape scalar to 1D
filtered_sd[k] = v.reshape(expected_shape)
continue
else:
# Skip if can't resolve
continue
filtered_sd[k] = v
# Check if this is a G model state dict (has layer 30)
if "text_model.encoder.layers.30.mlp.fc1.weight" in filtered_sd:
return self.clip_g.load_sd(filtered_sd)
else:
return self.clip_l.load_sd(filtered_sd)
class SDXLRefinerClipModel(SDClip.SD1ClipModel):
"""SDXL Refiner CLIP model (G only)."""
def __init__(self, device="cpu", dtype=None, model_options=None):
"""Initialize SDXL Refiner CLIP model.
Args:
device: Device to load model on
dtype: Data type for model weights
model_options: Additional model options (optional)
"""
if model_options is None:
model_options = {}
super().__init__(
device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options
)
def load_state_dict(self, state_dict, strict=True):
"""Override load_state_dict to handle common SDXL mismatches."""
filtered_sd = {}
for k, v in state_dict.items():
# Handle logit_scale shape mismatch
if "logit_scale" in k and k in self.state_dict():
expected_shape = self.state_dict()[k].shape
if v.shape != expected_shape and v.numel() == 1 and len(expected_shape) == 1:
filtered_sd[k] = v.reshape(expected_shape)
continue
# Skip position_ids if they cause mismatches (Refiner often doesn't need them if embeddings.weight is present)
if "position_ids" in k:
continue
filtered_sd[k] = v
return super().load_state_dict(filtered_sd, strict=strict)
def load_sd(self, sd):
"""Load state dict and route to G model."""
return self.clip_g.load_sd(sd)
|