malusama commited on
Commit
ea0524d
·
verified ·
1 Parent(s): bb7942b

Upload safetensors export

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -31
  2. README.md +91 -0
  3. config.json +37 -0
  4. configuration_m2_encoder.py +90 -0
  5. image_processing_m2_encoder.py +42 -0
  6. m2_encoder_1B.safetensors +3 -0
  7. modeling_m2_encoder.py +150 -0
  8. preprocessor_config.json +11 -0
  9. processing_m2_encoder.py +58 -0
  10. processor_config.json +6 -0
  11. requirements.txt +15 -0
  12. sp.model +3 -0
  13. tokenization_glm.py +307 -0
  14. tokenizer_config.json +17 -0
  15. upload_to_hub.py +31 -0
  16. vlmo/__init__.py +0 -0
  17. vlmo/__pycache__/__init__.cpython-311.pyc +0 -0
  18. vlmo/__pycache__/config.cpython-311.pyc +0 -0
  19. vlmo/config.py +165 -0
  20. vlmo/modules/__init__.py +1 -0
  21. vlmo/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  22. vlmo/modules/__pycache__/heads.cpython-311.pyc +0 -0
  23. vlmo/modules/__pycache__/modeling_utils.cpython-311.pyc +0 -0
  24. vlmo/modules/__pycache__/objectives.cpython-311.pyc +0 -0
  25. vlmo/modules/__pycache__/vlmo_module.cpython-311.pyc +0 -0
  26. vlmo/modules/__pycache__/vlmo_utils.cpython-311.pyc +0 -0
  27. vlmo/modules/heads.py +24 -0
  28. vlmo/modules/modeling_utils.py +179 -0
  29. vlmo/modules/multiway_transformer.py +396 -0
  30. vlmo/modules/objectives.py +12 -0
  31. vlmo/modules/vlmo_module.py +405 -0
  32. vlmo/modules/vlmo_utils.py +12 -0
  33. vlmo/tokenizer/__init__.py +6 -0
  34. vlmo/tokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
  35. vlmo/tokenizer/__pycache__/tokenization_glm.cpython-311.pyc +0 -0
  36. vlmo/tokenizer/sp.model +3 -0
  37. vlmo/tokenizer/tokenization_glm.py +307 -0
  38. vlmo/tokenizer/tokenizer_config.json +17 -0
  39. vlmo/torchscale/__init__.py +2 -0
  40. vlmo/torchscale/__pycache__/__init__.cpython-311.pyc +0 -0
  41. vlmo/torchscale/architecture/__init__.py +2 -0
  42. vlmo/torchscale/architecture/__pycache__/__init__.cpython-311.pyc +0 -0
  43. vlmo/torchscale/architecture/__pycache__/config.cpython-311.pyc +0 -0
  44. vlmo/torchscale/architecture/__pycache__/encoder.cpython-311.pyc +0 -0
  45. vlmo/torchscale/architecture/__pycache__/utils.cpython-311.pyc +0 -0
  46. vlmo/torchscale/architecture/config.py +197 -0
  47. vlmo/torchscale/architecture/decoder.py +428 -0
  48. vlmo/torchscale/architecture/encoder.py +482 -0
  49. vlmo/torchscale/architecture/encoder_decoder.py +43 -0
  50. vlmo/torchscale/architecture/utils.py +33 -0
.gitattributes CHANGED
@@ -1,35 +1,7 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
1
  *.ckpt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  *.pt filter=lfs diff=lfs merge=lfs -text
4
  *.pth filter=lfs diff=lfs merge=lfs -text
 
5
  *.safetensors filter=lfs diff=lfs merge=lfs -text
