StyleExper-V2 / src /siglip_layers.py
oedevs's picture
upload
56d35ce
# copied and modified from https://github.com/bytedance/USO/blob/main/uso/flux/modules/layers.py
import torch
from torch import nn
import os
from safetensors.torch import load_file
class SigLIPMultiFeatProjModel(nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.
Args:
layer_indices (list[int]): List of SigLIP hidden_states indices to extract, e.g. [-2, -11, -20]
siglip_token_nums (int): Number of SigLIP tokens.
style_token_nums (int): Number of style tokens.
siglip_token_dims (int): Dimension of SigLIP tokens.
hidden_size (int): Hidden layer size for the projection network.
projection_layers (int): Number of linear layers in the projection network.
context_layer_norm (bool): Whether to use context layer normalization.
post_projection_dim (int, optional): If specified, adds a final projection layer to this dimension.
"""
def __init__(
self,
layer_indices: list = [-2, -11, -20],
siglip_token_nums: int = 729,
style_token_nums: int = 64,
siglip_token_dims: int = 1152,
hidden_size: int = 4096,
projection_layers: int = 1, # New parameter to control projection depth
context_layer_norm: bool = True,
post_projection_dim: int = None,
):
super().__init__()
self.layer_indices = layer_indices
self.style_token_nums = style_token_nums
self.hidden_size = hidden_size
self.post_projection_dim = post_projection_dim
# Create independent processing modules for each specified layer
self.embedding_linears = nn.ModuleDict()
self.layer_norms = nn.ModuleDict()
self.projections = nn.ModuleDict()
activate_fn = nn.SiLU()
for idx in layer_indices:
name = str(idx) # Use layer index as the key
# Linear layer to adjust token numbers
self.embedding_linears[name] = nn.Sequential(
nn.Linear(siglip_token_nums, style_token_nums),
activate_fn,
)
# Layer normalization
self.layer_norms[name] = (
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
# --- MODIFIED PART ---
proj_layers = []
if projection_layers > 0:
# First layer maps from siglip_token_dims to hidden_size
proj_layers.append(nn.Linear(siglip_token_dims, hidden_size, bias=True))
for _ in range(projection_layers - 1):
proj_layers.append(nn.Linear(hidden_size, hidden_size, bias=True))
proj_layers.append(activate_fn)
if len(proj_layers) != 1: # 适配之前的架构
self.projections[name] = nn.Sequential(*proj_layers)
else:
self.projections[name] = proj_layers[0]
else:
self.projections[name] = nn.Identity()
#self.projections[name] = nn.Linear(siglip_token_dims, hidden_size, bias=True)
# --- END OF MODIFICATION ---
self.post_projection = None
if self.post_projection_dim is not None:
# Note: The input dimension here is flattened
input_dim = hidden_size * len(layer_indices) * style_token_nums
self.post_projection = nn.Linear(input_dim, self.post_projection_dim)
def forward(self, siglip_outputs):
"""
Args:
siglip_outputs: Output from a SigLIP model, which contains the `hidden_states`.
Returns:
torch.Tensor: The final projected features.
"""
first_module = next(iter(self.embedding_linears.values()))
dtype = next(first_module.parameters()).dtype
embeddings = []
for idx in self.layer_indices:
name = str(idx)
hidden_states = siglip_outputs.hidden_states[idx]
embedding = self._process_layer_features(
hidden_states,
self.embedding_linears[name],
self.layer_norms[name],
self.projections[name],
dtype
)
embeddings.append(embedding)
# Concatenate all embeddings along the token dimension
embeddings = torch.cat(embeddings, dim=1)
# If a post-projection layer is defined, apply it
if self.post_projection is not None:
# Flatten the tensor for the final linear layer
bs = embeddings.shape[0]
embeddings = embeddings.view(bs, -1)
embeddings = self.post_projection(embeddings)
return embeddings
def load_proj_model(self, checkpoint_path):
"""Loads weights for the projection model from a .pt or .safetensors checkpoint."""
proj_model_name = 'proj_model.'
model_state_dict = {}
# 判断文件类型
ext = os.path.splitext(checkpoint_path)[1].lower()
if ext == ".safetensors":
# 直接从 .safetensors 文件加载
all_state_dict = load_file(checkpoint_path)
else:
# 默认按 PyTorch checkpoint 方式加载
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
all_state_dict = checkpoint.get("model_state_dict", checkpoint)
# 处理键名
for k, v in all_state_dict.items():
if k.startswith(proj_model_name):
k = k.replace(proj_model_name, '')
model_state_dict[k] = v
# 加载参数
missing, unexpected = self.load_state_dict(model_state_dict, strict=False)
if missing or unexpected:
print(f"[Warning] Missing keys: {missing}, Unexpected keys: {unexpected}")
def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""Helper function to process features from a single layer."""
# Adjust token numbers
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)
# Normalize and project
embedding = layer_norm(embedding)
embedding = projection(embedding)
return embedding