Spaces:
Sleeping
Sleeping
| # copied and modified from https://github.com/bytedance/USO/blob/main/uso/flux/modules/layers.py | |
| import torch | |
| from torch import nn | |
| import os | |
| from safetensors.torch import load_file | |
| class SigLIPMultiFeatProjModel(nn.Module): | |
| """ | |
| SigLIP Multi-Feature Projection Model for processing style features from different layers | |
| and projecting them into a unified hidden space. | |
| Args: | |
| layer_indices (list[int]): List of SigLIP hidden_states indices to extract, e.g. [-2, -11, -20] | |
| siglip_token_nums (int): Number of SigLIP tokens. | |
| style_token_nums (int): Number of style tokens. | |
| siglip_token_dims (int): Dimension of SigLIP tokens. | |
| hidden_size (int): Hidden layer size for the projection network. | |
| projection_layers (int): Number of linear layers in the projection network. | |
| context_layer_norm (bool): Whether to use context layer normalization. | |
| post_projection_dim (int, optional): If specified, adds a final projection layer to this dimension. | |
| """ | |
| def __init__( | |
| self, | |
| layer_indices: list = [-2, -11, -20], | |
| siglip_token_nums: int = 729, | |
| style_token_nums: int = 64, | |
| siglip_token_dims: int = 1152, | |
| hidden_size: int = 4096, | |
| projection_layers: int = 1, # New parameter to control projection depth | |
| context_layer_norm: bool = True, | |
| post_projection_dim: int = None, | |
| ): | |
| super().__init__() | |
| self.layer_indices = layer_indices | |
| self.style_token_nums = style_token_nums | |
| self.hidden_size = hidden_size | |
| self.post_projection_dim = post_projection_dim | |
| # Create independent processing modules for each specified layer | |
| self.embedding_linears = nn.ModuleDict() | |
| self.layer_norms = nn.ModuleDict() | |
| self.projections = nn.ModuleDict() | |
| activate_fn = nn.SiLU() | |
| for idx in layer_indices: | |
| name = str(idx) # Use layer index as the key | |
| # Linear layer to adjust token numbers | |
| self.embedding_linears[name] = nn.Sequential( | |
| nn.Linear(siglip_token_nums, style_token_nums), | |
| activate_fn, | |
| ) | |
| # Layer normalization | |
| self.layer_norms[name] = ( | |
| nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() | |
| ) | |
| # --- MODIFIED PART --- | |
| proj_layers = [] | |
| if projection_layers > 0: | |
| # First layer maps from siglip_token_dims to hidden_size | |
| proj_layers.append(nn.Linear(siglip_token_dims, hidden_size, bias=True)) | |
| for _ in range(projection_layers - 1): | |
| proj_layers.append(nn.Linear(hidden_size, hidden_size, bias=True)) | |
| proj_layers.append(activate_fn) | |
| if len(proj_layers) != 1: # 适配之前的架构 | |
| self.projections[name] = nn.Sequential(*proj_layers) | |
| else: | |
| self.projections[name] = proj_layers[0] | |
| else: | |
| self.projections[name] = nn.Identity() | |
| #self.projections[name] = nn.Linear(siglip_token_dims, hidden_size, bias=True) | |
| # --- END OF MODIFICATION --- | |
| self.post_projection = None | |
| if self.post_projection_dim is not None: | |
| # Note: The input dimension here is flattened | |
| input_dim = hidden_size * len(layer_indices) * style_token_nums | |
| self.post_projection = nn.Linear(input_dim, self.post_projection_dim) | |
| def forward(self, siglip_outputs): | |
| """ | |
| Args: | |
| siglip_outputs: Output from a SigLIP model, which contains the `hidden_states`. | |
| Returns: | |
| torch.Tensor: The final projected features. | |
| """ | |
| first_module = next(iter(self.embedding_linears.values())) | |
| dtype = next(first_module.parameters()).dtype | |
| embeddings = [] | |
| for idx in self.layer_indices: | |
| name = str(idx) | |
| hidden_states = siglip_outputs.hidden_states[idx] | |
| embedding = self._process_layer_features( | |
| hidden_states, | |
| self.embedding_linears[name], | |
| self.layer_norms[name], | |
| self.projections[name], | |
| dtype | |
| ) | |
| embeddings.append(embedding) | |
| # Concatenate all embeddings along the token dimension | |
| embeddings = torch.cat(embeddings, dim=1) | |
| # If a post-projection layer is defined, apply it | |
| if self.post_projection is not None: | |
| # Flatten the tensor for the final linear layer | |
| bs = embeddings.shape[0] | |
| embeddings = embeddings.view(bs, -1) | |
| embeddings = self.post_projection(embeddings) | |
| return embeddings | |
| def load_proj_model(self, checkpoint_path): | |
| """Loads weights for the projection model from a .pt or .safetensors checkpoint.""" | |
| proj_model_name = 'proj_model.' | |
| model_state_dict = {} | |
| # 判断文件类型 | |
| ext = os.path.splitext(checkpoint_path)[1].lower() | |
| if ext == ".safetensors": | |
| # 直接从 .safetensors 文件加载 | |
| all_state_dict = load_file(checkpoint_path) | |
| else: | |
| # 默认按 PyTorch checkpoint 方式加载 | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| all_state_dict = checkpoint.get("model_state_dict", checkpoint) | |
| # 处理键名 | |
| for k, v in all_state_dict.items(): | |
| if k.startswith(proj_model_name): | |
| k = k.replace(proj_model_name, '') | |
| model_state_dict[k] = v | |
| # 加载参数 | |
| missing, unexpected = self.load_state_dict(model_state_dict, strict=False) | |
| if missing or unexpected: | |
| print(f"[Warning] Missing keys: {missing}, Unexpected keys: {unexpected}") | |
| def _process_layer_features( | |
| self, | |
| hidden_states: torch.Tensor, | |
| embedding_linear: nn.Module, | |
| layer_norm: nn.Module, | |
| projection: nn.Module, | |
| dtype: torch.dtype | |
| ) -> torch.Tensor: | |
| """Helper function to process features from a single layer.""" | |
| # Adjust token numbers | |
| embedding = embedding_linear( | |
| hidden_states.to(dtype).transpose(1, 2) | |
| ).transpose(1, 2) | |
| # Normalize and project | |
| embedding = layer_norm(embedding) | |
| embedding = projection(embedding) | |
| return embedding | |