|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.utils import logging |
|
|
from transformers.models.auto.configuration_auto import AutoConfig |
|
|
import copy |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class SAMCaptionerConfig(PretrainedConfig): |
|
|
model_type = "sam-captioner" |
|
|
is_composition = True |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
sam_config = kwargs.pop("sam", None) |
|
|
sam_model_type = sam_config.pop("model_type", None) |
|
|
captioner_config = kwargs.pop("captioner", None) |
|
|
captioner_model_type = captioner_config.pop("model_type", None) |
|
|
|
|
|
self.use_vcot = kwargs.pop("use_vcot", False) |
|
|
self.dtype = kwargs.pop("dtype", "float32") |
|
|
|
|
|
self.sam = AutoConfig.for_model(sam_model_type, **sam_config) |
|
|
self.captioner = AutoConfig.for_model(captioner_model_type, **captioner_config) |
|
|
|
|
|
@classmethod |
|
|
def from_sam_captioner_configs( |
|
|
cls, sam_config: PretrainedConfig, captioner_config: PretrainedConfig, dtype, use_vcot, **kwargs |
|
|
) -> PretrainedConfig: |
|
|
return cls( |
|
|
sam=sam_config.to_dict(), captioner=captioner_config.to_dict(), dtype=dtype, use_vcot=use_vcot, **kwargs |
|
|
) |
|
|
|
|
|
def to_dict(self): |
|
|
output = copy.deepcopy(self.__dict__) |
|
|
output["sam"] = self.sam.to_dict() |
|
|
output["captioner"] = self.captioner.to_dict() |
|
|
output["model_type"] = self.__class__.model_type |
|
|
return output |
|
|
|