6
+ sp.model filter=lfs diff=lfs merge=lfs -text
7
+ vlmo/tokenizer/sp.model filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ pipeline_tag: zero-shot-image-classification
5
+ tags:
6
+ - multimodal
7
+ - image-text-retrieval
8
+ - bilingual
9
+ - chinese
10
+ - english
11
+ - vision-language
12
+ - custom-code
13
+ ---
14
+
15
+ # M2-Encoder-1B Hugging Face Export
16
+
17
+ This folder is generated from `Ant-Multi-Modal-Framework/prj/M2_Encoder` and is structured for direct upload to Hugging Face Hub.
18
+
19
+ ## What This Repo Supports
20
+
21
+ - `AutoConfig.from_pretrained(..., trust_remote_code=True)`
22
+ - `AutoProcessor.from_pretrained(..., trust_remote_code=True)`
23
+ - `AutoModel.from_pretrained(..., trust_remote_code=True)`
24
+ - Zero-shot image-text retrieval and zero-shot image classification
25
+
26
+ ## Required Weight File
27
+
28
+ Put the model weight file in the repo root with this exact filename:
29
+
30
+ `m2_encoder_1B.safetensors`
31
+
32
+ Large files should be tracked by Git LFS. A `.gitattributes` file is included for that.
33
+
34
+ ## Usage
35
+
36
+ ### ModelScope-equivalent scoring
37
+
38
+ The original ModelScope sample computes probabilities from the raw normalized embeddings:
39
+
40
+ ```python
41
+ from transformers import AutoModel, AutoProcessor
42
+
43
+ repo_id = "your-name/your-m2-encoder-repo"
44
+
45
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
46
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
47
+
48
+ text_inputs = processor(
49
+ text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"],
50
+ return_tensors="pt",
51
+ )
52
+ image_inputs = processor(images="pokemon.jpeg", return_tensors="pt")
53
+
54
+ text_outputs = model(**text_inputs)
55
+ image_outputs = model(**image_inputs)
56
+
57
+ probs = (image_outputs.image_embeds @ text_outputs.text_embeds.t()).softmax(dim=-1)
58
+ print(probs)
59
+ ```
60
+
61
+ ### CLIP-style logits
62
+
63
+ `model(**inputs)` also returns `logits_per_image` and `logits_per_text`, which use the model's learned `logit_scale`.
64
+ Those logits are useful, but they are not the same computation as the raw dot product in the original ModelScope demo.
65
+
66
+ ## Upload
67
+
68
+ Option 1:
69
+
70
+ ```bash
71
+ python upload_to_hub.py --repo-id your-name/your-m2-encoder-repo
72
+ ```
73
+
74
+ Option 2:
75
+
76
+ ```bash
77
+ huggingface-cli login
78
+ git init
79
+ git lfs install
80
+ git remote add origin https://huggingface.co/your-name/your-m2-encoder-repo
81
+ git add .
82
+ git commit -m "Upload M2-Encoder HF export"
83
+ git push origin main
84
+ ```
85
+
86
+ ## Notes
87
+
88
+ - This is a Hugging Face remote-code adapter, not a native `transformers` implementation.
89
+ - The underlying model code still comes from the official M2-Encoder repo.
90
+ - You need `trust_remote_code=True`.
91
+ - The weights are not bundled by default when exporting unless you pass `--checkpoint`.
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "beit_version": "large",
3
+ "encoder_embed_dim": 1024,
4
+ "out_embed_dim": 1024,
5
+ "image_size": 224,
6
+ "visual_mask_size": 14,
7
+ "loss_names": {
8
+ "itc": 1
9
+ },
10
+ "encoder_layers": 21,
11
+ "beit3_vl_layers": 3,
12
+ "tokenizer_type": "GLMChineseTokenizer",
13
+ "tokenizer": ".",
14
+ "vocab_size": 115244,
15
+ "whole_word_masking": false,
16
+ "precision": 32,
17
+ "test_only": true,
18
+ "flash_attn": false,
19
+ "modelscope": {
20
+ "model_id": "M2Cognition/M2_Encoder_Large"
21
+ },
22
+ "model_file": "m2_encoder_1B.safetensors",
23
+ "model_type": "m2_encoder",
24
+ "architectures": [
25
+ "M2EncoderModel"
26
+ ],
27
+ "processor_class": "M2EncoderProcessor",
28
+ "auto_map": {
29
+ "AutoConfig": "configuration_m2_encoder.M2EncoderConfig",
30
+ "AutoModel": "modeling_m2_encoder.M2EncoderModel",
31
+ "AutoProcessor": "processing_m2_encoder.M2EncoderProcessor",
32
+ "AutoTokenizer": [
33
+ "tokenization_glm.GLMChineseTokenizer",
34
+ null
35
+ ]
36
+ }
37
+ }
configuration_m2_encoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class M2EncoderConfig(PretrainedConfig):
9
+ model_type = "m2_encoder"
10
+
11
+ def __init__(
12
+ self,
13
+ loss_names=None,
14
+ beit_version="large",
15
+ encoder_embed_dim=1024,
16
+ out_embed_dim=1024,
17
+ encoder_layers=21,
18
+ beit3_vl_layers=3,
19
+ image_size=224,
20
+ visual_mask_size=14,
21
+ tokenizer_type="GLMChineseTokenizer",
22
+ tokenizer=".",
23
+ vocab_size=115244,
24
+ whole_word_masking=False,
25
+ precision=32,
26
+ test_only=True,
27
+ flash_attn=False,
28
+ model_file="m2_encoder_1B.ckpt",
29
+ architectures=None,
30
+ auto_map=None,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.loss_names = loss_names or {"itc": 1}
35
+ self.beit_version = beit_version
36
+ self.encoder_embed_dim = encoder_embed_dim
37
+ self.out_embed_dim = out_embed_dim
38
+ self.encoder_layers = encoder_layers
39
+ self.beit3_vl_layers = beit3_vl_layers
40
+ self.image_size = image_size
41
+ self.visual_mask_size = visual_mask_size
42
+ self.tokenizer_type = tokenizer_type
43
+ self.tokenizer = tokenizer
44
+ self.vocab_size = vocab_size
45
+ self.whole_word_masking = whole_word_masking
46
+ self.precision = precision
47
+ self.test_only = test_only
48
+ self.flash_attn = flash_attn
49
+ self.model_file = model_file
50
+ self.architectures = architectures or ["M2EncoderModel"]
51
+ self.auto_map = auto_map or {
52
+ "AutoConfig": "configuration_m2_encoder.M2EncoderConfig",
53
+ "AutoModel": "modeling_m2_encoder.M2EncoderModel",
54
+ "AutoProcessor": "processing_m2_encoder.M2EncoderProcessor",
55
+ "AutoTokenizer": ["tokenization_glm.GLMChineseTokenizer", None],
56
+ }
57
+
58
+ @classmethod
59
+ def from_encoder_json(cls, config_path: str, **kwargs) -> "M2EncoderConfig":
60
+ with open(config_path, "r", encoding="utf-8") as f:
61
+ data = json.load(f)
62
+ data.update(kwargs)
63
+ return cls(**data)
64
+
65
+ def to_vlmo_overrides(self, model_dir: str) -> Dict[str, Any]:
66
+ return {
67
+ "loss_names": self.loss_names,
68
+ "beit_version": self.beit_version,
69
+ "encoder_embed_dim": self.encoder_embed_dim,
70
+ "out_embed_dim": self.out_embed_dim,
71
+ "encoder_layers": self.encoder_layers,
72
+ "beit3_vl_layers": self.beit3_vl_layers,
73
+ "image_size": self.image_size,
74
+ "visual_mask_size": self.visual_mask_size,
75
+ "tokenizer_type": self.tokenizer_type,
76
+ "tokenizer": self._resolve_tokenizer_dir(model_dir),
77
+ "vocab_size": self.vocab_size,
78
+ "whole_word_masking": self.whole_word_masking,
79
+ "precision": self.precision,
80
+ "test_only": self.test_only,
81
+ "flash_attn": self.flash_attn,
82
+ "load_path": os.path.join(model_dir, self.model_file),
83
+ }
84
+
85
+ def _resolve_tokenizer_dir(self, model_dir: str) -> str:
86
+ if os.path.isabs(self.tokenizer):
87
+ return self.tokenizer
88
+ if self.tokenizer in (".", "./", ""):
89
+ return model_dir
90
+ return os.path.join(model_dir, self.tokenizer)
image_processing_m2_encoder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
7
+ from transformers.image_utils import ImageFeatureExtractionMixin
8
+
9
+
10
+ class M2EncoderImageProcessor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
11
+ model_input_names = ["pixel_values"]
12
+
13
+ def __init__(self, size: int = 224, resample: int = Image.BICUBIC, **kwargs):
14
+ super().__init__(**kwargs)
15
+ if isinstance(size, dict):
16
+ size = int(size.get("height") or size.get("width"))
17
+ self.size = size
18
+ self.resample = resample
19
+
20
+ def __call__(
21
+ self,
22
+ images,
23
+ return_tensors: Optional[Union[str, torch.Tensor]] = None,
24
+ **kwargs,
25
+ ) -> BatchFeature:
26
+ if not isinstance(images, (list, tuple)):
27
+ images = [images]
28
+
29
+ pixel_values: List[np.ndarray] = []
30
+ for image in images:
31
+ if not isinstance(image, Image.Image):
32
+ image = Image.fromarray(np.asarray(image))
33
+ image = image.convert("RGB")
34
+ image = image.resize((self.size, self.size), resample=self.resample)
35
+ array = np.asarray(image, dtype=np.float32) / 255.0
36
+ array = np.transpose(array, (2, 0, 1))
37
+ pixel_values.append(array)
38
+
39
+ return BatchFeature(
40
+ data={"pixel_values": pixel_values},
41
+ tensor_type=return_tensors,
42
+ )
m2_encoder_1B.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8f7b220e3728a8211018c3fdb3b92c9a8eb9ffcbcf690057b258b819987b1bb
3
+ size 2921785216
modeling_m2_encoder.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import importlib
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ from huggingface_hub import snapshot_download
9
+ from safetensors.torch import load_file
10
+ from transformers import PreTrainedModel
11
+ from transformers.modeling_outputs import ModelOutput
12
+
13
+ from .configuration_m2_encoder import M2EncoderConfig
14
+
15
+
16
+ @dataclass
17
+ class M2EncoderOutput(ModelOutput):
18
+ loss: Optional[torch.FloatTensor] = None
19
+ text_embeds: Optional[torch.FloatTensor] = None
20
+ image_embeds: Optional[torch.FloatTensor] = None
21
+ logits_per_image: Optional[torch.FloatTensor] = None
22
+ logits_per_text: Optional[torch.FloatTensor] = None
23
+
24
+
25
+ class M2EncoderModel(PreTrainedModel):
26
+ config_class = M2EncoderConfig
27
+ base_model_prefix = "m2_encoder"
28
+ main_input_name = "pixel_values"
29
+
30
+ def __init__(self, config: M2EncoderConfig):
31
+ super().__init__(config)
32
+ model_dir = getattr(config, "_model_dir", None)
33
+ if model_dir is None:
34
+ raise ValueError(
35
+ "M2EncoderConfig is missing `_model_dir`. Use "
36
+ "`M2EncoderModel.from_pretrained(...)` so the checkpoint path can be resolved."
37
+ )
38
+ if model_dir not in sys.path:
39
+ sys.path.insert(0, model_dir)
40
+
41
+ vlmo_default_config = importlib.import_module("vlmo.config").config
42
+ VLMo = importlib.import_module("vlmo.modules").VLMo
43
+
44
+ vlmo_config = vlmo_default_config()
45
+ vlmo_config.update(config.to_vlmo_overrides(model_dir))
46
+ load_path = vlmo_config["load_path"]
47
+ use_safetensors = load_path.endswith(".safetensors")
48
+ if use_safetensors:
49
+ vlmo_config["load_path"] = ""
50
+
51
+ if vlmo_config["flash_attn"]:
52
+ patch_torch_scale_with_flash_attn = importlib.import_module(
53
+ "vlmo.utils.patch_utils"
54
+ ).patch_torch_scale_with_flash_attn
55
+ patch_torch_scale_with_flash_attn()
56
+
57
+ self.model = VLMo(vlmo_config)
58
+ if use_safetensors:
59
+ state_dict = load_file(load_path)
60
+ self.model.load_state_dict(state_dict, strict=False)
61
+
62
+ @classmethod
63
+ def from_pretrained(
64
+ cls,
65
+ pretrained_model_name_or_path,
66
+ *model_args,
67
+ config: Optional[M2EncoderConfig] = None,
68
+ **kwargs,
69
+ ):
70
+ model_dir = pretrained_model_name_or_path
71
+ if not os.path.isdir(model_dir):
72
+ model_dir = snapshot_download(repo_id=pretrained_model_name_or_path)
73
+
74
+ if config is None:
75
+ config = M2EncoderConfig.from_pretrained(model_dir, **kwargs)
76
+ checkpoint_path = os.path.join(
77
+ model_dir,
78
+ kwargs.pop("m2_checkpoint_name", config.model_file),
79
+ )
80
+ if not os.path.exists(checkpoint_path):
81
+ raise FileNotFoundError(
82
+ f"Missing M2-Encoder checkpoint: {checkpoint_path}"
83
+ )
84
+ config._model_dir = model_dir
85
+ return cls(config, *model_args)
86
+
87
+ def get_text_features(
88
+ self,
89
+ input_ids: torch.LongTensor,
90
+ attention_mask: torch.LongTensor,
91
+ ) -> torch.FloatTensor:
92
+ outputs = self.model.infer_text(
93
+ {
94
+ "text_ids": input_ids,
95
+ "text_masks": attention_mask,
96
+ "text_labels": None,
97
+ }
98
+ )
99
+ return outputs["cls_vlffn_feats"]
100
+
101
+ def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
102
+ outputs = self.model.infer_image({"image": [pixel_values]})
103
+ return outputs["cls_vlffn_feats"]
104
+
105
+ def forward(
106
+ self,
107
+ input_ids: Optional[torch.LongTensor] = None,
108
+ attention_mask: Optional[torch.LongTensor] = None,
109
+ pixel_values: Optional[torch.FloatTensor] = None,
110
+ return_dict: Optional[bool] = True,
111
+ **kwargs,
112
+ ) -> Union[M2EncoderOutput, Tuple[torch.FloatTensor, ...]]:
113
+ text_embeds = None
114
+ image_embeds = None
115
+
116
+ if input_ids is not None:
117
+ if attention_mask is None:
118
+ attention_mask = torch.ones_like(input_ids)
119
+ text_embeds = self.get_text_features(
120
+ input_ids=input_ids, attention_mask=attention_mask
121
+ )
122
+
123
+ if pixel_values is not None:
124
+ image_embeds = self.get_image_features(pixel_values=pixel_values)
125
+
126
+ logits_per_image = None
127
+ logits_per_text = None
128
+ if image_embeds is not None and text_embeds is not None:
129
+ logit_scale = self.model.logit_scale.exp()
130
+ logits_per_image = logit_scale * image_embeds @ text_embeds.t()
131
+ logits_per_text = logits_per_image.t()
132
+
133
+ if not return_dict:
134
+ return tuple(
135
+ value
136
+ for value in (
137
+ text_embeds,
138
+ image_embeds,
139
+ logits_per_image,
140
+ logits_per_text,
141
+ )
142
+ if value is not None
143
+ )
144
+
145
+ return M2EncoderOutput(
146
+ text_embeds=text_embeds,
147
+ image_embeds=image_embeds,
148
+ logits_per_image=logits_per_image,
149
+ logits_per_text=logits_per_text,
150
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "M2EncoderProcessor",
3
+ "image_processor_type": "M2EncoderImageProcessor",
4
+ "auto_map": {
5
+ "AutoProcessor": "processing_m2_encoder.M2EncoderProcessor"
6
+ },
7
+ "size": {
8
+ "height": 224,
9
+ "width": 224
10
+ }
11
+ }
processing_m2_encoder.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoTokenizer
4
+ from transformers.processing_utils import ProcessorMixin
5
+
6
+ from .image_processing_m2_encoder import M2EncoderImageProcessor
7
+
8
+
9
+ class M2EncoderProcessor(ProcessorMixin):
10
+ attributes = ["image_processor", "tokenizer"]
11
+ image_processor_class = "M2EncoderImageProcessor"
12
+ tokenizer_class = ("GLMChineseTokenizer", None)
13
+
14
+ def __init__(self, image_processor, tokenizer):
15
+ self.image_processor = image_processor
16
+ self.tokenizer = tokenizer
17
+
18
+ @classmethod
19
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
20
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
21
+ image_processor = M2EncoderImageProcessor.from_pretrained(
22
+ pretrained_model_name_or_path, **kwargs
23
+ )
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ pretrained_model_name_or_path,
26
+ trust_remote_code=trust_remote_code,
27
+ **kwargs,
28
+ )
29
+ return cls(image_processor=image_processor, tokenizer=tokenizer)
30
+
31
+ def __call__(
32
+ self,
33
+ text=None,
34
+ images=None,
35
+ padding="max_length",
36
+ truncation=True,
37
+ max_length: Optional[int] = 52,
38
+ return_tensors=None,
39
+ **kwargs,
40
+ ):
41
+ encoding = {}
42
+ if text is not None:
43
+ encoding.update(
44
+ self.tokenizer(
45
+ text,
46
+ padding=padding,
47
+ truncation=truncation,
48
+ max_length=max_length,
49
+ return_special_tokens_mask=True,
50
+ return_tensors=return_tensors,
51
+ **kwargs,
52
+ )
53
+ )
54
+ if images is not None:
55
+ encoding.update(
56
+ self.image_processor(images, return_tensors=return_tensors, **kwargs)
57
+ )
58
+ return encoding
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "M2EncoderProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_m2_encoder.M2EncoderProcessor"
5
+ }
6
+ }
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ pytorch_lightning<=2.0.8
3
+ transformers
4
+ safetensors
5
+ Pillow
6
+ tqdm
7
+ einops
8
+ sacred
9
+ timm
10
+ torchvision
11
+ fairscale
12
+ numpy
13
+ opencv-python
14
+ sentencepiece
15
+ huggingface_hub
sp.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7fe3bcc8d284fcb782691411e8b6fd4f45d7245565b094de6ab795e66bcd32f
3
+ size 2270960
tokenization_glm.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Optional, Tuple, List, Union
4
+
5
+ import sentencepiece as spm
6
+ import torch
7
+ from transformers import PreTrainedTokenizer
8
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
9
+ from transformers.tokenization_utils_base import BatchEncoding
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class GLMBatchEncoding(BatchEncoding):
16
+ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
17
+ """
18
+ Send all values to device by calling `v.to(device)` (PyTorch only).
19
+
20
+ Args:
21
+ device (`str` or `torch.device`): The device to put the tensors on.
22
+
23
+ Returns:
24
+ [`BatchEncoding`]: The same instance after modification.
25
+ """
26
+
27
+ # This check catches things like APEX blindly calling "to" on all inputs to a module
28
+ # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
29
+ # into a HalfTensor
30
+ if isinstance(device, str) or isinstance(device, int):
31
+ #if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
32
+ self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}
33
+ else:
34
+ logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
35
+ return self
36
+
37
+
38
+ class GLMTokenizerMixin:
39
+ @property
40
+ def sop_token(self) -> Optional[str]:
41
+ return "<|startofpiece|>"
42
+
43
+ @property
44
+ def sop_token_id(self) -> Optional[int]:
45
+ """
46
+ `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.
47
+ """
48
+ return self.convert_tokens_to_ids(self.sop_token)
49
+
50
+ @property
51
+ def eop_token(self) -> Optional[str]:
52
+ return "<|endofpiece|>"
53
+
54
+ @property
55
+ def eop_token_id(self) -> Optional[int]:
56
+ """
57
+ `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.
58
+ """
59
+ return self.convert_tokens_to_ids(self.eop_token)
60
+
61
+ @property
62
+ def gmask_token_id(self) -> int:
63
+ return self.convert_tokens_to_ids("[gMASK]")
64
+
65
+ @property
66
+ def smask_token_id(self) -> int:
67
+ return self.convert_tokens_to_ids("[sMASK]")
68
+
69
+ @property
70
+ def mask_token_ids(self):
71
+ return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
72
+
73
+ def _build_input_for_multiple_choice(self, context, choices):
74
+ context_id = context["input_ids"]
75
+ if torch.is_tensor(context_id):
76
+ context_id = context_id.tolist()
77
+
78
+ division = len(context_id)
79
+ mask_position = context_id.index(self.mask_token_id)
80
+
81
+ token = torch.tensor(context_id, dtype=torch.long)
82
+ attention_mask = [context["attention_mask"].expand(division, -1)]
83
+ position_id = torch.arange(division, dtype=torch.long)
84
+ block_position_id = torch.zeros(division, dtype=torch.long)
85
+
86
+ choice_ids, choice_indices = [], []
87
+
88
+ for choice_str in choices:
89
+ choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],
90
+ dtype=torch.long)
91
+ choice_ids.append(choice)
92
+ choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))
93
+ attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))
94
+
95
+ token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))
96
+ position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))
97
+ block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))
98
+
99
+ attention_mask = torch.block_diag(*attention_mask)
100
+ attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
101
+
102
+ return {
103
+ "input_ids": token,
104
+ "position_ids": torch.stack((position_id, block_position_id)),
105
+ "attention_mask": attention_mask,
106
+ "choice_ids": choice_ids,
107
+ "choice_indices": choice_indices
108
+ }
109
+
110
+ def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
111
+ pad_length = max_seq_length - len(tokens)
112
+ attention_mask = torch.nn.functional.pad(
113
+ attention_mask,
114
+ (0, pad_length, 0, pad_length),
115
+ mode="constant",
116
+ value=0,
117
+ )
118
+ tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))
119
+ position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)
120
+ return tokens, position_ids, attention_mask
121
+
122
+ def _collate(self, samples):
123
+ TILE = 1
124
+ length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
125
+
126
+ token_batch, position_id_batch, attention_mask_batch = [], [], []
127
+ choices_batch, choice_target_ids_batch = [], []
128
+
129
+ for sample in samples:
130
+ token, position_id, attention_mask = self._pad_batch(
131
+ sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
132
+ )
133
+ token_batch.append(token)
134
+ position_id_batch.append(position_id)
135
+ attention_mask_batch.append(attention_mask)
136
+ choices_batch.append(sample["choice_ids"])
137
+ choice_target_ids_batch.append(sample["choice_indices"])
138
+ return {
139
+ "input_ids": torch.stack(token_batch),
140
+ "position_ids": torch.stack(position_id_batch),
141
+ "attention_mask": torch.stack(attention_mask_batch).unsqueeze(1),
142
+ "choice_ids": choices_batch,
143
+ "choice_indices": choice_target_ids_batch,
144
+ }
145
+
146
+ def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):
147
+ samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))]
148
+ samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in
149
+ zip(samples, choices)]
150
+ inputs = self._collate(samples)
151
+ return GLMBatchEncoding(inputs)
152
+
153
+ def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):
154
+ mask_ids = self.mask_token_ids
155
+ input_ids = model_input.input_ids
156
+ batch_size, seq_length = input_ids.shape[:2]
157
+ position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
158
+ position_ids, block_position_ids = [], []
159
+ labels = None
160
+ if targets is not None:
161
+ is_batched = isinstance(targets, (list, tuple))
162
+ targets = self(targets, add_special_tokens=False, padding=False).input_ids
163
+ if not is_batched:
164
+ targets = [targets]
165
+ assert len(targets) == len(input_ids)
166
+ targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
167
+ if not padding:
168
+ max_gen_length = max(map(len, targets))
169
+ targets = [[self.sop_token_id] + target for target in targets]
170
+ labels = [target[1:] for target in targets]
171
+ targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]
172
+ labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
173
+ targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)
174
+ labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)
175
+ labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
176
+ for i in range(batch_size):
177
+ mask_positions = []
178
+ for mask_id in mask_ids:
179
+ mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()
180
+ if not mask_positions:
181
+ raise ValueError("Cannot find mask token in the input")
182
+ mask_positions.sort()
183
+ mask_pos = mask_positions[0]
184
+ position_ids.append(position_id + [mask_pos] * max_gen_length)
185
+ block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))
186
+ position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)
187
+ block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)
188
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
189
+ attention_mask = model_input.attention_mask
190
+ attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)
191
+ generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),
192
+ torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],
193
+ dim=0).unsqueeze(0).expand(batch_size, -1, -1)
194
+ attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
195
+ attention_mask = attention_mask.unsqueeze(1)
196
+ if targets is None:
197
+ input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
198
+ else:
199
+ input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)
200
+ batch = {"input_ids": input_ids, "position_ids": position_ids}
201
+ if labels is None:
202
+ batch["generation_attention_mask"] = attention_mask
203
+ else:
204
+ batch["attention_mask"] = attention_mask
205
+ batch["labels"] = labels
206
+ return BatchEncoding(batch)
207
+
208
+ def encode_whitespaces(content):
209
+ for i in range(10, 1, -1):
210
+ content = content.replace(' '*i, f'<|blank_{i}|>')
211
+ return content
212
+
213
+ def decode_whitespaces(content):
214
+ for i in range(10, 1, -1):
215
+ content = content.replace(f'<|blank_{i}|>', ' '*i)
216
+ return content
217
+
218
+
219
+ class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
220
+ vocab_files_names = {"vocab_file": "sp.model"}
221
+ truncation_side: str = "left"
222
+
223
+ def __init__(self, vocab_file, **kwargs):
224
+ self.vocab_file = vocab_file
225
+ self.sp_model = spm.SentencePieceProcessor()
226
+ self.sp_model.Load(vocab_file)
227
+ super().__init__(**kwargs)
228
+
229
+ @property
230
+ def vocab_size(self):
231
+ return len(self.sp_model)
232
+
233
+ def get_vocab(self):
234
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
235
+ vocab.update(self.added_tokens_encoder)
236
+ return vocab
237
+
238
+ def _tokenize(self, text, **kwargs):
239
+ text = encode_whitespaces(text)
240
+ return self.sp_model.EncodeAsPieces(text)
241
+ #return self.sp_model.EncodeAsPieces(text, out_type=str)
242
+
243
+ def _convert_token_to_id(self, token):
244
+ """Converts a token (str) in an id using the vocab."""
245
+ return self.sp_model.PieceToId(token)
246
+
247
+ def _convert_id_to_token(self, index):
248
+ """Converts an index (integer) in a token (str) using the vocab."""
249
+ return self.sp_model.IdToPiece(index)
250
+
251
+ def convert_tokens_to_string(self, tokens):
252
+ res = self.sp_model.DecodeIds(tokens)
253
+ return decode_whitespaces(res)
254
+
255
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
256
+ if not os.path.isdir(save_directory):
257
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
258
+ return
259
+ out_vocab_file = os.path.join(
260
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
261
+ )
262
+
263
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
264
+ copyfile(self.vocab_file, out_vocab_file)
265
+ elif not os.path.isfile(self.vocab_file):
266
+ with open(out_vocab_file, "wb") as fi:
267
+ content_spiece_model = self.sp_model.serialized_model_proto()
268
+ fi.write(content_spiece_model)
269
+
270
+ return (out_vocab_file,)
271
+
272
+ def build_inputs_with_special_tokens(
273
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
274
+ ) -> List[int]:
275
+ """
276
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
277
+ adding special tokens. A BERT sequence has the following format:
278
+
279
+ - single sequence: ``[CLS] X [SEP]``
280
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
281
+
282
+ Args:
283
+ token_ids_0 (:obj:`List[int]`):
284
+ List of IDs to which the special tokens will be added.
285
+ token_ids_1 (:obj:`List[int]`, `optional`):
286
+ Optional second list of IDs for sequence pairs.
287
+
288
+ Returns:
289
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
290
+ """
291
+ assert token_ids_1 is None
292
+ cls = [self.cls_token_id]
293
+ eos = [self.eos_token_id]
294
+ return cls + token_ids_0 + eos
295
+
296
+
297
+ class GLMTokenizer:
298
+ @classmethod
299
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
300
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
301
+ config_tokenizer_class = tokenizer_config.get("tokenizer_class")
302
+
303
+ if config_tokenizer_class == "GLMChineseTokenizer":
304
+ tokenizer_class = GLMChineseTokenizer
305
+ else:
306
+ raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class)
307
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "THUDM/glm-10b-chinese",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "cls_token": "[CLS]",
6
+ "mask_token": "[MASK]",
7
+ "unk_token": "[UNK]",
8
+ "add_prefix_space": false,
9
+ "tokenizer_class": "GLMChineseTokenizer",
10
+ "use_fast": false,
11
+ "auto_map": {
12
+ "AutoTokenizer": [
13
+ "tokenization_glm.GLMChineseTokenizer",
14
+ null
15
+ ]
16
+ }
17
+ }
upload_to_hub.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ from huggingface_hub import HfApi, create_repo, upload_folder
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--repo-id", required=True, help="Hugging Face repo id, e.g. user/M2-Encoder-Large")
10
+ parser.add_argument(
11
+ "--folder",
12
+ default=str(Path(__file__).resolve().parent),
13
+ help="Folder to upload. Defaults to this script's directory.",
14
+ )
15
+ parser.add_argument("--private", action="store_true", help="Create the repo as private.")
16
+ parser.add_argument("--commit-message", default="Upload M2-Encoder HF export")
17
+ args = parser.parse_args()
18
+
19
+ folder = Path(args.folder).resolve()
20
+ api = HfApi()
21
+ create_repo(repo_id=args.repo_id, private=args.private, exist_ok=True)
22
+ upload_folder(
23
+ repo_id=args.repo_id,
24
+ folder_path=str(folder),
25
+ commit_message=args.commit_message,
26
+ )
27
+ print(f"Uploaded {folder} -> {args.repo_id}")
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()
vlmo/__init__.py ADDED
File without changes
vlmo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
vlmo/__pycache__/config.cpython-311.pyc ADDED
Binary file (3.82 kB). View file
 
