| """Modified from https://github.com/khanrc/honeybee |
| """ |
|
|
| import torch |
| from torch import nn |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from .configuration_m4cxr import HoneybeeConfig |
|
|
|
|
| class QuickGELU(nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
|
|
| class LayerNormFp32(nn.LayerNorm): |
| """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, x: torch.Tensor): |
| output = torch.nn.functional.layer_norm( |
| x.float(), |
| self.normalized_shape, |
| self.weight.float() if self.weight is not None else None, |
| self.bias.float() if self.bias is not None else None, |
| self.eps, |
| ) |
| return output.type_as(x) |
|
|
|
|
| class HoneybeePreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and |
| a simple interface for downloading and loading pretrained models. |
| """ |
|
|
| config_class = HoneybeeConfig |
| base_model_prefix = "honeybee" |
| supports_gradient_checkpointing = True |
|
|
| _keys_to_ignore_on_load_missing = [ |
| r"position_ids", |
| r"language_model.encoder.embed_tokens.weight", |
| r"language_model.decoder.embed_tokens.weight", |
| r"language_model.lm_head.weight", |
| ] |
| _no_split_modules = [ |
| "CLIPEncoderLayer", |
| "LlamaDecoderLayer", |
| "HoneybeeVisualProjectorLayer", |
| "LlamaForCausalLM", |
| "Parameter", |
| ] |
| _keep_in_fp32_modules = ["wo"] |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if ( |
| isinstance(module, nn.Conv2d) |
| or isinstance(module, nn.Embedding) |
| or isinstance(module, nn.Linear) |
| ): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if hasattr(module, "bias") and module.bias is not None: |
| module.bias.data.zero_() |
|
|
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| elif isinstance(module, nn.Parameter): |
| raise ValueError() |
|
|