M4CXR-TNNLS / common_layers.py
jonggwon-park's picture
debug import
795e71e
"""Modified from https://github.com/khanrc/honeybee
"""
import torch
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from .configuration_m4cxr import HoneybeeConfig
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor):
output = torch.nn.functional.layer_norm(
x.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(x)
class HoneybeePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = HoneybeeConfig
base_model_prefix = "honeybee"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = [
"CLIPEncoderLayer",
"LlamaDecoderLayer",
"HoneybeeVisualProjectorLayer",
"LlamaForCausalLM",
"Parameter",
]
_keep_in_fp32_modules = ["wo"]
def _init_weights(self, module):
"""Initialize the weights"""
if (
isinstance(module, nn.Conv2d) # noqa: SIM101
or isinstance(module, nn.Embedding)
or isinstance(module, nn.Linear)
):
module.weight.data.normal_(mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Parameter):
raise ValueError()