Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| # from transformers.modeling_utils import PreTrainedModel | |
| from diffusers.configuration_utils import register_to_config, ConfigMixin | |
| from unipicv2.modeling_connector import ConnectorEncoder | |
| from unipicv2.configuration_connector import ConnectorConfig | |
| from diffusers.models.modeling_utils import ModelMixin | |
| class StableDiffusion3Conditioner(ModelMixin, ConfigMixin): | |
| model_type: str = "sd3_conditioner" # stored into config for hub niceties | |
| def __init__( | |
| self, | |
| connector_config: dict, # dict passed to ConnectorConfig(**connector) | |
| num_queries: int = 256, | |
| llm_hidden_size: int = 3584, | |
| pooled_projection_dim: int = 2048, | |
| joint_attention_dim: int = 4096, | |
| ): | |
| super().__init__() | |
| self.connector = ConnectorEncoder(ConnectorConfig(**connector_config)) | |
| self.projector_1 = nn.Linear(llm_hidden_size, self.connector.config.hidden_size) | |
| self.projector_2 = nn.Linear(self.connector.config.hidden_size, pooled_projection_dim) | |
| self.projector_3 = nn.Linear(self.connector.config.hidden_size, joint_attention_dim) | |
| self.meta_queries = nn.Parameter(torch.zeros(num_queries, llm_hidden_size)) | |
| def _init_weights(self, module): | |
| pass | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| x: (batch, seq_len, llm_hidden_size) | |
| Returns: | |
| prompt_embeds: (batch, seq_len, joint_attention_dim) | |
| pooled_prompt_embeds: (batch, pooled_projection_dim) | |
| """ | |
| x = self.projector_1(x) | |
| x = self.connector(x) # expects (B, L, hidden) | |
| pooled_prompt_embeds = self.projector_2(x.mean(1)) | |
| prompt_embeds = self.projector_3(x) | |
| return prompt_embeds, pooled_prompt_embeds | |
| if __name__ == "__main__": | |
| import torch | |
| import argparse | |
| import os | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", type=str, default=None) | |
| parser.add_argument("--output", type=str, default=None) | |
| args = parser.parse_args() | |
| pretrained_model_name_or_path = "stabilityai/stable-diffusion-3.5-medium" | |
| conditioner = StableDiffusion3Conditioner( | |
| num_queries=256, | |
| connector_config=dict( | |
| hidden_size=1536, | |
| intermediate_size=8960, | |
| num_hidden_layers=24, | |
| _attn_implementation='flash_attention_2', | |
| num_attention_heads=24, ), | |
| llm_hidden_size=3584, | |
| pooled_projection_dim=2048, | |
| joint_attention_dim=4096, | |
| ).bfloat16() | |
| checkpoint = torch.load(args.checkpoint) | |
| info = conditioner.load_state_dict(checkpoint, strict=False) | |
| import pdb; pdb.set_trace() | |
| os.makedirs(args.output, exist_ok=True) | |
| conditioner.save_pretrained(args.output) | |