File size: 3,197 Bytes
ea0524d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e33e190
ea0524d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import os
from typing import Any, Dict

from transformers import PretrainedConfig


class M2EncoderConfig(PretrainedConfig):
    model_type = "m2_encoder"

    def __init__(
        self,
        loss_names=None,
        beit_version="large",
        encoder_embed_dim=1024,
        out_embed_dim=1024,
        encoder_layers=21,
        beit3_vl_layers=3,
        image_size=224,
        visual_mask_size=14,
        tokenizer_type="GLMChineseTokenizer",
        tokenizer=".",
        vocab_size=115244,
        whole_word_masking=False,
        precision=32,
        test_only=True,
        flash_attn=False,
        model_file="m2_encoder_1B.safetensors",
        architectures=None,
        auto_map=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.loss_names = loss_names or {"itc": 1}
        self.beit_version = beit_version
        self.encoder_embed_dim = encoder_embed_dim
        self.out_embed_dim = out_embed_dim
        self.encoder_layers = encoder_layers
        self.beit3_vl_layers = beit3_vl_layers
        self.image_size = image_size
        self.visual_mask_size = visual_mask_size
        self.tokenizer_type = tokenizer_type
        self.tokenizer = tokenizer
        self.vocab_size = vocab_size
        self.whole_word_masking = whole_word_masking
        self.precision = precision
        self.test_only = test_only
        self.flash_attn = flash_attn
        self.model_file = model_file
        self.architectures = architectures or ["M2EncoderModel"]
        self.auto_map = auto_map or {
            "AutoConfig": "configuration_m2_encoder.M2EncoderConfig",
            "AutoModel": "modeling_m2_encoder.M2EncoderModel",
            "AutoProcessor": "processing_m2_encoder.M2EncoderProcessor",
        }

    @classmethod
    def from_encoder_json(cls, config_path: str, **kwargs) -> "M2EncoderConfig":
        with open(config_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        data.update(kwargs)
        return cls(**data)

    def to_vlmo_overrides(self, model_dir: str) -> Dict[str, Any]:
        return {
            "loss_names": self.loss_names,
            "beit_version": self.beit_version,
            "encoder_embed_dim": self.encoder_embed_dim,
            "out_embed_dim": self.out_embed_dim,
            "encoder_layers": self.encoder_layers,
            "beit3_vl_layers": self.beit3_vl_layers,
            "image_size": self.image_size,
            "visual_mask_size": self.visual_mask_size,
            "tokenizer_type": self.tokenizer_type,
            "tokenizer": self._resolve_tokenizer_dir(model_dir),
            "vocab_size": self.vocab_size,
            "whole_word_masking": self.whole_word_masking,
            "precision": self.precision,
            "test_only": self.test_only,
            "flash_attn": self.flash_attn,
            "load_path": os.path.join(model_dir, self.model_file),
        }

    def _resolve_tokenizer_dir(self, model_dir: str) -> str:
        if os.path.isabs(self.tokenizer):
            return self.tokenizer
        if self.tokenizer in (".", "./", ""):
            return model_dir
        return os.path.join(model_dir, self.tokenizer)