import json import os from typing import Any, Dict from transformers import PretrainedConfig class M2EncoderConfig(PretrainedConfig): model_type = "m2_encoder" def __init__( self, loss_names=None, beit_version="large", encoder_embed_dim=1024, out_embed_dim=1024, encoder_layers=21, beit3_vl_layers=3, image_size=224, visual_mask_size=14, tokenizer_type="GLMChineseTokenizer", tokenizer=".", vocab_size=115244, whole_word_masking=False, precision=32, test_only=True, flash_attn=False, model_file="m2_encoder_1B.safetensors", architectures=None, auto_map=None, **kwargs, ): super().__init__(**kwargs) self.loss_names = loss_names or {"itc": 1} self.beit_version = beit_version self.encoder_embed_dim = encoder_embed_dim self.out_embed_dim = out_embed_dim self.encoder_layers = encoder_layers self.beit3_vl_layers = beit3_vl_layers self.image_size = image_size self.visual_mask_size = visual_mask_size self.tokenizer_type = tokenizer_type self.tokenizer = tokenizer self.vocab_size = vocab_size self.whole_word_masking = whole_word_masking self.precision = precision self.test_only = test_only self.flash_attn = flash_attn self.model_file = model_file self.architectures = architectures or ["M2EncoderModel"] self.auto_map = auto_map or { "AutoConfig": "configuration_m2_encoder.M2EncoderConfig", "AutoModel": "modeling_m2_encoder.M2EncoderModel", "AutoProcessor": "processing_m2_encoder.M2EncoderProcessor", } @classmethod def from_encoder_json(cls, config_path: str, **kwargs) -> "M2EncoderConfig": with open(config_path, "r", encoding="utf-8") as f: data = json.load(f) data.update(kwargs) return cls(**data) def to_vlmo_overrides(self, model_dir: str) -> Dict[str, Any]: return { "loss_names": self.loss_names, "beit_version": self.beit_version, "encoder_embed_dim": self.encoder_embed_dim, "out_embed_dim": self.out_embed_dim, "encoder_layers": self.encoder_layers, "beit3_vl_layers": self.beit3_vl_layers, "image_size": self.image_size, "visual_mask_size": self.visual_mask_size, "tokenizer_type": self.tokenizer_type, "tokenizer": self._resolve_tokenizer_dir(model_dir), "vocab_size": self.vocab_size, "whole_word_masking": self.whole_word_masking, "precision": self.precision, "test_only": self.test_only, "flash_attn": self.flash_attn, "load_path": os.path.join(model_dir, self.model_file), } def _resolve_tokenizer_dir(self, model_dir: str) -> str: if os.path.isabs(self.tokenizer): return self.tokenizer if self.tokenizer in (".", "./", ""): return model_dir return os.path.join(model_dir, self.tokenizer)