| from dataclasses import dataclass |
| from transformers import CLIPModel as HFCLIPModel |
| from transformers import AutoTokenizer |
|
|
| from torch import nn, einsum |
|
|
| from .base_model import BaseModelConfig |
|
|
| from transformers import CLIPConfig |
| from typing import Any, Optional, Tuple, Union |
| import torch |
|
|
| from .cross_modeling import Cross_model |
|
|
| import json, os |
|
|
| class XCLIPModel(HFCLIPModel): |
| def __init__(self, config: CLIPConfig): |
| super().__init__(config) |
| |
| def get_text_features( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> torch.FloatTensor: |
|
|
| |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| text_outputs = self.text_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| |
| last_hidden_state = text_outputs[0] |
| text_features = self.text_projection(last_hidden_state) |
|
|
| pooled_output = text_outputs[1] |
| text_features_EOS = self.text_projection(pooled_output) |
|
|
|
|
| |
| |
|
|
| return text_features, text_features_EOS |
|
|
| def get_image_features( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> torch.FloatTensor: |
| |
| |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| vision_outputs = self.vision_model( |
| pixel_values=pixel_values, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| |
| last_hidden_state = vision_outputs[0] |
| image_features = self.visual_projection(last_hidden_state) |
|
|
| return image_features |
|
|
|
|
|
|
| @dataclass |
| class ClipModelConfig(BaseModelConfig): |
| _target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel" |
| pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32" |
|
|
|
|
| class CLIPModel(nn.Module): |
| def __init__(self, ckpt, config_file=False): |
| super().__init__() |
| if config_file is None: |
| self.model = XCLIPModel.from_pretrained(ckpt) |
| else: |
| with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f: |
| config = json.load(f) |
| config = CLIPConfig(**config) |
| self.model = XCLIPModel._from_config(config) |
| self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16) |
| |
| def get_text_features(self, *args, **kwargs): |
| return self.model.get_text_features(*args, **kwargs) |
|
|
| def get_image_features(self, *args, **kwargs): |
| return self.model.get_image_features(*args, **kwargs) |
|
|
| def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None): |
| outputs = () |
|
|
| text_f, text_EOS = self.model.get_text_features(text_inputs) |
| outputs += text_EOS, |
|
|
| image_f = self.model.get_image_features(image_inputs.half()) |
| condition_f, _ = self.model.get_text_features(condition_inputs) |
|
|
| sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) |
| sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] |
| sim_text_condition = sim_text_condition / sim_text_condition.max() |
| mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) |
|
|
| mask = mask.repeat(1,image_f.shape[1],1) |
| bc = int(image_f.shape[0]/2) |
|
|
| sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half()) |
| sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half()) |
| outputs += sim0[:,0,:], |
| outputs += sim1[:,0,:], |
|
|
| return outputs |
|
|
| @property |
| def logit_scale(self): |
| return self.model.logit_scale |
|
|
| def save(self, path): |
| self.model.save_pretrained(path) |
|
|
|
|