vlmo/config.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sacred import Experiment
2
+
3
+ ex = Experiment("VLMo")
4
+
5
+
6
+ def _loss_names(d):
7
+ ret = {
8
+ "itm": 0, # image-text matching loss
9
+ "itc": 0, # image-text contrastive loss
10
+ "caption": 0, # image captioning loss
11
+ "mvlm": 0, # masked language modeling loss
12
+ "textmlm": 0, # text-only masked language modeling
13
+ "imagemlm": 0, # image-only masked language modeling
14
+ "vqa": 0,
15
+ "nlvr2": 0,
16
+ "irtr": 0, # retrieval task ft
17
+ }
18
+ ret.update(d)
19
+ return ret
20
+
21
+
22
+ @ex.config
23
+ def config():
24
+ exp_name = "vlmo"
25
+ seed = 1
26
+ datasets = ["coco", "vg", "sbu", "gcc"] # dataset name, the definition can refer to: vlmo/datamodules/__init__.py # noqa
27
+ loss_names = _loss_names({"itm": 0, "itc": 0, "mvlm": 0}) # training loss
28
+ batch_size = 1024 # this is a desired batch size; pl trainer will accumulate gradients.
29
+
30
+ # BEiT-v3 setting
31
+ encoder_layers = 12 # the layer number of backbone
32
+ encoder_embed_dim = 768 # the hidden size of tokenizer
33
+ out_embed_dim = 768 # the hidden size of output embedding
34
+ beit_version = "base" # model size: base(0.4B)|large(1B)|huge(10B)
35
+ beit3_vl_layers = 3 # the layer number of vl_backbone
36
+ deepnorm_init = True # init method
37
+ share_layer = False # if share the weight between layer within backbone
38
+ share_attn = False # if share the attention weight of different layer
39
+ one_attn = False # if share the attention weight of vision and language
40
+
41
+ # Image setting
42
+ train_transform_keys = ["square_transform_randaug"] # train transform: refer to vlmo/transforms/__init__.py
43
+ val_transform_keys = ["square_transform"] # test transform: refer to refer to vlmo/transforms/__init__.py
44
+ image_size = 224 # image size
45
+ reclip_image_size = None # reclip image size
46
+ patch_size = 16 # patch size
47
+ draw_false_image = 0 # if get negative image
48
+ image_only = False # only input image
49
+ text_only = False # # only input text
50
+
51
+ # Video setting, video_num_frm is not None means video input
52
+ video_num_frm = None
53
+
54
+ # Visual tokenizer setting based on beit2
55
+ tokenizer_model = "beit2_visual_tokenizer"
56
+ codebook_size = 8192
57
+ codebook_dim = 32
58
+ visual_mask_size = 14
59
+ visual_mask_num = 80
60
+
61
+ # Text Setting
62
+ lang = 'cn' # language for zero-shot imagenet testing: cn|en
63
+ vqav2_label_size = 3129
64
+ max_text_len = 52 # the number of characters
65
+ max_text_len_of_initckpt = 196
66
+ tokenizer_type = "BertTokenizer" # Chinese text
67
+ vocab_size = 21128
68
+ tokenizer = "./vocab.txt"
69
+ whole_word_masking = True
70
+ mlm_prob = 0.15 # language mask ratio
71
+ draw_false_text = 0
72
+ mvlm_prob = 0.50 # vision-langurage mlm task
73
+ mask_ratio = 0 # flip: mask ratio for image
74
+
75
+ # cap setting
76
+ cap_onlytext = False # default caption image to text
77
+
78
+ # imagemlm setting
79
+ split_data_for_imagemlm = False # if True, split a batch data to two parts, and the first part for imagemlm.
80
+
81
+ # itc setting
82
+ itc_mask = False # itc use masked token
83
+ aggregate_nodes = -1 # aggregate nodes num for compute_itc, default -1 is for all nodes
84
+
85
+ # Transformer Setting
86
+ model_arch = "vlmo_base_patch16"
87
+ drop_path_rate = 0.1
88
+
89
+ # Downstream Setting
90
+ get_recall_metric = False
91
+ get_recall_rerank_metric = False
92
+ get_zeroshot_metric = False
93
+ get_muge_feat = False
94
+ get_f30k_feat = False
95
+ k_test = 32
96
+
97
+ # PL Trainer Setting
98
+ resume_from = None
99
+ fast_dev_run = False
100
+ val_check_interval = 1.0
101
+ test_only = False
102
+ use_sharded_training = False
103
+ resume_during_training = False
104
+ save_top_k = 10
105
+ every_n_train_steps = 2000 # the step to save checkpoint
106
+ log_metric_steps = 100 # the step to log metric
107
+
108
+ # below params varies with the environment
109
+ use_pcache = False # data storage method: pcache or nas
110
+ pcache_root = ""
111
+ # main_site: pcache://multimodalproxyi-pool.cz50c.alipay.com:39999/mnt/
112
+ # public_cloud: pcache://pcache_public_cloud.pcache.local:39999/mnt/abc7c88079a60b45ddfce7afa40720b7/
113
+ gpu_env = "main_site" # public_cloud or main_site
114
+ data_root = "" # data root for data list
115
+
116
+
117
+ log_dir = "result"
118
+ per_gpu_batchsize = 4 # you should define this manually with per_gpu_batch_size=#
119
+ num_gpus = 1
120
+ num_nodes = 1
121
+ load_path = ""
122
+ num_workers = 8
123
+ precision = 16
124
+ local_run = True
125
+ flash_attn = False
126
+ deepspeed_config = None # "ds_config.json"
127
+ coalesce_backbone = False
128
+ mask_data = "v+l" # 'v+l':choose input of imagemlm+textmlm task, 'vl': choose input of mvlm task.
129
+ communication_benchmark = False
130
+ checkpoint_activations = False
131
+
132
+ # dataset setting
133
+ single_cap = True # if have only one caption
134
+ random_one = False # if choose one caption from caption list
135
+
136
+ # ITC setting
137
+ itc_feats_name = "cls_vlffn_feats" # feat for itc loss
138
+ itc_distill = ""
139
+ itc_distill_dim = 1024
140
+ itc_teacher_weights = ""
141
+
142
+ # mup training setting
143
+ mup = False
144
+ base_encoder_embed_dim = 1
145
+ delta_encoder_embed_dim = 2
146
+ mup_encoder_attention_heads = 1
147
+ base_encoder_ffn_embed_dim = 1
148
+ delta_encoder_ffn_embed_dim = 2
149
+
150
+ # atorch
151
+ atorch_config = None
152
+ compile_op = False
153
+ optimizer_state_shard_save = False
154
+ model_state_shard_save = False
155
+
156
+ # itc loss
157
+ local_loss = False
158
+ use_dual_softmax = False
159
+
160
+ num_frames = 1
161
+ # ----------------------- LMM pretraining config -----------------------
162
+
163
+ # norm setting
164
+ deepnorm = False
165
+
vlmo/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .vlmo_module import VLMo
vlmo/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (255 Bytes). View file
 
vlmo/modules/__pycache__/heads.cpython-311.pyc ADDED
Binary file (2.09 kB). View file
 
vlmo/modules/__pycache__/modeling_utils.cpython-311.pyc ADDED
Binary file (5.9 kB). View file
 
vlmo/modules/__pycache__/objectives.cpython-311.pyc ADDED
Binary file (1.16 kB). View file
 
vlmo/modules/__pycache__/vlmo_module.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
vlmo/modules/__pycache__/vlmo_utils.cpython-311.pyc ADDED
Binary file (1.22 kB). View file
 
