M2-Encoder-1B / configuration_m2_encoder.py
malusama's picture
Fix model card repo names and safetensors wording
e33e190 verified
import json
import os
from typing import Any, Dict
from transformers import PretrainedConfig
class M2EncoderConfig(PretrainedConfig):
model_type = "m2_encoder"
def __init__(
self,
loss_names=None,
beit_version="large",
encoder_embed_dim=1024,
out_embed_dim=1024,
encoder_layers=21,
beit3_vl_layers=3,
image_size=224,
visual_mask_size=14,
tokenizer_type="GLMChineseTokenizer",
tokenizer=".",
vocab_size=115244,
whole_word_masking=False,
precision=32,
test_only=True,
flash_attn=False,
model_file="m2_encoder_1B.safetensors",
architectures=None,
auto_map=None,
**kwargs,
):
super().__init__(**kwargs)
self.loss_names = loss_names or {"itc": 1}
self.beit_version = beit_version
self.encoder_embed_dim = encoder_embed_dim
self.out_embed_dim = out_embed_dim
self.encoder_layers = encoder_layers
self.beit3_vl_layers = beit3_vl_layers
self.image_size = image_size
self.visual_mask_size = visual_mask_size
self.tokenizer_type = tokenizer_type
self.tokenizer = tokenizer
self.vocab_size = vocab_size
self.whole_word_masking = whole_word_masking
self.precision = precision
self.test_only = test_only
self.flash_attn = flash_attn
self.model_file = model_file
self.architectures = architectures or ["M2EncoderModel"]
self.auto_map = auto_map or {
"AutoConfig": "configuration_m2_encoder.M2EncoderConfig",
"AutoModel": "modeling_m2_encoder.M2EncoderModel",
"AutoProcessor": "processing_m2_encoder.M2EncoderProcessor",
}
@classmethod
def from_encoder_json(cls, config_path: str, **kwargs) -> "M2EncoderConfig":
with open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
data.update(kwargs)
return cls(**data)
def to_vlmo_overrides(self, model_dir: str) -> Dict[str, Any]:
return {
"loss_names": self.loss_names,
"beit_version": self.beit_version,
"encoder_embed_dim": self.encoder_embed_dim,
"out_embed_dim": self.out_embed_dim,
"encoder_layers": self.encoder_layers,
"beit3_vl_layers": self.beit3_vl_layers,
"image_size": self.image_size,
"visual_mask_size": self.visual_mask_size,
"tokenizer_type": self.tokenizer_type,
"tokenizer": self._resolve_tokenizer_dir(model_dir),
"vocab_size": self.vocab_size,
"whole_word_masking": self.whole_word_masking,
"precision": self.precision,
"test_only": self.test_only,
"flash_attn": self.flash_attn,
"load_path": os.path.join(model_dir, self.model_file),
}
def _resolve_tokenizer_dir(self, model_dir: str) -> str:
if os.path.isabs(self.tokenizer):
return self.tokenizer
if self.tokenizer in (".", "./", ""):
return model_dir
return os.path.join(model_dir, self.tokenizer)