File size: 829 Bytes
97bc03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import json
from attrdict2 import AttrDict
from transformers.configuration_utils import PretrainedConfig
def load_config_from_json(json_path):
with open(json_path, "r") as f:
config_data = json.load(f)
return config_data
class STARMultiModalConfig(PretrainedConfig):
model_type = "STARMultiModal"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.pixel_encoder = AttrDict(kwargs.get("pixel_encoder", {}))
self.pixel_adapter = AttrDict(kwargs.get("pixel_adapter", {}))
self.pixel_output_head = AttrDict(kwargs.get("pixel_output_head", {}))
self.language_model = AttrDict(kwargs.get("language_model", {}))
self.stacked_ar = AttrDict(kwargs.get("stacked_ar", {}))
self.pixel_decoder = AttrDict(kwargs.get("pixel_decoder", {}))
|