deepspeed / src /models /sca /configuration_sca.py
xingzhikb's picture
init
002bd9b
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import logging
from transformers.models.auto.configuration_auto import AutoConfig
from ..sam.configuration_sam import (
SamPromptEncoderConfig,
SamVisionConfig,
SAM_PRETRAINED_CONFIG_ARCHIVE_MAP,
SamConfig,
)
from transformers.models.auto import CONFIG_MAPPING
import copy
from typing import Optional
logger = logging.get_logger(__name__)
class ScaMaskCaptionDecoderConfig(PretrainedConfig):
def __init__(
self,
hidden_size=256,
hidden_act="relu",
mlp_dim=2048,
num_hidden_layers=2,
num_attention_heads=8,
attention_downsample_rate=2,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=256,
layer_norm_eps=1e-6,
# NOTE(xiaoke): for captioning
# NOTE: Remember to change `from_sam_text_configs` as well!
additional_num_hidden_layers: int = 2,
num_caption_tokens: int = 1,
num_caption_heads: int = 1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_dim = mlp_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_downsample_rate = attention_downsample_rate
self.num_multimask_outputs = num_multimask_outputs
self.iou_head_depth = iou_head_depth
self.iou_head_hidden_dim = iou_head_hidden_dim
self.layer_norm_eps = layer_norm_eps
# NOTE(xiaoke): additional_num_hidden_layers used in transformers layers to further fuse features
self.additional_num_hidden_layers = additional_num_hidden_layers
self.num_caption_tokens = num_caption_tokens
self.num_caption_heads = num_caption_heads
class ScaConfig(PretrainedConfig):
model_type = "sca"
is_composition = True
def __init__(
self,
vision_config=None,
prompt_encoder_config=None,
mask_caption_decoder_config=None,
text_config=None,
text_config_cls=None,
initializer_range=0.02,
# NOTE: for recoginition pretrain
num_task_tokens: int = 6,
**kwargs,
):
super().__init__(**kwargs)
vision_config = vision_config if vision_config is not None else {}
prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
mask_caption_decoder_config = mask_caption_decoder_config if mask_caption_decoder_config is not None else {}
text_config = text_config if text_config is not None else {}
if isinstance(vision_config, SamVisionConfig):
self.vision = vision_config.to_dict()
if isinstance(prompt_encoder_config, SamPromptEncoderConfig):
self.prompt_encoder = prompt_encoder_config.to_dict()
if isinstance(mask_caption_decoder_config, ScaMaskCaptionDecoderConfig):
self.mask_caption_decoder_config = mask_caption_decoder_config.to_dict()
text_model_type = text_config["model_type"] if "model_type" in text_config else "gpt2"
try:
# NOTE(xiaoke): use_decoder_only_language_model only return the model class like GPT2, rather the task model class
# like GPT2forCausalLM. We need the task model class to load the pretrained weights for the task.
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
except KeyError:
if text_config_cls is None:
raise ValueError(f"Unrecognized text model type: {text_model_type}")
logger.warning(f"use external config cls: {text_config_cls} for text model type: {text_model_type}")
self.text_config = text_config_cls(**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
if self.use_decoder_only_language_model is False:
# NOTE: External models like stablelm-zephyr-3b is not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.
# We need to check the architecture name to see if it is a decoder only language model.
if len(text_config["architectures"]) > 1:
raise ValueError(f"Only support one architecture in text_config, got {text_config['architectures']}")
lm_architecture = text_config["architectures"][0]
if lm_architecture.endswith("ForCausalLM"):
self.use_decoder_only_language_model = True
self.vision_config = SamVisionConfig(**vision_config)
self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
self.mask_caption_decoder_config = ScaMaskCaptionDecoderConfig(**mask_caption_decoder_config)
self.initializer_range = initializer_range
self.num_task_tokens = num_task_tokens
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["prompt_encoder_config"] = self.prompt_encoder_config.to_dict()
output["mask_caption_decoder_config"] = self.mask_caption_decoder_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
@classmethod
def from_sam_text_configs(
cls,
sam_config: SamConfig,
text_config: Optional[PretrainedConfig] = None,
additional_num_hidden_layers: Optional[int] = None,
num_caption_tokens: Optional[int] = None,
num_task_tokens: Optional[int] = None,
num_caption_heads: Optional[int] = None,
vl_projector_type: Optional[str] = None,
vl_projector_norm_type: Optional[str] = None,
**kwargs,
):
if additional_num_hidden_layers is None:
logger.warning("additional_num_hidden_layers is not set, using default value: 2. Make sure it is correct!")
additional_num_hidden_layers = 2
if num_caption_tokens is None:
logger.warning("num_caption_tokens is not set, using default value: 1. Make sure it is correct!")
num_caption_tokens: int = 1
if num_task_tokens is None:
logger.warning("num_task_tokens is not set, using default value: 6. Make sure it is correct!")
num_task_tokens = 6
if num_caption_heads is None:
logger.warning("num_caption_heads is not set, using default value: 1. Make sure it is correct!")
num_caption_heads = 1
if vl_projector_type is None:
logger.warning("vl_projector_type is not set, using default value: linear. Make sure it is correct!")
vl_projector_type = "linear"
if vl_projector_norm_type is None:
logger.warning("vl_projector_norm_type is not set, using default value: none. Make sure it is correct!")
vl_projector_norm_type = "none"
return cls(
vision_config=sam_config.vision_config.to_dict(),
prompt_encoder_config=sam_config.prompt_encoder_config.to_dict(),
mask_caption_decoder_config={
**sam_config.mask_decoder_config.to_dict(),
"additional_num_hidden_layers": additional_num_hidden_layers,
"num_caption_tokens": num_caption_tokens,
"num_caption_heads": num_caption_heads,
},
text_config=text_config.to_dict() if text_config is not None else None,
text_config_cls=type(text_config) if text_config is not None else None,
num_task_tokens=num_task_tokens,
vl_projector_type=vl_projector_type,
vl_projector_norm_type=vl_projector_norm_type,
**kwargs,
)
class ScaTimmConfig(ScaConfig):
model_type = "sca_timm"
is_composition = True
def __init__(
self,
timm_vision_name=None,
vision_config=None,
prompt_encoder_config=None,
mask_caption_decoder_config=None,
text_config=None,
initializer_range=0.02,
# NOTE: for recoginition pretrain
num_task_tokens: int = 6,
**kwargs,
):
super().__init__(
vision_config=vision_config,
prompt_encoder_config=prompt_encoder_config,
mask_caption_decoder_config=mask_caption_decoder_config,
text_config=text_config,
initializer_range=initializer_range,
num_task_tokens=num_task_tokens,
**kwargs,
)
timm_vision_name = timm_vision_name if timm_vision_name is not None else "vit_base_patch16_clip_224.openai"
if isinstance(timm_vision_name, str):
self.timm_vision_name = timm_vision_name
def to_dict(self):
output = super().to_dict()
output["timm_vision_name"] = self.timm_vision_name
return output
@classmethod
def from_sam_timm_text_configs(
cls,
timm_vision_name: str,
sam_config: SamConfig,
text_config: Optional[PretrainedConfig] = None,
additional_num_hidden_layers: Optional[int] = None,
num_caption_tokens: Optional[int] = None,
num_task_tokens: Optional[int] = None,
num_caption_heads: Optional[int] = None,
vl_projector_type: Optional[str] = None,
vl_projector_norm_type: Optional[str] = None,
**kwargs,
):
if additional_num_hidden_layers is None:
logger.warning("additional_num_hidden_layers is not set, using default value: 2. Make sure it is correct!")
additional_num_hidden_layers = 2
if num_caption_tokens is None:
logger.warning("num_caption_tokens is not set, using default value: 1. Make sure it is correct!")
num_caption_tokens: int = 1
if num_task_tokens is None:
logger.warning("num_task_tokens is not set, using default value: 6. Make sure it is correct!")
num_task_tokens = 6
if num_caption_heads is None:
logger.warning("num_caption_heads is not set, using default value: 1. Make sure it is correct!")
num_caption_heads = 1
if vl_projector_type is None:
logger.warning("vl_projector_type is not set, using default value: linear. Make sure it is correct!")
vl_projector_type = "linear"
if vl_projector_norm_type is None:
logger.warning("vl_projector_norm_type is not set, using default value: none. Make sure it is correct!")
vl_projector_norm_type = "none"
return cls(
timm_vision_name=timm_vision_name,
vision_config=sam_config.vision_config.to_dict(),
prompt_encoder_config=sam_config.prompt_encoder_config.to_dict(),
mask_caption_decoder_config={
**sam_config.mask_decoder_config.to_dict(),
"additional_num_hidden_layers": additional_num_hidden_layers,
"num_caption_tokens": num_caption_tokens,
"num_caption_heads": num_caption_heads,
},
text_config=text_config.to_dict() if text_config is not None else None,
num_task_tokens=num_task_tokens,
vl_projector_type=vl_projector_type,
vl_projector_norm_type=vl_projector_norm_type,
**kwargs,
)