vlmo/modules/heads.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class Pooler(nn.Module):
5
+ def __init__(self, hidden_size):
6
+ super().__init__()
7
+ self.dense = nn.Linear(hidden_size, hidden_size)
8
+ self.activation = nn.Tanh()
9
+
10
+ def forward(self, hidden_states):
11
+ first_token_tensor = hidden_states[:, 0]
12
+ pooled_output = self.dense(first_token_tensor)
13
+ pooled_output = self.activation(pooled_output)
14
+ return pooled_output
15
+
16
+
17
+ class ITCHead(nn.Module):
18
+ def __init__(self, hidden_size, out_size):
19
+ super().__init__()
20
+ self.fc = nn.Linear(hidden_size, out_size, bias=False)
21
+
22
+ def forward(self, x):
23
+ x = self.fc(x)
24
+ return x
vlmo/modules/modeling_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
12
+
13
+ from vlmo.torchscale.model.BEiT3 import BEiT3
14
+ from vlmo.torchscale.architecture.config import EncoderConfig
15
+
16
+
17
+ def trunc_normal_(tensor, mean=0.0, std=1.0):
18
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
19
+
20
+
21
+ def _get_base_config(
22
+ img_size=224,
23
+ patch_size=16,
24
+ drop_path_rate=0,
25
+ checkpoint_activations=None,
26
+ mlp_ratio=4,
27
+ vocab_size=64010,
28
+ encoder_layers=12,
29
+ encoder_embed_dim=768,
30
+ encoder_attention_heads=12,
31
+ share_layer=False,
32
+ share_attn=False,
33
+ deepnorm=False,
34
+ mask_ratio=0,
35
+ max_text_len=52,
36
+ one_attn=False,
37
+ **kwargs
38
+ ):
39
+ return EncoderConfig(
40
+ img_size=img_size,
41
+ patch_size=patch_size,
42
+ vocab_size=vocab_size,
43
+ multiway=True,
44
+ layernorm_embedding=False,
45
+ normalize_output=True,
46
+ no_output_layer=True,
47
+ drop_path_rate=drop_path_rate,
48
+ encoder_embed_dim=encoder_embed_dim,
49
+ encoder_attention_heads=encoder_attention_heads,
50
+ encoder_layers=encoder_layers,
51
+ encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
52
+ checkpoint_activations=checkpoint_activations,
53
+ share_layer=share_layer,
54
+ share_attn=share_attn,
55
+ deepnorm=deepnorm,
56
+ mask_ratio=mask_ratio,
57
+ max_text_len=max_text_len,
58
+ one_attn=one_attn,
59
+ )
60
+
61
+
62
+ def _get_large_config(
63
+ img_size=224,
64
+ patch_size=16,
65
+ drop_path_rate=0,
66
+ checkpoint_activations=None,
67
+ mlp_ratio=4,
68
+ vocab_size=64010,
69
+ encoder_layers=24,
70
+ encoder_embed_dim=1024,
71
+ encoder_attention_heads=16,
72
+ share_layer=False,
73
+ share_attn=False,
74
+ deepnorm=False,
75
+ mask_ratio=0,
76
+ max_text_len=52,
77
+ one_attn=False,
78
+ **kwargs
79
+ ):
80
+ return EncoderConfig(
81
+ img_size=img_size,
82
+ patch_size=patch_size,
83
+ vocab_size=vocab_size,
84
+ multiway=True,
85
+ layernorm_embedding=False,
86
+ normalize_output=True,
87
+ no_output_layer=True,
88
+ drop_path_rate=drop_path_rate,
89
+ encoder_embed_dim=encoder_embed_dim,
90
+ encoder_attention_heads=encoder_attention_heads,
91
+ encoder_layers=encoder_layers,
92
+ encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
93
+ checkpoint_activations=checkpoint_activations,
94
+ share_layer=share_layer,
95
+ share_attn=share_attn,
96
+ deepnorm=deepnorm,
97
+ mask_ratio=mask_ratio,
98
+ max_text_len=max_text_len,
99
+ one_attn=one_attn,
100
+ )
101
+
102
+
103
+ def _get_huge_config(
104
+ img_size=224,
105
+ patch_size=16,
106
+ drop_path_rate=0,
107
+ checkpoint_activations=None,
108
+ mlp_ratio=4,
109
+ vocab_size=30522,
110
+ encoder_layers=32,
111
+ encoder_embed_dim=4096,
112
+ encoder_attention_heads=32,
113
+ share_layer=False,
114
+ share_attn=False,
115
+ deepnorm=False,
116
+ mask_ratio=0,
117
+ max_text_len=52,
118
+ one_attn=False,
119
+ **kwargs
120
+ ):
121
+ return EncoderConfig(
122
+ img_size=img_size,
123
+ patch_size=patch_size,
124
+ vocab_size=vocab_size,
125
+ multiway=True,
126
+ layernorm_embedding=False,
127
+ normalize_output=True,
128
+ no_output_layer=True,
129
+ drop_path_rate=drop_path_rate,
130
+ encoder_embed_dim=encoder_embed_dim,
131
+ encoder_attention_heads=encoder_attention_heads,
132
+ encoder_layers=encoder_layers,
133
+ encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
134
+ checkpoint_activations=checkpoint_activations,
135
+ share_layer=share_layer,
136
+ share_attn=share_attn,
137
+ deepnorm=deepnorm,
138
+ mask_ratio=mask_ratio,
139
+ max_text_len=max_text_len,
140
+ one_attn=one_attn,
141
+ )
142
+
143
+
144
+ class BEiT3Wrapper(nn.Module):
145
+ def __init__(self, args, **kwargs):
146
+ super().__init__()
147
+ self.args = args
148
+ self.beit3 = BEiT3(args)
149
+ self.apply(self._init_weights)
150
+
151
+ def fix_init_weight(self):
152
+ def rescale(param, layer_id):
153
+ param.div_(math.sqrt(2.0 * layer_id))
154
+
155
+ for layer_id, layer in enumerate(self.blocks):
156
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
157
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
158
+
159
+ def get_num_layers(self):
160
+ return self.beit3.encoder.num_layers
161
+
162
+ @torch.jit.ignore
163
+ def no_weight_decay(self):
164
+ return {
165
+ "pos_embed",
166
+ "cls_token",
167
+ "beit3.encoder.embed_positions.A.weight",
168
+ "beit3.vision_embed.cls_token",
169
+ "logit_scale",
170
+ }
171
+
172
+ def _init_weights(self, m):
173
+ if isinstance(m, nn.Linear):
174
+ trunc_normal_(m.weight, std=0.02)
175
+ if isinstance(m, nn.Linear) and m.bias is not None:
176
+ nn.init.constant_(m.bias, 0)
177
+ elif isinstance(m, nn.LayerNorm):
178
+ nn.init.constant_(m.bias, 0)
179
+ nn.init.constant_(m.weight, 1.0)
vlmo/modules/multiway_transformer.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in
4
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5
+
6
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
7
+
8
+ Acknowledgments:
9
+ * The paper authors for releasing code and weights, thanks!
10
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
11
+ for some einops/einsum fun
12
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
13
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
14
+
15
+ DeiT model defs and weights from https://github.com/facebookresearch/deit,
16
+ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+ """
20
+ from functools import partial
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+ from timm.models.registry import register_model
28
+ from pytorch_lightning.utilities.distributed import rank_zero_info
29
+
30
+
31
+ class Mlp(nn.Module):
32
+ def __init__(
33
+ self,
34
+ in_features,
35
+ hidden_features=None,
36
+ out_features=None,
37
+ act_layer=nn.GELU,
38
+ drop=0.0,
39
+ ):
40
+ super().__init__()
41
+ out_features = out_features or in_features
42
+ hidden_features = hidden_features or in_features
43
+ self.fc1 = nn.Linear(in_features, hidden_features)
44
+ self.act = act_layer()
45
+ self.fc2 = nn.Linear(hidden_features, out_features)
46
+ self.drop = nn.Dropout(drop)
47
+
48
+ def forward(self, x):
49
+ x = self.fc1(x)
50
+ x = self.act(x)
51
+ x = self.drop(x)
52
+ x = self.fc2(x)
53
+ x = self.drop(x)
54
+ return x
55
+
56
+
57
+ class Attention(nn.Module):
58
+ def __init__(
59
+ self,
60
+ dim,
61
+ num_heads=8,
62
+ qkv_bias=False,
63
+ qk_scale=None,
64
+ attn_drop=0.0,
65
+ proj_drop=0.0,
66
+ ):
67
+ super().__init__()
68
+ self.num_heads = num_heads
69
+ head_dim = dim // num_heads
70
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
71
+ self.scale = qk_scale or head_dim**-0.5
72
+
73
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
74
+ if qkv_bias:
75
+ self.q_bias = nn.Parameter(torch.zeros(dim))
76
+ self.v_bias = nn.Parameter(torch.zeros(dim))
77
+ else:
78
+ self.q_bias = None
79
+ self.v_bias = None
80
+
81
+ self.attn_drop = nn.Dropout(attn_drop)
82
+ self.proj = nn.Linear(dim, dim)
83
+ self.proj_drop = nn.Dropout(proj_drop)
84
+
85
+ def forward(self, x, mask=None, relative_position_bias=None):
86
+ B, N, C = x.shape
87
+
88
+ qkv_bias = None
89
+ if self.q_bias is not None:
90
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
91
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
92
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
93
+
94
+ q, k, v = (
95
+ qkv[0],
96
+ qkv[1],
97
+ qkv[2],
98
+ ) # make torchscript happy (cannot use tensor as tuple)
99
+
100
+ q = q * self.scale
101
+ attn = q.float() @ k.float().transpose(-2, -1)
102
+
103
+ if relative_position_bias is not None:
104
+ attn = attn + relative_position_bias.unsqueeze(0)
105
+
106
+ if mask is not None:
107
+ mask = mask.bool()
108
+ attn = attn.masked_fill(~mask[:, None, None, :], float("-inf"))
109
+ attn = attn.softmax(dim=-1).type_as(x)
110
+ attn = self.attn_drop(attn)
111
+
112
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
113
+ x = self.proj(x)
114
+ x = self.proj_drop(x)
115
+ return x
116
+
117
+
118
+ class Block(nn.Module):
119
+ def __init__(
120
+ self,
121
+ dim,
122
+ num_heads,
123
+ mlp_ratio=4.0,
124
+ qkv_bias=False,
125
+ qk_scale=None,
126
+ drop=0.0,
127
+ attn_drop=0.0,
128
+ drop_path=0.0,
129
+ act_layer=nn.GELU,
130
+ norm_layer=nn.LayerNorm,
131
+ with_vlffn=False,
132
+ layer_scale_init_values=0.1,
133
+ max_text_len=40,
134
+ ):
135
+ super().__init__()
136
+ self.norm1 = norm_layer(dim)
137
+ self.attn = Attention(
138
+ dim,
139
+ num_heads=num_heads,
140
+ qkv_bias=qkv_bias,
141
+ qk_scale=qk_scale,
142
+ attn_drop=attn_drop,
143
+ proj_drop=drop,
144
+ )
145
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
146
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
147
+ self.norm2_text = norm_layer(dim)
148
+ self.norm2_imag = norm_layer(dim)
149
+ mlp_hidden_dim = int(dim * mlp_ratio)
150
+ self.mlp_text = Mlp(
151
+ in_features=dim,
152
+ hidden_features=mlp_hidden_dim,
153
+ act_layer=act_layer,
154
+ drop=drop,
155
+ )
156
+ self.mlp_imag = Mlp(
157
+ in_features=dim,
158
+ hidden_features=mlp_hidden_dim,
159
+ act_layer=act_layer,
160
+ drop=drop,
161
+ )
162
+ self.mlp_vl = None
163
+ if with_vlffn:
164
+ self.mlp_vl = Mlp(
165
+ in_features=dim,
166
+ hidden_features=mlp_hidden_dim,
167
+ act_layer=act_layer,
168
+ drop=drop,
169
+ )
170
+ self.norm2_vl = norm_layer(dim)
171
+
172
+ self.gamma_1 = (
173
+ nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
174
+ if layer_scale_init_values is not None
175
+ else 1.0
176
+ )
177
+ self.gamma_2 = (
178
+ nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
179
+ if layer_scale_init_values is not None
180
+ else 1.0
181
+ )
182
+
183
+ self.max_text_len = max_text_len
184
+
185
+ def forward(self, x, mask=None, modality_type=None, relative_position_bias=None):
186
+ x = x + self.drop_path(
187
+ self.gamma_1 * self.attn(self.norm1(x), mask=mask, relative_position_bias=relative_position_bias)
188
+ )
189
+
190
+ if modality_type == "image":
191
+ x = x + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x)))
192
+ elif modality_type == "text":
193
+ x = x + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x)))
194
+ else:
195
+ if self.mlp_vl is None:
196
+ x_text = x[:, : self.max_text_len]
197
+ x_imag = x[:, self.max_text_len :]
198
+ x_text = x_text + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x_text)))
199
+ x_imag = x_imag + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x_imag)))
200
+ x = torch.cat([x_text, x_imag], dim=1)
201
+ else:
202
+ x = x + self.drop_path(self.gamma_2 * self.mlp_vl(self.norm2_vl(x)))
203
+
204
+ return x
205
+
206
+
207
+ class PatchEmbed(nn.Module):
208
+ """Image to Patch Embedding"""
209
+
210
+ def __init__(
211
+ self,
212
+ img_size=224,
213
+ patch_size=16,
214
+ in_chans=3,
215
+ embed_dim=768,
216
+ no_patch_embed_bias=False,
217
+ ):
218
+ super().__init__()
219
+ img_size = to_2tuple(img_size)
220
+ patch_size = to_2tuple(patch_size)
221
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
222
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
223
+ self.img_size = img_size
224
+ self.patch_size = patch_size
225
+ self.num_patches = num_patches
226
+
227
+ self.proj = nn.Conv2d(
228
+ in_chans,
229
+ embed_dim,
230
+ kernel_size=patch_size,
231
+ stride=patch_size,
232
+ bias=False if no_patch_embed_bias else True,
233
+ )
234
+
235
+ def forward(self, x):
236
+ B, C, H, W = x.shape
237
+ assert (
238
+ H == self.img_size[0] and W == self.img_size[1]
239
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
240
+ # FIXME look at relaxing size constraints
241
+ x = self.proj(x)
242
+ return x
243
+
244
+
245
+ class MultiWayTransformer(nn.Module):
246
+ """Vision Transformer
247
+
248
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
249
+ https://arxiv.org/abs/2010.11929
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ img_size=224,
255
+ patch_size=16,
256
+ in_chans=3,
257
+ embed_dim=768,
258
+ depth=12,
259
+ num_heads=12,
260
+ mlp_ratio=4.0,
261
+ qkv_bias=True,
262
+ qk_scale=None,
263
+ drop_rate=0.0,
264
+ attn_drop_rate=0.0,
265
+ drop_path_rate=0.0,
266
+ norm_layer=None,
267
+ need_relative_position_embed=True,
268
+ use_abs_pos_emb=False,
269
+ layer_scale_init_values=0.1,
270
+ vlffn_start_layer_index=10,
271
+ config=None,
272
+ ):
273
+ """
274
+ Args:
275
+ img_size (int, tuple): input image size
276
+ patch_size (int, tuple): patch size
277
+ in_chans (int): number of input channels
278
+ num_classes (int): number of classes for classification head
279
+ embed_dim (int): embedding dimension
280
+ depth (int): depth of transformer
281
+ num_heads (int): number of attention heads
282
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
283
+ qkv_bias (bool): enable bias for qkv if True
284
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
285
+ drop_rate (float): dropout rate
286
+ attn_drop_rate (float): attention dropout rate
287
+ drop_path_rate (float): stochastic depth rate
288
+ norm_layer: (nn.Module): normalization layer
289
+ need_relative_position_embed (bool): enable relative position bias on self-attention
290
+ use_abs_pos_emb (bool): enable abs pos emb
291
+ layer_scale_init_values (float or None): layer scale init values, set None to disable
292
+ vlffn_start_layer_index (int): vl-ffn start index
293
+ config: (dict): other hyper from pytorch-lighting
294
+ """
295
+ super().__init__()
296
+ drop_path_rate = drop_path_rate if config is None else config["drop_path_rate"]
297
+ rank_zero_info("drop path rate: {}".format(drop_path_rate))
298
+ self.use_abs_pos_emb = use_abs_pos_emb
299
+ self.need_relative_position_embed = need_relative_position_embed
300
+
301
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
302
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
303
+
304
+ self.patch_embed = PatchEmbed(
305
+ img_size=img_size,
306
+ patch_size=patch_size,
307
+ in_chans=in_chans,
308
+ embed_dim=embed_dim,
309
+ )
310
+ num_patches = self.patch_embed.num_patches
311
+ self.patch_size = patch_size
312
+ self.num_heads = num_heads
313
+ self.vlffn_start_layer_index = vlffn_start_layer_index
314
+ if config["loss_names"]["textmlm"] > 0:
315
+ self.vlffn_start_layer_index = depth
316
+ rank_zero_info(
317
+ "Set vlffn_start_layer_index={} for text-only pretraining".format(self.vlffn_start_layer_index)
318
+ )
319
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
320
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if self.use_abs_pos_emb else None
321
+ self.pos_drop = nn.Dropout(p=drop_rate)
322
+
323
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
324
+ self.blocks = nn.ModuleList(
325
+ [
326
+ Block(
327
+ dim=embed_dim,
328
+ num_heads=num_heads,
329
+ mlp_ratio=mlp_ratio,
330
+ qkv_bias=qkv_bias,
331
+ qk_scale=qk_scale,
332
+ drop=drop_rate,
333
+ attn_drop=attn_drop_rate,
334
+ drop_path=dpr[i],
335
+ norm_layer=norm_layer,
336
+ with_vlffn=(i >= self.vlffn_start_layer_index),
337
+ layer_scale_init_values=layer_scale_init_values,
338
+ max_text_len=config["max_text_len"],
339
+ )
340
+ for i in range(depth)
341
+ ]
342
+ )
343
+ self.norm = norm_layer(embed_dim)
344
+
345
+ if self.pos_embed is not None:
346
+ trunc_normal_(self.pos_embed, std=0.02)
347
+ trunc_normal_(self.cls_token, std=0.02)
348
+ self.apply(self._init_weights)
349
+
350
+ def _init_weights(self, m):
351
+ if isinstance(m, nn.Linear):
352
+ trunc_normal_(m.weight, std=0.02)
353
+ if isinstance(m, nn.Linear) and m.bias is not None:
354
+ nn.init.constant_(m.bias, 0)
355
+ elif isinstance(m, nn.LayerNorm):
356
+ nn.init.constant_(m.bias, 0)
357
+ nn.init.constant_(m.weight, 1.0)
358
+
359
+ @torch.jit.ignore
360
+ def no_weight_decay(self):
361
+ return {"pos_embed", "cls_token"}
362
+
363
+ def visual_embed(self, _x):
364
+ x = self.patch_embed(_x)
365
+ x = x.flatten(2).transpose(1, 2)
366
+ B, L, _ = x.shape
367
+
368
+ cls_tokens = self.cls_token.expand(B, -1, -1)
369
+ x = torch.cat((cls_tokens, x), dim=1)
370
+
371
+ if self.pos_embed is not None:
372
+ x = x + self.pos_embed
373
+ x = self.pos_drop(x)
374
+
375
+ x_mask = torch.ones(x.shape[0], x.shape[1])
376
+
377
+ return x, x_mask
378
+
379
+
380
+ # VLMo base/p16
381
+ @register_model
382
+ def vlmo_base_patch16(pretrained=False, **kwargs):
383
+ img_size = kwargs.pop("img_size", 224)
384
+ model = MultiWayTransformer(
385
+ img_size=img_size,
386
+ patch_size=16,
387
+ embed_dim=768,
388
+ depth=12,
389
+ num_heads=12,
390
+ mlp_ratio=4,
391
+ qkv_bias=True,
392
+ vlffn_start_layer_index=10,
393
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
394
+ **kwargs,
395
+ )
396
+ return model
vlmo/modules/objectives.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def init_weights(module):
5
+ if isinstance(module, (nn.Linear, nn.Embedding)):
6
+ module.weight.data.normal_(mean=0.0, std=0.02)
7
+ elif isinstance(module, nn.LayerNorm):
8
+ module.bias.data.zero_()
9
+ module.weight.data.fill_(1.0)
10
+
11
+ if isinstance(module, nn.Linear) and module.bias is not None:
12
+ module.bias.data.zero_()
vlmo/modules/vlmo_module.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ from pytorch_lightning.utilities.distributed import rank_zero_info
11
+ from timm.models import create_model
12
+ from transformers import AutoTokenizer, BertTokenizer, XLMRobertaTokenizer # noqa
13
+ from vlmo.modules import heads, objectives, vlmo_utils
14
+ from vlmo.tokenizer.tokenization_glm import GLMChineseTokenizer # noqa
15
+ from vlmo.torchscale.architecture.encoder import Encoder
16
+ from vlmo.torchscale.model.BEiT3 import BEiT3 as ts_backbone
17
+ from vlmo.transforms.utils import inception_normalize as img_norm
18
+
19
+ from .modeling_utils import _get_base_config, _get_large_config, _get_huge_config, trunc_normal_ # noqa
20
+
21
+
22
+ def convert_pl_ckpt(state_dict, num_visual_token=197):
23
+ print("start convert_pl_ckpt!!!")
24
+ new_state_dict = {}
25
+ for key in state_dict:
26
+ value = state_dict[key]
27
+ if "visual_tokenizer" in key:
28
+ continue
29
+ elif "backbone.encoder.embed_positions.A.weight" in key:
30
+ if value.shape[0] < num_visual_token + 2:
31
+ N = value.shape[0] - 3
32
+ dim = value.shape[-1]
33
+ class_pos_embed = value[:3, ]
34
+ patch_pos_embed = value[3:, ]
35
+ w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
36
+ patch_pos_embed = patch_pos_embed.float()
37
+ patch_pos_embed = nn.functional.interpolate(
38
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
39
+ size=(w0, h0),
40
+ mode="area",
41
+ )
42
+ patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
43
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
44
+ new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
45
+ new_state_dict[key] = new_value
46
+ print("reshape ", key, "raw shape: ", value.shape, "new shape: ", new_value.shape, num_visual_token)
47
+ elif value.shape[0] > num_visual_token + 2:
48
+ new_state_dict[key] = value[: num_visual_token + 2, :]
49
+ print("first ", key, "raw shape: ", value.shape, new_state_dict[key].shape, num_visual_token)
50
+ else:
51
+ new_state_dict[key] = value
52
+ print("raw shape")
53
+ else:
54
+ new_state_dict[key] = state_dict[key]
55
+
56
+ return new_state_dict
57
+
58
+
59
+ def convert_deepspeed_ckpt(state_dict, num_visual_token=197):
60
+ new_state_dict = {}
61
+ for key in state_dict:
62
+ if key.startswith("_forward_module."):
63
+ new_key = key[len("_forward_module."):]
64
+ value = state_dict[key]
65
+ new_state_dict[new_key] = value
66
+ if "visual_tokenizer.encoder.pos_embed" in new_key or "visual_tokenizer.decoder.pos_embed" in new_key:
67
+ if value.shape[1] != num_visual_token:
68
+ N = value.shape[1] - 1
69
+ dim = value.shape[-1]
70
+ class_pos_embed = value[:, 0]
71
+ patch_pos_embed = value[:, 1:]
72
+ w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
73
+ patch_pos_embed = patch_pos_embed.float()
74
+ patch_pos_embed = nn.functional.interpolate(
75
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
76
+ size=(w0, h0),
77
+ mode="area",
78
+ )
79
+ patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
80
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
+ new_value = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
+ new_state_dict[new_key] = new_value
83
+ print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
84
+ if "backbone.encoder.embed_positions.A.weight" in new_key:
85
+ if value.shape[1] != num_visual_token + 2:
86
+ N = value.shape[0] - 3
87
+ dim = value.shape[-1]
88
+ class_pos_embed = value[:3, ]
89
+ patch_pos_embed = value[3:, ]
90
+ w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
91
+ patch_pos_embed = patch_pos_embed.float()
92
+ patch_pos_embed = nn.functional.interpolate(
93
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
94
+ size=(w0, h0),
95
+ mode="area",
96
+ )
97
+ patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
98
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
99
+ new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
100
+ new_state_dict[new_key] = new_value
101
+ print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
102
+
103
+ else:
104
+ new_state_dict[key] = state_dict[key]
105
+
106
+ return new_state_dict
107
+
108
+
109
+ def get_visual_tokenizer(config):
110
+ tokenizer_name = config["tokenizer_model"]
111
+ print(f"Creating visual tokenizer: {tokenizer_name}")
112
+ model = create_model(
113
+ config["tokenizer_model"],
114
+ img_size=config["image_size"],
115
+ n_code=config["codebook_size"],
116
+ code_dim=config["codebook_dim"],
117
+ ).eval()
118
+ return model
119
+
120
+
121
+ def get_pretrained_tokenizer(tokenizer_type, from_pretrained):
122
+ _Tokenizer = eval(f"{tokenizer_type}")
123
+ if torch.distributed.is_initialized():
124
+ if torch.distributed.get_rank() == 0:
125
+ _Tokenizer.from_pretrained(from_pretrained)
126
+ torch.distributed.barrier()
127
+ return _Tokenizer.from_pretrained(from_pretrained)
128
+
129
+
130
+ class VLMo(pl.LightningModule):
131
+ def __init__(self, config):
132
+ super().__init__()
133
+ self.save_hyperparameters()
134
+ s_t = time.time()
135
+
136
+ # tokenizer & backbone
137
+ self.img_size = config["image_size"]
138
+ if not config["test_only"]:
139
+ self.visual_tokenizer = get_visual_tokenizer(config)
140
+ kwargs = {}
141
+ if "encoder_attention_heads" in config:
142
+ kwargs["encoder_attention_heads"] = config["encoder_attention_heads"]
143
+ if "atorch_config" in config and config["atorch_config"]:
144
+ checkpoint_activations = False # ?
145
+ else:
146
+ checkpoint_activations = config["checkpoint_activations"]
147
+ args = eval(f'_get_{config["beit_version"]}_config')(
148
+ img_size=config["image_size"],
149
+ patch_size=config["patch_size"],
150
+ vocab_size=config["vocab_size"],
151
+ encoder_layers=config["encoder_layers"],
152
+ encoder_embed_dim=config["encoder_embed_dim"],
153
+ checkpoint_activations=checkpoint_activations,
154
+ share_layer=config["share_layer"],
155
+ share_attn=config["share_attn"],
156
+ deepnorm=config["deepnorm"],
157
+ mask_ratio=config["mask_ratio"],
158
+ max_text_len=config["max_text_len"],
159
+ one_attn=config["one_attn"],
160
+ **kwargs,
161
+ )
162
+ self.num_features = args.encoder_embed_dim
163
+ self.out_features = config["out_embed_dim"]
164
+ self.cap_onlytext = config["cap_onlytext"]
165
+ self.lang = config["lang"]
166
+ self.num_frames = config["num_frames"]
167
+ self.tokenizer_type = config["tokenizer_type"]
168
+ self.text_tokenizer = get_pretrained_tokenizer(self.tokenizer_type, from_pretrained=config["tokenizer"]) # noqa
169
+ print("BEiT args", args.__dict__)
170
+ self.backbone = ts_backbone(args)
171
+
172
+ self.use_vl = config["beit3_vl_layers"] > 0
173
+ if self.use_vl:
174
+ args.encoder_layers = config["beit3_vl_layers"]
175
+ self.backbone_vl = Encoder(args)
176
+
177
+ self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
178
+
179
+ # task layers
180
+ self.pooler = heads.Pooler(self.num_features)
181
+ self.pooler.apply(objectives.init_weights)
182
+
183
+ # contrastive loss (or sampling for global hard negative)
184
+ if config["loss_names"]["itc"] > 0:
185
+ self.itc_text_proj = heads.ITCHead(self.num_features, self.out_features)
186
+ self.itc_image_proj = heads.ITCHead(self.num_features, self.out_features)
187
+ self.itc_text_proj.apply(objectives.init_weights)
188
+ self.itc_image_proj.apply(objectives.init_weights)
189
+
190
+ self.itc_vl_text_proj = heads.ITCHead(self.num_features, self.out_features)
191
+ self.itc_vl_image_proj = heads.ITCHead(self.num_features, self.out_features)
192
+ self.itc_vl_text_proj.apply(objectives.init_weights)
193
+ self.itc_vl_image_proj.apply(objectives.init_weights)
194
+
195
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
196
+ self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
197
+
198
+ lp_s_t = time.time()
199
+
200
+ self.load_pretrained_weight()
201
+ load_pretrain_time = time.time() - lp_s_t
202
+
203
+ self.current_tasks = list()
204
+
205
+ # ===================== load downstream (test_only) ======================
206
+
207
+ if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
208
+ rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
209
+ ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
210
+
211
+ state_dict = None
212
+
213
+ for state_dict_key in ("state_dict", "module", "model"):
214
+ if state_dict_key in ckpt:
215
+ rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
216
+ state_dict = ckpt[state_dict_key]
217
+ break
218
+ if state_dict_key == "module":
219
+ state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
220
+ if state_dict_key == "state_dict":
221
+ state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
222
+ if state_dict is None:
223
+ if list(ckpt.keys())[0].startswith('_forward_module.'):
224
+ rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
225
+ state_dict = convert_deepspeed_ckpt(ckpt, self.backbone.vision_embed.num_position_embeddings())
226
+ else:
227
+ rank_zero_info("Read state dict from ckpt. ")
228
+ state_dict = ckpt
229
+
230
+ missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
231
+ rank_zero_info("missing_keys: {}".format(missing_keys))
232
+ rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
233
+
234
+ construct_time = time.time() - s_t
235
+ print(
236
+ f"Process {os.getpid()}. VLMo Constructor time: {construct_time}s;",
237
+ f"load_pretrain_time: {load_pretrain_time}s",
238
+ flush=True,
239
+ )
240
+ # coalesce backbone calls
241
+ self._coalesce_backbone = config["coalesce_backbone"]
242
+ self._mask_data = config["mask_data"]
243
+ self._backbone_inputs = {}
244
+ self._backbone_inputs_current_size = 0
245
+ self._backbone_inputs_keys = {}
246
+ self._backbone_outputs = None
247
+ self._default_attn_masks = {}
248
+ self._itc_group = None
249
+ self._itc_aggregate_dict = None
250
+ self._itc_mask = config["itc_mask"]
251
+ self._local_loss = config["local_loss"]
252
+ self._aggregate_nodes = config["aggregate_nodes"]
253
+ self.accumulated_batches_reached = False
254
+ vlmo_utils.set_task(self)
255
+ self._only_itc_single_machine = (
256
+ self._aggregate_nodes > 0 and len(self.current_tasks) == 1 and "itc" in self.current_tasks
257
+ )
258
+ self._split_data_for_imagemlm = config["split_data_for_imagemlm"]
259
+ self.log_metric_steps = config["log_metric_steps"]
260
+
261
+ def _init_weights(self, m):
262
+ if isinstance(m, nn.Linear):
263
+ trunc_normal_(m.weight, std=0.02)
264
+ if isinstance(m, nn.Linear) and m.bias is not None:
265
+ nn.init.constant_(m.bias, 0)
266
+ elif isinstance(m, nn.LayerNorm):
267
+ nn.init.constant_(m.bias, 0)
268
+ nn.init.constant_(m.weight, 1.0)
269
+
270
+ def fix_init_weight(self):
271
+ def rescale(param, layer_id):
272
+ param.div_(math.sqrt(2.0 * layer_id))
273
+
274
+ for layer_id, layer in enumerate(self.backbone.encoder.layers):
275
+ rescale(layer.self_attn.v_proj.A.weight.data, layer_id + 1)
276
+ rescale(layer.self_attn.v_proj.B.weight.data, layer_id + 1)
277
+ rescale(layer.self_attn.out_proj.A.weight.data, layer_id + 1)
278
+ rescale(layer.self_attn.out_proj.B.weight.data, layer_id + 1)
279
+ rescale(layer.ffn.A.fc2.weight.data, layer_id + 1)
280
+ rescale(layer.ffn.B.fc2.weight.data, layer_id + 1)
281
+
282
+ if self.use_vl:
283
+ pre_layers = len(self.backbone.encoder.layers) + 1
284
+ for layer_id, layer in enumerate(self.backbone_vl.layers):
285
+ rescale(layer.self_attn.v_proj.A.weight.data, layer_id + pre_layers)
286
+ rescale(layer.self_attn.v_proj.B.weight.data, layer_id + pre_layers)
287
+ rescale(layer.self_attn.out_proj.A.weight.data, layer_id + pre_layers)
288
+ rescale(layer.self_attn.out_proj.B.weight.data, layer_id + pre_layers)
289
+ rescale(layer.ffn.A.fc2.weight.data, layer_id + pre_layers)
290
+ rescale(layer.ffn.B.fc2.weight.data, layer_id + pre_layers)
291
+
292
+ def load_pretrained_weight(self):
293
+ if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]:
294
+ config = self.hparams.config
295
+ ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
296
+ rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
297
+
298
+ state_dict = None
299
+
300
+ for state_dict_key in ("state_dict", "module", "model"):
301
+ if state_dict_key in ckpt:
302
+ rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
303
+ state_dict = ckpt[state_dict_key]
304
+ break
305
+ if state_dict_key == "module":
306
+ state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
307
+ if state_dict_key == "state_dict":
308
+ state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
309
+ if state_dict is None:
310
+ if list(ckpt.keys())[0].startswith('_forward_module.'):
311
+ rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
312
+ state_dict = convert_deepspeed_ckpt(ckpt,
313
+ self.backbone.vision_embed.num_position_embeddings())
314
+ else:
315
+ rank_zero_info("Read state dict from ckpt. ")
316
+ state_dict = ckpt
317
+
318
+ missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
319
+ missing_keys = [k for k in missing_keys if "itc_teacher" not in k]
320
+ rank_zero_info("missing_keys: {}".format(missing_keys))
321
+ rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
322
+
323
+ def infer_text(
324
+ self,
325
+ batch,
326
+ mask_text=False,
327
+ ):
328
+ do_mlm = "_mlm" if mask_text else ""
329
+ text_ids = batch[f"text_ids{do_mlm}"]
330
+ text_labels = batch[f"text_labels{do_mlm}"]
331
+ text_masks = batch[f"text_masks"]
332
+ text_embed = self.backbone.text_embed(text_ids)
333
+ text_padding_position = 1 - text_masks
334
+ lffn_hiddens = self.backbone(
335
+ textual_tokens=text_ids,
336
+ text_padding_position=text_padding_position,
337
+ )["encoder_out"]
338
+ vlffn_hiddens = self.backbone_vl(
339
+ src_tokens=None,
340
+ token_embeddings=lffn_hiddens,
341
+ encoder_padding_mask=text_padding_position,
342
+ multiway_split_position=-1,
343
+ )["encoder_out"]
344
+
345
+ cls_feats = self.itc_text_proj(lffn_hiddens[:, 0])
346
+ cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
347
+
348
+ cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0])
349
+ cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
350
+
351
+ ret = {
352
+ "cls_feats": cls_feats,
353
+ "cls_vlffn_feats": cls_vlffn_feats,
354
+ "text_embed": text_embed,
355
+ }
356
+
357
+ return ret
358
+
359
+ def infer_image(
360
+ self,
361
+ batch,
362
+ mask_image=False,
363
+ image_token_type_idx=1,
364
+ image_embeds=None,
365
+ image_masks=None,
366
+ ):
367
+ if f"image_{image_token_type_idx - 1}" in batch:
368
+ imgkey = f"image_{image_token_type_idx - 1}"
369
+ else:
370
+ imgkey = "image"
371
+
372
+ img = batch[imgkey][0]
373
+ if mask_image:
374
+ image_masks = batch[f"{imgkey}_masks"][0].flatten(1)
375
+
376
+ with torch.no_grad():
377
+ img = self.visual_tokenizer.pre_process(img)
378
+ quantize, embed_ind, _ = self.visual_tokenizer.encode(img)
379
+ image_ids = embed_ind.view(img.shape[0], -1)
380
+
381
+ image_labels = torch.full_like(image_ids, -100)
382
+ bool_masked_pos = image_masks.to(torch.bool)
383
+ image_labels[bool_masked_pos] = image_ids[bool_masked_pos]
384
+
385
+ img_tensor = img_norm(img)
386
+ vffn_hiddens = self.backbone(visual_tokens=img_tensor)["encoder_out"]
387
+ vlffn_hiddens = self.backbone_vl(
388
+ src_tokens=None,
389
+ token_embeddings=vffn_hiddens,
390
+ multiway_split_position=-1,
391
+ )["encoder_out"]
392
+
393
+ cls_feats = self.itc_image_proj(vffn_hiddens[:, 0])
394
+ cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
395
+
396
+ cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0])
397
+ cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
398
+
399
+ ret = {
400
+ "image_feats": vffn_hiddens,
401
+ "cls_feats": cls_feats,
402
+ "cls_vlffn_feats": cls_vlffn_feats,
403
+ }
404
+
405
+ return ret
vlmo/modules/vlmo_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def set_task(pl_module):
2
+ pl_module.current_tasks = [k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1]
3
+ return
4
+
5
+
6
+ def no_sync_module_apply(module, fn):
7
+ """FSDP module .apply will use _unshard_params_recurse which will sync params across ranks.
8
+ using this function when apply fn is unnecessary to sync params across ranks.
9
+ """
10
+ for child in module.children():
11
+ fn(child)
12
+ no_sync_module_apply(child, fn)
vlmo/tokenizer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) Antfin, Inc. All rights reserved.
3
+
4
+ from __future__ import absolute_import
5
+ from __future__ import division
6
+ from __future__ import print_function
vlmo/tokenizer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (369 Bytes). View file
 
