| | """ huggingface model adapter |
| | |
| | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. |
| | """ |
| |
|
| | import re |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import TensorType |
| |
|
| | try: |
| | import transformers |
| | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig |
| | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ |
| | BaseModelOutputWithPoolingAndCrossAttentions |
| | except ImportError as e: |
| | transformers = None |
| |
|
| |
|
| | class BaseModelOutput: |
| | pass |
| |
|
| |
|
| | class PretrainedConfig: |
| | pass |
| |
|
| | from .hf_configs import arch_dict |
| |
|
| |
|
| | |
| | def _camel2snake(s): |
| | return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower() |
| |
|
| |
|
| | |
| | _POOLERS = {} |
| |
|
| |
|
| | def register_pooler(cls): |
| | """Decorator registering pooler class""" |
| | _POOLERS[_camel2snake(cls.__name__)] = cls |
| | return cls |
| |
|
| |
|
| | @register_pooler |
| | class MeanPooler(nn.Module): |
| | """Mean pooling""" |
| |
|
| | def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| | masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1) |
| | return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True) |
| |
|
| |
|
| | @register_pooler |
| | class MaxPooler(nn.Module): |
| | """Max pooling""" |
| |
|
| | def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| | masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf) |
| | return masked_output.max(1).values |
| |
|
| |
|
| | @register_pooler |
| | class ClsPooler(nn.Module): |
| | """CLS token pooling""" |
| |
|
| | def __init__(self, use_pooler_output=True): |
| | super().__init__() |
| | self.cls_token_position = 0 |
| | self.use_pooler_output = use_pooler_output |
| |
|
| | def forward(self, x: BaseModelOutput, attention_mask: TensorType): |
| | if (self.use_pooler_output and |
| | isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and |
| | (x.pooler_output is not None) |
| | ): |
| | return x.pooler_output |
| |
|
| | return x.last_hidden_state[:, self.cls_token_position, :] |
| |
|
| |
|
| | class HFTextEncoder(nn.Module): |
| | """HuggingFace model adapter""" |
| | output_tokens: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | model_name_or_path: str, |
| | output_dim: int, |
| | config: PretrainedConfig = None, |
| | pooler_type: str = None, |
| | proj: str = None, |
| | pretrained: bool = True, |
| | output_tokens: bool = False, |
| | ): |
| | super().__init__() |
| | self.output_tokens = output_tokens |
| | self.output_dim = output_dim |
| |
|
| | |
| | uses_transformer_pooler = (pooler_type == "cls_pooler") |
| |
|
| | if transformers is None: |
| | raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models") |
| | if config is None: |
| | self.config = AutoConfig.from_pretrained(model_name_or_path) |
| | create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else ( |
| | AutoModel.from_config, self.config) |
| | |
| | if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder: |
| | self.transformer = create_func(model_args) |
| | self.transformer = self.transformer.encoder |
| | else: |
| | self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler) |
| | else: |
| | self.config = config |
| | self.transformer = AutoModel.from_config(config) |
| | if pooler_type is None: |
| | pooler_type = (arch_dict[self.config.model_type]["pooler"]) |
| | |
| | self.pooler = _POOLERS[pooler_type]() |
| |
|
| | d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"]) |
| | if (d_model == output_dim) and (proj is None): |
| | self.proj = nn.Identity() |
| | elif proj == 'linear': |
| | self.proj = nn.Linear(d_model, output_dim, bias=False) |
| | elif proj == 'mlp': |
| | hidden_size = (d_model + output_dim) // 2 |
| | self.proj = nn.Sequential( |
| | nn.Linear(d_model, hidden_size, bias=False), |
| | nn.GELU(), |
| | nn.Linear(hidden_size, output_dim, bias=False), |
| | ) |
| |
|
| | def forward(self, x: TensorType): |
| | attn_mask = (x != self.config.pad_token_id).long() |
| | out = self.transformer(input_ids=x, attention_mask=attn_mask) |
| | pooled_out = self.pooler(out, attn_mask) |
| | projected = self.proj(pooled_out) |
| |
|
| | seq_len = out.last_hidden_state.shape[1] |
| | tokens = ( |
| | out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] |
| | if type(self.pooler) == ClsPooler |
| | else out.last_hidden_state |
| | ) |
| | |
| | if self.output_tokens: |
| | return projected, tokens |
| | return projected |
| |
|
| | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): |
| | if not unlocked_layers: |
| | for n, p in self.transformer.named_parameters(): |
| | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False |
| | return |
| |
|
| | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer |
| | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) |
| | print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") |
| | embeddings = getattr( |
| | self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) |
| | modules = [embeddings, *layer_list][:-unlocked_layers] |
| | |
| | for module in modules: |
| | for n, p in module.named_parameters(): |
| | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | self.transformer.gradient_checkpointing_enable() |
| |
|
| | def init_parameters(self): |
| | pass |
| |
|