Spaces:
Sleeping
Sleeping
File size: 6,591 Bytes
56d35ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | # copied and modified from https://github.com/bytedance/USO/blob/main/uso/flux/modules/layers.py
import torch
from torch import nn
import os
from safetensors.torch import load_file
class SigLIPMultiFeatProjModel(nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.
Args:
layer_indices (list[int]): List of SigLIP hidden_states indices to extract, e.g. [-2, -11, -20]
siglip_token_nums (int): Number of SigLIP tokens.
style_token_nums (int): Number of style tokens.
siglip_token_dims (int): Dimension of SigLIP tokens.
hidden_size (int): Hidden layer size for the projection network.
projection_layers (int): Number of linear layers in the projection network.
context_layer_norm (bool): Whether to use context layer normalization.
post_projection_dim (int, optional): If specified, adds a final projection layer to this dimension.
"""
def __init__(
self,
layer_indices: list = [-2, -11, -20],
siglip_token_nums: int = 729,
style_token_nums: int = 64,
siglip_token_dims: int = 1152,
hidden_size: int = 4096,
projection_layers: int = 1, # New parameter to control projection depth
context_layer_norm: bool = True,
post_projection_dim: int = None,
):
super().__init__()
self.layer_indices = layer_indices
self.style_token_nums = style_token_nums
self.hidden_size = hidden_size
self.post_projection_dim = post_projection_dim
# Create independent processing modules for each specified layer
self.embedding_linears = nn.ModuleDict()
self.layer_norms = nn.ModuleDict()
self.projections = nn.ModuleDict()
activate_fn = nn.SiLU()
for idx in layer_indices:
name = str(idx) # Use layer index as the key
# Linear layer to adjust token numbers
self.embedding_linears[name] = nn.Sequential(
nn.Linear(siglip_token_nums, style_token_nums),
activate_fn,
)
# Layer normalization
self.layer_norms[name] = (
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
# --- MODIFIED PART ---
proj_layers = []
if projection_layers > 0:
# First layer maps from siglip_token_dims to hidden_size
proj_layers.append(nn.Linear(siglip_token_dims, hidden_size, bias=True))
for _ in range(projection_layers - 1):
proj_layers.append(nn.Linear(hidden_size, hidden_size, bias=True))
proj_layers.append(activate_fn)
if len(proj_layers) != 1: # 适配之前的架构
self.projections[name] = nn.Sequential(*proj_layers)
else:
self.projections[name] = proj_layers[0]
else:
self.projections[name] = nn.Identity()
#self.projections[name] = nn.Linear(siglip_token_dims, hidden_size, bias=True)
# --- END OF MODIFICATION ---
self.post_projection = None
if self.post_projection_dim is not None:
# Note: The input dimension here is flattened
input_dim = hidden_size * len(layer_indices) * style_token_nums
self.post_projection = nn.Linear(input_dim, self.post_projection_dim)
def forward(self, siglip_outputs):
"""
Args:
siglip_outputs: Output from a SigLIP model, which contains the `hidden_states`.
Returns:
torch.Tensor: The final projected features.
"""
first_module = next(iter(self.embedding_linears.values()))
dtype = next(first_module.parameters()).dtype
embeddings = []
for idx in self.layer_indices:
name = str(idx)
hidden_states = siglip_outputs.hidden_states[idx]
embedding = self._process_layer_features(
hidden_states,
self.embedding_linears[name],
self.layer_norms[name],
self.projections[name],
dtype
)
embeddings.append(embedding)
# Concatenate all embeddings along the token dimension
embeddings = torch.cat(embeddings, dim=1)
# If a post-projection layer is defined, apply it
if self.post_projection is not None:
# Flatten the tensor for the final linear layer
bs = embeddings.shape[0]
embeddings = embeddings.view(bs, -1)
embeddings = self.post_projection(embeddings)
return embeddings
def load_proj_model(self, checkpoint_path):
"""Loads weights for the projection model from a .pt or .safetensors checkpoint."""
proj_model_name = 'proj_model.'
model_state_dict = {}
# 判断文件类型
ext = os.path.splitext(checkpoint_path)[1].lower()
if ext == ".safetensors":
# 直接从 .safetensors 文件加载
all_state_dict = load_file(checkpoint_path)
else:
# 默认按 PyTorch checkpoint 方式加载
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
all_state_dict = checkpoint.get("model_state_dict", checkpoint)
# 处理键名
for k, v in all_state_dict.items():
if k.startswith(proj_model_name):
k = k.replace(proj_model_name, '')
model_state_dict[k] = v
# 加载参数
missing, unexpected = self.load_state_dict(model_state_dict, strict=False)
if missing or unexpected:
print(f"[Warning] Missing keys: {missing}, Unexpected keys: {unexpected}")
def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""Helper function to process features from a single layer."""
# Adjust token numbers
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)
# Normalize and project
embedding = layer_norm(embedding)
embedding = projection(embedding)
return embedding
|