| from typing import Union, Optional | |
| from transformers import PretrainedConfig, AutoConfig | |
| from .visual_tokenizer import ClipVisualTokenizerConfig | |
| class OvisConfig(PretrainedConfig): | |
| model_type = "ovis" | |
| def __init__(self, | |
| llm_config: Optional[Union[PretrainedConfig, dict]] = None, | |
| visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, | |
| multimodal_max_length=2048, | |
| hidden_size=None, | |
| conversation_formatter_class=None, | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| if llm_config is not None: | |
| assert isinstance(llm_config, (PretrainedConfig, dict)), \ | |
| f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" | |
| if not isinstance(llm_config, PretrainedConfig): | |
| model_type = llm_config['model_type'] | |
| llm_config.pop('model_type') | |
| llm_config = AutoConfig.for_model(model_type, **llm_config) | |
| self.llm_config = llm_config | |
| if visual_tokenizer_config is not None: | |
| assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ | |
| f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" | |
| if not isinstance(visual_tokenizer_config, PretrainedConfig): | |
| model_type = visual_tokenizer_config['model_type'] | |
| visual_tokenizer_config.pop('model_type') | |
| visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config) | |
| self.visual_tokenizer_config = visual_tokenizer_config | |
| self.multimodal_max_length = multimodal_max_length | |
| self.hidden_size = hidden_size | |
| self.conversation_formatter_class = conversation_formatter_class | |