| | from transformers import AutoConfig, AutoModel, PretrainedConfig, CLIPTextConfig, CLIPVisionConfig, PreTrainedModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection |
| | from transformers.utils import ModelOutput |
| | import torch |
| | import open_clip |
| | from dataclasses import dataclass |
| | import safetensors.torch |
| | from peft import get_peft_config, get_peft_model, LoraConfig, TaskType |
| | import os |
| |
|
| | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" |
| | HF_SAFE_WEIGHTS_NAME_PRIOR = "prior_model.safetensors" |
| |
|
| | @dataclass |
| | class PriorTransformerOutput(ModelOutput): |
| | """ |
| | The output of [`PriorTransformer`]. |
| | |
| | Args: |
| | predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): |
| | The predicted CLIP image embedding conditioned on the CLIP text embedding input. |
| | """ |
| |
|
| | predicted_image_embedding: torch.FloatTensor |
| |
|
| | @dataclass |
| | class TextEncoderOutput(ModelOutput): |
| | """ |
| | Output class for CLIPTextEncoderOnly model to store the outputs in a Hugging Face transformer style. |
| | |
| | Attributes: |
| | prompt_embeds (torch.Tensor): The embeddings of the input prompts. |
| | last_hidden_states (torch.Tensor): The last hidden states from the model. |
| | """ |
| | text_embeds: torch.FloatTensor = None |
| | last_hidden_state: torch.FloatTensor = None |
| |
|
| | class CLIPTextEncoderOnlyConfig(CLIPTextConfig): |
| | model_type = "clip_custom_text_model" |
| |
|
| | def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): |
| | self.model_name = model_name |
| | self.pretrained = pretrained |
| | self.frozen = frozen |
| | self.lora = lora |
| | super().__init__(**kwargs) |
| |
|
| | class CLIPTextEncoderOnly(PreTrainedModel): |
| | config_class = CLIPTextEncoderOnlyConfig |
| |
|
| | def __init__(self, config): |
| | """ |
| | Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| | |
| | :param model_name: The name or path of the pretrained model. |
| | :param pretrained: Whether to load the pretrained weights. |
| | """ |
| | super().__init__(config) |
| | |
| | if config.pretrained: |
| | self.model = CLIPTextModelWithProjection.from_pretrained(config.model_name) |
| | else: |
| | base_cfg = CLIPTextConfig.from_pretrained(config.model_name) |
| | self.model = CLIPTextModelWithProjection(base_cfg) |
| |
|
| | if config.lora: |
| | l_config = LoraConfig( |
| | r=config.lora.lora_r, |
| | lora_alpha=config.lora.lora_alpha, |
| | target_modules=[ |
| | "k_proj", |
| | "v_proj", |
| | "q_proj", |
| | "out_proj", |
| | "fc1", |
| | "fc2", |
| | "visual_projection", |
| | "text_projection" |
| | ], |
| | lora_dropout=config.lora.lora_dropout, |
| | bias="lora_only", |
| | ) |
| | self.model = get_peft_model(self.model, l_config) |
| | |
| |
|
| | def forward(self, input_ids, attention_mask=None, position_ids=None): |
| | """ |
| | Forward pass of the model. |
| | |
| | :param input_ids: Indices of input sequence tokens in the vocabulary. |
| | :param attention_mask: Mask to avoid performing attention on padding token indices. |
| | :param token_type_ids: Segment token indices to indicate first and second portions of the inputs. |
| | :return: Outputs of the model. |
| | """ |
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True) |
| | return TextEncoderOutput(text_embeds=outputs.text_embeds, last_hidden_state=outputs.last_hidden_state) |
| | |
| |
|
| | class CustomTextEncoderOnlyConfig(PretrainedConfig): |
| | model_type = "whole_custom_text_model" |
| |
|
| | def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, output_hidden_size: int = 512, last_hidden_state: bool = False, lora: dict = None, **kwargs): |
| | self.model_name = model_name |
| | self.pretrained = pretrained |
| | self.frozen = frozen |
| | self.output_hidden_size = output_hidden_size |
| | self.last_hidden_state = last_hidden_state |
| | self.lora = lora |
| | super().__init__(**kwargs) |
| |
|
| | class CustomTextEncoderOnly(PreTrainedModel): |
| | config_class = CustomTextEncoderOnlyConfig |
| |
|
| | def __init__(self, config): |
| | """ |
| | Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| | |
| | :param model_name: The name or path of the pretrained model. |
| | :param pretrained: Whether to load the pretrained weights. |
| | """ |
| | super().__init__(config) |
| |
|
| | self.last_hidden_state = config.last_hidden_state |
| |
|
| | if config.pretrained: |
| | self.model = AutoModel.from_pretrained(config.model_name) |
| | if config.frozen: |
| | for param in self.model.parameters(): |
| | param.requires_grad = False |
| | else: |
| | self.model = AutoModel(config) |
| |
|
| | self.fc1 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) |
| | if config.last_hidden_state: |
| | self.fc2 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) |
| |
|
| | if config.lora: |
| | l_config = LoraConfig( |
| | task_type=TaskType.FEATURE_EXTRACTION, |
| | r=config.lora.lora_r, |
| | lora_alpha=config.lora.lora_alpha, |
| | lora_dropout=config.lora.lora_dropout, |
| | bias="lora_only", |
| | ) |
| | self.model = get_peft_model(self.model, l_config) |
| |
|
| | def forward(self, input_ids, attention_mask=None, token_type_ids=None): |
| | """ |
| | Forward pass of the model. |
| | |
| | :param input_ids: Indices of input sequence tokens in the vocabulary. |
| | :param attention_mask: Mask to avoid performing attention on padding token indices. |
| | :param token_type_ids: Segment token indices to indicate first and second portions of the inputs. |
| | :return: Outputs of the model. |
| | """ |
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) |
| | text_embeds = self.fc1(outputs[1]) |
| | last_hidden_state = None |
| | if self.last_hidden_state: |
| | last_hidden_state = self.fc2(outputs[0]) |
| | else: |
| | last_hidden_state = outputs[0] |
| | return TextEncoderOutput(text_embeds=text_embeds, last_hidden_state=last_hidden_state) |
| |
|
| | class CLIPVisionEncoderOnlyConfig(PretrainedConfig): |
| | model_type = "clip_custom_vision_model" |
| |
|
| | def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): |
| | self.model_name = model_name |
| | self.pretrained = pretrained |
| | self.frozen = frozen |
| | self.lora = lora |
| | super().__init__(**kwargs) |
| |
|
| | class CLIPVisionEncoderOnly(PreTrainedModel): |
| | config_class = CLIPVisionEncoderOnlyConfig |
| |
|
| | def __init__(self, config): |
| | """ |
| | Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| | |
| | :param model_name: The name or path of the pretrained model. |
| | :param pretrained: Whether to load the pretrained weights. |
| | """ |
| | super().__init__(config) |
| | |
| | if config.pretrained: |
| | self.model = CLIPVisionModelWithProjection.from_pretrained(config.model_name) |
| | else: |
| | base_cfg = CLIPVisionConfig.from_pretrained(config.model_name) |
| | self.model = CLIPVisionModelWithProjection(base_cfg) |
| |
|
| | if config.lora: |
| | l_config = LoraConfig( |
| | r=config.lora.lora_r, |
| | lora_alpha=config.lora.lora_alpha, |
| | target_modules=[ |
| | "k_proj", |
| | "v_proj", |
| | "q_proj", |
| | "out_proj", |
| | "fc1", |
| | "fc2", |
| | "visual_projection", |
| | "text_projection" |
| | ], |
| | lora_dropout=config.lora.lora_dropout, |
| | bias="lora_only", |
| | ) |
| | self.model = get_peft_model(self.model, l_config) |
| |
|
| | def forward(self, data): |
| | """ |
| | Forward pass of the model. |
| | """ |
| | return self.model(**data).image_embeds |
| | |
| | def parameters(self): |
| | return self.model.parameters() |
| |
|
| |
|
| | class OpenCLIPVisionEncoderOnly(torch.nn.Module): |
| | def __init__(self, model_name: str, pretrained: bool = True, frozen: bool = False, lora: dict = None): |
| | """ |
| | Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| | |
| | :param model_name: The name or path of the pretrained model. |
| | :param pretrained: Whether to load the pretrained weights. |
| | """ |
| | super().__init__() |
| | if pretrained: |
| | model, _ = open_clip.create_model_from_pretrained(f"hf-hub:{model_name}") |
| | model = model.visual |
| | else: |
| | raise NotImplemented |
| | self.model = model |
| |
|
| | if lora: |
| | l_config = LoraConfig( |
| | r=lora.lora_r, |
| | lora_alpha=lora.lora_alpha, |
| | target_modules=[ |
| | "k_proj", |
| | "v_proj", |
| | "q_proj", |
| | "out_proj", |
| | "fc1", |
| | "fc2", |
| | "visual_projection", |
| | "text_projection" |
| | ], |
| | lora_dropout=lora.lora_dropout, |
| | bias="lora_only", |
| | ) |
| | self.model = get_peft_model(self.model, l_config) |
| |
|
| | def forward(self, image): |
| | """ |
| | Forward pass of the model. |
| | """ |
| | return self.model(image) |
| | |
| | def save_pretrained(self, save_dir): |
| | tensors = self.model.state_dict() |
| | safetensors.torch.save_file(tensors, save_dir / HF_SAFE_WEIGHTS_NAME) |
| |
|
| | class CustomPriorModel(torch.nn.Module): |
| | def __init__(self, in_hidden_state, out_hidden_state): |
| | """ |
| | Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| | |
| | :param model_name: The name or path of the pretrained model. |
| | :param pretrained: Whether to load the pretrained weights. |
| | """ |
| | super().__init__() |
| | mid_hidden_state = max(in_hidden_state, out_hidden_state) |
| |
|
| | self.fc1 = torch.nn.Linear(in_hidden_state*2, mid_hidden_state) |
| | self.relu = torch.nn.ReLU() |
| | self.fc2 = torch.nn.Linear(mid_hidden_state, out_hidden_state) |
| | |
| | def reinitialize_model(self): |
| | for name, param in self.named_parameters(): |
| | if param.requires_grad: |
| | if len(param.shape) > 1: |
| | torch.nn.init.xavier_uniform_(param) |
| | else: |
| | if 'weight' in name: |
| | torch.nn.init.normal_(param) |
| | else: |
| | torch.nn.init.zeros_(param) |
| |
|
| | def forward(self, feats): |
| | """ |
| | Forward pass of the model. |
| | """ |
| | return PriorTransformerOutput(predicted_image_embedding=self.fc2(self.relu(self.fc1(feats)))) |
| | |
| | def save_pretrained(self, save_dir): |
| | pass |
| | |
| | |
| |
|
| |
|
| | def test_text_model(register=False, upload=False): |
| | |
| | if register: |
| | AutoConfig.register("clip_custom_text_model", CLIPTextEncoderOnlyConfig) |
| | AutoModel.register(CLIPTextEncoderOnlyConfig, CLIPTextEncoderOnly) |
| | CLIPTextEncoderOnlyConfig.register_for_auto_class() |
| | CLIPTextEncoderOnly.register_for_auto_class("AutoModel") |
| |
|
| | if upload: |
| | |
| | model_name = "openai/clip-vit-base-patch32" |
| | pretrained=True |
| | lora=None |
| |
|
| | cfg = CLIPTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) |
| | model = CLIPTextEncoderOnly(cfg) |
| | model.push_to_hub("test-text-hf-upload") |
| |
|
| | model = CLIPTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) |
| |
|
| | def test_custom_text_model(register=False, upload=False): |
| | |
| | if register: |
| | AutoConfig.register("whole_custom_text_model", CustomTextEncoderOnlyConfig) |
| | AutoModel.register(CustomTextEncoderOnlyConfig, CustomTextEncoderOnly) |
| | CustomTextEncoderOnlyConfig.register_for_auto_class() |
| | CustomTextEncoderOnly.register_for_auto_class("AutoModel") |
| |
|
| | if upload: |
| | |
| | model_name = "google-bert/bert-base-uncased" |
| | pretrained=True |
| | frozen=False |
| | output_hidden_size=512 |
| | last_hidden_state=False |
| |
|
| | lora=None |
| |
|
| | cfg = CustomTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, frozen=frozen, output_hidden_size=output_hidden_size, last_hidden_state=last_hidden_state, lora=lora) |
| | model = CustomTextEncoderOnly(cfg) |
| | model.push_to_hub("test-text-hf-upload") |
| |
|
| | model = CustomTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) |
| |
|
| | def test_vision_model(register=False, upload=False): |
| | |
| | if register: |
| | AutoConfig.register("clip_custom_vision_model", CLIPVisionEncoderOnlyConfig) |
| | AutoModel.register(CLIPVisionEncoderOnlyConfig, CLIPVisionEncoderOnly) |
| | CLIPVisionEncoderOnlyConfig.register_for_auto_class() |
| | CLIPVisionEncoderOnly.register_for_auto_class("AutoModel") |
| |
|
| | if upload: |
| | |
| | model_name = "openai/clip-vit-base-patch32" |
| | pretrained=True |
| | lora=None |
| |
|
| | cfg = CLIPVisionEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) |
| | model = CLIPVisionEncoderOnly(cfg) |
| | model.push_to_hub("test-vision-hf-upload") |
| |
|
| | model = CLIPVisionEncoderOnly.from_pretrained("mpatel57/test-vision-hf-upload", force_download=True) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | test_custom_text_model(register=False, upload=True) |
| | |
| | |
| |
|