File size: 1,474 Bytes
002bd9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.auto.configuration_auto import AutoConfig
import copy

logger = logging.get_logger(__name__)


class SAMCaptionerConfig(PretrainedConfig):
    model_type = "sam-captioner"
    is_composition = True

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        sam_config = kwargs.pop("sam", None)
        sam_model_type = sam_config.pop("model_type", None)
        captioner_config = kwargs.pop("captioner", None)
        captioner_model_type = captioner_config.pop("model_type", None)

        self.use_vcot = kwargs.pop("use_vcot", False)
        self.dtype = kwargs.pop("dtype", "float32")

        self.sam = AutoConfig.for_model(sam_model_type, **sam_config)
        self.captioner = AutoConfig.for_model(captioner_model_type, **captioner_config)

    @classmethod
    def from_sam_captioner_configs(
        cls, sam_config: PretrainedConfig, captioner_config: PretrainedConfig, dtype, use_vcot, **kwargs
    ) -> PretrainedConfig:
        return cls(
            sam=sam_config.to_dict(), captioner=captioner_config.to_dict(), dtype=dtype, use_vcot=use_vcot, **kwargs
        )

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        output["sam"] = self.sam.to_dict()
        output["captioner"] = self.captioner.to_dict()
        output["model_type"] = self.__class__.model_type
        return output