vlmo/tokenizer/__pycache__/tokenization_glm.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
vlmo/tokenizer/sp.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7fe3bcc8d284fcb782691411e8b6fd4f45d7245565b094de6ab795e66bcd32f
3
+ size 2270960
vlmo/tokenizer/tokenization_glm.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Optional, Tuple, List, Union
4
+
5
+ import sentencepiece as spm
6
+ import torch
7
+ from transformers import PreTrainedTokenizer
8
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
9
+ from transformers.tokenization_utils_base import BatchEncoding
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class GLMBatchEncoding(BatchEncoding):
16
+ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
17
+ """
18
+ Send all values to device by calling `v.to(device)` (PyTorch only).
19
+
20
+ Args:
21
+ device (`str` or `torch.device`): The device to put the tensors on.
22
+
23
+ Returns:
24
+ [`BatchEncoding`]: The same instance after modification.
25
+ """
26
+
27
+ # This check catches things like APEX blindly calling "to" on all inputs to a module
28
+ # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
29
+ # into a HalfTensor
30
+ if isinstance(device, str) or isinstance(device, int):
31
+ #if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
32
+ self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}
33
+ else:
34
+ logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
35
+ return self
36
+
37
+
38
+ class GLMTokenizerMixin:
39
+ @property
40
+ def sop_token(self) -> Optional[str]:
41
+ return "<|startofpiece|>"
42
+
43
+ @property
44
+ def sop_token_id(self) -> Optional[int]:
45
+ """
46
+ `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.
47
+ """
48
+ return self.convert_tokens_to_ids(self.sop_token)
49
+
50
+ @property
51
+ def eop_token(self) -> Optional[str]:
52
+ return "<|endofpiece|>"
53
+
54
+ @property
55
+ def eop_token_id(self) -> Optional[int]:
56
+ """
57
+ `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.
58
+ """
59
+ return self.convert_tokens_to_ids(self.eop_token)
60
+
61
+ @property
62
+ def gmask_token_id(self) -> int:
63
+ return self.convert_tokens_to_ids("[gMASK]")
64
+
65
+ @property
66
+ def smask_token_id(self) -> int:
67
+ return self.convert_tokens_to_ids("[sMASK]")
68
+
69
+ @property
70
+ def mask_token_ids(self):
71
+ return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
72
+
73
+ def _build_input_for_multiple_choice(self, context, choices):
74
+ context_id = context["input_ids"]
75
+ if torch.is_tensor(context_id):
76
+ context_id = context_id.tolist()
77
+
78
+ division = len(context_id)
79
+ mask_position = context_id.index(self.mask_token_id)
80
+
81
+ token = torch.tensor(context_id, dtype=torch.long)
82
+ attention_mask = [context["attention_mask"].expand(division, -1)]
83
+ position_id = torch.arange(division, dtype=torch.long)
84
+ block_position_id = torch.zeros(division, dtype=torch.long)
85
+
86
+ choice_ids, choice_indices = [], []
87
+
88
+ for choice_str in choices:
89
+ choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],
90
+ dtype=torch.long)
91
+ choice_ids.append(choice)
92
+ choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))
93
+ attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))
94
+
95
+ token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))
96
+ position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))
97
+ block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))
98
+
99
+ attention_mask = torch.block_diag(*attention_mask)
100
+ attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
101
+
102
+ return {
103
+ "input_ids": token,
104
+ "position_ids": torch.stack((position_id, block_position_id)),
105
+ "attention_mask": attention_mask,
106
+ "choice_ids": choice_ids,
107
+ "choice_indices": choice_indices
108
+ }
109
+
110
+ def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
111
+ pad_length = max_seq_length - len(tokens)
112
+ attention_mask = torch.nn.functional.pad(
113
+ attention_mask,
114
+ (0, pad_length, 0, pad_length),
115
+ mode="constant",
116
+ value=0,
117
+ )
118
+ tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))
119
+ position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)
120
+ return tokens, position_ids, attention_mask
121
+
122
+ def _collate(self, samples):
123
+ TILE = 1
124
+ length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
125
+
126
+ token_batch, position_id_batch, attention_mask_batch = [], [], []
127
+ choices_batch, choice_target_ids_batch = [], []
128
+
129
+ for sample in samples:
130
+ token, position_id, attention_mask = self._pad_batch(
131
+ sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
132
+ )
133
+ token_batch.append(token)
134
+ position_id_batch.append(position_id)
135
+ attention_mask_batch.append(attention_mask)
136
+ choices_batch.append(sample["choice_ids"])
137
+ choice_target_ids_batch.append(sample["choice_indices"])
138
+ return {
139
+ "input_ids": torch.stack(token_batch),
140
+ "position_ids": torch.stack(position_id_batch),
141
+ "attention_mask": torch.stack(attention_mask_batch).unsqueeze(1),
142
+ "choice_ids": choices_batch,
143
+ "choice_indices": choice_target_ids_batch,
144
+ }
145
+
146
+ def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):
147
+ samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))]
148
+ samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in
149
+ zip(samples, choices)]
150
+ inputs = self._collate(samples)
151
+ return GLMBatchEncoding(inputs)
152
+
153
+ def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):
154
+ mask_ids = self.mask_token_ids
155
+ input_ids = model_input.input_ids
156
+ batch_size, seq_length = input_ids.shape[:2]
157
+ position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
158
+ position_ids, block_position_ids = [], []
159
+ labels = None
160
+ if targets is not None:
161
+ is_batched = isinstance(targets, (list, tuple))
162
+ targets = self(targets, add_special_tokens=False, padding=False).input_ids
163
+ if not is_batched:
164
+ targets = [targets]
165
+ assert len(targets) == len(input_ids)
166
+ targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
167
+ if not padding:
168
+ max_gen_length = max(map(len, targets))
169
+ targets = [[self.sop_token_id] + target for target in targets]
170
+ labels = [target[1:] for target in targets]
171
+ targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]
172
+ labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
173
+ targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)
174
+ labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)
175
+ labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
176
+ for i in range(batch_size):
177
+ mask_positions = []
178
+ for mask_id in mask_ids:
179
+ mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()
180
+ if not mask_positions:
181
+ raise ValueError("Cannot find mask token in the input")
182
+ mask_positions.sort()
183
+ mask_pos = mask_positions[0]
184
+ position_ids.append(position_id + [mask_pos] * max_gen_length)
185
+ block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))
186
+ position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)
187
+ block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)
188
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
189
+ attention_mask = model_input.attention_mask
190
+ attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)
191
+ generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),
192
+ torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],
193
+ dim=0).unsqueeze(0).expand(batch_size, -1, -1)
194
+ attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
195
+ attention_mask = attention_mask.unsqueeze(1)
196
+ if targets is None:
197
+ input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
198
+ else:
199
+ input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)
200
+ batch = {"input_ids": input_ids, "position_ids": position_ids}
201
+ if labels is None:
202
+ batch["generation_attention_mask"] = attention_mask
203
+ else:
204
+ batch["attention_mask"] = attention_mask
205
+ batch["labels"] = labels
206
+ return BatchEncoding(batch)
207
+
208
+ def encode_whitespaces(content):
209
+ for i in range(10, 1, -1):
210
+ content = content.replace(' '*i, f'<|blank_{i}|>')
211
+ return content
212
+
213
+ def decode_whitespaces(content):
214
+ for i in range(10, 1, -1):
215
+ content = content.replace(f'<|blank_{i}|>', ' '*i)
216
+ return content
217
+
218
+
219
+ class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
220
+ vocab_files_names = {"vocab_file": "sp.model"}
221
+ truncation_side: str = "left"
222
+
223
+ def __init__(self, vocab_file, **kwargs):
224
+ self.vocab_file = vocab_file
225
+ self.sp_model = spm.SentencePieceProcessor()
226
+ self.sp_model.Load(vocab_file)
227
+ super().__init__(**kwargs)
228
+
229
+ @property
230
+ def vocab_size(self):
231
+ return len(self.sp_model)
232
+
233
+ def get_vocab(self):
234
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
235
+ vocab.update(self.added_tokens_encoder)
236
+ return vocab
237
+
238
+ def _tokenize(self, text, **kwargs):
239
+ text = encode_whitespaces(text)
240
+ return self.sp_model.EncodeAsPieces(text)
241
+ #return self.sp_model.EncodeAsPieces(text, out_type=str)
242
+
243
+ def _convert_token_to_id(self, token):
244
+ """Converts a token (str) in an id using the vocab."""
245
+ return self.sp_model.PieceToId(token)
246
+
247
+ def _convert_id_to_token(self, index):
248
+ """Converts an index (integer) in a token (str) using the vocab."""
249
+ return self.sp_model.IdToPiece(index)
250
+
251
+ def convert_tokens_to_string(self, tokens):
252
+ res = self.sp_model.DecodeIds(tokens)
253
+ return decode_whitespaces(res)
254
+
255
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
256
+ if not os.path.isdir(save_directory):
257
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
258
+ return
259
+ out_vocab_file = os.path.join(
260
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
261
+ )
262
+
263
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
264
+ copyfile(self.vocab_file, out_vocab_file)
265
+ elif not os.path.isfile(self.vocab_file):
266
+ with open(out_vocab_file, "wb") as fi:
267
+ content_spiece_model = self.sp_model.serialized_model_proto()
268
+ fi.write(content_spiece_model)
269
+
270
+ return (out_vocab_file,)
271
+
272
+ def build_inputs_with_special_tokens(
273
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
274
+ ) -> List[int]:
275
+ """
276
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
277
+ adding special tokens. A BERT sequence has the following format:
278
+
279
+ - single sequence: ``[CLS] X [SEP]``
280
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
281
+
282
+ Args:
283
+ token_ids_0 (:obj:`List[int]`):
284
+ List of IDs to which the special tokens will be added.
285
+ token_ids_1 (:obj:`List[int]`, `optional`):
286
+ Optional second list of IDs for sequence pairs.
287
+
288
+ Returns:
289
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
290
+ """
291
+ assert token_ids_1 is None
292
+ cls = [self.cls_token_id]
293
+ eos = [self.eos_token_id]
294
+ return cls + token_ids_0 + eos
295
+
296
+
297
+ class GLMTokenizer:
298
+ @classmethod
299
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
300
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
301
+ config_tokenizer_class = tokenizer_config.get("tokenizer_class")
302
+
303
+ if config_tokenizer_class == "GLMChineseTokenizer":
304
+ tokenizer_class = GLMChineseTokenizer
305
+ else:
306
+ raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class)
307
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
vlmo/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "THUDM/glm-10b-chinese",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "cls_token": "[CLS]",
6
+ "mask_token": "[MASK]",
7
+ "unk_token": "[UNK]",
8
+ "add_prefix_space": false,
9
+ "tokenizer_class": "GLMChineseTokenizer",
10
+ "use_fast": false,
11
+ "auto_map": {
12
+ "AutoTokenizer": [
13
+ "tokenization_glm.GLMChineseTokenizer",
14
+ null
15
+ ]
16
+ }
17
+ }
vlmo/torchscale/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
vlmo/torchscale/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (204 Bytes). View file
 
