Upload safetensors export
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -31
- README.md +91 -0
- config.json +37 -0
- configuration_m2_encoder.py +90 -0
- image_processing_m2_encoder.py +42 -0
- m2_encoder_0.4B.safetensors +3 -0
- modeling_m2_encoder.py +150 -0
- preprocessor_config.json +11 -0
- processing_m2_encoder.py +58 -0
- processor_config.json +6 -0
- requirements.txt +15 -0
- sp.model +3 -0
- tokenization_glm.py +307 -0
- tokenizer_config.json +17 -0
- upload_to_hub.py +31 -0
- vlmo/__init__.py +0 -0
- vlmo/__pycache__/__init__.cpython-311.pyc +0 -0
- vlmo/__pycache__/config.cpython-311.pyc +0 -0
- vlmo/config.py +165 -0
- vlmo/modules/__init__.py +1 -0
- vlmo/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- vlmo/modules/__pycache__/heads.cpython-311.pyc +0 -0
- vlmo/modules/__pycache__/modeling_utils.cpython-311.pyc +0 -0
- vlmo/modules/__pycache__/objectives.cpython-311.pyc +0 -0
- vlmo/modules/__pycache__/vlmo_module.cpython-311.pyc +0 -0
- vlmo/modules/__pycache__/vlmo_utils.cpython-311.pyc +0 -0
- vlmo/modules/heads.py +24 -0
- vlmo/modules/modeling_utils.py +179 -0
- vlmo/modules/multiway_transformer.py +396 -0
- vlmo/modules/objectives.py +12 -0
- vlmo/modules/vlmo_module.py +405 -0
- vlmo/modules/vlmo_utils.py +12 -0
- vlmo/tokenizer/__init__.py +6 -0
- vlmo/tokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
- vlmo/tokenizer/__pycache__/tokenization_glm.cpython-311.pyc +0 -0
- vlmo/tokenizer/sp.model +3 -0
- vlmo/tokenizer/tokenization_glm.py +307 -0
- vlmo/tokenizer/tokenizer_config.json +17 -0
- vlmo/torchscale/__init__.py +2 -0
- vlmo/torchscale/__pycache__/__init__.cpython-311.pyc +0 -0
- vlmo/torchscale/architecture/__init__.py +2 -0
- vlmo/torchscale/architecture/__pycache__/__init__.cpython-311.pyc +0 -0
- vlmo/torchscale/architecture/__pycache__/config.cpython-311.pyc +0 -0
- vlmo/torchscale/architecture/__pycache__/encoder.cpython-311.pyc +0 -0
- vlmo/torchscale/architecture/__pycache__/utils.cpython-311.pyc +0 -0
- vlmo/torchscale/architecture/config.py +197 -0
- vlmo/torchscale/architecture/decoder.py +428 -0
- vlmo/torchscale/architecture/encoder.py +482 -0
- vlmo/torchscale/architecture/encoder_decoder.py +43 -0
- vlmo/torchscale/architecture/utils.py +33 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,7 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 5 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
sp.model filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
vlmo/tokenizer/sp.model filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: transformers
|
| 4 |
+
pipeline_tag: zero-shot-image-classification
|
| 5 |
+
tags:
|
| 6 |
+
- multimodal
|
| 7 |
+
- image-text-retrieval
|
| 8 |
+
- bilingual
|
| 9 |
+
- chinese
|
| 10 |
+
- english
|
| 11 |
+
- vision-language
|
| 12 |
+
- custom-code
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# M2-Encoder-0.4B Hugging Face Export
|
| 16 |
+
|
| 17 |
+
This folder is generated from `Ant-Multi-Modal-Framework/prj/M2_Encoder` and is structured for direct upload to Hugging Face Hub.
|
| 18 |
+
|
| 19 |
+
## What This Repo Supports
|
| 20 |
+
|
| 21 |
+
- `AutoConfig.from_pretrained(..., trust_remote_code=True)`
|
| 22 |
+
- `AutoProcessor.from_pretrained(..., trust_remote_code=True)`
|
| 23 |
+
- `AutoModel.from_pretrained(..., trust_remote_code=True)`
|
| 24 |
+
- Zero-shot image-text retrieval and zero-shot image classification
|
| 25 |
+
|
| 26 |
+
## Required Weight File
|
| 27 |
+
|
| 28 |
+
Put the model weight file in the repo root with this exact filename:
|
| 29 |
+
|
| 30 |
+
`m2_encoder_0.4B.safetensors`
|
| 31 |
+
|
| 32 |
+
Large files should be tracked by Git LFS. A `.gitattributes` file is included for that.
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
### ModelScope-equivalent scoring
|
| 37 |
+
|
| 38 |
+
The original ModelScope sample computes probabilities from the raw normalized embeddings:
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
from transformers import AutoModel, AutoProcessor
|
| 42 |
+
|
| 43 |
+
repo_id = "your-name/your-m2-encoder-repo"
|
| 44 |
+
|
| 45 |
+
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
|
| 46 |
+
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
|
| 47 |
+
|
| 48 |
+
text_inputs = processor(
|
| 49 |
+
text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"],
|
| 50 |
+
return_tensors="pt",
|
| 51 |
+
)
|
| 52 |
+
image_inputs = processor(images="pokemon.jpeg", return_tensors="pt")
|
| 53 |
+
|
| 54 |
+
text_outputs = model(**text_inputs)
|
| 55 |
+
image_outputs = model(**image_inputs)
|
| 56 |
+
|
| 57 |
+
probs = (image_outputs.image_embeds @ text_outputs.text_embeds.t()).softmax(dim=-1)
|
| 58 |
+
print(probs)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### CLIP-style logits
|
| 62 |
+
|
| 63 |
+
`model(**inputs)` also returns `logits_per_image` and `logits_per_text`, which use the model's learned `logit_scale`.
|
| 64 |
+
Those logits are useful, but they are not the same computation as the raw dot product in the original ModelScope demo.
|
| 65 |
+
|
| 66 |
+
## Upload
|
| 67 |
+
|
| 68 |
+
Option 1:
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
python upload_to_hub.py --repo-id your-name/your-m2-encoder-repo
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Option 2:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
huggingface-cli login
|
| 78 |
+
git init
|
| 79 |
+
git lfs install
|
| 80 |
+
git remote add origin https://huggingface.co/your-name/your-m2-encoder-repo
|
| 81 |
+
git add .
|
| 82 |
+
git commit -m "Upload M2-Encoder HF export"
|
| 83 |
+
git push origin main
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## Notes
|
| 87 |
+
|
| 88 |
+
- This is a Hugging Face remote-code adapter, not a native `transformers` implementation.
|
| 89 |
+
- The underlying model code still comes from the official M2-Encoder repo.
|
| 90 |
+
- You need `trust_remote_code=True`.
|
| 91 |
+
- The weights are not bundled by default when exporting unless you pass `--checkpoint`.
|
config.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"beit_version": "base",
|
| 3 |
+
"encoder_embed_dim": 768,
|
| 4 |
+
"out_embed_dim": 768,
|
| 5 |
+
"image_size": 224,
|
| 6 |
+
"visual_mask_size": 14,
|
| 7 |
+
"loss_names": {
|
| 8 |
+
"itc": 1
|
| 9 |
+
},
|
| 10 |
+
"encoder_layers": 9,
|
| 11 |
+
"beit3_vl_layers": 3,
|
| 12 |
+
"tokenizer_type": "GLMChineseTokenizer",
|
| 13 |
+
"tokenizer": ".",
|
| 14 |
+
"vocab_size": 115244,
|
| 15 |
+
"whole_word_masking": true,
|
| 16 |
+
"precision": 32,
|
| 17 |
+
"test_only": true,
|
| 18 |
+
"flash_attn": false,
|
| 19 |
+
"modelscope": {
|
| 20 |
+
"model_id": "M2Cognition/M2-Encoder"
|
| 21 |
+
},
|
| 22 |
+
"model_file": "m2_encoder_0.4B.safetensors",
|
| 23 |
+
"model_type": "m2_encoder",
|
| 24 |
+
"architectures": [
|
| 25 |
+
"M2EncoderModel"
|
| 26 |
+
],
|
| 27 |
+
"processor_class": "M2EncoderProcessor",
|
| 28 |
+
"auto_map": {
|
| 29 |
+
"AutoConfig": "configuration_m2_encoder.M2EncoderConfig",
|
| 30 |
+
"AutoModel": "modeling_m2_encoder.M2EncoderModel",
|
| 31 |
+
"AutoProcessor": "processing_m2_encoder.M2EncoderProcessor",
|
| 32 |
+
"AutoTokenizer": [
|
| 33 |
+
"tokenization_glm.GLMChineseTokenizer",
|
| 34 |
+
null
|
| 35 |
+
]
|
| 36 |
+
}
|
| 37 |
+
}
|
configuration_m2_encoder.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from transformers import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class M2EncoderConfig(PretrainedConfig):
|
| 9 |
+
model_type = "m2_encoder"
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
loss_names=None,
|
| 14 |
+
beit_version="large",
|
| 15 |
+
encoder_embed_dim=1024,
|
| 16 |
+
out_embed_dim=1024,
|
| 17 |
+
encoder_layers=21,
|
| 18 |
+
beit3_vl_layers=3,
|
| 19 |
+
image_size=224,
|
| 20 |
+
visual_mask_size=14,
|
| 21 |
+
tokenizer_type="GLMChineseTokenizer",
|
| 22 |
+
tokenizer=".",
|
| 23 |
+
vocab_size=115244,
|
| 24 |
+
whole_word_masking=False,
|
| 25 |
+
precision=32,
|
| 26 |
+
test_only=True,
|
| 27 |
+
flash_attn=False,
|
| 28 |
+
model_file="m2_encoder_1B.ckpt",
|
| 29 |
+
architectures=None,
|
| 30 |
+
auto_map=None,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
super().__init__(**kwargs)
|
| 34 |
+
self.loss_names = loss_names or {"itc": 1}
|
| 35 |
+
self.beit_version = beit_version
|
| 36 |
+
self.encoder_embed_dim = encoder_embed_dim
|
| 37 |
+
self.out_embed_dim = out_embed_dim
|
| 38 |
+
self.encoder_layers = encoder_layers
|
| 39 |
+
self.beit3_vl_layers = beit3_vl_layers
|
| 40 |
+
self.image_size = image_size
|
| 41 |
+
self.visual_mask_size = visual_mask_size
|
| 42 |
+
self.tokenizer_type = tokenizer_type
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
self.vocab_size = vocab_size
|
| 45 |
+
self.whole_word_masking = whole_word_masking
|
| 46 |
+
self.precision = precision
|
| 47 |
+
self.test_only = test_only
|
| 48 |
+
self.flash_attn = flash_attn
|
| 49 |
+
self.model_file = model_file
|
| 50 |
+
self.architectures = architectures or ["M2EncoderModel"]
|
| 51 |
+
self.auto_map = auto_map or {
|
| 52 |
+
"AutoConfig": "configuration_m2_encoder.M2EncoderConfig",
|
| 53 |
+
"AutoModel": "modeling_m2_encoder.M2EncoderModel",
|
| 54 |
+
"AutoProcessor": "processing_m2_encoder.M2EncoderProcessor",
|
| 55 |
+
"AutoTokenizer": ["tokenization_glm.GLMChineseTokenizer", None],
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def from_encoder_json(cls, config_path: str, **kwargs) -> "M2EncoderConfig":
|
| 60 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 61 |
+
data = json.load(f)
|
| 62 |
+
data.update(kwargs)
|
| 63 |
+
return cls(**data)
|
| 64 |
+
|
| 65 |
+
def to_vlmo_overrides(self, model_dir: str) -> Dict[str, Any]:
|
| 66 |
+
return {
|
| 67 |
+
"loss_names": self.loss_names,
|
| 68 |
+
"beit_version": self.beit_version,
|
| 69 |
+
"encoder_embed_dim": self.encoder_embed_dim,
|
| 70 |
+
"out_embed_dim": self.out_embed_dim,
|
| 71 |
+
"encoder_layers": self.encoder_layers,
|
| 72 |
+
"beit3_vl_layers": self.beit3_vl_layers,
|
| 73 |
+
"image_size": self.image_size,
|
| 74 |
+
"visual_mask_size": self.visual_mask_size,
|
| 75 |
+
"tokenizer_type": self.tokenizer_type,
|
| 76 |
+
"tokenizer": self._resolve_tokenizer_dir(model_dir),
|
| 77 |
+
"vocab_size": self.vocab_size,
|
| 78 |
+
"whole_word_masking": self.whole_word_masking,
|
| 79 |
+
"precision": self.precision,
|
| 80 |
+
"test_only": self.test_only,
|
| 81 |
+
"flash_attn": self.flash_attn,
|
| 82 |
+
"load_path": os.path.join(model_dir, self.model_file),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def _resolve_tokenizer_dir(self, model_dir: str) -> str:
|
| 86 |
+
if os.path.isabs(self.tokenizer):
|
| 87 |
+
return self.tokenizer
|
| 88 |
+
if self.tokenizer in (".", "./", ""):
|
| 89 |
+
return model_dir
|
| 90 |
+
return os.path.join(model_dir, self.tokenizer)
|
image_processing_m2_encoder.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
| 7 |
+
from transformers.image_utils import ImageFeatureExtractionMixin
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class M2EncoderImageProcessor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
| 11 |
+
model_input_names = ["pixel_values"]
|
| 12 |
+
|
| 13 |
+
def __init__(self, size: int = 224, resample: int = Image.BICUBIC, **kwargs):
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
if isinstance(size, dict):
|
| 16 |
+
size = int(size.get("height") or size.get("width"))
|
| 17 |
+
self.size = size
|
| 18 |
+
self.resample = resample
|
| 19 |
+
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
images,
|
| 23 |
+
return_tensors: Optional[Union[str, torch.Tensor]] = None,
|
| 24 |
+
**kwargs,
|
| 25 |
+
) -> BatchFeature:
|
| 26 |
+
if not isinstance(images, (list, tuple)):
|
| 27 |
+
images = [images]
|
| 28 |
+
|
| 29 |
+
pixel_values: List[np.ndarray] = []
|
| 30 |
+
for image in images:
|
| 31 |
+
if not isinstance(image, Image.Image):
|
| 32 |
+
image = Image.fromarray(np.asarray(image))
|
| 33 |
+
image = image.convert("RGB")
|
| 34 |
+
image = image.resize((self.size, self.size), resample=self.resample)
|
| 35 |
+
array = np.asarray(image, dtype=np.float32) / 255.0
|
| 36 |
+
array = np.transpose(array, (2, 0, 1))
|
| 37 |
+
pixel_values.append(array)
|
| 38 |
+
|
| 39 |
+
return BatchFeature(
|
| 40 |
+
data={"pixel_values": pixel_values},
|
| 41 |
+
tensor_type=return_tensors,
|
| 42 |
+
)
|
m2_encoder_0.4B.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a75931d296881382b0c8282dafc63cbe2d1cacbb315c6d32572acbdaa2d9203e
|
| 3 |
+
size 1053218384
|
modeling_m2_encoder.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import importlib
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
from transformers import PreTrainedModel
|
| 11 |
+
from transformers.modeling_outputs import ModelOutput
|
| 12 |
+
|
| 13 |
+
from .configuration_m2_encoder import M2EncoderConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class M2EncoderOutput(ModelOutput):
|
| 18 |
+
loss: Optional[torch.FloatTensor] = None
|
| 19 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
| 20 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 21 |
+
logits_per_image: Optional[torch.FloatTensor] = None
|
| 22 |
+
logits_per_text: Optional[torch.FloatTensor] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class M2EncoderModel(PreTrainedModel):
|
| 26 |
+
config_class = M2EncoderConfig
|
| 27 |
+
base_model_prefix = "m2_encoder"
|
| 28 |
+
main_input_name = "pixel_values"
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: M2EncoderConfig):
|
| 31 |
+
super().__init__(config)
|
| 32 |
+
model_dir = getattr(config, "_model_dir", None)
|
| 33 |
+
if model_dir is None:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
"M2EncoderConfig is missing `_model_dir`. Use "
|
| 36 |
+
"`M2EncoderModel.from_pretrained(...)` so the checkpoint path can be resolved."
|
| 37 |
+
)
|
| 38 |
+
if model_dir not in sys.path:
|
| 39 |
+
sys.path.insert(0, model_dir)
|
| 40 |
+
|
| 41 |
+
vlmo_default_config = importlib.import_module("vlmo.config").config
|
| 42 |
+
VLMo = importlib.import_module("vlmo.modules").VLMo
|
| 43 |
+
|
| 44 |
+
vlmo_config = vlmo_default_config()
|
| 45 |
+
vlmo_config.update(config.to_vlmo_overrides(model_dir))
|
| 46 |
+
load_path = vlmo_config["load_path"]
|
| 47 |
+
use_safetensors = load_path.endswith(".safetensors")
|
| 48 |
+
if use_safetensors:
|
| 49 |
+
vlmo_config["load_path"] = ""
|
| 50 |
+
|
| 51 |
+
if vlmo_config["flash_attn"]:
|
| 52 |
+
patch_torch_scale_with_flash_attn = importlib.import_module(
|
| 53 |
+
"vlmo.utils.patch_utils"
|
| 54 |
+
).patch_torch_scale_with_flash_attn
|
| 55 |
+
patch_torch_scale_with_flash_attn()
|
| 56 |
+
|
| 57 |
+
self.model = VLMo(vlmo_config)
|
| 58 |
+
if use_safetensors:
|
| 59 |
+
state_dict = load_file(load_path)
|
| 60 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def from_pretrained(
|
| 64 |
+
cls,
|
| 65 |
+
pretrained_model_name_or_path,
|
| 66 |
+
*model_args,
|
| 67 |
+
config: Optional[M2EncoderConfig] = None,
|
| 68 |
+
**kwargs,
|
| 69 |
+
):
|
| 70 |
+
model_dir = pretrained_model_name_or_path
|
| 71 |
+
if not os.path.isdir(model_dir):
|
| 72 |
+
model_dir = snapshot_download(repo_id=pretrained_model_name_or_path)
|
| 73 |
+
|
| 74 |
+
if config is None:
|
| 75 |
+
config = M2EncoderConfig.from_pretrained(model_dir, **kwargs)
|
| 76 |
+
checkpoint_path = os.path.join(
|
| 77 |
+
model_dir,
|
| 78 |
+
kwargs.pop("m2_checkpoint_name", config.model_file),
|
| 79 |
+
)
|
| 80 |
+
if not os.path.exists(checkpoint_path):
|
| 81 |
+
raise FileNotFoundError(
|
| 82 |
+
f"Missing M2-Encoder checkpoint: {checkpoint_path}"
|
| 83 |
+
)
|
| 84 |
+
config._model_dir = model_dir
|
| 85 |
+
return cls(config, *model_args)
|
| 86 |
+
|
| 87 |
+
def get_text_features(
|
| 88 |
+
self,
|
| 89 |
+
input_ids: torch.LongTensor,
|
| 90 |
+
attention_mask: torch.LongTensor,
|
| 91 |
+
) -> torch.FloatTensor:
|
| 92 |
+
outputs = self.model.infer_text(
|
| 93 |
+
{
|
| 94 |
+
"text_ids": input_ids,
|
| 95 |
+
"text_masks": attention_mask,
|
| 96 |
+
"text_labels": None,
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
return outputs["cls_vlffn_feats"]
|
| 100 |
+
|
| 101 |
+
def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
|
| 102 |
+
outputs = self.model.infer_image({"image": [pixel_values]})
|
| 103 |
+
return outputs["cls_vlffn_feats"]
|
| 104 |
+
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 108 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 109 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 110 |
+
return_dict: Optional[bool] = True,
|
| 111 |
+
**kwargs,
|
| 112 |
+
) -> Union[M2EncoderOutput, Tuple[torch.FloatTensor, ...]]:
|
| 113 |
+
text_embeds = None
|
| 114 |
+
image_embeds = None
|
| 115 |
+
|
| 116 |
+
if input_ids is not None:
|
| 117 |
+
if attention_mask is None:
|
| 118 |
+
attention_mask = torch.ones_like(input_ids)
|
| 119 |
+
text_embeds = self.get_text_features(
|
| 120 |
+
input_ids=input_ids, attention_mask=attention_mask
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if pixel_values is not None:
|
| 124 |
+
image_embeds = self.get_image_features(pixel_values=pixel_values)
|
| 125 |
+
|
| 126 |
+
logits_per_image = None
|
| 127 |
+
logits_per_text = None
|
| 128 |
+
if image_embeds is not None and text_embeds is not None:
|
| 129 |
+
logit_scale = self.model.logit_scale.exp()
|
| 130 |
+
logits_per_image = logit_scale * image_embeds @ text_embeds.t()
|
| 131 |
+
logits_per_text = logits_per_image.t()
|
| 132 |
+
|
| 133 |
+
if not return_dict:
|
| 134 |
+
return tuple(
|
| 135 |
+
value
|
| 136 |
+
for value in (
|
| 137 |
+
text_embeds,
|
| 138 |
+
image_embeds,
|
| 139 |
+
logits_per_image,
|
| 140 |
+
logits_per_text,
|
| 141 |
+
)
|
| 142 |
+
if value is not None
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
return M2EncoderOutput(
|
| 146 |
+
text_embeds=text_embeds,
|
| 147 |
+
image_embeds=image_embeds,
|
| 148 |
+
logits_per_image=logits_per_image,
|
| 149 |
+
logits_per_text=logits_per_text,
|
| 150 |
+
)
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"processor_class": "M2EncoderProcessor",
|
| 3 |
+
"image_processor_type": "M2EncoderImageProcessor",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoProcessor": "processing_m2_encoder.M2EncoderProcessor"
|
| 6 |
+
},
|
| 7 |
+
"size": {
|
| 8 |
+
"height": 224,
|
| 9 |
+
"width": 224
|
| 10 |
+
}
|
| 11 |
+
}
|
processing_m2_encoder.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from transformers.processing_utils import ProcessorMixin
|
| 5 |
+
|
| 6 |
+
from .image_processing_m2_encoder import M2EncoderImageProcessor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class M2EncoderProcessor(ProcessorMixin):
|
| 10 |
+
attributes = ["image_processor", "tokenizer"]
|
| 11 |
+
image_processor_class = "M2EncoderImageProcessor"
|
| 12 |
+
tokenizer_class = ("GLMChineseTokenizer", None)
|
| 13 |
+
|
| 14 |
+
def __init__(self, image_processor, tokenizer):
|
| 15 |
+
self.image_processor = image_processor
|
| 16 |
+
self.tokenizer = tokenizer
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 20 |
+
trust_remote_code = kwargs.pop("trust_remote_code", True)
|
| 21 |
+
image_processor = M2EncoderImageProcessor.from_pretrained(
|
| 22 |
+
pretrained_model_name_or_path, **kwargs
|
| 23 |
+
)
|
| 24 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 25 |
+
pretrained_model_name_or_path,
|
| 26 |
+
trust_remote_code=trust_remote_code,
|
| 27 |
+
**kwargs,
|
| 28 |
+
)
|
| 29 |
+
return cls(image_processor=image_processor, tokenizer=tokenizer)
|
| 30 |
+
|
| 31 |
+
def __call__(
|
| 32 |
+
self,
|
| 33 |
+
text=None,
|
| 34 |
+
images=None,
|
| 35 |
+
padding="max_length",
|
| 36 |
+
truncation=True,
|
| 37 |
+
max_length: Optional[int] = 52,
|
| 38 |
+
return_tensors=None,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
encoding = {}
|
| 42 |
+
if text is not None:
|
| 43 |
+
encoding.update(
|
| 44 |
+
self.tokenizer(
|
| 45 |
+
text,
|
| 46 |
+
padding=padding,
|
| 47 |
+
truncation=truncation,
|
| 48 |
+
max_length=max_length,
|
| 49 |
+
return_special_tokens_mask=True,
|
| 50 |
+
return_tensors=return_tensors,
|
| 51 |
+
**kwargs,
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
if images is not None:
|
| 55 |
+
encoding.update(
|
| 56 |
+
self.image_processor(images, return_tensors=return_tensors, **kwargs)
|
| 57 |
+
)
|
| 58 |
+
return encoding
|
processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"processor_class": "M2EncoderProcessor",
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_m2_encoder.M2EncoderProcessor"
|
| 5 |
+
}
|
| 6 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
pytorch_lightning<=2.0.8
|
| 3 |
+
transformers
|
| 4 |
+
safetensors
|
| 5 |
+
Pillow
|
| 6 |
+
tqdm
|
| 7 |
+
einops
|
| 8 |
+
sacred
|
| 9 |
+
timm
|
| 10 |
+
torchvision
|
| 11 |
+
fairscale
|
| 12 |
+
numpy
|
| 13 |
+
opencv-python
|
| 14 |
+
sentencepiece
|
| 15 |
+
huggingface_hub
|
sp.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7fe3bcc8d284fcb782691411e8b6fd4f45d7245565b094de6ab795e66bcd32f
|
| 3 |
+
size 2270960
|
tokenization_glm.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from shutil import copyfile
|
| 3 |
+
from typing import Optional, Tuple, List, Union
|
| 4 |
+
|
| 5 |
+
import sentencepiece as spm
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import PreTrainedTokenizer
|
| 8 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
| 9 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GLMBatchEncoding(BatchEncoding):
|
| 16 |
+
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
|
| 17 |
+
"""
|
| 18 |
+
Send all values to device by calling `v.to(device)` (PyTorch only).
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
device (`str` or `torch.device`): The device to put the tensors on.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
[`BatchEncoding`]: The same instance after modification.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
| 28 |
+
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
| 29 |
+
# into a HalfTensor
|
| 30 |
+
if isinstance(device, str) or isinstance(device, int):
|
| 31 |
+
#if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
| 32 |
+
self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}
|
| 33 |
+
else:
|
| 34 |
+
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
| 35 |
+
return self
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GLMTokenizerMixin:
|
| 39 |
+
@property
|
| 40 |
+
def sop_token(self) -> Optional[str]:
|
| 41 |
+
return "<|startofpiece|>"
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def sop_token_id(self) -> Optional[int]:
|
| 45 |
+
"""
|
| 46 |
+
`Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.
|
| 47 |
+
"""
|
| 48 |
+
return self.convert_tokens_to_ids(self.sop_token)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def eop_token(self) -> Optional[str]:
|
| 52 |
+
return "<|endofpiece|>"
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def eop_token_id(self) -> Optional[int]:
|
| 56 |
+
"""
|
| 57 |
+
`Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.
|
| 58 |
+
"""
|
| 59 |
+
return self.convert_tokens_to_ids(self.eop_token)
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def gmask_token_id(self) -> int:
|
| 63 |
+
return self.convert_tokens_to_ids("[gMASK]")
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def smask_token_id(self) -> int:
|
| 67 |
+
return self.convert_tokens_to_ids("[sMASK]")
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def mask_token_ids(self):
|
| 71 |
+
return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
|
| 72 |
+
|
| 73 |
+
def _build_input_for_multiple_choice(self, context, choices):
|
| 74 |
+
context_id = context["input_ids"]
|
| 75 |
+
if torch.is_tensor(context_id):
|
| 76 |
+
context_id = context_id.tolist()
|
| 77 |
+
|
| 78 |
+
division = len(context_id)
|
| 79 |
+
mask_position = context_id.index(self.mask_token_id)
|
| 80 |
+
|
| 81 |
+
token = torch.tensor(context_id, dtype=torch.long)
|
| 82 |
+
attention_mask = [context["attention_mask"].expand(division, -1)]
|
| 83 |
+
position_id = torch.arange(division, dtype=torch.long)
|
| 84 |
+
block_position_id = torch.zeros(division, dtype=torch.long)
|
| 85 |
+
|
| 86 |
+
choice_ids, choice_indices = [], []
|
| 87 |
+
|
| 88 |
+
for choice_str in choices:
|
| 89 |
+
choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],
|
| 90 |
+
dtype=torch.long)
|
| 91 |
+
choice_ids.append(choice)
|
| 92 |
+
choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))
|
| 93 |
+
attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))
|
| 94 |
+
|
| 95 |
+
token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))
|
| 96 |
+
position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))
|
| 97 |
+
block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))
|
| 98 |
+
|
| 99 |
+
attention_mask = torch.block_diag(*attention_mask)
|
| 100 |
+
attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
"input_ids": token,
|
| 104 |
+
"position_ids": torch.stack((position_id, block_position_id)),
|
| 105 |
+
"attention_mask": attention_mask,
|
| 106 |
+
"choice_ids": choice_ids,
|
| 107 |
+
"choice_indices": choice_indices
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
|
| 111 |
+
pad_length = max_seq_length - len(tokens)
|
| 112 |
+
attention_mask = torch.nn.functional.pad(
|
| 113 |
+
attention_mask,
|
| 114 |
+
(0, pad_length, 0, pad_length),
|
| 115 |
+
mode="constant",
|
| 116 |
+
value=0,
|
| 117 |
+
)
|
| 118 |
+
tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))
|
| 119 |
+
position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)
|
| 120 |
+
return tokens, position_ids, attention_mask
|
| 121 |
+
|
| 122 |
+
def _collate(self, samples):
|
| 123 |
+
TILE = 1
|
| 124 |
+
length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
|
| 125 |
+
|
| 126 |
+
token_batch, position_id_batch, attention_mask_batch = [], [], []
|
| 127 |
+
choices_batch, choice_target_ids_batch = [], []
|
| 128 |
+
|
| 129 |
+
for sample in samples:
|
| 130 |
+
token, position_id, attention_mask = self._pad_batch(
|
| 131 |
+
sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
|
| 132 |
+
)
|
| 133 |
+
token_batch.append(token)
|
| 134 |
+
position_id_batch.append(position_id)
|
| 135 |
+
attention_mask_batch.append(attention_mask)
|
| 136 |
+
choices_batch.append(sample["choice_ids"])
|
| 137 |
+
choice_target_ids_batch.append(sample["choice_indices"])
|
| 138 |
+
return {
|
| 139 |
+
"input_ids": torch.stack(token_batch),
|
| 140 |
+
"position_ids": torch.stack(position_id_batch),
|
| 141 |
+
"attention_mask": torch.stack(attention_mask_batch).unsqueeze(1),
|
| 142 |
+
"choice_ids": choices_batch,
|
| 143 |
+
"choice_indices": choice_target_ids_batch,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):
|
| 147 |
+
samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))]
|
| 148 |
+
samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in
|
| 149 |
+
zip(samples, choices)]
|
| 150 |
+
inputs = self._collate(samples)
|
| 151 |
+
return GLMBatchEncoding(inputs)
|
| 152 |
+
|
| 153 |
+
def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):
|
| 154 |
+
mask_ids = self.mask_token_ids
|
| 155 |
+
input_ids = model_input.input_ids
|
| 156 |
+
batch_size, seq_length = input_ids.shape[:2]
|
| 157 |
+
position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
|
| 158 |
+
position_ids, block_position_ids = [], []
|
| 159 |
+
labels = None
|
| 160 |
+
if targets is not None:
|
| 161 |
+
is_batched = isinstance(targets, (list, tuple))
|
| 162 |
+
targets = self(targets, add_special_tokens=False, padding=False).input_ids
|
| 163 |
+
if not is_batched:
|
| 164 |
+
targets = [targets]
|
| 165 |
+
assert len(targets) == len(input_ids)
|
| 166 |
+
targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
|
| 167 |
+
if not padding:
|
| 168 |
+
max_gen_length = max(map(len, targets))
|
| 169 |
+
targets = [[self.sop_token_id] + target for target in targets]
|
| 170 |
+
labels = [target[1:] for target in targets]
|
| 171 |
+
targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]
|
| 172 |
+
labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
|
| 173 |
+
targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)
|
| 174 |
+
labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)
|
| 175 |
+
labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
|
| 176 |
+
for i in range(batch_size):
|
| 177 |
+
mask_positions = []
|
| 178 |
+
for mask_id in mask_ids:
|
| 179 |
+
mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()
|
| 180 |
+
if not mask_positions:
|
| 181 |
+
raise ValueError("Cannot find mask token in the input")
|
| 182 |
+
mask_positions.sort()
|
| 183 |
+
mask_pos = mask_positions[0]
|
| 184 |
+
position_ids.append(position_id + [mask_pos] * max_gen_length)
|
| 185 |
+
block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))
|
| 186 |
+
position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)
|
| 187 |
+
block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)
|
| 188 |
+
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
| 189 |
+
attention_mask = model_input.attention_mask
|
| 190 |
+
attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)
|
| 191 |
+
generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),
|
| 192 |
+
torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],
|
| 193 |
+
dim=0).unsqueeze(0).expand(batch_size, -1, -1)
|
| 194 |
+
attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
|
| 195 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 196 |
+
if targets is None:
|
| 197 |
+
input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
|
| 198 |
+
else:
|
| 199 |
+
input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)
|
| 200 |
+
batch = {"input_ids": input_ids, "position_ids": position_ids}
|
| 201 |
+
if labels is None:
|
| 202 |
+
batch["generation_attention_mask"] = attention_mask
|
| 203 |
+
else:
|
| 204 |
+
batch["attention_mask"] = attention_mask
|
| 205 |
+
batch["labels"] = labels
|
| 206 |
+
return BatchEncoding(batch)
|
| 207 |
+
|
| 208 |
+
def encode_whitespaces(content):
|
| 209 |
+
for i in range(10, 1, -1):
|
| 210 |
+
content = content.replace(' '*i, f'<|blank_{i}|>')
|
| 211 |
+
return content
|
| 212 |
+
|
| 213 |
+
def decode_whitespaces(content):
|
| 214 |
+
for i in range(10, 1, -1):
|
| 215 |
+
content = content.replace(f'<|blank_{i}|>', ' '*i)
|
| 216 |
+
return content
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
|
| 220 |
+
vocab_files_names = {"vocab_file": "sp.model"}
|
| 221 |
+
truncation_side: str = "left"
|
| 222 |
+
|
| 223 |
+
def __init__(self, vocab_file, **kwargs):
|
| 224 |
+
self.vocab_file = vocab_file
|
| 225 |
+
self.sp_model = spm.SentencePieceProcessor()
|
| 226 |
+
self.sp_model.Load(vocab_file)
|
| 227 |
+
super().__init__(**kwargs)
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def vocab_size(self):
|
| 231 |
+
return len(self.sp_model)
|
| 232 |
+
|
| 233 |
+
def get_vocab(self):
|
| 234 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 235 |
+
vocab.update(self.added_tokens_encoder)
|
| 236 |
+
return vocab
|
| 237 |
+
|
| 238 |
+
def _tokenize(self, text, **kwargs):
|
| 239 |
+
text = encode_whitespaces(text)
|
| 240 |
+
return self.sp_model.EncodeAsPieces(text)
|
| 241 |
+
#return self.sp_model.EncodeAsPieces(text, out_type=str)
|
| 242 |
+
|
| 243 |
+
def _convert_token_to_id(self, token):
|
| 244 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 245 |
+
return self.sp_model.PieceToId(token)
|
| 246 |
+
|
| 247 |
+
def _convert_id_to_token(self, index):
|
| 248 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 249 |
+
return self.sp_model.IdToPiece(index)
|
| 250 |
+
|
| 251 |
+
def convert_tokens_to_string(self, tokens):
|
| 252 |
+
res = self.sp_model.DecodeIds(tokens)
|
| 253 |
+
return decode_whitespaces(res)
|
| 254 |
+
|
| 255 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 256 |
+
if not os.path.isdir(save_directory):
|
| 257 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 258 |
+
return
|
| 259 |
+
out_vocab_file = os.path.join(
|
| 260 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 264 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 265 |
+
elif not os.path.isfile(self.vocab_file):
|
| 266 |
+
with open(out_vocab_file, "wb") as fi:
|
| 267 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 268 |
+
fi.write(content_spiece_model)
|
| 269 |
+
|
| 270 |
+
return (out_vocab_file,)
|
| 271 |
+
|
| 272 |
+
def build_inputs_with_special_tokens(
|
| 273 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 274 |
+
) -> List[int]:
|
| 275 |
+
"""
|
| 276 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 277 |
+
adding special tokens. A BERT sequence has the following format:
|
| 278 |
+
|
| 279 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 280 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
token_ids_0 (:obj:`List[int]`):
|
| 284 |
+
List of IDs to which the special tokens will be added.
|
| 285 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
| 286 |
+
Optional second list of IDs for sequence pairs.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 290 |
+
"""
|
| 291 |
+
assert token_ids_1 is None
|
| 292 |
+
cls = [self.cls_token_id]
|
| 293 |
+
eos = [self.eos_token_id]
|
| 294 |
+
return cls + token_ids_0 + eos
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class GLMTokenizer:
|
| 298 |
+
@classmethod
|
| 299 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 300 |
+
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
| 301 |
+
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
| 302 |
+
|
| 303 |
+
if config_tokenizer_class == "GLMChineseTokenizer":
|
| 304 |
+
tokenizer_class = GLMChineseTokenizer
|
| 305 |
+
else:
|
| 306 |
+
raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class)
|
| 307 |
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name_or_path": "THUDM/glm-10b-chinese",
|
| 3 |
+
"eos_token": "<|endoftext|>",
|
| 4 |
+
"pad_token": "<|endoftext|>",
|
| 5 |
+
"cls_token": "[CLS]",
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"unk_token": "[UNK]",
|
| 8 |
+
"add_prefix_space": false,
|
| 9 |
+
"tokenizer_class": "GLMChineseTokenizer",
|
| 10 |
+
"use_fast": false,
|
| 11 |
+
"auto_map": {
|
| 12 |
+
"AutoTokenizer": [
|
| 13 |
+
"tokenization_glm.GLMChineseTokenizer",
|
| 14 |
+
null
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
}
|
upload_to_hub.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import HfApi, create_repo, upload_folder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
parser.add_argument("--repo-id", required=True, help="Hugging Face repo id, e.g. user/M2-Encoder-Large")
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--folder",
|
| 12 |
+
default=str(Path(__file__).resolve().parent),
|
| 13 |
+
help="Folder to upload. Defaults to this script's directory.",
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument("--private", action="store_true", help="Create the repo as private.")
|
| 16 |
+
parser.add_argument("--commit-message", default="Upload M2-Encoder HF export")
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
folder = Path(args.folder).resolve()
|
| 20 |
+
api = HfApi()
|
| 21 |
+
create_repo(repo_id=args.repo_id, private=args.private, exist_ok=True)
|
| 22 |
+
upload_folder(
|
| 23 |
+
repo_id=args.repo_id,
|
| 24 |
+
folder_path=str(folder),
|
| 25 |
+
commit_message=args.commit_message,
|
| 26 |
+
)
|
| 27 |
+
print(f"Uploaded {folder} -> {args.repo_id}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
main()
|
vlmo/__init__.py
ADDED
|
File without changes
|
vlmo/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
vlmo/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (3.82 kB). View file
|
|
|
vlmo/config.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sacred import Experiment
|
| 2 |
+
|
| 3 |
+
ex = Experiment("VLMo")
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _loss_names(d):
|
| 7 |
+
ret = {
|
| 8 |
+
"itm": 0, # image-text matching loss
|
| 9 |
+
"itc": 0, # image-text contrastive loss
|
| 10 |
+
"caption": 0, # image captioning loss
|
| 11 |
+
"mvlm": 0, # masked language modeling loss
|
| 12 |
+
"textmlm": 0, # text-only masked language modeling
|
| 13 |
+
"imagemlm": 0, # image-only masked language modeling
|
| 14 |
+
"vqa": 0,
|
| 15 |
+
"nlvr2": 0,
|
| 16 |
+
"irtr": 0, # retrieval task ft
|
| 17 |
+
}
|
| 18 |
+
ret.update(d)
|
| 19 |
+
return ret
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@ex.config
|
| 23 |
+
def config():
|
| 24 |
+
exp_name = "vlmo"
|
| 25 |
+
seed = 1
|
| 26 |
+
datasets = ["coco", "vg", "sbu", "gcc"] # dataset name, the definition can refer to: vlmo/datamodules/__init__.py # noqa
|
| 27 |
+
loss_names = _loss_names({"itm": 0, "itc": 0, "mvlm": 0}) # training loss
|
| 28 |
+
batch_size = 1024 # this is a desired batch size; pl trainer will accumulate gradients.
|
| 29 |
+
|
| 30 |
+
# BEiT-v3 setting
|
| 31 |
+
encoder_layers = 12 # the layer number of backbone
|
| 32 |
+
encoder_embed_dim = 768 # the hidden size of tokenizer
|
| 33 |
+
out_embed_dim = 768 # the hidden size of output embedding
|
| 34 |
+
beit_version = "base" # model size: base(0.4B)|large(1B)|huge(10B)
|
| 35 |
+
beit3_vl_layers = 3 # the layer number of vl_backbone
|
| 36 |
+
deepnorm_init = True # init method
|
| 37 |
+
share_layer = False # if share the weight between layer within backbone
|
| 38 |
+
share_attn = False # if share the attention weight of different layer
|
| 39 |
+
one_attn = False # if share the attention weight of vision and language
|
| 40 |
+
|
| 41 |
+
# Image setting
|
| 42 |
+
train_transform_keys = ["square_transform_randaug"] # train transform: refer to vlmo/transforms/__init__.py
|
| 43 |
+
val_transform_keys = ["square_transform"] # test transform: refer to refer to vlmo/transforms/__init__.py
|
| 44 |
+
image_size = 224 # image size
|
| 45 |
+
reclip_image_size = None # reclip image size
|
| 46 |
+
patch_size = 16 # patch size
|
| 47 |
+
draw_false_image = 0 # if get negative image
|
| 48 |
+
image_only = False # only input image
|
| 49 |
+
text_only = False # # only input text
|
| 50 |
+
|
| 51 |
+
# Video setting, video_num_frm is not None means video input
|
| 52 |
+
video_num_frm = None
|
| 53 |
+
|
| 54 |
+
# Visual tokenizer setting based on beit2
|
| 55 |
+
tokenizer_model = "beit2_visual_tokenizer"
|
| 56 |
+
codebook_size = 8192
|
| 57 |
+
codebook_dim = 32
|
| 58 |
+
visual_mask_size = 14
|
| 59 |
+
visual_mask_num = 80
|
| 60 |
+
|
| 61 |
+
# Text Setting
|
| 62 |
+
lang = 'cn' # language for zero-shot imagenet testing: cn|en
|
| 63 |
+
vqav2_label_size = 3129
|
| 64 |
+
max_text_len = 52 # the number of characters
|
| 65 |
+
max_text_len_of_initckpt = 196
|
| 66 |
+
tokenizer_type = "BertTokenizer" # Chinese text
|
| 67 |
+
vocab_size = 21128
|
| 68 |
+
tokenizer = "./vocab.txt"
|
| 69 |
+
whole_word_masking = True
|
| 70 |
+
mlm_prob = 0.15 # language mask ratio
|
| 71 |
+
draw_false_text = 0
|
| 72 |
+
mvlm_prob = 0.50 # vision-langurage mlm task
|
| 73 |
+
mask_ratio = 0 # flip: mask ratio for image
|
| 74 |
+
|
| 75 |
+
# cap setting
|
| 76 |
+
cap_onlytext = False # default caption image to text
|
| 77 |
+
|
| 78 |
+
# imagemlm setting
|
| 79 |
+
split_data_for_imagemlm = False # if True, split a batch data to two parts, and the first part for imagemlm.
|
| 80 |
+
|
| 81 |
+
# itc setting
|
| 82 |
+
itc_mask = False # itc use masked token
|
| 83 |
+
aggregate_nodes = -1 # aggregate nodes num for compute_itc, default -1 is for all nodes
|
| 84 |
+
|
| 85 |
+
# Transformer Setting
|
| 86 |
+
model_arch = "vlmo_base_patch16"
|
| 87 |
+
drop_path_rate = 0.1
|
| 88 |
+
|
| 89 |
+
# Downstream Setting
|
| 90 |
+
get_recall_metric = False
|
| 91 |
+
get_recall_rerank_metric = False
|
| 92 |
+
get_zeroshot_metric = False
|
| 93 |
+
get_muge_feat = False
|
| 94 |
+
get_f30k_feat = False
|
| 95 |
+
k_test = 32
|
| 96 |
+
|
| 97 |
+
# PL Trainer Setting
|
| 98 |
+
resume_from = None
|
| 99 |
+
fast_dev_run = False
|
| 100 |
+
val_check_interval = 1.0
|
| 101 |
+
test_only = False
|
| 102 |
+
use_sharded_training = False
|
| 103 |
+
resume_during_training = False
|
| 104 |
+
save_top_k = 10
|
| 105 |
+
every_n_train_steps = 2000 # the step to save checkpoint
|
| 106 |
+
log_metric_steps = 100 # the step to log metric
|
| 107 |
+
|
| 108 |
+
# below params varies with the environment
|
| 109 |
+
use_pcache = False # data storage method: pcache or nas
|
| 110 |
+
pcache_root = ""
|
| 111 |
+
# main_site: pcache://multimodalproxyi-pool.cz50c.alipay.com:39999/mnt/
|
| 112 |
+
# public_cloud: pcache://pcache_public_cloud.pcache.local:39999/mnt/abc7c88079a60b45ddfce7afa40720b7/
|
| 113 |
+
gpu_env = "main_site" # public_cloud or main_site
|
| 114 |
+
data_root = "" # data root for data list
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
log_dir = "result"
|
| 118 |
+
per_gpu_batchsize = 4 # you should define this manually with per_gpu_batch_size=#
|
| 119 |
+
num_gpus = 1
|
| 120 |
+
num_nodes = 1
|
| 121 |
+
load_path = ""
|
| 122 |
+
num_workers = 8
|
| 123 |
+
precision = 16
|
| 124 |
+
local_run = True
|
| 125 |
+
flash_attn = False
|
| 126 |
+
deepspeed_config = None # "ds_config.json"
|
| 127 |
+
coalesce_backbone = False
|
| 128 |
+
mask_data = "v+l" # 'v+l':choose input of imagemlm+textmlm task, 'vl': choose input of mvlm task.
|
| 129 |
+
communication_benchmark = False
|
| 130 |
+
checkpoint_activations = False
|
| 131 |
+
|
| 132 |
+
# dataset setting
|
| 133 |
+
single_cap = True # if have only one caption
|
| 134 |
+
random_one = False # if choose one caption from caption list
|
| 135 |
+
|
| 136 |
+
# ITC setting
|
| 137 |
+
itc_feats_name = "cls_vlffn_feats" # feat for itc loss
|
| 138 |
+
itc_distill = ""
|
| 139 |
+
itc_distill_dim = 1024
|
| 140 |
+
itc_teacher_weights = ""
|
| 141 |
+
|
| 142 |
+
# mup training setting
|
| 143 |
+
mup = False
|
| 144 |
+
base_encoder_embed_dim = 1
|
| 145 |
+
delta_encoder_embed_dim = 2
|
| 146 |
+
mup_encoder_attention_heads = 1
|
| 147 |
+
base_encoder_ffn_embed_dim = 1
|
| 148 |
+
delta_encoder_ffn_embed_dim = 2
|
| 149 |
+
|
| 150 |
+
# atorch
|
| 151 |
+
atorch_config = None
|
| 152 |
+
compile_op = False
|
| 153 |
+
optimizer_state_shard_save = False
|
| 154 |
+
model_state_shard_save = False
|
| 155 |
+
|
| 156 |
+
# itc loss
|
| 157 |
+
local_loss = False
|
| 158 |
+
use_dual_softmax = False
|
| 159 |
+
|
| 160 |
+
num_frames = 1
|
| 161 |
+
# ----------------------- LMM pretraining config -----------------------
|
| 162 |
+
|
| 163 |
+
# norm setting
|
| 164 |
+
deepnorm = False
|
| 165 |
+
|
vlmo/modules/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .vlmo_module import VLMo
|
vlmo/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (255 Bytes). View file
|
|
|
vlmo/modules/__pycache__/heads.cpython-311.pyc
ADDED
|
Binary file (2.09 kB). View file
|
|
|
vlmo/modules/__pycache__/modeling_utils.cpython-311.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
vlmo/modules/__pycache__/objectives.cpython-311.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
vlmo/modules/__pycache__/vlmo_module.cpython-311.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
vlmo/modules/__pycache__/vlmo_utils.cpython-311.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
vlmo/modules/heads.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Pooler(nn.Module):
|
| 5 |
+
def __init__(self, hidden_size):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
| 8 |
+
self.activation = nn.Tanh()
|
| 9 |
+
|
| 10 |
+
def forward(self, hidden_states):
|
| 11 |
+
first_token_tensor = hidden_states[:, 0]
|
| 12 |
+
pooled_output = self.dense(first_token_tensor)
|
| 13 |
+
pooled_output = self.activation(pooled_output)
|
| 14 |
+
return pooled_output
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ITCHead(nn.Module):
|
| 18 |
+
def __init__(self, hidden_size, out_size):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.fc = nn.Linear(hidden_size, out_size, bias=False)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
x = self.fc(x)
|
| 24 |
+
return x
|
vlmo/modules/modeling_utils.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
|
| 4 |
+
# Copyright (c) 2023 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# --------------------------------------------------------'
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
|
| 12 |
+
|
| 13 |
+
from vlmo.torchscale.model.BEiT3 import BEiT3
|
| 14 |
+
from vlmo.torchscale.architecture.config import EncoderConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0):
|
| 18 |
+
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_base_config(
|
| 22 |
+
img_size=224,
|
| 23 |
+
patch_size=16,
|
| 24 |
+
drop_path_rate=0,
|
| 25 |
+
checkpoint_activations=None,
|
| 26 |
+
mlp_ratio=4,
|
| 27 |
+
vocab_size=64010,
|
| 28 |
+
encoder_layers=12,
|
| 29 |
+
encoder_embed_dim=768,
|
| 30 |
+
encoder_attention_heads=12,
|
| 31 |
+
share_layer=False,
|
| 32 |
+
share_attn=False,
|
| 33 |
+
deepnorm=False,
|
| 34 |
+
mask_ratio=0,
|
| 35 |
+
max_text_len=52,
|
| 36 |
+
one_attn=False,
|
| 37 |
+
**kwargs
|
| 38 |
+
):
|
| 39 |
+
return EncoderConfig(
|
| 40 |
+
img_size=img_size,
|
| 41 |
+
patch_size=patch_size,
|
| 42 |
+
vocab_size=vocab_size,
|
| 43 |
+
multiway=True,
|
| 44 |
+
layernorm_embedding=False,
|
| 45 |
+
normalize_output=True,
|
| 46 |
+
no_output_layer=True,
|
| 47 |
+
drop_path_rate=drop_path_rate,
|
| 48 |
+
encoder_embed_dim=encoder_embed_dim,
|
| 49 |
+
encoder_attention_heads=encoder_attention_heads,
|
| 50 |
+
encoder_layers=encoder_layers,
|
| 51 |
+
encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
|
| 52 |
+
checkpoint_activations=checkpoint_activations,
|
| 53 |
+
share_layer=share_layer,
|
| 54 |
+
share_attn=share_attn,
|
| 55 |
+
deepnorm=deepnorm,
|
| 56 |
+
mask_ratio=mask_ratio,
|
| 57 |
+
max_text_len=max_text_len,
|
| 58 |
+
one_attn=one_attn,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _get_large_config(
|
| 63 |
+
img_size=224,
|
| 64 |
+
patch_size=16,
|
| 65 |
+
drop_path_rate=0,
|
| 66 |
+
checkpoint_activations=None,
|
| 67 |
+
mlp_ratio=4,
|
| 68 |
+
vocab_size=64010,
|
| 69 |
+
encoder_layers=24,
|
| 70 |
+
encoder_embed_dim=1024,
|
| 71 |
+
encoder_attention_heads=16,
|
| 72 |
+
share_layer=False,
|
| 73 |
+
share_attn=False,
|
| 74 |
+
deepnorm=False,
|
| 75 |
+
mask_ratio=0,
|
| 76 |
+
max_text_len=52,
|
| 77 |
+
one_attn=False,
|
| 78 |
+
**kwargs
|
| 79 |
+
):
|
| 80 |
+
return EncoderConfig(
|
| 81 |
+
img_size=img_size,
|
| 82 |
+
patch_size=patch_size,
|
| 83 |
+
vocab_size=vocab_size,
|
| 84 |
+
multiway=True,
|
| 85 |
+
layernorm_embedding=False,
|
| 86 |
+
normalize_output=True,
|
| 87 |
+
no_output_layer=True,
|
| 88 |
+
drop_path_rate=drop_path_rate,
|
| 89 |
+
encoder_embed_dim=encoder_embed_dim,
|
| 90 |
+
encoder_attention_heads=encoder_attention_heads,
|
| 91 |
+
encoder_layers=encoder_layers,
|
| 92 |
+
encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
|
| 93 |
+
checkpoint_activations=checkpoint_activations,
|
| 94 |
+
share_layer=share_layer,
|
| 95 |
+
share_attn=share_attn,
|
| 96 |
+
deepnorm=deepnorm,
|
| 97 |
+
mask_ratio=mask_ratio,
|
| 98 |
+
max_text_len=max_text_len,
|
| 99 |
+
one_attn=one_attn,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_huge_config(
|
| 104 |
+
img_size=224,
|
| 105 |
+
patch_size=16,
|
| 106 |
+
drop_path_rate=0,
|
| 107 |
+
checkpoint_activations=None,
|
| 108 |
+
mlp_ratio=4,
|
| 109 |
+
vocab_size=30522,
|
| 110 |
+
encoder_layers=32,
|
| 111 |
+
encoder_embed_dim=4096,
|
| 112 |
+
encoder_attention_heads=32,
|
| 113 |
+
share_layer=False,
|
| 114 |
+
share_attn=False,
|
| 115 |
+
deepnorm=False,
|
| 116 |
+
mask_ratio=0,
|
| 117 |
+
max_text_len=52,
|
| 118 |
+
one_attn=False,
|
| 119 |
+
**kwargs
|
| 120 |
+
):
|
| 121 |
+
return EncoderConfig(
|
| 122 |
+
img_size=img_size,
|
| 123 |
+
patch_size=patch_size,
|
| 124 |
+
vocab_size=vocab_size,
|
| 125 |
+
multiway=True,
|
| 126 |
+
layernorm_embedding=False,
|
| 127 |
+
normalize_output=True,
|
| 128 |
+
no_output_layer=True,
|
| 129 |
+
drop_path_rate=drop_path_rate,
|
| 130 |
+
encoder_embed_dim=encoder_embed_dim,
|
| 131 |
+
encoder_attention_heads=encoder_attention_heads,
|
| 132 |
+
encoder_layers=encoder_layers,
|
| 133 |
+
encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
|
| 134 |
+
checkpoint_activations=checkpoint_activations,
|
| 135 |
+
share_layer=share_layer,
|
| 136 |
+
share_attn=share_attn,
|
| 137 |
+
deepnorm=deepnorm,
|
| 138 |
+
mask_ratio=mask_ratio,
|
| 139 |
+
max_text_len=max_text_len,
|
| 140 |
+
one_attn=one_attn,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class BEiT3Wrapper(nn.Module):
|
| 145 |
+
def __init__(self, args, **kwargs):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.args = args
|
| 148 |
+
self.beit3 = BEiT3(args)
|
| 149 |
+
self.apply(self._init_weights)
|
| 150 |
+
|
| 151 |
+
def fix_init_weight(self):
|
| 152 |
+
def rescale(param, layer_id):
|
| 153 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 154 |
+
|
| 155 |
+
for layer_id, layer in enumerate(self.blocks):
|
| 156 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 157 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 158 |
+
|
| 159 |
+
def get_num_layers(self):
|
| 160 |
+
return self.beit3.encoder.num_layers
|
| 161 |
+
|
| 162 |
+
@torch.jit.ignore
|
| 163 |
+
def no_weight_decay(self):
|
| 164 |
+
return {
|
| 165 |
+
"pos_embed",
|
| 166 |
+
"cls_token",
|
| 167 |
+
"beit3.encoder.embed_positions.A.weight",
|
| 168 |
+
"beit3.vision_embed.cls_token",
|
| 169 |
+
"logit_scale",
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
def _init_weights(self, m):
|
| 173 |
+
if isinstance(m, nn.Linear):
|
| 174 |
+
trunc_normal_(m.weight, std=0.02)
|
| 175 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 176 |
+
nn.init.constant_(m.bias, 0)
|
| 177 |
+
elif isinstance(m, nn.LayerNorm):
|
| 178 |
+
nn.init.constant_(m.bias, 0)
|
| 179 |
+
nn.init.constant_(m.weight, 1.0)
|
vlmo/modules/multiway_transformer.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Vision Transformer (ViT) in PyTorch
|
| 2 |
+
|
| 3 |
+
A PyTorch implement of Vision Transformers as described in
|
| 4 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
| 5 |
+
|
| 6 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
| 7 |
+
|
| 8 |
+
Acknowledgments:
|
| 9 |
+
* The paper authors for releasing code and weights, thanks!
|
| 10 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
| 11 |
+
for some einops/einsum fun
|
| 12 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
| 13 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
| 14 |
+
|
| 15 |
+
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
| 16 |
+
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
| 17 |
+
|
| 18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 19 |
+
"""
|
| 20 |
+
from functools import partial
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 27 |
+
from timm.models.registry import register_model
|
| 28 |
+
from pytorch_lightning.utilities.distributed import rank_zero_info
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Mlp(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
in_features,
|
| 35 |
+
hidden_features=None,
|
| 36 |
+
out_features=None,
|
| 37 |
+
act_layer=nn.GELU,
|
| 38 |
+
drop=0.0,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
out_features = out_features or in_features
|
| 42 |
+
hidden_features = hidden_features or in_features
|
| 43 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 44 |
+
self.act = act_layer()
|
| 45 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 46 |
+
self.drop = nn.Dropout(drop)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.fc1(x)
|
| 50 |
+
x = self.act(x)
|
| 51 |
+
x = self.drop(x)
|
| 52 |
+
x = self.fc2(x)
|
| 53 |
+
x = self.drop(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Attention(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
dim,
|
| 61 |
+
num_heads=8,
|
| 62 |
+
qkv_bias=False,
|
| 63 |
+
qk_scale=None,
|
| 64 |
+
attn_drop=0.0,
|
| 65 |
+
proj_drop=0.0,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.num_heads = num_heads
|
| 69 |
+
head_dim = dim // num_heads
|
| 70 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 71 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 72 |
+
|
| 73 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
| 74 |
+
if qkv_bias:
|
| 75 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
| 76 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
| 77 |
+
else:
|
| 78 |
+
self.q_bias = None
|
| 79 |
+
self.v_bias = None
|
| 80 |
+
|
| 81 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 82 |
+
self.proj = nn.Linear(dim, dim)
|
| 83 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 84 |
+
|
| 85 |
+
def forward(self, x, mask=None, relative_position_bias=None):
|
| 86 |
+
B, N, C = x.shape
|
| 87 |
+
|
| 88 |
+
qkv_bias = None
|
| 89 |
+
if self.q_bias is not None:
|
| 90 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
| 91 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 92 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 93 |
+
|
| 94 |
+
q, k, v = (
|
| 95 |
+
qkv[0],
|
| 96 |
+
qkv[1],
|
| 97 |
+
qkv[2],
|
| 98 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
| 99 |
+
|
| 100 |
+
q = q * self.scale
|
| 101 |
+
attn = q.float() @ k.float().transpose(-2, -1)
|
| 102 |
+
|
| 103 |
+
if relative_position_bias is not None:
|
| 104 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 105 |
+
|
| 106 |
+
if mask is not None:
|
| 107 |
+
mask = mask.bool()
|
| 108 |
+
attn = attn.masked_fill(~mask[:, None, None, :], float("-inf"))
|
| 109 |
+
attn = attn.softmax(dim=-1).type_as(x)
|
| 110 |
+
attn = self.attn_drop(attn)
|
| 111 |
+
|
| 112 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 113 |
+
x = self.proj(x)
|
| 114 |
+
x = self.proj_drop(x)
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Block(nn.Module):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
dim,
|
| 122 |
+
num_heads,
|
| 123 |
+
mlp_ratio=4.0,
|
| 124 |
+
qkv_bias=False,
|
| 125 |
+
qk_scale=None,
|
| 126 |
+
drop=0.0,
|
| 127 |
+
attn_drop=0.0,
|
| 128 |
+
drop_path=0.0,
|
| 129 |
+
act_layer=nn.GELU,
|
| 130 |
+
norm_layer=nn.LayerNorm,
|
| 131 |
+
with_vlffn=False,
|
| 132 |
+
layer_scale_init_values=0.1,
|
| 133 |
+
max_text_len=40,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.norm1 = norm_layer(dim)
|
| 137 |
+
self.attn = Attention(
|
| 138 |
+
dim,
|
| 139 |
+
num_heads=num_heads,
|
| 140 |
+
qkv_bias=qkv_bias,
|
| 141 |
+
qk_scale=qk_scale,
|
| 142 |
+
attn_drop=attn_drop,
|
| 143 |
+
proj_drop=drop,
|
| 144 |
+
)
|
| 145 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 146 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 147 |
+
self.norm2_text = norm_layer(dim)
|
| 148 |
+
self.norm2_imag = norm_layer(dim)
|
| 149 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 150 |
+
self.mlp_text = Mlp(
|
| 151 |
+
in_features=dim,
|
| 152 |
+
hidden_features=mlp_hidden_dim,
|
| 153 |
+
act_layer=act_layer,
|
| 154 |
+
drop=drop,
|
| 155 |
+
)
|
| 156 |
+
self.mlp_imag = Mlp(
|
| 157 |
+
in_features=dim,
|
| 158 |
+
hidden_features=mlp_hidden_dim,
|
| 159 |
+
act_layer=act_layer,
|
| 160 |
+
drop=drop,
|
| 161 |
+
)
|
| 162 |
+
self.mlp_vl = None
|
| 163 |
+
if with_vlffn:
|
| 164 |
+
self.mlp_vl = Mlp(
|
| 165 |
+
in_features=dim,
|
| 166 |
+
hidden_features=mlp_hidden_dim,
|
| 167 |
+
act_layer=act_layer,
|
| 168 |
+
drop=drop,
|
| 169 |
+
)
|
| 170 |
+
self.norm2_vl = norm_layer(dim)
|
| 171 |
+
|
| 172 |
+
self.gamma_1 = (
|
| 173 |
+
nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
|
| 174 |
+
if layer_scale_init_values is not None
|
| 175 |
+
else 1.0
|
| 176 |
+
)
|
| 177 |
+
self.gamma_2 = (
|
| 178 |
+
nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
|
| 179 |
+
if layer_scale_init_values is not None
|
| 180 |
+
else 1.0
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.max_text_len = max_text_len
|
| 184 |
+
|
| 185 |
+
def forward(self, x, mask=None, modality_type=None, relative_position_bias=None):
|
| 186 |
+
x = x + self.drop_path(
|
| 187 |
+
self.gamma_1 * self.attn(self.norm1(x), mask=mask, relative_position_bias=relative_position_bias)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if modality_type == "image":
|
| 191 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x)))
|
| 192 |
+
elif modality_type == "text":
|
| 193 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x)))
|
| 194 |
+
else:
|
| 195 |
+
if self.mlp_vl is None:
|
| 196 |
+
x_text = x[:, : self.max_text_len]
|
| 197 |
+
x_imag = x[:, self.max_text_len :]
|
| 198 |
+
x_text = x_text + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x_text)))
|
| 199 |
+
x_imag = x_imag + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x_imag)))
|
| 200 |
+
x = torch.cat([x_text, x_imag], dim=1)
|
| 201 |
+
else:
|
| 202 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp_vl(self.norm2_vl(x)))
|
| 203 |
+
|
| 204 |
+
return x
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class PatchEmbed(nn.Module):
|
| 208 |
+
"""Image to Patch Embedding"""
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
img_size=224,
|
| 213 |
+
patch_size=16,
|
| 214 |
+
in_chans=3,
|
| 215 |
+
embed_dim=768,
|
| 216 |
+
no_patch_embed_bias=False,
|
| 217 |
+
):
|
| 218 |
+
super().__init__()
|
| 219 |
+
img_size = to_2tuple(img_size)
|
| 220 |
+
patch_size = to_2tuple(patch_size)
|
| 221 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 222 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 223 |
+
self.img_size = img_size
|
| 224 |
+
self.patch_size = patch_size
|
| 225 |
+
self.num_patches = num_patches
|
| 226 |
+
|
| 227 |
+
self.proj = nn.Conv2d(
|
| 228 |
+
in_chans,
|
| 229 |
+
embed_dim,
|
| 230 |
+
kernel_size=patch_size,
|
| 231 |
+
stride=patch_size,
|
| 232 |
+
bias=False if no_patch_embed_bias else True,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
B, C, H, W = x.shape
|
| 237 |
+
assert (
|
| 238 |
+
H == self.img_size[0] and W == self.img_size[1]
|
| 239 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 240 |
+
# FIXME look at relaxing size constraints
|
| 241 |
+
x = self.proj(x)
|
| 242 |
+
return x
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class MultiWayTransformer(nn.Module):
|
| 246 |
+
"""Vision Transformer
|
| 247 |
+
|
| 248 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
| 249 |
+
https://arxiv.org/abs/2010.11929
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
img_size=224,
|
| 255 |
+
patch_size=16,
|
| 256 |
+
in_chans=3,
|
| 257 |
+
embed_dim=768,
|
| 258 |
+
depth=12,
|
| 259 |
+
num_heads=12,
|
| 260 |
+
mlp_ratio=4.0,
|
| 261 |
+
qkv_bias=True,
|
| 262 |
+
qk_scale=None,
|
| 263 |
+
drop_rate=0.0,
|
| 264 |
+
attn_drop_rate=0.0,
|
| 265 |
+
drop_path_rate=0.0,
|
| 266 |
+
norm_layer=None,
|
| 267 |
+
need_relative_position_embed=True,
|
| 268 |
+
use_abs_pos_emb=False,
|
| 269 |
+
layer_scale_init_values=0.1,
|
| 270 |
+
vlffn_start_layer_index=10,
|
| 271 |
+
config=None,
|
| 272 |
+
):
|
| 273 |
+
"""
|
| 274 |
+
Args:
|
| 275 |
+
img_size (int, tuple): input image size
|
| 276 |
+
patch_size (int, tuple): patch size
|
| 277 |
+
in_chans (int): number of input channels
|
| 278 |
+
num_classes (int): number of classes for classification head
|
| 279 |
+
embed_dim (int): embedding dimension
|
| 280 |
+
depth (int): depth of transformer
|
| 281 |
+
num_heads (int): number of attention heads
|
| 282 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 283 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 284 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 285 |
+
drop_rate (float): dropout rate
|
| 286 |
+
attn_drop_rate (float): attention dropout rate
|
| 287 |
+
drop_path_rate (float): stochastic depth rate
|
| 288 |
+
norm_layer: (nn.Module): normalization layer
|
| 289 |
+
need_relative_position_embed (bool): enable relative position bias on self-attention
|
| 290 |
+
use_abs_pos_emb (bool): enable abs pos emb
|
| 291 |
+
layer_scale_init_values (float or None): layer scale init values, set None to disable
|
| 292 |
+
vlffn_start_layer_index (int): vl-ffn start index
|
| 293 |
+
config: (dict): other hyper from pytorch-lighting
|
| 294 |
+
"""
|
| 295 |
+
super().__init__()
|
| 296 |
+
drop_path_rate = drop_path_rate if config is None else config["drop_path_rate"]
|
| 297 |
+
rank_zero_info("drop path rate: {}".format(drop_path_rate))
|
| 298 |
+
self.use_abs_pos_emb = use_abs_pos_emb
|
| 299 |
+
self.need_relative_position_embed = need_relative_position_embed
|
| 300 |
+
|
| 301 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 302 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 303 |
+
|
| 304 |
+
self.patch_embed = PatchEmbed(
|
| 305 |
+
img_size=img_size,
|
| 306 |
+
patch_size=patch_size,
|
| 307 |
+
in_chans=in_chans,
|
| 308 |
+
embed_dim=embed_dim,
|
| 309 |
+
)
|
| 310 |
+
num_patches = self.patch_embed.num_patches
|
| 311 |
+
self.patch_size = patch_size
|
| 312 |
+
self.num_heads = num_heads
|
| 313 |
+
self.vlffn_start_layer_index = vlffn_start_layer_index
|
| 314 |
+
if config["loss_names"]["textmlm"] > 0:
|
| 315 |
+
self.vlffn_start_layer_index = depth
|
| 316 |
+
rank_zero_info(
|
| 317 |
+
"Set vlffn_start_layer_index={} for text-only pretraining".format(self.vlffn_start_layer_index)
|
| 318 |
+
)
|
| 319 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 320 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if self.use_abs_pos_emb else None
|
| 321 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 322 |
+
|
| 323 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 324 |
+
self.blocks = nn.ModuleList(
|
| 325 |
+
[
|
| 326 |
+
Block(
|
| 327 |
+
dim=embed_dim,
|
| 328 |
+
num_heads=num_heads,
|
| 329 |
+
mlp_ratio=mlp_ratio,
|
| 330 |
+
qkv_bias=qkv_bias,
|
| 331 |
+
qk_scale=qk_scale,
|
| 332 |
+
drop=drop_rate,
|
| 333 |
+
attn_drop=attn_drop_rate,
|
| 334 |
+
drop_path=dpr[i],
|
| 335 |
+
norm_layer=norm_layer,
|
| 336 |
+
with_vlffn=(i >= self.vlffn_start_layer_index),
|
| 337 |
+
layer_scale_init_values=layer_scale_init_values,
|
| 338 |
+
max_text_len=config["max_text_len"],
|
| 339 |
+
)
|
| 340 |
+
for i in range(depth)
|
| 341 |
+
]
|
| 342 |
+
)
|
| 343 |
+
self.norm = norm_layer(embed_dim)
|
| 344 |
+
|
| 345 |
+
if self.pos_embed is not None:
|
| 346 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 347 |
+
trunc_normal_(self.cls_token, std=0.02)
|
| 348 |
+
self.apply(self._init_weights)
|
| 349 |
+
|
| 350 |
+
def _init_weights(self, m):
|
| 351 |
+
if isinstance(m, nn.Linear):
|
| 352 |
+
trunc_normal_(m.weight, std=0.02)
|
| 353 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 354 |
+
nn.init.constant_(m.bias, 0)
|
| 355 |
+
elif isinstance(m, nn.LayerNorm):
|
| 356 |
+
nn.init.constant_(m.bias, 0)
|
| 357 |
+
nn.init.constant_(m.weight, 1.0)
|
| 358 |
+
|
| 359 |
+
@torch.jit.ignore
|
| 360 |
+
def no_weight_decay(self):
|
| 361 |
+
return {"pos_embed", "cls_token"}
|
| 362 |
+
|
| 363 |
+
def visual_embed(self, _x):
|
| 364 |
+
x = self.patch_embed(_x)
|
| 365 |
+
x = x.flatten(2).transpose(1, 2)
|
| 366 |
+
B, L, _ = x.shape
|
| 367 |
+
|
| 368 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 369 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 370 |
+
|
| 371 |
+
if self.pos_embed is not None:
|
| 372 |
+
x = x + self.pos_embed
|
| 373 |
+
x = self.pos_drop(x)
|
| 374 |
+
|
| 375 |
+
x_mask = torch.ones(x.shape[0], x.shape[1])
|
| 376 |
+
|
| 377 |
+
return x, x_mask
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# VLMo base/p16
|
| 381 |
+
@register_model
|
| 382 |
+
def vlmo_base_patch16(pretrained=False, **kwargs):
|
| 383 |
+
img_size = kwargs.pop("img_size", 224)
|
| 384 |
+
model = MultiWayTransformer(
|
| 385 |
+
img_size=img_size,
|
| 386 |
+
patch_size=16,
|
| 387 |
+
embed_dim=768,
|
| 388 |
+
depth=12,
|
| 389 |
+
num_heads=12,
|
| 390 |
+
mlp_ratio=4,
|
| 391 |
+
qkv_bias=True,
|
| 392 |
+
vlffn_start_layer_index=10,
|
| 393 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 394 |
+
**kwargs,
|
| 395 |
+
)
|
| 396 |
+
return model
|
vlmo/modules/objectives.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def init_weights(module):
|
| 5 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 6 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 7 |
+
elif isinstance(module, nn.LayerNorm):
|
| 8 |
+
module.bias.data.zero_()
|
| 9 |
+
module.weight.data.fill_(1.0)
|
| 10 |
+
|
| 11 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 12 |
+
module.bias.data.zero_()
|
vlmo/modules/vlmo_module.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from pytorch_lightning.utilities.distributed import rank_zero_info
|
| 11 |
+
from timm.models import create_model
|
| 12 |
+
from transformers import AutoTokenizer, BertTokenizer, XLMRobertaTokenizer # noqa
|
| 13 |
+
from vlmo.modules import heads, objectives, vlmo_utils
|
| 14 |
+
from vlmo.tokenizer.tokenization_glm import GLMChineseTokenizer # noqa
|
| 15 |
+
from vlmo.torchscale.architecture.encoder import Encoder
|
| 16 |
+
from vlmo.torchscale.model.BEiT3 import BEiT3 as ts_backbone
|
| 17 |
+
from vlmo.transforms.utils import inception_normalize as img_norm
|
| 18 |
+
|
| 19 |
+
from .modeling_utils import _get_base_config, _get_large_config, _get_huge_config, trunc_normal_ # noqa
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def convert_pl_ckpt(state_dict, num_visual_token=197):
|
| 23 |
+
print("start convert_pl_ckpt!!!")
|
| 24 |
+
new_state_dict = {}
|
| 25 |
+
for key in state_dict:
|
| 26 |
+
value = state_dict[key]
|
| 27 |
+
if "visual_tokenizer" in key:
|
| 28 |
+
continue
|
| 29 |
+
elif "backbone.encoder.embed_positions.A.weight" in key:
|
| 30 |
+
if value.shape[0] < num_visual_token + 2:
|
| 31 |
+
N = value.shape[0] - 3
|
| 32 |
+
dim = value.shape[-1]
|
| 33 |
+
class_pos_embed = value[:3, ]
|
| 34 |
+
patch_pos_embed = value[3:, ]
|
| 35 |
+
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
|
| 36 |
+
patch_pos_embed = patch_pos_embed.float()
|
| 37 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 38 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 39 |
+
size=(w0, h0),
|
| 40 |
+
mode="area",
|
| 41 |
+
)
|
| 42 |
+
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
|
| 43 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
|
| 44 |
+
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
|
| 45 |
+
new_state_dict[key] = new_value
|
| 46 |
+
print("reshape ", key, "raw shape: ", value.shape, "new shape: ", new_value.shape, num_visual_token)
|
| 47 |
+
elif value.shape[0] > num_visual_token + 2:
|
| 48 |
+
new_state_dict[key] = value[: num_visual_token + 2, :]
|
| 49 |
+
print("first ", key, "raw shape: ", value.shape, new_state_dict[key].shape, num_visual_token)
|
| 50 |
+
else:
|
| 51 |
+
new_state_dict[key] = value
|
| 52 |
+
print("raw shape")
|
| 53 |
+
else:
|
| 54 |
+
new_state_dict[key] = state_dict[key]
|
| 55 |
+
|
| 56 |
+
return new_state_dict
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def convert_deepspeed_ckpt(state_dict, num_visual_token=197):
|
| 60 |
+
new_state_dict = {}
|
| 61 |
+
for key in state_dict:
|
| 62 |
+
if key.startswith("_forward_module."):
|
| 63 |
+
new_key = key[len("_forward_module."):]
|
| 64 |
+
value = state_dict[key]
|
| 65 |
+
new_state_dict[new_key] = value
|
| 66 |
+
if "visual_tokenizer.encoder.pos_embed" in new_key or "visual_tokenizer.decoder.pos_embed" in new_key:
|
| 67 |
+
if value.shape[1] != num_visual_token:
|
| 68 |
+
N = value.shape[1] - 1
|
| 69 |
+
dim = value.shape[-1]
|
| 70 |
+
class_pos_embed = value[:, 0]
|
| 71 |
+
patch_pos_embed = value[:, 1:]
|
| 72 |
+
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
|
| 73 |
+
patch_pos_embed = patch_pos_embed.float()
|
| 74 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 75 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 76 |
+
size=(w0, h0),
|
| 77 |
+
mode="area",
|
| 78 |
+
)
|
| 79 |
+
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
|
| 80 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 81 |
+
new_value = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 82 |
+
new_state_dict[new_key] = new_value
|
| 83 |
+
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
|
| 84 |
+
if "backbone.encoder.embed_positions.A.weight" in new_key:
|
| 85 |
+
if value.shape[1] != num_visual_token + 2:
|
| 86 |
+
N = value.shape[0] - 3
|
| 87 |
+
dim = value.shape[-1]
|
| 88 |
+
class_pos_embed = value[:3, ]
|
| 89 |
+
patch_pos_embed = value[3:, ]
|
| 90 |
+
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
|
| 91 |
+
patch_pos_embed = patch_pos_embed.float()
|
| 92 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 93 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 94 |
+
size=(w0, h0),
|
| 95 |
+
mode="area",
|
| 96 |
+
)
|
| 97 |
+
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
|
| 98 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
|
| 99 |
+
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
|
| 100 |
+
new_state_dict[new_key] = new_value
|
| 101 |
+
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
|
| 102 |
+
|
| 103 |
+
else:
|
| 104 |
+
new_state_dict[key] = state_dict[key]
|
| 105 |
+
|
| 106 |
+
return new_state_dict
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_visual_tokenizer(config):
|
| 110 |
+
tokenizer_name = config["tokenizer_model"]
|
| 111 |
+
print(f"Creating visual tokenizer: {tokenizer_name}")
|
| 112 |
+
model = create_model(
|
| 113 |
+
config["tokenizer_model"],
|
| 114 |
+
img_size=config["image_size"],
|
| 115 |
+
n_code=config["codebook_size"],
|
| 116 |
+
code_dim=config["codebook_dim"],
|
| 117 |
+
).eval()
|
| 118 |
+
return model
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_pretrained_tokenizer(tokenizer_type, from_pretrained):
|
| 122 |
+
_Tokenizer = eval(f"{tokenizer_type}")
|
| 123 |
+
if torch.distributed.is_initialized():
|
| 124 |
+
if torch.distributed.get_rank() == 0:
|
| 125 |
+
_Tokenizer.from_pretrained(from_pretrained)
|
| 126 |
+
torch.distributed.barrier()
|
| 127 |
+
return _Tokenizer.from_pretrained(from_pretrained)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class VLMo(pl.LightningModule):
|
| 131 |
+
def __init__(self, config):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.save_hyperparameters()
|
| 134 |
+
s_t = time.time()
|
| 135 |
+
|
| 136 |
+
# tokenizer & backbone
|
| 137 |
+
self.img_size = config["image_size"]
|
| 138 |
+
if not config["test_only"]:
|
| 139 |
+
self.visual_tokenizer = get_visual_tokenizer(config)
|
| 140 |
+
kwargs = {}
|
| 141 |
+
if "encoder_attention_heads" in config:
|
| 142 |
+
kwargs["encoder_attention_heads"] = config["encoder_attention_heads"]
|
| 143 |
+
if "atorch_config" in config and config["atorch_config"]:
|
| 144 |
+
checkpoint_activations = False # ?
|
| 145 |
+
else:
|
| 146 |
+
checkpoint_activations = config["checkpoint_activations"]
|
| 147 |
+
args = eval(f'_get_{config["beit_version"]}_config')(
|
| 148 |
+
img_size=config["image_size"],
|
| 149 |
+
patch_size=config["patch_size"],
|
| 150 |
+
vocab_size=config["vocab_size"],
|
| 151 |
+
encoder_layers=config["encoder_layers"],
|
| 152 |
+
encoder_embed_dim=config["encoder_embed_dim"],
|
| 153 |
+
checkpoint_activations=checkpoint_activations,
|
| 154 |
+
share_layer=config["share_layer"],
|
| 155 |
+
share_attn=config["share_attn"],
|
| 156 |
+
deepnorm=config["deepnorm"],
|
| 157 |
+
mask_ratio=config["mask_ratio"],
|
| 158 |
+
max_text_len=config["max_text_len"],
|
| 159 |
+
one_attn=config["one_attn"],
|
| 160 |
+
**kwargs,
|
| 161 |
+
)
|
| 162 |
+
self.num_features = args.encoder_embed_dim
|
| 163 |
+
self.out_features = config["out_embed_dim"]
|
| 164 |
+
self.cap_onlytext = config["cap_onlytext"]
|
| 165 |
+
self.lang = config["lang"]
|
| 166 |
+
self.num_frames = config["num_frames"]
|
| 167 |
+
self.tokenizer_type = config["tokenizer_type"]
|
| 168 |
+
self.text_tokenizer = get_pretrained_tokenizer(self.tokenizer_type, from_pretrained=config["tokenizer"]) # noqa
|
| 169 |
+
print("BEiT args", args.__dict__)
|
| 170 |
+
self.backbone = ts_backbone(args)
|
| 171 |
+
|
| 172 |
+
self.use_vl = config["beit3_vl_layers"] > 0
|
| 173 |
+
if self.use_vl:
|
| 174 |
+
args.encoder_layers = config["beit3_vl_layers"]
|
| 175 |
+
self.backbone_vl = Encoder(args)
|
| 176 |
+
|
| 177 |
+
self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
|
| 178 |
+
|
| 179 |
+
# task layers
|
| 180 |
+
self.pooler = heads.Pooler(self.num_features)
|
| 181 |
+
self.pooler.apply(objectives.init_weights)
|
| 182 |
+
|
| 183 |
+
# contrastive loss (or sampling for global hard negative)
|
| 184 |
+
if config["loss_names"]["itc"] > 0:
|
| 185 |
+
self.itc_text_proj = heads.ITCHead(self.num_features, self.out_features)
|
| 186 |
+
self.itc_image_proj = heads.ITCHead(self.num_features, self.out_features)
|
| 187 |
+
self.itc_text_proj.apply(objectives.init_weights)
|
| 188 |
+
self.itc_image_proj.apply(objectives.init_weights)
|
| 189 |
+
|
| 190 |
+
self.itc_vl_text_proj = heads.ITCHead(self.num_features, self.out_features)
|
| 191 |
+
self.itc_vl_image_proj = heads.ITCHead(self.num_features, self.out_features)
|
| 192 |
+
self.itc_vl_text_proj.apply(objectives.init_weights)
|
| 193 |
+
self.itc_vl_image_proj.apply(objectives.init_weights)
|
| 194 |
+
|
| 195 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 196 |
+
self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 197 |
+
|
| 198 |
+
lp_s_t = time.time()
|
| 199 |
+
|
| 200 |
+
self.load_pretrained_weight()
|
| 201 |
+
load_pretrain_time = time.time() - lp_s_t
|
| 202 |
+
|
| 203 |
+
self.current_tasks = list()
|
| 204 |
+
|
| 205 |
+
# ===================== load downstream (test_only) ======================
|
| 206 |
+
|
| 207 |
+
if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
|
| 208 |
+
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
|
| 209 |
+
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
|
| 210 |
+
|
| 211 |
+
state_dict = None
|
| 212 |
+
|
| 213 |
+
for state_dict_key in ("state_dict", "module", "model"):
|
| 214 |
+
if state_dict_key in ckpt:
|
| 215 |
+
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
|
| 216 |
+
state_dict = ckpt[state_dict_key]
|
| 217 |
+
break
|
| 218 |
+
if state_dict_key == "module":
|
| 219 |
+
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
| 220 |
+
if state_dict_key == "state_dict":
|
| 221 |
+
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
| 222 |
+
if state_dict is None:
|
| 223 |
+
if list(ckpt.keys())[0].startswith('_forward_module.'):
|
| 224 |
+
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
|
| 225 |
+
state_dict = convert_deepspeed_ckpt(ckpt, self.backbone.vision_embed.num_position_embeddings())
|
| 226 |
+
else:
|
| 227 |
+
rank_zero_info("Read state dict from ckpt. ")
|
| 228 |
+
state_dict = ckpt
|
| 229 |
+
|
| 230 |
+
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
|
| 231 |
+
rank_zero_info("missing_keys: {}".format(missing_keys))
|
| 232 |
+
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
|
| 233 |
+
|
| 234 |
+
construct_time = time.time() - s_t
|
| 235 |
+
print(
|
| 236 |
+
f"Process {os.getpid()}. VLMo Constructor time: {construct_time}s;",
|
| 237 |
+
f"load_pretrain_time: {load_pretrain_time}s",
|
| 238 |
+
flush=True,
|
| 239 |
+
)
|
| 240 |
+
# coalesce backbone calls
|
| 241 |
+
self._coalesce_backbone = config["coalesce_backbone"]
|
| 242 |
+
self._mask_data = config["mask_data"]
|
| 243 |
+
self._backbone_inputs = {}
|
| 244 |
+
self._backbone_inputs_current_size = 0
|
| 245 |
+
self._backbone_inputs_keys = {}
|
| 246 |
+
self._backbone_outputs = None
|
| 247 |
+
self._default_attn_masks = {}
|
| 248 |
+
self._itc_group = None
|
| 249 |
+
self._itc_aggregate_dict = None
|
| 250 |
+
self._itc_mask = config["itc_mask"]
|
| 251 |
+
self._local_loss = config["local_loss"]
|
| 252 |
+
self._aggregate_nodes = config["aggregate_nodes"]
|
| 253 |
+
self.accumulated_batches_reached = False
|
| 254 |
+
vlmo_utils.set_task(self)
|
| 255 |
+
self._only_itc_single_machine = (
|
| 256 |
+
self._aggregate_nodes > 0 and len(self.current_tasks) == 1 and "itc" in self.current_tasks
|
| 257 |
+
)
|
| 258 |
+
self._split_data_for_imagemlm = config["split_data_for_imagemlm"]
|
| 259 |
+
self.log_metric_steps = config["log_metric_steps"]
|
| 260 |
+
|
| 261 |
+
def _init_weights(self, m):
|
| 262 |
+
if isinstance(m, nn.Linear):
|
| 263 |
+
trunc_normal_(m.weight, std=0.02)
|
| 264 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 265 |
+
nn.init.constant_(m.bias, 0)
|
| 266 |
+
elif isinstance(m, nn.LayerNorm):
|
| 267 |
+
nn.init.constant_(m.bias, 0)
|
| 268 |
+
nn.init.constant_(m.weight, 1.0)
|
| 269 |
+
|
| 270 |
+
def fix_init_weight(self):
|
| 271 |
+
def rescale(param, layer_id):
|
| 272 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 273 |
+
|
| 274 |
+
for layer_id, layer in enumerate(self.backbone.encoder.layers):
|
| 275 |
+
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + 1)
|
| 276 |
+
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + 1)
|
| 277 |
+
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + 1)
|
| 278 |
+
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + 1)
|
| 279 |
+
rescale(layer.ffn.A.fc2.weight.data, layer_id + 1)
|
| 280 |
+
rescale(layer.ffn.B.fc2.weight.data, layer_id + 1)
|
| 281 |
+
|
| 282 |
+
if self.use_vl:
|
| 283 |
+
pre_layers = len(self.backbone.encoder.layers) + 1
|
| 284 |
+
for layer_id, layer in enumerate(self.backbone_vl.layers):
|
| 285 |
+
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + pre_layers)
|
| 286 |
+
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + pre_layers)
|
| 287 |
+
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + pre_layers)
|
| 288 |
+
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + pre_layers)
|
| 289 |
+
rescale(layer.ffn.A.fc2.weight.data, layer_id + pre_layers)
|
| 290 |
+
rescale(layer.ffn.B.fc2.weight.data, layer_id + pre_layers)
|
| 291 |
+
|
| 292 |
+
def load_pretrained_weight(self):
|
| 293 |
+
if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]:
|
| 294 |
+
config = self.hparams.config
|
| 295 |
+
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
|
| 296 |
+
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
|
| 297 |
+
|
| 298 |
+
state_dict = None
|
| 299 |
+
|
| 300 |
+
for state_dict_key in ("state_dict", "module", "model"):
|
| 301 |
+
if state_dict_key in ckpt:
|
| 302 |
+
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
|
| 303 |
+
state_dict = ckpt[state_dict_key]
|
| 304 |
+
break
|
| 305 |
+
if state_dict_key == "module":
|
| 306 |
+
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
| 307 |
+
if state_dict_key == "state_dict":
|
| 308 |
+
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
| 309 |
+
if state_dict is None:
|
| 310 |
+
if list(ckpt.keys())[0].startswith('_forward_module.'):
|
| 311 |
+
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
|
| 312 |
+
state_dict = convert_deepspeed_ckpt(ckpt,
|
| 313 |
+
self.backbone.vision_embed.num_position_embeddings())
|
| 314 |
+
else:
|
| 315 |
+
rank_zero_info("Read state dict from ckpt. ")
|
| 316 |
+
state_dict = ckpt
|
| 317 |
+
|
| 318 |
+
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
|
| 319 |
+
missing_keys = [k for k in missing_keys if "itc_teacher" not in k]
|
| 320 |
+
rank_zero_info("missing_keys: {}".format(missing_keys))
|
| 321 |
+
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
|
| 322 |
+
|
| 323 |
+
def infer_text(
|
| 324 |
+
self,
|
| 325 |
+
batch,
|
| 326 |
+
mask_text=False,
|
| 327 |
+
):
|
| 328 |
+
do_mlm = "_mlm" if mask_text else ""
|
| 329 |
+
text_ids = batch[f"text_ids{do_mlm}"]
|
| 330 |
+
text_labels = batch[f"text_labels{do_mlm}"]
|
| 331 |
+
text_masks = batch[f"text_masks"]
|
| 332 |
+
text_embed = self.backbone.text_embed(text_ids)
|
| 333 |
+
text_padding_position = 1 - text_masks
|
| 334 |
+
lffn_hiddens = self.backbone(
|
| 335 |
+
textual_tokens=text_ids,
|
| 336 |
+
text_padding_position=text_padding_position,
|
| 337 |
+
)["encoder_out"]
|
| 338 |
+
vlffn_hiddens = self.backbone_vl(
|
| 339 |
+
src_tokens=None,
|
| 340 |
+
token_embeddings=lffn_hiddens,
|
| 341 |
+
encoder_padding_mask=text_padding_position,
|
| 342 |
+
multiway_split_position=-1,
|
| 343 |
+
)["encoder_out"]
|
| 344 |
+
|
| 345 |
+
cls_feats = self.itc_text_proj(lffn_hiddens[:, 0])
|
| 346 |
+
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
|
| 347 |
+
|
| 348 |
+
cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0])
|
| 349 |
+
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
|
| 350 |
+
|
| 351 |
+
ret = {
|
| 352 |
+
"cls_feats": cls_feats,
|
| 353 |
+
"cls_vlffn_feats": cls_vlffn_feats,
|
| 354 |
+
"text_embed": text_embed,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
return ret
|
| 358 |
+
|
| 359 |
+
def infer_image(
|
| 360 |
+
self,
|
| 361 |
+
batch,
|
| 362 |
+
mask_image=False,
|
| 363 |
+
image_token_type_idx=1,
|
| 364 |
+
image_embeds=None,
|
| 365 |
+
image_masks=None,
|
| 366 |
+
):
|
| 367 |
+
if f"image_{image_token_type_idx - 1}" in batch:
|
| 368 |
+
imgkey = f"image_{image_token_type_idx - 1}"
|
| 369 |
+
else:
|
| 370 |
+
imgkey = "image"
|
| 371 |
+
|
| 372 |
+
img = batch[imgkey][0]
|
| 373 |
+
if mask_image:
|
| 374 |
+
image_masks = batch[f"{imgkey}_masks"][0].flatten(1)
|
| 375 |
+
|
| 376 |
+
with torch.no_grad():
|
| 377 |
+
img = self.visual_tokenizer.pre_process(img)
|
| 378 |
+
quantize, embed_ind, _ = self.visual_tokenizer.encode(img)
|
| 379 |
+
image_ids = embed_ind.view(img.shape[0], -1)
|
| 380 |
+
|
| 381 |
+
image_labels = torch.full_like(image_ids, -100)
|
| 382 |
+
bool_masked_pos = image_masks.to(torch.bool)
|
| 383 |
+
image_labels[bool_masked_pos] = image_ids[bool_masked_pos]
|
| 384 |
+
|
| 385 |
+
img_tensor = img_norm(img)
|
| 386 |
+
vffn_hiddens = self.backbone(visual_tokens=img_tensor)["encoder_out"]
|
| 387 |
+
vlffn_hiddens = self.backbone_vl(
|
| 388 |
+
src_tokens=None,
|
| 389 |
+
token_embeddings=vffn_hiddens,
|
| 390 |
+
multiway_split_position=-1,
|
| 391 |
+
)["encoder_out"]
|
| 392 |
+
|
| 393 |
+
cls_feats = self.itc_image_proj(vffn_hiddens[:, 0])
|
| 394 |
+
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
|
| 395 |
+
|
| 396 |
+
cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0])
|
| 397 |
+
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
|
| 398 |
+
|
| 399 |
+
ret = {
|
| 400 |
+
"image_feats": vffn_hiddens,
|
| 401 |
+
"cls_feats": cls_feats,
|
| 402 |
+
"cls_vlffn_feats": cls_vlffn_feats,
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
return ret
|
vlmo/modules/vlmo_utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def set_task(pl_module):
|
| 2 |
+
pl_module.current_tasks = [k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1]
|
| 3 |
+
return
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def no_sync_module_apply(module, fn):
|
| 7 |
+
"""FSDP module .apply will use _unshard_params_recurse which will sync params across ranks.
|
| 8 |
+
using this function when apply fn is unnecessary to sync params across ranks.
|
| 9 |
+
"""
|
| 10 |
+
for child in module.children():
|
| 11 |
+
fn(child)
|
| 12 |
+
no_sync_module_apply(child, fn)
|
vlmo/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) Antfin, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
from __future__ import absolute_import
|
| 5 |
+
from __future__ import division
|
| 6 |
+
from __future__ import print_function
|
vlmo/tokenizer/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (369 Bytes). View file
|
|
|
vlmo/tokenizer/__pycache__/tokenization_glm.cpython-311.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
vlmo/tokenizer/sp.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7fe3bcc8d284fcb782691411e8b6fd4f45d7245565b094de6ab795e66bcd32f
|
| 3 |
+
size 2270960
|
vlmo/tokenizer/tokenization_glm.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from shutil import copyfile
|
| 3 |
+
from typing import Optional, Tuple, List, Union
|
| 4 |
+
|
| 5 |
+
import sentencepiece as spm
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import PreTrainedTokenizer
|
| 8 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
| 9 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GLMBatchEncoding(BatchEncoding):
|
| 16 |
+
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
|
| 17 |
+
"""
|
| 18 |
+
Send all values to device by calling `v.to(device)` (PyTorch only).
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
device (`str` or `torch.device`): The device to put the tensors on.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
[`BatchEncoding`]: The same instance after modification.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
| 28 |
+
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
| 29 |
+
# into a HalfTensor
|
| 30 |
+
if isinstance(device, str) or isinstance(device, int):
|
| 31 |
+
#if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
| 32 |
+
self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}
|
| 33 |
+
else:
|
| 34 |
+
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
| 35 |
+
return self
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GLMTokenizerMixin:
|
| 39 |
+
@property
|
| 40 |
+
def sop_token(self) -> Optional[str]:
|
| 41 |
+
return "<|startofpiece|>"
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def sop_token_id(self) -> Optional[int]:
|
| 45 |
+
"""
|
| 46 |
+
`Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.
|
| 47 |
+
"""
|
| 48 |
+
return self.convert_tokens_to_ids(self.sop_token)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def eop_token(self) -> Optional[str]:
|
| 52 |
+
return "<|endofpiece|>"
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def eop_token_id(self) -> Optional[int]:
|
| 56 |
+
"""
|
| 57 |
+
`Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.
|
| 58 |
+
"""
|
| 59 |
+
return self.convert_tokens_to_ids(self.eop_token)
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def gmask_token_id(self) -> int:
|
| 63 |
+
return self.convert_tokens_to_ids("[gMASK]")
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def smask_token_id(self) -> int:
|
| 67 |
+
return self.convert_tokens_to_ids("[sMASK]")
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def mask_token_ids(self):
|
| 71 |
+
return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
|
| 72 |
+
|
| 73 |
+
def _build_input_for_multiple_choice(self, context, choices):
|
| 74 |
+
context_id = context["input_ids"]
|
| 75 |
+
if torch.is_tensor(context_id):
|
| 76 |
+
context_id = context_id.tolist()
|
| 77 |
+
|
| 78 |
+
division = len(context_id)
|
| 79 |
+
mask_position = context_id.index(self.mask_token_id)
|
| 80 |
+
|
| 81 |
+
token = torch.tensor(context_id, dtype=torch.long)
|
| 82 |
+
attention_mask = [context["attention_mask"].expand(division, -1)]
|
| 83 |
+
position_id = torch.arange(division, dtype=torch.long)
|
| 84 |
+
block_position_id = torch.zeros(division, dtype=torch.long)
|
| 85 |
+
|
| 86 |
+
choice_ids, choice_indices = [], []
|
| 87 |
+
|
| 88 |
+
for choice_str in choices:
|
| 89 |
+
choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],
|
| 90 |
+
dtype=torch.long)
|
| 91 |
+
choice_ids.append(choice)
|
| 92 |
+
choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))
|
| 93 |
+
attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))
|
| 94 |
+
|
| 95 |
+
token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))
|
| 96 |
+
position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))
|
| 97 |
+
block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))
|
| 98 |
+
|
| 99 |
+
attention_mask = torch.block_diag(*attention_mask)
|
| 100 |
+
attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
"input_ids": token,
|
| 104 |
+
"position_ids": torch.stack((position_id, block_position_id)),
|
| 105 |
+
"attention_mask": attention_mask,
|
| 106 |
+
"choice_ids": choice_ids,
|
| 107 |
+
"choice_indices": choice_indices
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
|
| 111 |
+
pad_length = max_seq_length - len(tokens)
|
| 112 |
+
attention_mask = torch.nn.functional.pad(
|
| 113 |
+
attention_mask,
|
| 114 |
+
(0, pad_length, 0, pad_length),
|
| 115 |
+
mode="constant",
|
| 116 |
+
value=0,
|
| 117 |
+
)
|
| 118 |
+
tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))
|
| 119 |
+
position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)
|
| 120 |
+
return tokens, position_ids, attention_mask
|
| 121 |
+
|
| 122 |
+
def _collate(self, samples):
|
| 123 |
+
TILE = 1
|
| 124 |
+
length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
|
| 125 |
+
|
| 126 |
+
token_batch, position_id_batch, attention_mask_batch = [], [], []
|
| 127 |
+
choices_batch, choice_target_ids_batch = [], []
|
| 128 |
+
|
| 129 |
+
for sample in samples:
|
| 130 |
+
token, position_id, attention_mask = self._pad_batch(
|
| 131 |
+
sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
|
| 132 |
+
)
|
| 133 |
+
token_batch.append(token)
|
| 134 |
+
position_id_batch.append(position_id)
|
| 135 |
+
attention_mask_batch.append(attention_mask)
|
| 136 |
+
choices_batch.append(sample["choice_ids"])
|
| 137 |
+
choice_target_ids_batch.append(sample["choice_indices"])
|
| 138 |
+
return {
|
| 139 |
+
"input_ids": torch.stack(token_batch),
|
| 140 |
+
"position_ids": torch.stack(position_id_batch),
|
| 141 |
+
"attention_mask": torch.stack(attention_mask_batch).unsqueeze(1),
|
| 142 |
+
"choice_ids": choices_batch,
|
| 143 |
+
"choice_indices": choice_target_ids_batch,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):
|
| 147 |
+
samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))]
|
| 148 |
+
samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in
|
| 149 |
+
zip(samples, choices)]
|
| 150 |
+
inputs = self._collate(samples)
|
| 151 |
+
return GLMBatchEncoding(inputs)
|
| 152 |
+
|
| 153 |
+
def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):
|
| 154 |
+
mask_ids = self.mask_token_ids
|
| 155 |
+
input_ids = model_input.input_ids
|
| 156 |
+
batch_size, seq_length = input_ids.shape[:2]
|
| 157 |
+
position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
|
| 158 |
+
position_ids, block_position_ids = [], []
|
| 159 |
+
labels = None
|
| 160 |
+
if targets is not None:
|
| 161 |
+
is_batched = isinstance(targets, (list, tuple))
|
| 162 |
+
targets = self(targets, add_special_tokens=False, padding=False).input_ids
|
| 163 |
+
if not is_batched:
|
| 164 |
+
targets = [targets]
|
| 165 |
+
assert len(targets) == len(input_ids)
|
| 166 |
+
targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
|
| 167 |
+
if not padding:
|
| 168 |
+
max_gen_length = max(map(len, targets))
|
| 169 |
+
targets = [[self.sop_token_id] + target for target in targets]
|
| 170 |
+
labels = [target[1:] for target in targets]
|
| 171 |
+
targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]
|
| 172 |
+
labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
|
| 173 |
+
targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)
|
| 174 |
+
labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)
|
| 175 |
+
labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
|
| 176 |
+
for i in range(batch_size):
|
| 177 |
+
mask_positions = []
|
| 178 |
+
for mask_id in mask_ids:
|
| 179 |
+
mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()
|
| 180 |
+
if not mask_positions:
|
| 181 |
+
raise ValueError("Cannot find mask token in the input")
|
| 182 |
+
mask_positions.sort()
|
| 183 |
+
mask_pos = mask_positions[0]
|
| 184 |
+
position_ids.append(position_id + [mask_pos] * max_gen_length)
|
| 185 |
+
block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))
|
| 186 |
+
position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)
|
| 187 |
+
block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)
|
| 188 |
+
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
| 189 |
+
attention_mask = model_input.attention_mask
|
| 190 |
+
attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)
|
| 191 |
+
generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),
|
| 192 |
+
torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],
|
| 193 |
+
dim=0).unsqueeze(0).expand(batch_size, -1, -1)
|
| 194 |
+
attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
|
| 195 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 196 |
+
if targets is None:
|
| 197 |
+
input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
|
| 198 |
+
else:
|
| 199 |
+
input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)
|
| 200 |
+
batch = {"input_ids": input_ids, "position_ids": position_ids}
|
| 201 |
+
if labels is None:
|
| 202 |
+
batch["generation_attention_mask"] = attention_mask
|
| 203 |
+
else:
|
| 204 |
+
batch["attention_mask"] = attention_mask
|
| 205 |
+
batch["labels"] = labels
|
| 206 |
+
return BatchEncoding(batch)
|
| 207 |
+
|
| 208 |
+
def encode_whitespaces(content):
|
| 209 |
+
for i in range(10, 1, -1):
|
| 210 |
+
content = content.replace(' '*i, f'<|blank_{i}|>')
|
| 211 |
+
return content
|
| 212 |
+
|
| 213 |
+
def decode_whitespaces(content):
|
| 214 |
+
for i in range(10, 1, -1):
|
| 215 |
+
content = content.replace(f'<|blank_{i}|>', ' '*i)
|
| 216 |
+
return content
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
|
| 220 |
+
vocab_files_names = {"vocab_file": "sp.model"}
|
| 221 |
+
truncation_side: str = "left"
|
| 222 |
+
|
| 223 |
+
def __init__(self, vocab_file, **kwargs):
|
| 224 |
+
self.vocab_file = vocab_file
|
| 225 |
+
self.sp_model = spm.SentencePieceProcessor()
|
| 226 |
+
self.sp_model.Load(vocab_file)
|
| 227 |
+
super().__init__(**kwargs)
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def vocab_size(self):
|
| 231 |
+
return len(self.sp_model)
|
| 232 |
+
|
| 233 |
+
def get_vocab(self):
|
| 234 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 235 |
+
vocab.update(self.added_tokens_encoder)
|
| 236 |
+
return vocab
|
| 237 |
+
|
| 238 |
+
def _tokenize(self, text, **kwargs):
|
| 239 |
+
text = encode_whitespaces(text)
|
| 240 |
+
return self.sp_model.EncodeAsPieces(text)
|
| 241 |
+
#return self.sp_model.EncodeAsPieces(text, out_type=str)
|
| 242 |
+
|
| 243 |
+
def _convert_token_to_id(self, token):
|
| 244 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 245 |
+
return self.sp_model.PieceToId(token)
|
| 246 |
+
|
| 247 |
+
def _convert_id_to_token(self, index):
|
| 248 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 249 |
+
return self.sp_model.IdToPiece(index)
|
| 250 |
+
|
| 251 |
+
def convert_tokens_to_string(self, tokens):
|
| 252 |
+
res = self.sp_model.DecodeIds(tokens)
|
| 253 |
+
return decode_whitespaces(res)
|
| 254 |
+
|
| 255 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 256 |
+
if not os.path.isdir(save_directory):
|
| 257 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 258 |
+
return
|
| 259 |
+
out_vocab_file = os.path.join(
|
| 260 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 264 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 265 |
+
elif not os.path.isfile(self.vocab_file):
|
| 266 |
+
with open(out_vocab_file, "wb") as fi:
|
| 267 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 268 |
+
fi.write(content_spiece_model)
|
| 269 |
+
|
| 270 |
+
return (out_vocab_file,)
|
| 271 |
+
|
| 272 |
+
def build_inputs_with_special_tokens(
|
| 273 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 274 |
+
) -> List[int]:
|
| 275 |
+
"""
|
| 276 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 277 |
+
adding special tokens. A BERT sequence has the following format:
|
| 278 |
+
|
| 279 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 280 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
token_ids_0 (:obj:`List[int]`):
|
| 284 |
+
List of IDs to which the special tokens will be added.
|
| 285 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
| 286 |
+
Optional second list of IDs for sequence pairs.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 290 |
+
"""
|
| 291 |
+
assert token_ids_1 is None
|
| 292 |
+
cls = [self.cls_token_id]
|
| 293 |
+
eos = [self.eos_token_id]
|
| 294 |
+
return cls + token_ids_0 + eos
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class GLMTokenizer:
|
| 298 |
+
@classmethod
|
| 299 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 300 |
+
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
| 301 |
+
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
| 302 |
+
|
| 303 |
+
if config_tokenizer_class == "GLMChineseTokenizer":
|
| 304 |
+
tokenizer_class = GLMChineseTokenizer
|
| 305 |
+
else:
|
| 306 |
+
raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class)
|
| 307 |
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
vlmo/tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name_or_path": "THUDM/glm-10b-chinese",
|
| 3 |
+
"eos_token": "<|endoftext|>",
|
| 4 |
+
"pad_token": "<|endoftext|>",
|
| 5 |
+
"cls_token": "[CLS]",
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"unk_token": "[UNK]",
|
| 8 |
+
"add_prefix_space": false,
|
| 9 |
+
"tokenizer_class": "GLMChineseTokenizer",
|
| 10 |
+
"use_fast": false,
|
| 11 |
+
"auto_map": {
|
| 12 |
+
"AutoTokenizer": [
|
| 13 |
+
"tokenization_glm.GLMChineseTokenizer",
|
| 14 |
+
null
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
}
|
vlmo/torchscale/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Microsoft
|
| 2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
vlmo/torchscale/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
vlmo/torchscale/architecture/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Microsoft
|
| 2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
vlmo/torchscale/architecture/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (217 Bytes). View file
|
|
|
vlmo/torchscale/architecture/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
vlmo/torchscale/architecture/__pycache__/encoder.cpython-311.pyc
ADDED
|
Binary file (22.8 kB). View file
|
|
|
vlmo/torchscale/architecture/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
vlmo/torchscale/architecture/config.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Microsoft
|
| 2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class EncoderConfig(object):
|
| 6 |
+
def __init__(self, **kwargs):
|
| 7 |
+
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
| 8 |
+
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
| 9 |
+
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
| 10 |
+
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
| 11 |
+
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
| 12 |
+
self.normalize_output = kwargs.pop("normalize_output", True)
|
| 13 |
+
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
| 14 |
+
self.dropout = kwargs.pop("dropout", 0.0)
|
| 15 |
+
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
| 16 |
+
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
| 17 |
+
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
| 18 |
+
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
| 19 |
+
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
| 20 |
+
self.moe_freq = kwargs.pop("moe_freq", 0)
|
| 21 |
+
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
| 22 |
+
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
| 23 |
+
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
| 24 |
+
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
| 25 |
+
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
| 26 |
+
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
| 27 |
+
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
| 28 |
+
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
| 29 |
+
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
| 30 |
+
self.deepnorm = kwargs.pop("deepnorm", False)
|
| 31 |
+
self.subln = kwargs.pop("subln", True)
|
| 32 |
+
self.bert_init = kwargs.pop("bert_init", False)
|
| 33 |
+
self.multiway = kwargs.pop("multiway", False)
|
| 34 |
+
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
|
| 35 |
+
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
| 36 |
+
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
| 37 |
+
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
| 38 |
+
self.share_layer = kwargs.pop("share_layer", False)
|
| 39 |
+
self.share_attn = kwargs.pop("share_attn", False)
|
| 40 |
+
self.mask_ratio = kwargs.pop("mask_ratio", 0)
|
| 41 |
+
self.max_text_len = kwargs.pop("max_text_len", 52)
|
| 42 |
+
self.one_attn = kwargs.pop('one_attn', False)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Text
|
| 46 |
+
self.vocab_size = kwargs.pop("vocab_size", -1)
|
| 47 |
+
# Vision
|
| 48 |
+
self.img_size = kwargs.pop("img_size", 224)
|
| 49 |
+
self.patch_size = kwargs.pop("patch_size", 16)
|
| 50 |
+
self.in_chans = kwargs.pop("in_chans", 3)
|
| 51 |
+
# Fairscale
|
| 52 |
+
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
| 53 |
+
self.fsdp = kwargs.pop("fsdp", False)
|
| 54 |
+
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
| 55 |
+
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
| 56 |
+
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
| 57 |
+
|
| 58 |
+
if self.deepnorm:
|
| 59 |
+
self.encoder_normalize_before = False
|
| 60 |
+
self.subln = False
|
| 61 |
+
if self.subln:
|
| 62 |
+
self.encoder_normalize_before = True
|
| 63 |
+
self.deepnorm = False
|
| 64 |
+
if self.use_xmoe:
|
| 65 |
+
self.moe_normalize_gate_prob_before_dropping = True
|
| 66 |
+
self.moe_second_expert_policy = "random"
|
| 67 |
+
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
| 68 |
+
|
| 69 |
+
def override(self, args):
|
| 70 |
+
for hp in self.__dict__.keys():
|
| 71 |
+
if getattr(args, hp, None) is not None:
|
| 72 |
+
self.__dict__[hp] = getattr(args, hp, None)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class DecoderConfig(object):
|
| 76 |
+
def __init__(self, **kwargs):
|
| 77 |
+
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
| 78 |
+
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
| 79 |
+
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
| 80 |
+
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
| 81 |
+
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
| 82 |
+
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
| 83 |
+
self.dropout = kwargs.pop("dropout", 0.0)
|
| 84 |
+
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
| 85 |
+
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
| 86 |
+
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
| 87 |
+
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
| 88 |
+
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
| 89 |
+
self.moe_freq = kwargs.pop("moe_freq", 0)
|
| 90 |
+
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
| 91 |
+
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
| 92 |
+
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
| 93 |
+
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
| 94 |
+
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
| 95 |
+
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
| 96 |
+
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
| 97 |
+
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
| 98 |
+
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
| 99 |
+
self.deepnorm = kwargs.pop("deepnorm", False)
|
| 100 |
+
self.subln = kwargs.pop("subln", True)
|
| 101 |
+
self.bert_init = kwargs.pop("bert_init", False)
|
| 102 |
+
self.multiway = kwargs.pop("multiway", False)
|
| 103 |
+
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
| 104 |
+
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
| 105 |
+
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
| 106 |
+
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
| 107 |
+
# Text
|
| 108 |
+
self.vocab_size = kwargs.pop("vocab_size", -1)
|
| 109 |
+
# Fairscale
|
| 110 |
+
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
| 111 |
+
self.fsdp = kwargs.pop("fsdp", False)
|
| 112 |
+
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
| 113 |
+
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
| 114 |
+
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
| 115 |
+
|
| 116 |
+
if self.deepnorm:
|
| 117 |
+
self.decoder_normalize_before = False
|
| 118 |
+
self.subln = False
|
| 119 |
+
if self.subln:
|
| 120 |
+
self.decoder_normalize_before = True
|
| 121 |
+
self.deepnorm = False
|
| 122 |
+
if self.use_xmoe:
|
| 123 |
+
self.moe_normalize_gate_prob_before_dropping = True
|
| 124 |
+
self.moe_second_expert_policy = "random"
|
| 125 |
+
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
| 126 |
+
|
| 127 |
+
def override(self, args):
|
| 128 |
+
for hp in self.__dict__.keys():
|
| 129 |
+
if getattr(args, hp, None) is not None:
|
| 130 |
+
self.__dict__[hp] = getattr(args, hp, None)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class EncoderDecoderConfig(object):
|
| 134 |
+
def __init__(self, **kwargs):
|
| 135 |
+
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
| 136 |
+
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
| 137 |
+
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
| 138 |
+
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
| 139 |
+
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
| 140 |
+
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
| 141 |
+
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
| 142 |
+
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
| 143 |
+
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
| 144 |
+
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
| 145 |
+
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
| 146 |
+
self.dropout = kwargs.pop("dropout", 0.0)
|
| 147 |
+
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
| 148 |
+
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
| 149 |
+
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
| 150 |
+
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
| 151 |
+
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
| 152 |
+
self.moe_freq = kwargs.pop("moe_freq", 0)
|
| 153 |
+
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
| 154 |
+
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
| 155 |
+
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
| 156 |
+
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
| 157 |
+
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
| 158 |
+
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
| 159 |
+
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
| 160 |
+
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
| 161 |
+
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
| 162 |
+
self.deepnorm = kwargs.pop("deepnorm", False)
|
| 163 |
+
self.subln = kwargs.pop("subln", True)
|
| 164 |
+
self.bert_init = kwargs.pop("bert_init", False)
|
| 165 |
+
self.multiway = kwargs.pop("multiway", False)
|
| 166 |
+
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
| 167 |
+
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
| 168 |
+
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
| 169 |
+
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
| 170 |
+
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
| 171 |
+
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
| 172 |
+
# Text
|
| 173 |
+
self.vocab_size = kwargs.pop("vocab_size", -1)
|
| 174 |
+
# Fairscale
|
| 175 |
+
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
| 176 |
+
self.fsdp = kwargs.pop("fsdp", False)
|
| 177 |
+
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
| 178 |
+
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
| 179 |
+
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
| 180 |
+
|
| 181 |
+
if self.deepnorm:
|
| 182 |
+
self.encoder_normalize_before = False
|
| 183 |
+
self.decoder_normalize_before = False
|
| 184 |
+
self.subln = False
|
| 185 |
+
if self.subln:
|
| 186 |
+
self.encoder_normalize_before = True
|
| 187 |
+
self.decoder_normalize_before = True
|
| 188 |
+
self.deepnorm = False
|
| 189 |
+
if self.use_xmoe:
|
| 190 |
+
self.moe_normalize_gate_prob_before_dropping = True
|
| 191 |
+
self.moe_second_expert_policy = "random"
|
| 192 |
+
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
| 193 |
+
|
| 194 |
+
def override(self, args):
|
| 195 |
+
for hp in self.__dict__.keys():
|
| 196 |
+
if getattr(args, hp, None) is not None:
|
| 197 |
+
self.__dict__[hp] = getattr(args, hp, None)
|
vlmo/torchscale/architecture/decoder.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Microsoft
|
| 2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from fairscale.nn import checkpoint_wrapper, wrap
|
| 10 |
+
|
| 11 |
+
from vlmo.torchscale.architecture.utils import init_bert_params
|
| 12 |
+
from vlmo.torchscale.component.droppath import DropPath
|
| 13 |
+
from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
| 14 |
+
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
|
| 15 |
+
from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
|
| 16 |
+
from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
|
| 17 |
+
from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from apex.normalization import FusedLayerNorm as LayerNorm
|
| 21 |
+
except ModuleNotFoundError:
|
| 22 |
+
from torch.nn import LayerNorm
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DecoderLayer(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
args,
|
| 29 |
+
depth,
|
| 30 |
+
is_moe_layer=False,
|
| 31 |
+
is_encoder_decoder=False,
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.args = args
|
| 35 |
+
self.embed_dim = args.decoder_embed_dim
|
| 36 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
| 37 |
+
|
| 38 |
+
if args.drop_path_rate > 0:
|
| 39 |
+
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
|
| 40 |
+
self.drop_path = DropPath(drop_path_prob)
|
| 41 |
+
else:
|
| 42 |
+
self.drop_path = None
|
| 43 |
+
|
| 44 |
+
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
| 45 |
+
|
| 46 |
+
self.normalize_before = args.decoder_normalize_before
|
| 47 |
+
|
| 48 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
| 49 |
+
|
| 50 |
+
if not is_encoder_decoder:
|
| 51 |
+
self.encoder_attn = None
|
| 52 |
+
self.encoder_attn_layer_norm = None
|
| 53 |
+
else:
|
| 54 |
+
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
| 55 |
+
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
| 56 |
+
|
| 57 |
+
self.is_moe_layer = is_moe_layer
|
| 58 |
+
self.ffn_dim = args.decoder_ffn_embed_dim
|
| 59 |
+
|
| 60 |
+
if not self.is_moe_layer:
|
| 61 |
+
self.ffn = self.build_ffn(
|
| 62 |
+
self.embed_dim,
|
| 63 |
+
self.args,
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
if args.moe_top1_expert:
|
| 67 |
+
gate = Top1Gate(
|
| 68 |
+
self.embed_dim,
|
| 69 |
+
args.moe_expert_count,
|
| 70 |
+
use_fp32=args.moe_gating_use_fp32,
|
| 71 |
+
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
| 72 |
+
use_xmoe=args.use_xmoe,
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
gate = Top2Gate(
|
| 76 |
+
self.embed_dim,
|
| 77 |
+
args.moe_expert_count,
|
| 78 |
+
args.moe_gating_use_fp32,
|
| 79 |
+
args.moe_second_expert_policy,
|
| 80 |
+
args.moe_normalize_gate_prob_before_dropping,
|
| 81 |
+
args.moe_eval_capacity_token_fraction,
|
| 82 |
+
use_xmoe=args.use_xmoe,
|
| 83 |
+
)
|
| 84 |
+
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
| 85 |
+
self.moe_layer = MOELayer(gate, experts, args)
|
| 86 |
+
|
| 87 |
+
self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
| 88 |
+
|
| 89 |
+
if args.deepnorm:
|
| 90 |
+
if is_encoder_decoder:
|
| 91 |
+
self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
|
| 92 |
+
else:
|
| 93 |
+
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
| 94 |
+
else:
|
| 95 |
+
self.alpha = 1.0
|
| 96 |
+
|
| 97 |
+
def build_ffn(self, embed_dim, args):
|
| 98 |
+
return FeedForwardNetwork(
|
| 99 |
+
embed_dim,
|
| 100 |
+
self.ffn_dim,
|
| 101 |
+
args.activation_fn,
|
| 102 |
+
args.dropout,
|
| 103 |
+
args.activation_dropout,
|
| 104 |
+
args.layernorm_eps,
|
| 105 |
+
args.subln,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def build_self_attention(self, embed_dim, args):
|
| 109 |
+
return MultiheadAttention(
|
| 110 |
+
args,
|
| 111 |
+
embed_dim,
|
| 112 |
+
args.decoder_attention_heads,
|
| 113 |
+
dropout=args.attention_dropout,
|
| 114 |
+
self_attention=True,
|
| 115 |
+
encoder_decoder_attention=False,
|
| 116 |
+
subln=args.subln,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def build_encoder_attention(self, embed_dim, args):
|
| 120 |
+
return MultiheadAttention(
|
| 121 |
+
args,
|
| 122 |
+
embed_dim,
|
| 123 |
+
args.decoder_attention_heads,
|
| 124 |
+
dropout=args.attention_dropout,
|
| 125 |
+
self_attention=False,
|
| 126 |
+
encoder_decoder_attention=True,
|
| 127 |
+
subln=args.subln,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def residual_connection(self, x, residual):
|
| 131 |
+
return residual * self.alpha + x
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
x,
|
| 136 |
+
encoder_out=None,
|
| 137 |
+
encoder_padding_mask=None,
|
| 138 |
+
incremental_state=None,
|
| 139 |
+
self_attn_mask=None,
|
| 140 |
+
self_attn_padding_mask=None,
|
| 141 |
+
self_attn_rel_pos=None,
|
| 142 |
+
cross_attn_rel_pos=None,
|
| 143 |
+
):
|
| 144 |
+
residual = x
|
| 145 |
+
if self.normalize_before:
|
| 146 |
+
x = self.self_attn_layer_norm(x)
|
| 147 |
+
|
| 148 |
+
x, attn = self.self_attn(
|
| 149 |
+
query=x,
|
| 150 |
+
key=x,
|
| 151 |
+
value=x,
|
| 152 |
+
key_padding_mask=self_attn_padding_mask,
|
| 153 |
+
incremental_state=incremental_state,
|
| 154 |
+
attn_mask=self_attn_mask,
|
| 155 |
+
rel_pos=self_attn_rel_pos,
|
| 156 |
+
)
|
| 157 |
+
x = self.dropout_module(x)
|
| 158 |
+
|
| 159 |
+
if self.drop_path is not None:
|
| 160 |
+
x = self.drop_path(x)
|
| 161 |
+
|
| 162 |
+
x = self.residual_connection(x, residual)
|
| 163 |
+
if not self.normalize_before:
|
| 164 |
+
x = self.self_attn_layer_norm(x)
|
| 165 |
+
|
| 166 |
+
if self.encoder_attn is not None and encoder_out is not None:
|
| 167 |
+
residual = x
|
| 168 |
+
if self.normalize_before:
|
| 169 |
+
x = self.encoder_attn_layer_norm(x)
|
| 170 |
+
|
| 171 |
+
x, attn = self.encoder_attn(
|
| 172 |
+
query=x,
|
| 173 |
+
key=encoder_out,
|
| 174 |
+
value=encoder_out,
|
| 175 |
+
key_padding_mask=encoder_padding_mask,
|
| 176 |
+
incremental_state=None,
|
| 177 |
+
rel_pos=cross_attn_rel_pos,
|
| 178 |
+
)
|
| 179 |
+
x = self.dropout_module(x)
|
| 180 |
+
|
| 181 |
+
if self.drop_path is not None:
|
| 182 |
+
x = self.drop_path(x)
|
| 183 |
+
|
| 184 |
+
x = self.residual_connection(x, residual)
|
| 185 |
+
if not self.normalize_before:
|
| 186 |
+
x = self.encoder_attn_layer_norm(x)
|
| 187 |
+
|
| 188 |
+
residual = x
|
| 189 |
+
if self.normalize_before:
|
| 190 |
+
x = self.final_layer_norm(x)
|
| 191 |
+
if not self.is_moe_layer:
|
| 192 |
+
x = self.ffn(x)
|
| 193 |
+
l_aux = None
|
| 194 |
+
else:
|
| 195 |
+
x, l_aux = self.moe_layer(x)
|
| 196 |
+
|
| 197 |
+
if self.drop_path is not None:
|
| 198 |
+
x = self.drop_path(x)
|
| 199 |
+
|
| 200 |
+
x = self.residual_connection(x, residual)
|
| 201 |
+
if not self.normalize_before:
|
| 202 |
+
x = self.final_layer_norm(x)
|
| 203 |
+
|
| 204 |
+
return x, attn, None, l_aux
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Decoder(nn.Module):
|
| 208 |
+
def __init__(
|
| 209 |
+
self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
|
| 210 |
+
):
|
| 211 |
+
super().__init__(**kwargs)
|
| 212 |
+
self.args = args
|
| 213 |
+
|
| 214 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
| 215 |
+
|
| 216 |
+
embed_dim = args.decoder_embed_dim
|
| 217 |
+
self.embed_dim = embed_dim
|
| 218 |
+
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
| 219 |
+
|
| 220 |
+
self.embed_tokens = embed_tokens
|
| 221 |
+
self.embed_positions = embed_positions
|
| 222 |
+
|
| 223 |
+
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
|
| 224 |
+
self.output_projection = self.build_output_projection(args)
|
| 225 |
+
else:
|
| 226 |
+
self.output_projection = output_projection
|
| 227 |
+
|
| 228 |
+
if args.layernorm_embedding:
|
| 229 |
+
self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
| 230 |
+
else:
|
| 231 |
+
self.layernorm_embedding = None
|
| 232 |
+
|
| 233 |
+
self.layers = nn.ModuleList([])
|
| 234 |
+
|
| 235 |
+
moe_freq = args.moe_freq
|
| 236 |
+
for i in range(args.decoder_layers):
|
| 237 |
+
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
| 238 |
+
self.layers.append(
|
| 239 |
+
self.build_decoder_layer(
|
| 240 |
+
args,
|
| 241 |
+
depth=i,
|
| 242 |
+
is_moe_layer=is_moe_layer,
|
| 243 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
self.num_layers = len(self.layers)
|
| 248 |
+
|
| 249 |
+
if args.decoder_normalize_before:
|
| 250 |
+
self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
| 251 |
+
else:
|
| 252 |
+
self.layer_norm = None
|
| 253 |
+
|
| 254 |
+
self.self_attn_relative_position = None
|
| 255 |
+
self.cross_attn_relative_position = None
|
| 256 |
+
|
| 257 |
+
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
| 258 |
+
self.self_attn_relative_position = RelativePositionBias(
|
| 259 |
+
num_buckets=args.rel_pos_buckets,
|
| 260 |
+
max_distance=args.max_rel_pos,
|
| 261 |
+
n_heads=args.decoder_attention_heads,
|
| 262 |
+
)
|
| 263 |
+
if is_encoder_decoder:
|
| 264 |
+
self.cross_attn_relative_position = RelativePositionBias(
|
| 265 |
+
num_buckets=args.rel_pos_buckets,
|
| 266 |
+
max_distance=args.max_rel_pos,
|
| 267 |
+
n_heads=args.decoder_attention_heads,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if args.bert_init:
|
| 271 |
+
self.apply(init_bert_params)
|
| 272 |
+
|
| 273 |
+
if args.deepnorm:
|
| 274 |
+
if is_encoder_decoder:
|
| 275 |
+
init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
|
| 276 |
+
else:
|
| 277 |
+
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
| 278 |
+
for name, p in self.named_parameters():
|
| 279 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
| 280 |
+
p.data.div_(init_scale)
|
| 281 |
+
|
| 282 |
+
if args.subln:
|
| 283 |
+
if is_encoder_decoder:
|
| 284 |
+
init_scale = math.sqrt(math.log(args.decoder_layers * 3))
|
| 285 |
+
else:
|
| 286 |
+
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
| 287 |
+
for name, p in self.named_parameters():
|
| 288 |
+
if "encoder_attn" in name:
|
| 289 |
+
continue
|
| 290 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
| 291 |
+
p.data.mul_(init_scale)
|
| 292 |
+
|
| 293 |
+
def build_output_projection(
|
| 294 |
+
self,
|
| 295 |
+
args,
|
| 296 |
+
):
|
| 297 |
+
if args.share_decoder_input_output_embed:
|
| 298 |
+
output_projection = torch.nn.Linear(
|
| 299 |
+
self.embed_tokens.weight.shape[1],
|
| 300 |
+
self.embed_tokens.weight.shape[0],
|
| 301 |
+
bias=False,
|
| 302 |
+
)
|
| 303 |
+
output_projection.weight = self.embed_tokens.weight
|
| 304 |
+
else:
|
| 305 |
+
output_projection = torch.nn.Linear(args.decoder_embed_dim, args.vocab_size, bias=False)
|
| 306 |
+
torch.nn.init.normal_(output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5)
|
| 307 |
+
return output_projection
|
| 308 |
+
|
| 309 |
+
def build_decoder_layer(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
|
| 310 |
+
layer = DecoderLayer(
|
| 311 |
+
args,
|
| 312 |
+
depth,
|
| 313 |
+
is_moe_layer=is_moe_layer,
|
| 314 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 315 |
+
)
|
| 316 |
+
if args.checkpoint_activations:
|
| 317 |
+
layer = checkpoint_wrapper(layer)
|
| 318 |
+
if args.fsdp:
|
| 319 |
+
layer = wrap(layer)
|
| 320 |
+
return layer
|
| 321 |
+
|
| 322 |
+
def forward_embedding(
|
| 323 |
+
self,
|
| 324 |
+
tokens,
|
| 325 |
+
token_embedding=None,
|
| 326 |
+
incremental_state=None,
|
| 327 |
+
):
|
| 328 |
+
positions = None
|
| 329 |
+
if self.embed_positions is not None:
|
| 330 |
+
positions = self.embed_positions(tokens, incremental_state=incremental_state)
|
| 331 |
+
|
| 332 |
+
if incremental_state is not None:
|
| 333 |
+
tokens = tokens[:, -1:]
|
| 334 |
+
if positions is not None:
|
| 335 |
+
positions = positions[:, -1:]
|
| 336 |
+
|
| 337 |
+
if token_embedding is None:
|
| 338 |
+
token_embedding = self.embed_tokens(tokens)
|
| 339 |
+
|
| 340 |
+
x = embed = self.embed_scale * token_embedding
|
| 341 |
+
|
| 342 |
+
if positions is not None:
|
| 343 |
+
x += positions
|
| 344 |
+
|
| 345 |
+
if self.layernorm_embedding is not None:
|
| 346 |
+
x = self.layernorm_embedding(x)
|
| 347 |
+
|
| 348 |
+
x = self.dropout_module(x)
|
| 349 |
+
|
| 350 |
+
return x, embed
|
| 351 |
+
|
| 352 |
+
def forward(
|
| 353 |
+
self,
|
| 354 |
+
prev_output_tokens,
|
| 355 |
+
self_attn_padding_mask=None,
|
| 356 |
+
encoder_out=None,
|
| 357 |
+
incremental_state=None,
|
| 358 |
+
features_only=False,
|
| 359 |
+
return_all_hiddens=False,
|
| 360 |
+
token_embeddings=None,
|
| 361 |
+
**kwargs
|
| 362 |
+
):
|
| 363 |
+
# embed tokens and positions
|
| 364 |
+
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
|
| 365 |
+
|
| 366 |
+
# relative position
|
| 367 |
+
self_attn_rel_pos_bias = None
|
| 368 |
+
slen = prev_output_tokens.size(1)
|
| 369 |
+
if self.self_attn_relative_position is not None:
|
| 370 |
+
self_attn_rel_pos_bias = self.self_attn_relative_position(batch_size=x.size(0), qlen=slen, klen=slen)
|
| 371 |
+
if incremental_state is not None:
|
| 372 |
+
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
| 373 |
+
cross_attn_rel_pos_bias = None
|
| 374 |
+
if self.cross_attn_relative_position is not None:
|
| 375 |
+
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
|
| 376 |
+
batch_size=x.size(0),
|
| 377 |
+
qlen=slen,
|
| 378 |
+
klen=encoder_out["encoder_out"].size(1),
|
| 379 |
+
)
|
| 380 |
+
if incremental_state is not None:
|
| 381 |
+
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
| 382 |
+
|
| 383 |
+
# decoder layers
|
| 384 |
+
inner_states = [x]
|
| 385 |
+
|
| 386 |
+
if encoder_out is None:
|
| 387 |
+
l_aux = []
|
| 388 |
+
else:
|
| 389 |
+
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
| 390 |
+
|
| 391 |
+
for idx, layer in enumerate(self.layers):
|
| 392 |
+
if incremental_state is None:
|
| 393 |
+
self_attn_mask = torch.triu(
|
| 394 |
+
torch.zeros([x.size(1), x.size(1)]).float().fill_(float("-inf")).type_as(x),
|
| 395 |
+
1,
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
self_attn_mask = None
|
| 399 |
+
if idx not in incremental_state:
|
| 400 |
+
incremental_state[idx] = {}
|
| 401 |
+
|
| 402 |
+
x, layer_attn, _, l_aux_i = layer(
|
| 403 |
+
x,
|
| 404 |
+
encoder_out["encoder_out"] if encoder_out is not None else None,
|
| 405 |
+
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
|
| 406 |
+
incremental_state[idx] if incremental_state is not None else None,
|
| 407 |
+
self_attn_mask=self_attn_mask,
|
| 408 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
| 409 |
+
self_attn_rel_pos=self_attn_rel_pos_bias,
|
| 410 |
+
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
| 411 |
+
)
|
| 412 |
+
l_aux.append(l_aux_i)
|
| 413 |
+
inner_states.append(x)
|
| 414 |
+
|
| 415 |
+
if self.layer_norm is not None:
|
| 416 |
+
x = self.layer_norm(x)
|
| 417 |
+
|
| 418 |
+
if not features_only:
|
| 419 |
+
x = self.output_layer(x)
|
| 420 |
+
|
| 421 |
+
return x, {
|
| 422 |
+
"inner_states": inner_states,
|
| 423 |
+
"l_aux": l_aux,
|
| 424 |
+
"attn": None,
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
def output_layer(self, features):
|
| 428 |
+
return self.output_projection(features)
|
vlmo/torchscale/architecture/encoder.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Microsoft
|
| 2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from fairscale.nn import checkpoint_wrapper, wrap
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from apex.normalization import FusedLayerNorm as LayerNorm
|
| 13 |
+
except ModuleNotFoundError:
|
| 14 |
+
from torch.nn import LayerNorm
|
| 15 |
+
|
| 16 |
+
from vlmo.torchscale.architecture.utils import init_bert_params
|
| 17 |
+
from vlmo.torchscale.component.droppath import DropPath
|
| 18 |
+
from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
| 19 |
+
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
|
| 20 |
+
from vlmo.torchscale.component.multiway_network import MultiwayWrapper, set_split_position
|
| 21 |
+
from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
|
| 22 |
+
from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
|
| 23 |
+
from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
| 24 |
+
from vlmo.modules.vlmo_utils import no_sync_module_apply
|
| 25 |
+
from pytorch_lightning.utilities.distributed import rank_zero_info
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class EncoderLayer(nn.Module):
|
| 29 |
+
def __init__(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.args = args
|
| 32 |
+
self.embed_dim = args.encoder_embed_dim
|
| 33 |
+
self.self_attn = self.build_self_attention(self.embed_dim, args) if attn is None else attn
|
| 34 |
+
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
| 35 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
| 36 |
+
|
| 37 |
+
if args.drop_path_rate > 0:
|
| 38 |
+
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
|
| 39 |
+
self.drop_path = DropPath(drop_path_prob)
|
| 40 |
+
else:
|
| 41 |
+
self.drop_path = None
|
| 42 |
+
|
| 43 |
+
self.normalize_before = args.encoder_normalize_before
|
| 44 |
+
self.is_moe_layer = is_moe_layer
|
| 45 |
+
self.ffn_dim = args.encoder_ffn_embed_dim
|
| 46 |
+
|
| 47 |
+
if not self.is_moe_layer:
|
| 48 |
+
self.ffn = MultiwayWrapper(
|
| 49 |
+
args,
|
| 50 |
+
self.build_ffn(
|
| 51 |
+
self.embed_dim,
|
| 52 |
+
self.args,
|
| 53 |
+
),
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
assert not self.args.multiway
|
| 57 |
+
if args.moe_top1_expert:
|
| 58 |
+
gate = Top1Gate(
|
| 59 |
+
self.embed_dim,
|
| 60 |
+
args.moe_expert_count,
|
| 61 |
+
use_fp32=args.moe_gating_use_fp32,
|
| 62 |
+
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
| 63 |
+
use_xmoe=args.use_xmoe,
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
gate = Top2Gate(
|
| 67 |
+
self.embed_dim,
|
| 68 |
+
args.moe_expert_count,
|
| 69 |
+
args.moe_gating_use_fp32,
|
| 70 |
+
args.moe_second_expert_policy,
|
| 71 |
+
args.moe_normalize_gate_prob_before_dropping,
|
| 72 |
+
args.moe_eval_capacity_token_fraction,
|
| 73 |
+
use_xmoe=args.use_xmoe,
|
| 74 |
+
)
|
| 75 |
+
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
| 76 |
+
self.moe_layer = MOELayer(gate, experts, args)
|
| 77 |
+
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
| 78 |
+
|
| 79 |
+
if args.deepnorm:
|
| 80 |
+
if is_encoder_decoder:
|
| 81 |
+
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
|
| 82 |
+
else:
|
| 83 |
+
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
|
| 84 |
+
else:
|
| 85 |
+
self.alpha = 1.0
|
| 86 |
+
|
| 87 |
+
def build_ffn(self, embed_dim, args):
|
| 88 |
+
return FeedForwardNetwork(
|
| 89 |
+
embed_dim,
|
| 90 |
+
self.ffn_dim,
|
| 91 |
+
args.activation_fn,
|
| 92 |
+
args.dropout,
|
| 93 |
+
args.activation_dropout,
|
| 94 |
+
args.layernorm_eps,
|
| 95 |
+
args.subln,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def build_self_attention(self, embed_dim, args):
|
| 99 |
+
return MultiheadAttention(
|
| 100 |
+
args,
|
| 101 |
+
embed_dim,
|
| 102 |
+
args.encoder_attention_heads,
|
| 103 |
+
dropout=args.attention_dropout,
|
| 104 |
+
self_attention=True,
|
| 105 |
+
encoder_decoder_attention=False,
|
| 106 |
+
subln=args.subln,
|
| 107 |
+
one_attn=args.one_attn,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def residual_connection(self, x, residual):
|
| 111 |
+
return residual * self.alpha + x
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
x,
|
| 116 |
+
encoder_padding_mask,
|
| 117 |
+
attn_mask=None,
|
| 118 |
+
rel_pos=None,
|
| 119 |
+
multiway_split_position=None,
|
| 120 |
+
incremental_state=None,
|
| 121 |
+
):
|
| 122 |
+
if multiway_split_position is not None:
|
| 123 |
+
assert self.args.multiway
|
| 124 |
+
no_sync_module_apply(self, set_split_position(multiway_split_position))
|
| 125 |
+
|
| 126 |
+
if attn_mask is not None:
|
| 127 |
+
# float16: -1e8 equal 0
|
| 128 |
+
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
| 129 |
+
|
| 130 |
+
residual = x
|
| 131 |
+
if self.normalize_before:
|
| 132 |
+
x = self.self_attn_layer_norm(x)
|
| 133 |
+
x, _ = self.self_attn(
|
| 134 |
+
query=x,
|
| 135 |
+
key=x,
|
| 136 |
+
value=x,
|
| 137 |
+
key_padding_mask=encoder_padding_mask,
|
| 138 |
+
attn_mask=attn_mask,
|
| 139 |
+
rel_pos=rel_pos,
|
| 140 |
+
incremental_state=incremental_state,
|
| 141 |
+
)
|
| 142 |
+
x = self.dropout_module(x)
|
| 143 |
+
|
| 144 |
+
if self.drop_path is not None:
|
| 145 |
+
x = self.drop_path(x)
|
| 146 |
+
|
| 147 |
+
x = self.residual_connection(x, residual)
|
| 148 |
+
if not self.normalize_before:
|
| 149 |
+
x = self.self_attn_layer_norm(x)
|
| 150 |
+
|
| 151 |
+
residual = x
|
| 152 |
+
if self.normalize_before:
|
| 153 |
+
x = self.final_layer_norm(x)
|
| 154 |
+
if not self.is_moe_layer:
|
| 155 |
+
x = self.ffn(x)
|
| 156 |
+
l_aux = None
|
| 157 |
+
else:
|
| 158 |
+
x = x.transpose(0, 1)
|
| 159 |
+
x, l_aux = self.moe_layer(x)
|
| 160 |
+
x = x.transpose(0, 1)
|
| 161 |
+
|
| 162 |
+
if self.drop_path is not None:
|
| 163 |
+
x = self.drop_path(x)
|
| 164 |
+
|
| 165 |
+
x = self.residual_connection(x, residual)
|
| 166 |
+
if not self.normalize_before:
|
| 167 |
+
x = self.final_layer_norm(x)
|
| 168 |
+
return x, l_aux
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class Encoder(nn.Module):
|
| 172 |
+
def __init__(
|
| 173 |
+
self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
|
| 174 |
+
):
|
| 175 |
+
self.args = args
|
| 176 |
+
super().__init__(**kwargs)
|
| 177 |
+
|
| 178 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
| 179 |
+
|
| 180 |
+
embed_dim = args.encoder_embed_dim
|
| 181 |
+
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
| 182 |
+
self.mask_ratio = args.mask_ratio
|
| 183 |
+
self.max_text_len = args.max_text_len
|
| 184 |
+
self.vision_len = (args.img_size // args.patch_size) * (args.img_size // args.patch_size)
|
| 185 |
+
|
| 186 |
+
self.embed_tokens = embed_tokens
|
| 187 |
+
self.embed_positions = embed_positions
|
| 188 |
+
|
| 189 |
+
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
|
| 190 |
+
self.output_projection = self.build_output_projection(args)
|
| 191 |
+
else:
|
| 192 |
+
self.output_projection = output_projection
|
| 193 |
+
|
| 194 |
+
if args.layernorm_embedding:
|
| 195 |
+
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1)
|
| 196 |
+
else:
|
| 197 |
+
self.layernorm_embedding = None
|
| 198 |
+
|
| 199 |
+
self.layers = nn.ModuleList([])
|
| 200 |
+
if self.args.share_layer:
|
| 201 |
+
single_layer = self.build_encoder_layer(
|
| 202 |
+
args, depth=0, is_moe_layer=False, is_encoder_decoder=is_encoder_decoder
|
| 203 |
+
)
|
| 204 |
+
for i in range(args.encoder_layers):
|
| 205 |
+
self.layers.append(single_layer)
|
| 206 |
+
elif self.args.share_attn:
|
| 207 |
+
moe_freq = args.moe_freq
|
| 208 |
+
embed_dim = args.encoder_embed_dim
|
| 209 |
+
shared_attn = self.build_self_attention(embed_dim, self.args)
|
| 210 |
+
for i in range(args.encoder_layers):
|
| 211 |
+
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
| 212 |
+
self.layers.append(
|
| 213 |
+
self.build_encoder_layer(
|
| 214 |
+
args,
|
| 215 |
+
depth=i,
|
| 216 |
+
attn=shared_attn,
|
| 217 |
+
is_moe_layer=is_moe_layer,
|
| 218 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
else:
|
| 223 |
+
moe_freq = args.moe_freq
|
| 224 |
+
for i in range(args.encoder_layers):
|
| 225 |
+
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
| 226 |
+
self.layers.append(
|
| 227 |
+
self.build_encoder_layer(
|
| 228 |
+
args,
|
| 229 |
+
depth=i,
|
| 230 |
+
is_moe_layer=is_moe_layer,
|
| 231 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
self.num_layers = len(self.layers)
|
| 235 |
+
|
| 236 |
+
if args.encoder_normalize_before and args.normalize_output:
|
| 237 |
+
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
|
| 238 |
+
else:
|
| 239 |
+
self.layer_norm = None
|
| 240 |
+
|
| 241 |
+
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
| 242 |
+
self.relative_position = RelativePositionBias(
|
| 243 |
+
num_buckets=args.rel_pos_buckets,
|
| 244 |
+
max_distance=args.max_rel_pos,
|
| 245 |
+
n_heads=args.encoder_attention_heads,
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
self.relative_position = None
|
| 249 |
+
|
| 250 |
+
if args.bert_init:
|
| 251 |
+
self.apply(init_bert_params)
|
| 252 |
+
|
| 253 |
+
if args.deepnorm:
|
| 254 |
+
if is_encoder_decoder:
|
| 255 |
+
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
|
| 256 |
+
else:
|
| 257 |
+
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
|
| 258 |
+
for name, p in self.named_parameters():
|
| 259 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
| 260 |
+
p.data.div_(init_scale)
|
| 261 |
+
|
| 262 |
+
if args.subln:
|
| 263 |
+
if is_encoder_decoder:
|
| 264 |
+
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
|
| 265 |
+
else:
|
| 266 |
+
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
|
| 267 |
+
for name, p in self.named_parameters():
|
| 268 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
| 269 |
+
p.data.mul_(init_scale)
|
| 270 |
+
|
| 271 |
+
def random_masking(self, x, mask_ratio):
|
| 272 |
+
N, L, D = x.shape # batch, length, dim
|
| 273 |
+
len_keep = int(L * (1 - mask_ratio))
|
| 274 |
+
|
| 275 |
+
noise = torch.rand(N, L - 1, device=x.device)
|
| 276 |
+
ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device, dtype=int)
|
| 277 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 278 |
+
|
| 279 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
| 280 |
+
|
| 281 |
+
x0 = x[:, 0, :]
|
| 282 |
+
x0 = x0.reshape(N, 1, D)
|
| 283 |
+
x_masked_add = torch.cat([x0, x_masked], axis=1)
|
| 284 |
+
return x_masked_add, ids_keep
|
| 285 |
+
|
| 286 |
+
def build_self_attention(self, embed_dim, args):
|
| 287 |
+
return MultiheadAttention(
|
| 288 |
+
args,
|
| 289 |
+
embed_dim,
|
| 290 |
+
args.encoder_attention_heads,
|
| 291 |
+
dropout=args.attention_dropout,
|
| 292 |
+
self_attention=True,
|
| 293 |
+
encoder_decoder_attention=False,
|
| 294 |
+
subln=args.subln,
|
| 295 |
+
one_attn=args.one_attn,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def build_output_projection(
|
| 299 |
+
self,
|
| 300 |
+
args,
|
| 301 |
+
):
|
| 302 |
+
if args.share_encoder_input_output_embed:
|
| 303 |
+
assert args.encoder_embedding_type == "language"
|
| 304 |
+
output_projection = torch.nn.Linear(
|
| 305 |
+
self.embed_tokens.weight.shape[1],
|
| 306 |
+
self.embed_tokens.weight.shape[0],
|
| 307 |
+
bias=False,
|
| 308 |
+
)
|
| 309 |
+
output_projection.weight = self.embed_tokens.weight
|
| 310 |
+
else:
|
| 311 |
+
output_projection = torch.nn.Linear(args.encoder_embed_dim, args.vocab_size, bias=False)
|
| 312 |
+
torch.nn.init.normal_(output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5)
|
| 313 |
+
return output_projection
|
| 314 |
+
|
| 315 |
+
def checkpointing_and_params_allgather(
|
| 316 |
+
self,
|
| 317 |
+
origin_layer,
|
| 318 |
+
):
|
| 319 |
+
origin_forward = origin_layer.forward
|
| 320 |
+
|
| 321 |
+
from deepspeed import checkpointing
|
| 322 |
+
def forward(*args, **kwargs):
|
| 323 |
+
# deepspeed checkpoint not support kwargs
|
| 324 |
+
ret = checkpointing.checkpoint(origin_forward, *args, **kwargs)
|
| 325 |
+
return ret
|
| 326 |
+
|
| 327 |
+
return forward
|
| 328 |
+
|
| 329 |
+
def build_encoder_layer(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
|
| 330 |
+
layer = EncoderLayer(
|
| 331 |
+
args,
|
| 332 |
+
depth,
|
| 333 |
+
attn,
|
| 334 |
+
is_moe_layer=is_moe_layer,
|
| 335 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 336 |
+
)
|
| 337 |
+
if args.checkpoint_activations:
|
| 338 |
+
rank_zero_info("EncoderLayer params: %s", sum(p.numel() for p in layer.parameters() if p.requires_grad))
|
| 339 |
+
layer = checkpoint_wrapper(layer)
|
| 340 |
+
# layer.ffn = checkpoint_wrapper(layer.ffn,)
|
| 341 |
+
if args.fsdp:
|
| 342 |
+
layer = wrap(layer)
|
| 343 |
+
return layer
|
| 344 |
+
|
| 345 |
+
def checkpointing_layers(self):
|
| 346 |
+
for i, layer in enumerate(self.layers):
|
| 347 |
+
rank_zero_info(f"Checkpointing wrapper EncoderLayers: {i}")
|
| 348 |
+
self.layers[i] = checkpoint_wrapper(layer)
|
| 349 |
+
|
| 350 |
+
def forward_embedding(
|
| 351 |
+
self,
|
| 352 |
+
src_tokens,
|
| 353 |
+
token_embedding=None,
|
| 354 |
+
positions=None,
|
| 355 |
+
):
|
| 356 |
+
if token_embedding is None:
|
| 357 |
+
token_embedding = self.embed_tokens(src_tokens)
|
| 358 |
+
x = embed = self.embed_scale * token_embedding
|
| 359 |
+
if self.embed_positions is not None:
|
| 360 |
+
if src_tokens is not None:
|
| 361 |
+
x = embed + self.embed_positions(src_tokens, positions=positions)
|
| 362 |
+
else:
|
| 363 |
+
x = embed + self.embed_positions(x, positions=positions)
|
| 364 |
+
is_flip, ids_keep = 0, None
|
| 365 |
+
if self.mask_ratio > 0:
|
| 366 |
+
if x.shape[1] == self.vision_len + 1:
|
| 367 |
+
x, ids_keep = self.random_masking(x, self.mask_ratio)
|
| 368 |
+
is_flip = 1
|
| 369 |
+
elif x.shape[1] == self.vision_len + self.max_text_len + 1:
|
| 370 |
+
vision_tokens = x[:, : self.vision_len + 1, :]
|
| 371 |
+
vision_tokens, ids_keep = self.random_masking(vision_tokens, self.mask_ratio)
|
| 372 |
+
x = torch.cat(
|
| 373 |
+
[
|
| 374 |
+
vision_tokens,
|
| 375 |
+
x[
|
| 376 |
+
:,
|
| 377 |
+
self.vision_len + 1 :,
|
| 378 |
+
],
|
| 379 |
+
],
|
| 380 |
+
dim=1,
|
| 381 |
+
)
|
| 382 |
+
is_flip = 2
|
| 383 |
+
if self.layernorm_embedding is not None:
|
| 384 |
+
x = self.layernorm_embedding(x)
|
| 385 |
+
x = self.dropout_module(x)
|
| 386 |
+
return x, embed, ids_keep, is_flip
|
| 387 |
+
|
| 388 |
+
def forward(
|
| 389 |
+
self,
|
| 390 |
+
src_tokens,
|
| 391 |
+
encoder_padding_mask=None,
|
| 392 |
+
attn_mask=None,
|
| 393 |
+
return_all_hiddens=False,
|
| 394 |
+
token_embeddings=None,
|
| 395 |
+
multiway_split_position=None,
|
| 396 |
+
features_only=False,
|
| 397 |
+
incremental_state=None,
|
| 398 |
+
positions=None,
|
| 399 |
+
**kwargs
|
| 400 |
+
):
|
| 401 |
+
assert src_tokens is not None or token_embeddings is not None
|
| 402 |
+
|
| 403 |
+
if encoder_padding_mask is None:
|
| 404 |
+
if src_tokens is not None:
|
| 405 |
+
encoder_padding_mask = torch.zeros_like(src_tokens, device=src_tokens.device).bool()
|
| 406 |
+
else:
|
| 407 |
+
encoder_padding_mask = torch.zeros(
|
| 408 |
+
[token_embeddings.size(0), token_embeddings.size(1)],
|
| 409 |
+
device=token_embeddings.device,
|
| 410 |
+
).bool()
|
| 411 |
+
|
| 412 |
+
if multiway_split_position is not None:
|
| 413 |
+
assert self.args.multiway
|
| 414 |
+
no_sync_module_apply(self, set_split_position(multiway_split_position))
|
| 415 |
+
|
| 416 |
+
x, encoder_embedding, ids_keep, is_flip = self.forward_embedding(src_tokens, token_embeddings, positions)
|
| 417 |
+
if is_flip > 0:
|
| 418 |
+
if is_flip == 2:
|
| 419 |
+
text_ids = (
|
| 420 |
+
torch.arange(
|
| 421 |
+
self.vision_len + 1, self.vision_len + 1 + self.max_text_len, device=x.device, dtype=torch.int64
|
| 422 |
+
)
|
| 423 |
+
.unsqueeze(0)
|
| 424 |
+
.repeat(ids_keep.shape[0], 1)
|
| 425 |
+
)
|
| 426 |
+
cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
|
| 427 |
+
ids_keep = torch.cat([cls_ids, ids_keep, text_ids], dim=1)
|
| 428 |
+
elif is_flip == 1:
|
| 429 |
+
cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
|
| 430 |
+
ids_keep = torch.cat([cls_ids, ids_keep], dim=1)
|
| 431 |
+
if encoder_padding_mask is not None:
|
| 432 |
+
encoder_padding_mask = torch.gather(encoder_padding_mask, dim=1, index=ids_keep)
|
| 433 |
+
if attn_mask is not None:
|
| 434 |
+
attn_mask = torch.gather(
|
| 435 |
+
attn_mask, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, attn_mask.shape[-1])
|
| 436 |
+
)
|
| 437 |
+
attn_mask = torch.gather(attn_mask, dim=2, index=ids_keep.unsqueeze(1).repeat(1, attn_mask.shape[1], 1))
|
| 438 |
+
if multiway_split_position > 0:
|
| 439 |
+
multiway_split_position = ids_keep.shape[1] - self.max_text_len
|
| 440 |
+
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
| 441 |
+
|
| 442 |
+
encoder_states = []
|
| 443 |
+
|
| 444 |
+
if return_all_hiddens:
|
| 445 |
+
encoder_states.append(x)
|
| 446 |
+
|
| 447 |
+
rel_pos_bias = None
|
| 448 |
+
if self.relative_position is not None:
|
| 449 |
+
rel_pos_bias = self.relative_position(batch_size=x.size(0), qlen=x.size(1), klen=x.size(1))
|
| 450 |
+
|
| 451 |
+
l_aux = []
|
| 452 |
+
for idx, layer in enumerate(self.layers):
|
| 453 |
+
x, l_aux_i = layer(
|
| 454 |
+
x,
|
| 455 |
+
encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
|
| 456 |
+
attn_mask=attn_mask,
|
| 457 |
+
rel_pos=rel_pos_bias,
|
| 458 |
+
multiway_split_position=multiway_split_position,
|
| 459 |
+
incremental_state=incremental_state[idx] if incremental_state is not None else None,
|
| 460 |
+
)
|
| 461 |
+
if return_all_hiddens:
|
| 462 |
+
assert encoder_states is not None
|
| 463 |
+
encoder_states.append(x)
|
| 464 |
+
l_aux.append(l_aux_i)
|
| 465 |
+
|
| 466 |
+
if multiway_split_position is not None:
|
| 467 |
+
assert self.args.multiway
|
| 468 |
+
no_sync_module_apply(self, set_split_position(multiway_split_position))
|
| 469 |
+
if self.layer_norm is not None:
|
| 470 |
+
x = self.layer_norm(x)
|
| 471 |
+
|
| 472 |
+
if not features_only and self.output_projection is not None:
|
| 473 |
+
x = self.output_projection(x)
|
| 474 |
+
|
| 475 |
+
return {
|
| 476 |
+
"encoder_out": x,
|
| 477 |
+
"encoder_embedding": encoder_embedding,
|
| 478 |
+
"encoder_padding_mask": encoder_padding_mask,
|
| 479 |
+
"encoder_states": encoder_states,
|
| 480 |
+
"l_aux": l_aux,
|
| 481 |
+
"multiway_split_position": multiway_split_position,
|
| 482 |
+
}
|
vlmo/torchscale/architecture/encoder_decoder.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Microsoft
|
| 2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from vlmo.torchscale.architecture.decoder import Decoder
|
| 7 |
+
from vlmo.torchscale.architecture.encoder import Encoder
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EncoderDecoder(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
args,
|
| 14 |
+
encoder_embed_tokens=None,
|
| 15 |
+
encoder_embed_positions=None,
|
| 16 |
+
decoder_embed_tokens=None,
|
| 17 |
+
decoder_embed_positions=None,
|
| 18 |
+
output_projection=None,
|
| 19 |
+
**kwargs
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.args = args
|
| 23 |
+
if args.share_all_embeddings:
|
| 24 |
+
args.share_decoder_input_output_embed = True
|
| 25 |
+
|
| 26 |
+
self.encoder = Encoder(args, encoder_embed_tokens, encoder_embed_positions, is_encoder_decoder=True, **kwargs)
|
| 27 |
+
|
| 28 |
+
if args.share_all_embeddings and decoder_embed_tokens is None:
|
| 29 |
+
decoder_embed_tokens = self.encoder.embed_tokens
|
| 30 |
+
|
| 31 |
+
self.decoder = Decoder(
|
| 32 |
+
args, decoder_embed_tokens, decoder_embed_positions, output_projection, is_encoder_decoder=True, **kwargs
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def forward(self, src_tokens, prev_output_tokens, return_all_hiddens=False, features_only=False, **kwargs):
|
| 36 |
+
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
|
| 37 |
+
decoder_out = self.decoder(
|
| 38 |
+
prev_output_tokens,
|
| 39 |
+
encoder_out=encoder_out,
|
| 40 |
+
features_only=features_only,
|
| 41 |
+
return_all_hiddens=return_all_hiddens,
|
| 42 |
+
)
|
| 43 |
+
return decoder_out
|
vlmo/torchscale/architecture/utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Microsoft
|
| 2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
|
| 7 |
+
from vlmo.torchscale.component.multiway_network import MultiwayNetwork
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def init_bert_params(module):
|
| 11 |
+
def normal_(data):
|
| 12 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
| 13 |
+
|
| 14 |
+
if isinstance(module, nn.Linear):
|
| 15 |
+
normal_(module.weight.data)
|
| 16 |
+
if module.bias is not None:
|
| 17 |
+
module.bias.data.zero_()
|
| 18 |
+
if isinstance(module, nn.Embedding):
|
| 19 |
+
normal_(module.weight.data)
|
| 20 |
+
if module.padding_idx is not None:
|
| 21 |
+
module.weight.data[module.padding_idx].zero_()
|
| 22 |
+
if isinstance(module, MultiheadAttention):
|
| 23 |
+
if isinstance(module.q_proj, MultiwayNetwork):
|
| 24 |
+
normal_(module.q_proj.A.weight.data)
|
| 25 |
+
normal_(module.q_proj.B.weight.data)
|
| 26 |
+
normal_(module.k_proj.A.weight.data)
|
| 27 |
+
normal_(module.k_proj.B.weight.data)
|
| 28 |
+
normal_(module.v_proj.A.weight.data)
|
| 29 |
+
normal_(module.v_proj.B.weight.data)
|
| 30 |
+
else:
|
| 31 |
+
normal_(module.q_proj.weight.data)
|
| 32 |
+
normal_(module.k_proj.weight.data)
|
| 33 |
+
normal_(module.v_proj.weight.data)
|