deepspeed / src /models /sam_captioner /configuration_sam_captioner.py
xingzhikb's picture
init
002bd9b
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.auto.configuration_auto import AutoConfig
import copy
logger = logging.get_logger(__name__)
class SAMCaptionerConfig(PretrainedConfig):
model_type = "sam-captioner"
is_composition = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
sam_config = kwargs.pop("sam", None)
sam_model_type = sam_config.pop("model_type", None)
captioner_config = kwargs.pop("captioner", None)
captioner_model_type = captioner_config.pop("model_type", None)
self.use_vcot = kwargs.pop("use_vcot", False)
self.dtype = kwargs.pop("dtype", "float32")
self.sam = AutoConfig.for_model(sam_model_type, **sam_config)
self.captioner = AutoConfig.for_model(captioner_model_type, **captioner_config)
@classmethod
def from_sam_captioner_configs(
cls, sam_config: PretrainedConfig, captioner_config: PretrainedConfig, dtype, use_vcot, **kwargs
) -> PretrainedConfig:
return cls(
sam=sam_config.to_dict(), captioner=captioner_config.to_dict(), dtype=dtype, use_vcot=use_vcot, **kwargs
)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output["sam"] = self.sam.to_dict()
output["captioner"] = self.captioner.to_dict()
output["model_type"] = self.__class__.model_type
return output