vlmo/torchscale/architecture/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
vlmo/torchscale/architecture/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (217 Bytes). View file
 
vlmo/torchscale/architecture/__pycache__/config.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
vlmo/torchscale/architecture/__pycache__/encoder.cpython-311.pyc ADDED
Binary file (22.8 kB). View file
 
vlmo/torchscale/architecture/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.58 kB). View file
 
vlmo/torchscale/architecture/config.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+
5
+ class EncoderConfig(object):
6
+ def __init__(self, **kwargs):
7
+ self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
8
+ self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
9
+ self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
10
+ self.encoder_layers = kwargs.pop("encoder_layers", 12)
11
+ self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
12
+ self.normalize_output = kwargs.pop("normalize_output", True)
13
+ self.activation_fn = kwargs.pop("activation_fn", "gelu")
14
+ self.dropout = kwargs.pop("dropout", 0.0)
15
+ self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
16
+ self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
17
+ self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
18
+ self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
19
+ self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
20
+ self.moe_freq = kwargs.pop("moe_freq", 0)
21
+ self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
22
+ self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
23
+ self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
24
+ self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
25
+ self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
26
+ self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
27
+ self.use_xmoe = kwargs.pop("use_xmoe", False)
28
+ self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
29
+ self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
30
+ self.deepnorm = kwargs.pop("deepnorm", False)
31
+ self.subln = kwargs.pop("subln", True)
32
+ self.bert_init = kwargs.pop("bert_init", False)
33
+ self.multiway = kwargs.pop("multiway", False)
34
+ self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
35
+ self.max_source_positions = kwargs.pop("max_source_positions", 1024)
36
+ self.no_output_layer = kwargs.pop("no_output_layer", False)
37
+ self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
38
+ self.share_layer = kwargs.pop("share_layer", False)
39
+ self.share_attn = kwargs.pop("share_attn", False)
40
+ self.mask_ratio = kwargs.pop("mask_ratio", 0)
41
+ self.max_text_len = kwargs.pop("max_text_len", 52)
42
+ self.one_attn = kwargs.pop('one_attn', False)
43
+
44
+
45
+ # Text
46
+ self.vocab_size = kwargs.pop("vocab_size", -1)
47
+ # Vision
48
+ self.img_size = kwargs.pop("img_size", 224)
49
+ self.patch_size = kwargs.pop("patch_size", 16)
50
+ self.in_chans = kwargs.pop("in_chans", 3)
51
+ # Fairscale
52
+ self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
53
+ self.fsdp = kwargs.pop("fsdp", False)
54
+ self.ddp_rank = kwargs.pop("ddp_rank", 0)
55
+ self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
56
+ self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
57
+
58
+ if self.deepnorm:
59
+ self.encoder_normalize_before = False
60
+ self.subln = False
61
+ if self.subln:
62
+ self.encoder_normalize_before = True
63
+ self.deepnorm = False
64
+ if self.use_xmoe:
65
+ self.moe_normalize_gate_prob_before_dropping = True
66
+ self.moe_second_expert_policy = "random"
67
+ assert self.moe_freq > 0 and self.moe_expert_count > 0
68
+
69
+ def override(self, args):
70
+ for hp in self.__dict__.keys():
71
+ if getattr(args, hp, None) is not None:
72
+ self.__dict__[hp] = getattr(args, hp, None)
73
+
74
+
75
+ class DecoderConfig(object):
76
+ def __init__(self, **kwargs):
77
+ self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
78
+ self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
79
+ self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
80
+ self.decoder_layers = kwargs.pop("decoder_layers", 12)
81
+ self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
82
+ self.activation_fn = kwargs.pop("activation_fn", "gelu")
83
+ self.dropout = kwargs.pop("dropout", 0.0)
84
+ self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
85
+ self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
86
+ self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
87
+ self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
88
+ self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
89
+ self.moe_freq = kwargs.pop("moe_freq", 0)
90
+ self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
91
+ self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
92
+ self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
93
+ self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
94
+ self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
95
+ self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
96
+ self.use_xmoe = kwargs.pop("use_xmoe", False)
97
+ self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
98
+ self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
99
+ self.deepnorm = kwargs.pop("deepnorm", False)
100
+ self.subln = kwargs.pop("subln", True)
101
+ self.bert_init = kwargs.pop("bert_init", False)
102
+ self.multiway = kwargs.pop("multiway", False)
103
+ self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
104
+ self.max_target_positions = kwargs.pop("max_target_positions", 1024)
105
+ self.no_output_layer = kwargs.pop("no_output_layer", False)
106
+ self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
107
+ # Text
108
+ self.vocab_size = kwargs.pop("vocab_size", -1)
109
+ # Fairscale
110
+ self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
111
+ self.fsdp = kwargs.pop("fsdp", False)
112
+ self.ddp_rank = kwargs.pop("ddp_rank", 0)
113
+ self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
114
+ self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
115
+
116
+ if self.deepnorm:
117
+ self.decoder_normalize_before = False
118
+ self.subln = False
119
+ if self.subln:
120
+ self.decoder_normalize_before = True
121
+ self.deepnorm = False
122
+ if self.use_xmoe:
123
+ self.moe_normalize_gate_prob_before_dropping = True
124
+ self.moe_second_expert_policy = "random"
125
+ assert self.moe_freq > 0 and self.moe_expert_count > 0
126
+
127
+ def override(self, args):
128
+ for hp in self.__dict__.keys():
129
+ if getattr(args, hp, None) is not None:
130
+ self.__dict__[hp] = getattr(args, hp, None)
131
+
132
+
133
+ class EncoderDecoderConfig(object):
134
+ def __init__(self, **kwargs):
135
+ self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
136
+ self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
137
+ self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
138
+ self.encoder_layers = kwargs.pop("encoder_layers", 12)
139
+ self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
140
+ self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
141
+ self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
142
+ self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
143
+ self.decoder_layers = kwargs.pop("decoder_layers", 12)
144
+ self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
145
+ self.activation_fn = kwargs.pop("activation_fn", "gelu")
146
+ self.dropout = kwargs.pop("dropout", 0.0)
147
+ self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
148
+ self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
149
+ self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
150
+ self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
151
+ self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
152
+ self.moe_freq = kwargs.pop("moe_freq", 0)
153
+ self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
154
+ self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
155
+ self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
156
+ self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
157
+ self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
158
+ self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
159
+ self.use_xmoe = kwargs.pop("use_xmoe", False)
160
+ self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
161
+ self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
162
+ self.deepnorm = kwargs.pop("deepnorm", False)
163
+ self.subln = kwargs.pop("subln", True)
164
+ self.bert_init = kwargs.pop("bert_init", False)
165
+ self.multiway = kwargs.pop("multiway", False)
166
+ self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
167
+ self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
168
+ self.max_source_positions = kwargs.pop("max_source_positions", 1024)
169
+ self.max_target_positions = kwargs.pop("max_target_positions", 1024)
170
+ self.no_output_layer = kwargs.pop("no_output_layer", False)
171
+ self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
172
+ # Text
173
+ self.vocab_size = kwargs.pop("vocab_size", -1)
174
+ # Fairscale
175
+ self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
176
+ self.fsdp = kwargs.pop("fsdp", False)
177
+ self.ddp_rank = kwargs.pop("ddp_rank", 0)
178
+ self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
179
+ self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
180
+
181
+ if self.deepnorm:
182
+ self.encoder_normalize_before = False
183
+ self.decoder_normalize_before = False
184
+ self.subln = False
185
+ if self.subln:
186
+ self.encoder_normalize_before = True
187
+ self.decoder_normalize_before = True
188
+ self.deepnorm = False
189
+ if self.use_xmoe:
190
+ self.moe_normalize_gate_prob_before_dropping = True
191
+ self.moe_second_expert_policy = "random"
192
+ assert self.moe_freq > 0 and self.moe_expert_count > 0
193
+
194
+ def override(self, args):
195
+ for hp in self.__dict__.keys():
196
+ if getattr(args, hp, None) is not None:
197
+ self.__dict__[hp] = getattr(args, hp, None)
vlmo/torchscale/architecture/decoder.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from fairscale.nn import checkpoint_wrapper, wrap
10
+
11
+ from vlmo.torchscale.architecture.utils import init_bert_params
12
+ from vlmo.torchscale.component.droppath import DropPath
13
+ from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
14
+ from vlmo.torchscale.component.multihead_attention import MultiheadAttention
15
+ from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
16
+ from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
17
+ from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
18
+
19
+ try:
20
+ from apex.normalization import FusedLayerNorm as LayerNorm
21
+ except ModuleNotFoundError:
22
+ from torch.nn import LayerNorm
23
+
24
+
25
+ class DecoderLayer(nn.Module):
26
+ def __init__(
27
+ self,
28
+ args,
29
+ depth,
30
+ is_moe_layer=False,
31
+ is_encoder_decoder=False,
32
+ ):
33
+ super().__init__()
34
+ self.args = args
35
+ self.embed_dim = args.decoder_embed_dim
36
+ self.dropout_module = torch.nn.Dropout(args.dropout)
37
+
38
+ if args.drop_path_rate > 0:
39
+ drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
40
+ self.drop_path = DropPath(drop_path_prob)
41
+ else:
42
+ self.drop_path = None
43
+
44
+ self.self_attn = self.build_self_attention(self.embed_dim, args)
45
+
46
+ self.normalize_before = args.decoder_normalize_before
47
+
48
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
49
+
50
+ if not is_encoder_decoder:
51
+ self.encoder_attn = None
52
+ self.encoder_attn_layer_norm = None
53
+ else:
54
+ self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
55
+ self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
56
+
57
+ self.is_moe_layer = is_moe_layer
58
+ self.ffn_dim = args.decoder_ffn_embed_dim
59
+
60
+ if not self.is_moe_layer:
61
+ self.ffn = self.build_ffn(
62
+ self.embed_dim,
63
+ self.args,
64
+ )
65
+ else:
66
+ if args.moe_top1_expert:
67
+ gate = Top1Gate(
68
+ self.embed_dim,
69
+ args.moe_expert_count,
70
+ use_fp32=args.moe_gating_use_fp32,
71
+ moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
72
+ use_xmoe=args.use_xmoe,
73
+ )
74
+ else:
75
+ gate = Top2Gate(
76
+ self.embed_dim,
77
+ args.moe_expert_count,
78
+ args.moe_gating_use_fp32,
79
+ args.moe_second_expert_policy,
80
+ args.moe_normalize_gate_prob_before_dropping,
81
+ args.moe_eval_capacity_token_fraction,
82
+ use_xmoe=args.use_xmoe,
83
+ )
84
+ experts = make_experts(args, self.embed_dim, self.ffn_dim)
85
+ self.moe_layer = MOELayer(gate, experts, args)
86
+
87
+ self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
88
+
89
+ if args.deepnorm:
90
+ if is_encoder_decoder:
91
+ self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
92
+ else:
93
+ self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
94
+ else:
95
+ self.alpha = 1.0
96
+
97
+ def build_ffn(self, embed_dim, args):
98
+ return FeedForwardNetwork(
99
+ embed_dim,
100
+ self.ffn_dim,
101
+ args.activation_fn,
102
+ args.dropout,
103
+ args.activation_dropout,
104
+ args.layernorm_eps,
105
+ args.subln,
106
+ )
107
+
108
+ def build_self_attention(self, embed_dim, args):
109
+ return MultiheadAttention(
110
+ args,
111
+ embed_dim,
112
+ args.decoder_attention_heads,
113
+ dropout=args.attention_dropout,
114
+ self_attention=True,
115
+ encoder_decoder_attention=False,
116
+ subln=args.subln,
117
+ )
118
+
119
+ def build_encoder_attention(self, embed_dim, args):
120
+ return MultiheadAttention(
121
+ args,
122
+ embed_dim,
123
+ args.decoder_attention_heads,
124
+ dropout=args.attention_dropout,
125
+ self_attention=False,
126
+ encoder_decoder_attention=True,
127
+ subln=args.subln,
128
+ )
129
+
130
+ def residual_connection(self, x, residual):
131
+ return residual * self.alpha + x
132
+
133
+ def forward(
134
+ self,
135
+ x,
136
+ encoder_out=None,
137
+ encoder_padding_mask=None,
138
+ incremental_state=None,
139
+ self_attn_mask=None,
140
+ self_attn_padding_mask=None,
141
+ self_attn_rel_pos=None,
142
+ cross_attn_rel_pos=None,
143
+ ):
144
+ residual = x
145
+ if self.normalize_before:
146
+ x = self.self_attn_layer_norm(x)
147
+
148
+ x, attn = self.self_attn(
149
+ query=x,
150
+ key=x,
151
+ value=x,
152
+ key_padding_mask=self_attn_padding_mask,
153
+ incremental_state=incremental_state,
154
+ attn_mask=self_attn_mask,
155
+ rel_pos=self_attn_rel_pos,
156
+ )
157
+ x = self.dropout_module(x)
158
+
159
+ if self.drop_path is not None:
160
+ x = self.drop_path(x)
161
+
162
+ x = self.residual_connection(x, residual)
163
+ if not self.normalize_before:
164
+ x = self.self_attn_layer_norm(x)
165
+
166
+ if self.encoder_attn is not None and encoder_out is not None:
167
+ residual = x
168
+ if self.normalize_before:
169
+ x = self.encoder_attn_layer_norm(x)
170
+
171
+ x, attn = self.encoder_attn(
172
+ query=x,
173
+ key=encoder_out,
174
+ value=encoder_out,
175
+ key_padding_mask=encoder_padding_mask,
176
+ incremental_state=None,
177
+ rel_pos=cross_attn_rel_pos,
178
+ )
179
+ x = self.dropout_module(x)
180
+
181
+ if self.drop_path is not None:
182
+ x = self.drop_path(x)
183
+
184
+ x = self.residual_connection(x, residual)
185
+ if not self.normalize_before:
186
+ x = self.encoder_attn_layer_norm(x)
187
+
188
+ residual = x
189
+ if self.normalize_before:
190
+ x = self.final_layer_norm(x)
191
+ if not self.is_moe_layer:
192
+ x = self.ffn(x)
193
+ l_aux = None
194
+ else:
195
+ x, l_aux = self.moe_layer(x)
196
+
197
+ if self.drop_path is not None:
198
+ x = self.drop_path(x)
199
+
200
+ x = self.residual_connection(x, residual)
201
+ if not self.normalize_before:
202
+ x = self.final_layer_norm(x)
203
+
204
+ return x, attn, None, l_aux
205
+
206
+
207
+ class Decoder(nn.Module):
208
+ def __init__(
209
+ self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
210
+ ):
211
+ super().__init__(**kwargs)
212
+ self.args = args
213
+
214
+ self.dropout_module = torch.nn.Dropout(args.dropout)
215
+
216
+ embed_dim = args.decoder_embed_dim
217
+ self.embed_dim = embed_dim
218
+ self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
219
+
220
+ self.embed_tokens = embed_tokens
221
+ self.embed_positions = embed_positions
222
+
223
+ if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
224
+ self.output_projection = self.build_output_projection(args)
225
+ else:
226
+ self.output_projection = output_projection
227
+
228
+ if args.layernorm_embedding:
229
+ self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
230
+ else:
231
+ self.layernorm_embedding = None
232
+
233
+ self.layers = nn.ModuleList([])
234
+
235
+ moe_freq = args.moe_freq
236
+ for i in range(args.decoder_layers):
237
+ is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
238
+ self.layers.append(
239
+ self.build_decoder_layer(
240
+ args,
241
+ depth=i,
242
+ is_moe_layer=is_moe_layer,
243
+ is_encoder_decoder=is_encoder_decoder,
244
+ )
245
+ )
246
+
247
+ self.num_layers = len(self.layers)
248
+
249
+ if args.decoder_normalize_before:
250
+ self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
251
+ else:
252
+ self.layer_norm = None
253
+
254
+ self.self_attn_relative_position = None
255
+ self.cross_attn_relative_position = None
256
+
257
+ if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
258
+ self.self_attn_relative_position = RelativePositionBias(
259
+ num_buckets=args.rel_pos_buckets,
260
+ max_distance=args.max_rel_pos,
261
+ n_heads=args.decoder_attention_heads,
262
+ )
263
+ if is_encoder_decoder:
264
+ self.cross_attn_relative_position = RelativePositionBias(
265
+ num_buckets=args.rel_pos_buckets,
266
+ max_distance=args.max_rel_pos,
267
+ n_heads=args.decoder_attention_heads,
268
+ )
269
+
270
+ if args.bert_init:
271
+ self.apply(init_bert_params)
272
+
273
+ if args.deepnorm:
274
+ if is_encoder_decoder:
275
+ init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
276
+ else:
277
+ init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
278
+ for name, p in self.named_parameters():
279
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
280
+ p.data.div_(init_scale)
281
+
282
+ if args.subln:
283
+ if is_encoder_decoder:
284
+ init_scale = math.sqrt(math.log(args.decoder_layers * 3))
285
+ else:
286
+ init_scale = math.sqrt(math.log(args.decoder_layers * 2))
287
+ for name, p in self.named_parameters():
288
+ if "encoder_attn" in name:
289
+ continue
290
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
291
+ p.data.mul_(init_scale)
292
+
293
+ def build_output_projection(
294
+ self,
295
+ args,
296
+ ):
297
+ if args.share_decoder_input_output_embed:
298
+ output_projection = torch.nn.Linear(
299
+ self.embed_tokens.weight.shape[1],
300
+ self.embed_tokens.weight.shape[0],
301
+ bias=False,
302
+ )
303
+ output_projection.weight = self.embed_tokens.weight
304
+ else:
305
+ output_projection = torch.nn.Linear(args.decoder_embed_dim, args.vocab_size, bias=False)
306
+ torch.nn.init.normal_(output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5)
307
+ return output_projection
308
+
309
+ def build_decoder_layer(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
310
+ layer = DecoderLayer(
311
+ args,
312
+ depth,
313
+ is_moe_layer=is_moe_layer,
314
+ is_encoder_decoder=is_encoder_decoder,
315
+ )
316
+ if args.checkpoint_activations:
317
+ layer = checkpoint_wrapper(layer)
318
+ if args.fsdp:
319
+ layer = wrap(layer)
320
+ return layer
321
+
322
+ def forward_embedding(
323
+ self,
324
+ tokens,
325
+ token_embedding=None,
326
+ incremental_state=None,
327
+ ):
328
+ positions = None
329
+ if self.embed_positions is not None:
330
+ positions = self.embed_positions(tokens, incremental_state=incremental_state)
331
+
332
+ if incremental_state is not None:
333
+ tokens = tokens[:, -1:]
334
+ if positions is not None:
335
+ positions = positions[:, -1:]
336
+
337
+ if token_embedding is None:
338
+ token_embedding = self.embed_tokens(tokens)
339
+
340
+ x = embed = self.embed_scale * token_embedding
341
+
342
+ if positions is not None:
343
+ x += positions
344
+
345
+ if self.layernorm_embedding is not None:
346
+ x = self.layernorm_embedding(x)
347
+
348
+ x = self.dropout_module(x)
349
+
350
+ return x, embed
351
+
352
+ def forward(
353
+ self,
354
+ prev_output_tokens,
355
+ self_attn_padding_mask=None,
356
+ encoder_out=None,
357
+ incremental_state=None,
358
+ features_only=False,
359
+ return_all_hiddens=False,
360
+ token_embeddings=None,
361
+ **kwargs
362
+ ):
363
+ # embed tokens and positions
364
+ x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
365
+
366
+ # relative position
367
+ self_attn_rel_pos_bias = None
368
+ slen = prev_output_tokens.size(1)
369
+ if self.self_attn_relative_position is not None:
370
+ self_attn_rel_pos_bias = self.self_attn_relative_position(batch_size=x.size(0), qlen=slen, klen=slen)
371
+ if incremental_state is not None:
372
+ self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
373
+ cross_attn_rel_pos_bias = None
374
+ if self.cross_attn_relative_position is not None:
375
+ cross_attn_rel_pos_bias = self.cross_attn_relative_position(
376
+ batch_size=x.size(0),
377
+ qlen=slen,
378
+ klen=encoder_out["encoder_out"].size(1),
379
+ )
380
+ if incremental_state is not None:
381
+ cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
382
+
383
+ # decoder layers
384
+ inner_states = [x]
385
+
386
+ if encoder_out is None:
387
+ l_aux = []
388
+ else:
389
+ l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
390
+
391
+ for idx, layer in enumerate(self.layers):
392
+ if incremental_state is None:
393
+ self_attn_mask = torch.triu(
394
+ torch.zeros([x.size(1), x.size(1)]).float().fill_(float("-inf")).type_as(x),
395
+ 1,
396
+ )
397
+ else:
398
+ self_attn_mask = None
399
+ if idx not in incremental_state:
400
+ incremental_state[idx] = {}
401
+
402
+ x, layer_attn, _, l_aux_i = layer(
403
+ x,
404
+ encoder_out["encoder_out"] if encoder_out is not None else None,
405
+ encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
406
+ incremental_state[idx] if incremental_state is not None else None,
407
+ self_attn_mask=self_attn_mask,
408
+ self_attn_padding_mask=self_attn_padding_mask,
409
+ self_attn_rel_pos=self_attn_rel_pos_bias,
410
+ cross_attn_rel_pos=cross_attn_rel_pos_bias,
411
+ )
412
+ l_aux.append(l_aux_i)
413
+ inner_states.append(x)
414
+
415
+ if self.layer_norm is not None:
416
+ x = self.layer_norm(x)
417
+
418
+ if not features_only:
419
+ x = self.output_layer(x)
420
+
421
+ return x, {
422
+ "inner_states": inner_states,
423
+ "l_aux": l_aux,
424
+ "attn": None,
425
+ }
426
+
427
+ def output_layer(self, features):
428
+ return self.output_projection(features)
vlmo/torchscale/architecture/encoder.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from fairscale.nn import checkpoint_wrapper, wrap
10
+
11
+ try:
12
+ from apex.normalization import FusedLayerNorm as LayerNorm
13
+ except ModuleNotFoundError:
14
+ from torch.nn import LayerNorm
15
+
16
+ from vlmo.torchscale.architecture.utils import init_bert_params
17
+ from vlmo.torchscale.component.droppath import DropPath
18
+ from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
19
+ from vlmo.torchscale.component.multihead_attention import MultiheadAttention
20
+ from vlmo.torchscale.component.multiway_network import MultiwayWrapper, set_split_position
21
+ from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
22
+ from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
23
+ from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
24
+ from vlmo.modules.vlmo_utils import no_sync_module_apply
25
+ from pytorch_lightning.utilities.distributed import rank_zero_info
26
+
27
+
28
+ class EncoderLayer(nn.Module):
29
+ def __init__(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
30
+ super().__init__()
31
+ self.args = args
32
+ self.embed_dim = args.encoder_embed_dim
33
+ self.self_attn = self.build_self_attention(self.embed_dim, args) if attn is None else attn
34
+ self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
35
+ self.dropout_module = torch.nn.Dropout(args.dropout)
36
+
37
+ if args.drop_path_rate > 0:
38
+ drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
39
+ self.drop_path = DropPath(drop_path_prob)
40
+ else:
41
+ self.drop_path = None
42
+
43
+ self.normalize_before = args.encoder_normalize_before
44
+ self.is_moe_layer = is_moe_layer
45
+ self.ffn_dim = args.encoder_ffn_embed_dim
46
+
47
+ if not self.is_moe_layer:
48
+ self.ffn = MultiwayWrapper(
49
+ args,
50
+ self.build_ffn(
51
+ self.embed_dim,
52
+ self.args,
53
+ ),
54
+ )
55
+ else:
56
+ assert not self.args.multiway
57
+ if args.moe_top1_expert:
58
+ gate = Top1Gate(
59
+ self.embed_dim,
60
+ args.moe_expert_count,
61
+ use_fp32=args.moe_gating_use_fp32,
62
+ moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
63
+ use_xmoe=args.use_xmoe,
64
+ )
65
+ else:
66
+ gate = Top2Gate(
67
+ self.embed_dim,
68
+ args.moe_expert_count,
69
+ args.moe_gating_use_fp32,
70
+ args.moe_second_expert_policy,
71
+ args.moe_normalize_gate_prob_before_dropping,
72
+ args.moe_eval_capacity_token_fraction,
73
+ use_xmoe=args.use_xmoe,
74
+ )
75
+ experts = make_experts(args, self.embed_dim, self.ffn_dim)
76
+ self.moe_layer = MOELayer(gate, experts, args)
77
+ self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
78
+
79
+ if args.deepnorm:
80
+ if is_encoder_decoder:
81
+ self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
82
+ else:
83
+ self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
84
+ else:
85
+ self.alpha = 1.0
86
+
87
+ def build_ffn(self, embed_dim, args):
88
+ return FeedForwardNetwork(
89
+ embed_dim,
90
+ self.ffn_dim,
91
+ args.activation_fn,
92
+ args.dropout,
93
+ args.activation_dropout,
94
+ args.layernorm_eps,
95
+ args.subln,
96
+ )
97
+
98
+ def build_self_attention(self, embed_dim, args):
99
+ return MultiheadAttention(
100
+ args,
101
+ embed_dim,
102
+ args.encoder_attention_heads,
103
+ dropout=args.attention_dropout,
104
+ self_attention=True,
105
+ encoder_decoder_attention=False,
106
+ subln=args.subln,
107
+ one_attn=args.one_attn,
108
+ )
109
+
110
+ def residual_connection(self, x, residual):
111
+ return residual * self.alpha + x
112
+
113
+ def forward(
114
+ self,
115
+ x,
116
+ encoder_padding_mask,
117
+ attn_mask=None,
118
+ rel_pos=None,
119
+ multiway_split_position=None,
120
+ incremental_state=None,
121
+ ):
122
+ if multiway_split_position is not None:
123
+ assert self.args.multiway
124
+ no_sync_module_apply(self, set_split_position(multiway_split_position))
125
+
126
+ if attn_mask is not None:
127
+ # float16: -1e8 equal 0
128
+ attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
129
+
130
+ residual = x
131
+ if self.normalize_before:
132
+ x = self.self_attn_layer_norm(x)
133
+ x, _ = self.self_attn(
134
+ query=x,
135
+ key=x,
136
+ value=x,
137
+ key_padding_mask=encoder_padding_mask,
138
+ attn_mask=attn_mask,
139
+ rel_pos=rel_pos,
140
+ incremental_state=incremental_state,
141
+ )
142
+ x = self.dropout_module(x)
143
+
144
+ if self.drop_path is not None:
145
+ x = self.drop_path(x)
146
+
147
+ x = self.residual_connection(x, residual)
148
+ if not self.normalize_before:
149
+ x = self.self_attn_layer_norm(x)
150
+
151
+ residual = x
152
+ if self.normalize_before:
153
+ x = self.final_layer_norm(x)
154
+ if not self.is_moe_layer:
155
+ x = self.ffn(x)
156
+ l_aux = None
157
+ else:
158
+ x = x.transpose(0, 1)
159
+ x, l_aux = self.moe_layer(x)
160
+ x = x.transpose(0, 1)
161
+
162
+ if self.drop_path is not None:
163
+ x = self.drop_path(x)
164
+
165
+ x = self.residual_connection(x, residual)
166
+ if not self.normalize_before:
167
+ x = self.final_layer_norm(x)
168
+ return x, l_aux
169
+
170
+
171
+ class Encoder(nn.Module):
172
+ def __init__(
173
+ self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
174
+ ):
175
+ self.args = args
176
+ super().__init__(**kwargs)
177
+
178
+ self.dropout_module = torch.nn.Dropout(args.dropout)
179
+
180
+ embed_dim = args.encoder_embed_dim
181
+ self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
182
+ self.mask_ratio = args.mask_ratio
183
+ self.max_text_len = args.max_text_len
184
+ self.vision_len = (args.img_size // args.patch_size) * (args.img_size // args.patch_size)
185
+
186
+ self.embed_tokens = embed_tokens
187
+ self.embed_positions = embed_positions
188
+
189
+ if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
190
+ self.output_projection = self.build_output_projection(args)
191
+ else:
192
+ self.output_projection = output_projection
193
+
194
+ if args.layernorm_embedding:
195
+ self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1)
196
+ else:
197
+ self.layernorm_embedding = None
198
+
199
+ self.layers = nn.ModuleList([])
200
+ if self.args.share_layer:
201
+ single_layer = self.build_encoder_layer(
202
+ args, depth=0, is_moe_layer=False, is_encoder_decoder=is_encoder_decoder
203
+ )
204
+ for i in range(args.encoder_layers):
205
+ self.layers.append(single_layer)
206
+ elif self.args.share_attn:
207
+ moe_freq = args.moe_freq
208
+ embed_dim = args.encoder_embed_dim
209
+ shared_attn = self.build_self_attention(embed_dim, self.args)
210
+ for i in range(args.encoder_layers):
211
+ is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
212
+ self.layers.append(
213
+ self.build_encoder_layer(
214
+ args,
215
+ depth=i,
216
+ attn=shared_attn,
217
+ is_moe_layer=is_moe_layer,
218
+ is_encoder_decoder=is_encoder_decoder,
219
+ )
220
+ )
221
+
222
+ else:
223
+ moe_freq = args.moe_freq
224
+ for i in range(args.encoder_layers):
225
+ is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
226
+ self.layers.append(
227
+ self.build_encoder_layer(
228
+ args,
229
+ depth=i,
230
+ is_moe_layer=is_moe_layer,
231
+ is_encoder_decoder=is_encoder_decoder,
232
+ )
233
+ )
234
+ self.num_layers = len(self.layers)
235
+
236
+ if args.encoder_normalize_before and args.normalize_output:
237
+ self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
238
+ else:
239
+ self.layer_norm = None
240
+
241
+ if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
242
+ self.relative_position = RelativePositionBias(
243
+ num_buckets=args.rel_pos_buckets,
244
+ max_distance=args.max_rel_pos,
245
+ n_heads=args.encoder_attention_heads,
246
+ )
247
+ else:
248
+ self.relative_position = None
249
+
250
+ if args.bert_init:
251
+ self.apply(init_bert_params)
252
+
253
+ if args.deepnorm:
254
+ if is_encoder_decoder:
255
+ init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
256
+ else:
257
+ init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
258
+ for name, p in self.named_parameters():
259
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
260
+ p.data.div_(init_scale)
261
+
262
+ if args.subln:
263
+ if is_encoder_decoder:
264
+ init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
265
+ else:
266
+ init_scale = math.sqrt(math.log(args.encoder_layers * 2))
267
+ for name, p in self.named_parameters():
268
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
269
+ p.data.mul_(init_scale)
270
+
271
+ def random_masking(self, x, mask_ratio):
272
+ N, L, D = x.shape # batch, length, dim
273
+ len_keep = int(L * (1 - mask_ratio))
274
+
275
+ noise = torch.rand(N, L - 1, device=x.device)
276
+ ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device, dtype=int)
277
+ ids_keep = ids_shuffle[:, :len_keep]
278
+
279
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
280
+
281
+ x0 = x[:, 0, :]
282
+ x0 = x0.reshape(N, 1, D)
283
+ x_masked_add = torch.cat([x0, x_masked], axis=1)
284
+ return x_masked_add, ids_keep
285
+
286
+ def build_self_attention(self, embed_dim, args):
287
+ return MultiheadAttention(
288
+ args,
289
+ embed_dim,
290
+ args.encoder_attention_heads,
291
+ dropout=args.attention_dropout,
292
+ self_attention=True,
293
+ encoder_decoder_attention=False,
294
+ subln=args.subln,
295
+ one_attn=args.one_attn,
296
+ )
297
+
298
+ def build_output_projection(
299
+ self,
300
+ args,
301
+ ):
302
+ if args.share_encoder_input_output_embed:
303
+ assert args.encoder_embedding_type == "language"
304
+ output_projection = torch.nn.Linear(
305
+ self.embed_tokens.weight.shape[1],
306
+ self.embed_tokens.weight.shape[0],
307
+ bias=False,
308
+ )
309
+ output_projection.weight = self.embed_tokens.weight
310
+ else:
311
+ output_projection = torch.nn.Linear(args.encoder_embed_dim, args.vocab_size, bias=False)
312
+ torch.nn.init.normal_(output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5)
313
+ return output_projection
314
+
315
+ def checkpointing_and_params_allgather(
316
+ self,
317
+ origin_layer,
318
+ ):
319
+ origin_forward = origin_layer.forward
320
+
321
+ from deepspeed import checkpointing
322
+ def forward(*args, **kwargs):
323
+ # deepspeed checkpoint not support kwargs
324
+ ret = checkpointing.checkpoint(origin_forward, *args, **kwargs)
325
+ return ret
326
+
327
+ return forward
328
+
329
+ def build_encoder_layer(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
330
+ layer = EncoderLayer(
331
+ args,
332
+ depth,
333
+ attn,
334
+ is_moe_layer=is_moe_layer,
335
+ is_encoder_decoder=is_encoder_decoder,
336
+ )
337
+ if args.checkpoint_activations:
338
+ rank_zero_info("EncoderLayer params: %s", sum(p.numel() for p in layer.parameters() if p.requires_grad))
339
+ layer = checkpoint_wrapper(layer)
340
+ # layer.ffn = checkpoint_wrapper(layer.ffn,)
341
+ if args.fsdp:
342
+ layer = wrap(layer)
343
+ return layer
344
+
345
+ def checkpointing_layers(self):
346
+ for i, layer in enumerate(self.layers):
347
+ rank_zero_info(f"Checkpointing wrapper EncoderLayers: {i}")
348
+ self.layers[i] = checkpoint_wrapper(layer)
349
+
350
+ def forward_embedding(
351
+ self,
352
+ src_tokens,
353
+ token_embedding=None,
354
+ positions=None,
355
+ ):
356
+ if token_embedding is None:
357
+ token_embedding = self.embed_tokens(src_tokens)
358
+ x = embed = self.embed_scale * token_embedding
359
+ if self.embed_positions is not None:
360
+ if src_tokens is not None:
361
+ x = embed + self.embed_positions(src_tokens, positions=positions)
362
+ else:
363
+ x = embed + self.embed_positions(x, positions=positions)
364
+ is_flip, ids_keep = 0, None
365
+ if self.mask_ratio > 0:
366
+ if x.shape[1] == self.vision_len + 1:
367
+ x, ids_keep = self.random_masking(x, self.mask_ratio)
368
+ is_flip = 1
369
+ elif x.shape[1] == self.vision_len + self.max_text_len + 1:
370
+ vision_tokens = x[:, : self.vision_len + 1, :]
371
+ vision_tokens, ids_keep = self.random_masking(vision_tokens, self.mask_ratio)
372
+ x = torch.cat(
373
+ [
374
+ vision_tokens,
375
+ x[
376
+ :,
377
+ self.vision_len + 1 :,
378
+ ],
379
+ ],
380
+ dim=1,
381
+ )
382
+ is_flip = 2
383
+ if self.layernorm_embedding is not None:
384
+ x = self.layernorm_embedding(x)
385
+ x = self.dropout_module(x)
386
+ return x, embed, ids_keep, is_flip
387
+
388
+ def forward(
389
+ self,
390
+ src_tokens,
391
+ encoder_padding_mask=None,
392
+ attn_mask=None,
393
+ return_all_hiddens=False,
394
+ token_embeddings=None,
395
+ multiway_split_position=None,
396
+ features_only=False,
397
+ incremental_state=None,
398
+ positions=None,
399
+ **kwargs
400
+ ):
401
+ assert src_tokens is not None or token_embeddings is not None
402
+
403
+ if encoder_padding_mask is None:
404
+ if src_tokens is not None:
405
+ encoder_padding_mask = torch.zeros_like(src_tokens, device=src_tokens.device).bool()
406
+ else:
407
+ encoder_padding_mask = torch.zeros(
408
+ [token_embeddings.size(0), token_embeddings.size(1)],
409
+ device=token_embeddings.device,
410
+ ).bool()
411
+
412
+ if multiway_split_position is not None:
413
+ assert self.args.multiway
414
+ no_sync_module_apply(self, set_split_position(multiway_split_position))
415
+
416
+ x, encoder_embedding, ids_keep, is_flip = self.forward_embedding(src_tokens, token_embeddings, positions)
417
+ if is_flip > 0:
418
+ if is_flip == 2:
419
+ text_ids = (
420
+ torch.arange(
421
+ self.vision_len + 1, self.vision_len + 1 + self.max_text_len, device=x.device, dtype=torch.int64
422
+ )
423
+ .unsqueeze(0)
424
+ .repeat(ids_keep.shape[0], 1)
425
+ )
426
+ cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
427
+ ids_keep = torch.cat([cls_ids, ids_keep, text_ids], dim=1)
428
+ elif is_flip == 1:
429
+ cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
430
+ ids_keep = torch.cat([cls_ids, ids_keep], dim=1)
431
+ if encoder_padding_mask is not None:
432
+ encoder_padding_mask = torch.gather(encoder_padding_mask, dim=1, index=ids_keep)
433
+ if attn_mask is not None:
434
+ attn_mask = torch.gather(
435
+ attn_mask, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, attn_mask.shape[-1])
436
+ )
437
+ attn_mask = torch.gather(attn_mask, dim=2, index=ids_keep.unsqueeze(1).repeat(1, attn_mask.shape[1], 1))
438
+ if multiway_split_position > 0:
439
+ multiway_split_position = ids_keep.shape[1] - self.max_text_len
440
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
441
+
442
+ encoder_states = []
443
+
444
+ if return_all_hiddens:
445
+ encoder_states.append(x)
446
+
447
+ rel_pos_bias = None
448
+ if self.relative_position is not None:
449
+ rel_pos_bias = self.relative_position(batch_size=x.size(0), qlen=x.size(1), klen=x.size(1))
450
+
451
+ l_aux = []
452
+ for idx, layer in enumerate(self.layers):
453
+ x, l_aux_i = layer(
454
+ x,
455
+ encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
456
+ attn_mask=attn_mask,
457
+ rel_pos=rel_pos_bias,
458
+ multiway_split_position=multiway_split_position,
459
+ incremental_state=incremental_state[idx] if incremental_state is not None else None,
460
+ )
461
+ if return_all_hiddens:
462
+ assert encoder_states is not None
463
+ encoder_states.append(x)
464
+ l_aux.append(l_aux_i)
465
+
466
+ if multiway_split_position is not None:
467
+ assert self.args.multiway
468
+ no_sync_module_apply(self, set_split_position(multiway_split_position))
469
+ if self.layer_norm is not None:
470
+ x = self.layer_norm(x)
471
+
472
+ if not features_only and self.output_projection is not None:
473
+ x = self.output_projection(x)
474
+
475
+ return {
476
+ "encoder_out": x,
477
+ "encoder_embedding": encoder_embedding,
478
+ "encoder_padding_mask": encoder_padding_mask,
479
+ "encoder_states": encoder_states,
480
+ "l_aux": l_aux,
481
+ "multiway_split_position": multiway_split_position,
482
+ }
vlmo/torchscale/architecture/encoder_decoder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch.nn as nn
5
+
6
+ from vlmo.torchscale.architecture.decoder import Decoder
7
+ from vlmo.torchscale.architecture.encoder import Encoder
8
+
9
+
10
+ class EncoderDecoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ args,
14
+ encoder_embed_tokens=None,
15
+ encoder_embed_positions=None,
16
+ decoder_embed_tokens=None,
17
+ decoder_embed_positions=None,
18
+ output_projection=None,
19
+ **kwargs
20
+ ):
21
+ super().__init__()
22
+ self.args = args
23
+ if args.share_all_embeddings:
24
+ args.share_decoder_input_output_embed = True
25
+
26
+ self.encoder = Encoder(args, encoder_embed_tokens, encoder_embed_positions, is_encoder_decoder=True, **kwargs)
27
+
28
+ if args.share_all_embeddings and decoder_embed_tokens is None:
29
+ decoder_embed_tokens = self.encoder.embed_tokens
30
+
31
+ self.decoder = Decoder(
32
+ args, decoder_embed_tokens, decoder_embed_positions, output_projection, is_encoder_decoder=True, **kwargs
33
+ )
34
+
35
+ def forward(self, src_tokens, prev_output_tokens, return_all_hiddens=False, features_only=False, **kwargs):
36
+ encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
37
+ decoder_out = self.decoder(
38
+ prev_output_tokens,
39
+ encoder_out=encoder_out,
40
+ features_only=features_only,
41
+ return_all_hiddens=return_all_hiddens,
42
+ )
43
+ return decoder_out
vlmo/torchscale/architecture/utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch.nn as nn
5
+
6
+ from vlmo.torchscale.component.multihead_attention import MultiheadAttention
7
+ from vlmo.torchscale.component.multiway_network import MultiwayNetwork
8
+
9
+
10
+ def init_bert_params(module):
11
+ def normal_(data):
12
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
13
+
14
+ if isinstance(module, nn.Linear):
15
+ normal_(module.weight.data)
16
+ if module.bias is not None:
17
+ module.bias.data.zero_()
18
+ if isinstance(module, nn.Embedding):
19
+ normal_(module.weight.data)
20
+ if module.padding_idx is not None:
21
+ module.weight.data[module.padding_idx].zero_()
22
+ if isinstance(module, MultiheadAttention):
23
+ if isinstance(module.q_proj, MultiwayNetwork):
24
+ normal_(module.q_proj.A.weight.data)
25
+ normal_(module.q_proj.B.weight.data)
26
+ normal_(module.k_proj.A.weight.data)
27
+ normal_(module.k_proj.B.weight.data)
28
+ normal_(module.v_proj.A.weight.data)
29
+ normal_(module.v_proj.B.weight.data)
30
+ else:
31
+ normal_(module.q_proj.weight.data)
32
+ normal_(module.k_proj.weight.data)
33
+ normal_(module.v_proj.weight.data)