Spaces:
Sleeping
Sleeping
home
commited on
Commit
·
776deff
1
Parent(s):
500705e
Add Gradio app with model weights via LFS
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- app.py +51 -0
- app_infer.py +35 -0
- clip_linear.pt +3 -0
- defake/.DS_Store +0 -0
- defake/LICENSE +8 -0
- defake/README.md +44 -0
- defake/blipmodels/__init__.py +1 -0
- defake/blipmodels/__pycache__/__init__.cpython-38.pyc +0 -0
- defake/blipmodels/__pycache__/blip.cpython-38.pyc +0 -0
- defake/blipmodels/__pycache__/med.cpython-38.pyc +0 -0
- defake/blipmodels/__pycache__/vit.cpython-38.pyc +0 -0
- defake/blipmodels/blip.py +238 -0
- defake/blipmodels/blip_itm.py +76 -0
- defake/blipmodels/blip_nlvr.py +103 -0
- defake/blipmodels/blip_pretrain.py +339 -0
- defake/blipmodels/blip_retrieval.py +319 -0
- defake/blipmodels/blip_vqa.py +186 -0
- defake/blipmodels/blipconfig/bert_config.json +21 -0
- defake/blipmodels/blipconfig/caption_coco.yaml +33 -0
- defake/blipmodels/blipconfig/med_config.json +21 -0
- defake/blipmodels/blipconfig/nlvr.yaml +21 -0
- defake/blipmodels/blipconfig/nocaps.yaml +15 -0
- defake/blipmodels/blipconfig/pretrain.yaml +27 -0
- defake/blipmodels/blipconfig/retrieval_coco.yaml +34 -0
- defake/blipmodels/blipconfig/retrieval_flickr.yaml +34 -0
- defake/blipmodels/blipconfig/retrieval_msrvtt.yaml +12 -0
- defake/blipmodels/blipconfig/vqa.yaml +25 -0
- defake/blipmodels/med.py +955 -0
- defake/blipmodels/nlvr_encoder.py +843 -0
- defake/blipmodels/vit.py +305 -0
- defake/clipdatasets.py +110 -0
- defake/environment.yaml +23 -0
- defake/models/__init__.py +0 -0
- defake/models/__pycache__/__init__.cpython-38.pyc +0 -0
- defake/models/__pycache__/blip.cpython-38.pyc +0 -0
- defake/models/__pycache__/med.cpython-38.pyc +0 -0
- defake/models/__pycache__/vit.cpython-38.pyc +0 -0
- defake/models/blip.py +238 -0
- defake/models/blip_itm.py +76 -0
- defake/models/blip_nlvr.py +103 -0
- defake/models/blip_pretrain.py +339 -0
- defake/models/blip_retrieval.py +319 -0
- defake/models/blip_vqa.py +186 -0
- defake/models/med.py +955 -0
- defake/models/nlvr_encoder.py +843 -0
- defake/models/vit.py +305 -0
- defake/test.py +95 -0
- defake/test_api.py +114 -0
- defake/train.py +187 -0
.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from app_infer import run_infer_from_image
|
| 6 |
+
|
| 7 |
+
APP_API_TOKEN = os.environ.get("APP_API_TOKEN", "")
|
| 8 |
+
|
| 9 |
+
def detect_api(image: Image.Image, api_token: str = ""):
|
| 10 |
+
if APP_API_TOKEN and api_token != APP_API_TOKEN:
|
| 11 |
+
raise gr.Error("Invalid API token.")
|
| 12 |
+
result = run_infer_from_image(image)
|
| 13 |
+
return result
|
| 14 |
+
|
| 15 |
+
def detect_ui(image: Image.Image):
|
| 16 |
+
result = run_infer_from_image(image)
|
| 17 |
+
return result["fake_score"], str(result["is_fake"])
|
| 18 |
+
|
| 19 |
+
with gr.Blocks(title="CISPA Citizen DeFake") as demo:
|
| 20 |
+
gr.Markdown("# CISPA Citizen DeFake")
|
| 21 |
+
gr.Markdown("Upload an image to detect if it is fake.")
|
| 22 |
+
|
| 23 |
+
with gr.Row():
|
| 24 |
+
with gr.Column():
|
| 25 |
+
img_input = gr.Image(type="pil", label="Upload Image")
|
| 26 |
+
btn = gr.Button("Detect")
|
| 27 |
+
with gr.Column():
|
| 28 |
+
score_out = gr.Number(label="Fake score (1=fake)")
|
| 29 |
+
is_fake_out = gr.Textbox(label="Prediction (True=fake)")
|
| 30 |
+
|
| 31 |
+
btn.click(fn=detect_ui, inputs=img_input, outputs=[score_out, is_fake_out])
|
| 32 |
+
|
| 33 |
+
gr.Markdown("## API")
|
| 34 |
+
gr.Markdown(
|
| 35 |
+
"POST to `/run/predict` with form-data: `image`, `api_token`.\n"
|
| 36 |
+
"Returns JSON: `{fake_score: float, is_fake: bool, probs: [p0,p1], pred_class: int}`."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
api = gr.Interface(
|
| 40 |
+
fn=detect_api,
|
| 41 |
+
inputs=[
|
| 42 |
+
gr.Image(type="pil", label="image"),
|
| 43 |
+
gr.Textbox(label="api_token", type="password"),
|
| 44 |
+
],
|
| 45 |
+
outputs="json",
|
| 46 |
+
allow_flagging="never",
|
| 47 |
+
title="API Endpoint",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
demo.launch()
|
app_infer.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import tempfile
|
| 3 |
+
from defake.test_api import load_models, predict_image
|
| 4 |
+
|
| 5 |
+
# 全局只加载一次模型
|
| 6 |
+
_models = None
|
| 7 |
+
|
| 8 |
+
def get_models():
|
| 9 |
+
global _models
|
| 10 |
+
if _models is None:
|
| 11 |
+
_models = load_models()
|
| 12 |
+
return _models
|
| 13 |
+
|
| 14 |
+
def run_infer_from_image(pil_image: Image.Image):
|
| 15 |
+
"""
|
| 16 |
+
输入:PIL Image
|
| 17 |
+
输出:dict,包含 fake_score、is_fake 等
|
| 18 |
+
"""
|
| 19 |
+
models = get_models()
|
| 20 |
+
|
| 21 |
+
# 暂存到临时文件
|
| 22 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
| 23 |
+
pil_image.save(tmp.name)
|
| 24 |
+
pred, probs = predict_image(tmp.name, models)
|
| 25 |
+
|
| 26 |
+
# 假设类别 1 是 fake,0 是 real
|
| 27 |
+
fake_score = float(probs[1])
|
| 28 |
+
is_fake = (pred == 1)
|
| 29 |
+
|
| 30 |
+
return {
|
| 31 |
+
"fake_score": fake_score,
|
| 32 |
+
"is_fake": is_fake,
|
| 33 |
+
"probs": probs,
|
| 34 |
+
"pred_class": int(pred),
|
| 35 |
+
}
|
clip_linear.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb29c2dbea6446aa86d41969032e27d48917823c304cb93e72d0fc3357010f6f
|
| 3 |
+
size 2629335
|
defake/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
defake/LICENSE
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The MIT License (MIT)
|
| 2 |
+
Copyright © 2025 Zeyang Sha
|
| 3 |
+
|
| 4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 5 |
+
|
| 6 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 7 |
+
|
| 8 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
defake/README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# De-Fake
|
| 2 |
+
|
| 3 |
+
It is the code for the paper: DE-FAKE: Detection and Attribution of Fake Images Generated by Text-to-Image Generation Models.
|
| 4 |
+
|
| 5 |
+
### Environment
|
| 6 |
+
|
| 7 |
+
You first need to build the environment by:
|
| 8 |
+
```
|
| 9 |
+
conda env create -f environment.yaml
|
| 10 |
+
conda activate defake
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
### Infer
|
| 14 |
+
|
| 15 |
+
For the usage, You can download our model on
|
| 16 |
+
|
| 17 |
+
https://drive.google.com/file/d/1qI7x5iodaCFq0S61LKw4wWjql7cYou_4/view?usp=sharing
|
| 18 |
+
|
| 19 |
+
and
|
| 20 |
+
|
| 21 |
+
https://drive.google.com/file/d/1SuenxJP10VwArC6zW0SHMUGObMRqQhBD/view?usp=sharing
|
| 22 |
+
|
| 23 |
+
for the encoder and classifier.
|
| 24 |
+
|
| 25 |
+
Then test on
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
python test.py --image_path XXX
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
### Train
|
| 33 |
+
|
| 34 |
+
If you want to train the detector yourself, please enter the correct file path in train.py.
|
| 35 |
+
|
| 36 |
+
Then
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
python train.py
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### License
|
| 43 |
+
|
| 44 |
+
DeFake is licensed under the term of thr MIT license. See LICENSE for more details.
|
defake/blipmodels/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .blip import blip_decoder
|
defake/blipmodels/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (188 Bytes). View file
|
|
|
defake/blipmodels/__pycache__/blip.cpython-38.pyc
ADDED
|
Binary file (6.98 kB). View file
|
|
|
defake/blipmodels/__pycache__/med.cpython-38.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
defake/blipmodels/__pycache__/vit.cpython-38.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
defake/blipmodels/blip.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
'''
|
| 8 |
+
import warnings
|
| 9 |
+
warnings.filterwarnings("ignore")
|
| 10 |
+
|
| 11 |
+
from .vit import VisionTransformer, interpolate_pos_embed
|
| 12 |
+
from .med import BertConfig, BertModel, BertLMHeadModel
|
| 13 |
+
from transformers import BertTokenizer
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
from urllib.parse import urlparse
|
| 21 |
+
from timm.models.hub import download_cached_file
|
| 22 |
+
|
| 23 |
+
class BLIP_Base(nn.Module):
|
| 24 |
+
def __init__(self,
|
| 25 |
+
med_config = '/home/sha/stable-diffusion/blipmodels/blipconfig/med_config.json',
|
| 26 |
+
image_size = 224,
|
| 27 |
+
vit = 'base',
|
| 28 |
+
vit_grad_ckpt = False,
|
| 29 |
+
vit_ckpt_layer = 0,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 34 |
+
image_size (int): input image size
|
| 35 |
+
vit (str): model size of vision transformer
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 40 |
+
self.tokenizer = init_tokenizer()
|
| 41 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 42 |
+
med_config.encoder_width = vision_width
|
| 43 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def forward(self, image, caption, mode):
|
| 47 |
+
|
| 48 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
| 49 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
| 50 |
+
|
| 51 |
+
if mode=='image':
|
| 52 |
+
# return image features
|
| 53 |
+
image_embeds = self.visual_encoder(image)
|
| 54 |
+
return image_embeds
|
| 55 |
+
|
| 56 |
+
elif mode=='text':
|
| 57 |
+
# return text features
|
| 58 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 59 |
+
return_dict = True, mode = 'text')
|
| 60 |
+
return text_output.last_hidden_state
|
| 61 |
+
|
| 62 |
+
elif mode=='multimodal':
|
| 63 |
+
# return multimodel features
|
| 64 |
+
image_embeds = self.visual_encoder(image)
|
| 65 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 66 |
+
|
| 67 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 68 |
+
output = self.text_encoder(text.input_ids,
|
| 69 |
+
attention_mask = text.attention_mask,
|
| 70 |
+
encoder_hidden_states = image_embeds,
|
| 71 |
+
encoder_attention_mask = image_atts,
|
| 72 |
+
return_dict = True,
|
| 73 |
+
)
|
| 74 |
+
return output.last_hidden_state
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class BLIP_Decoder(nn.Module):
|
| 79 |
+
def __init__(self,
|
| 80 |
+
med_config = '/home/sha/stable-diffusion/blipmodels/blipconfig/med_config.json',
|
| 81 |
+
image_size = 384,
|
| 82 |
+
vit = 'base',
|
| 83 |
+
vit_grad_ckpt = False,
|
| 84 |
+
vit_ckpt_layer = 0,
|
| 85 |
+
prompt = 'a picture of ',
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 90 |
+
image_size (int): input image size
|
| 91 |
+
vit (str): model size of vision transformer
|
| 92 |
+
"""
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 96 |
+
self.tokenizer = init_tokenizer()
|
| 97 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 98 |
+
med_config.encoder_width = vision_width
|
| 99 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
| 100 |
+
|
| 101 |
+
self.prompt = prompt
|
| 102 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def forward(self, image, caption):
|
| 106 |
+
|
| 107 |
+
image_embeds = self.visual_encoder(image)
|
| 108 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 109 |
+
|
| 110 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
| 111 |
+
|
| 112 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 113 |
+
|
| 114 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
| 115 |
+
decoder_targets[:,:self.prompt_length] = -100
|
| 116 |
+
|
| 117 |
+
decoder_output = self.text_decoder(text.input_ids,
|
| 118 |
+
attention_mask = text.attention_mask,
|
| 119 |
+
encoder_hidden_states = image_embeds,
|
| 120 |
+
encoder_attention_mask = image_atts,
|
| 121 |
+
labels = decoder_targets,
|
| 122 |
+
return_dict = True,
|
| 123 |
+
)
|
| 124 |
+
loss_lm = decoder_output.loss
|
| 125 |
+
|
| 126 |
+
return loss_lm
|
| 127 |
+
|
| 128 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
| 129 |
+
image_embeds = self.visual_encoder(image)
|
| 130 |
+
|
| 131 |
+
if not sample:
|
| 132 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
| 133 |
+
|
| 134 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 135 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
| 136 |
+
|
| 137 |
+
prompt = [self.prompt] * image.size(0)
|
| 138 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
| 139 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
| 140 |
+
input_ids = input_ids[:, :-1]
|
| 141 |
+
|
| 142 |
+
if sample:
|
| 143 |
+
#nucleus sampling
|
| 144 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 145 |
+
max_length=max_length,
|
| 146 |
+
min_length=min_length,
|
| 147 |
+
do_sample=True,
|
| 148 |
+
top_p=top_p,
|
| 149 |
+
num_return_sequences=1,
|
| 150 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 151 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 152 |
+
repetition_penalty=1.1,
|
| 153 |
+
**model_kwargs)
|
| 154 |
+
else:
|
| 155 |
+
#beam search
|
| 156 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 157 |
+
max_length=max_length,
|
| 158 |
+
min_length=min_length,
|
| 159 |
+
num_beams=num_beams,
|
| 160 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 161 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 162 |
+
repetition_penalty=repetition_penalty,
|
| 163 |
+
**model_kwargs)
|
| 164 |
+
|
| 165 |
+
captions = []
|
| 166 |
+
for output in outputs:
|
| 167 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 168 |
+
captions.append(caption[len(self.prompt):])
|
| 169 |
+
return captions
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def blip_decoder(pretrained='',**kwargs):
|
| 173 |
+
model = BLIP_Decoder(**kwargs)
|
| 174 |
+
if pretrained:
|
| 175 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 176 |
+
assert(len(msg.missing_keys)==0)
|
| 177 |
+
return model
|
| 178 |
+
|
| 179 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
| 180 |
+
model = BLIP_Base(**kwargs)
|
| 181 |
+
if pretrained:
|
| 182 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 183 |
+
assert(len(msg.missing_keys)==0)
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
def init_tokenizer():
|
| 187 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 188 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
| 189 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
| 190 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
| 191 |
+
return tokenizer
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
| 195 |
+
|
| 196 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
| 197 |
+
if vit=='base':
|
| 198 |
+
vision_width = 768
|
| 199 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
| 200 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 201 |
+
drop_path_rate=0 or drop_path_rate
|
| 202 |
+
)
|
| 203 |
+
elif vit=='large':
|
| 204 |
+
vision_width = 1024
|
| 205 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
| 206 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 207 |
+
drop_path_rate=0.1 or drop_path_rate
|
| 208 |
+
)
|
| 209 |
+
return visual_encoder, vision_width
|
| 210 |
+
|
| 211 |
+
def is_url(url_or_filename):
|
| 212 |
+
parsed = urlparse(url_or_filename)
|
| 213 |
+
return parsed.scheme in ("http", "https")
|
| 214 |
+
|
| 215 |
+
def load_checkpoint(model,url_or_filename):
|
| 216 |
+
if is_url(url_or_filename):
|
| 217 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 218 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 219 |
+
elif os.path.isfile(url_or_filename):
|
| 220 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 221 |
+
else:
|
| 222 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 223 |
+
|
| 224 |
+
state_dict = checkpoint['model']
|
| 225 |
+
|
| 226 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 227 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
| 228 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
| 229 |
+
model.visual_encoder_m)
|
| 230 |
+
for key in model.state_dict().keys():
|
| 231 |
+
if key in state_dict.keys():
|
| 232 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
| 233 |
+
del state_dict[key]
|
| 234 |
+
|
| 235 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 236 |
+
print('load checkpoint from %s'%url_or_filename)
|
| 237 |
+
return model,msg
|
| 238 |
+
|
defake/blipmodels/blip_itm.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig, BertModel
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 9 |
+
|
| 10 |
+
class BLIP_ITM(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
med_config = 'configs/med_config.json',
|
| 13 |
+
image_size = 384,
|
| 14 |
+
vit = 'base',
|
| 15 |
+
vit_grad_ckpt = False,
|
| 16 |
+
vit_ckpt_layer = 0,
|
| 17 |
+
embed_dim = 256,
|
| 18 |
+
):
|
| 19 |
+
"""
|
| 20 |
+
Args:
|
| 21 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 22 |
+
image_size (int): input image size
|
| 23 |
+
vit (str): model size of vision transformer
|
| 24 |
+
"""
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 28 |
+
self.tokenizer = init_tokenizer()
|
| 29 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 30 |
+
med_config.encoder_width = vision_width
|
| 31 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 32 |
+
|
| 33 |
+
text_width = self.text_encoder.config.hidden_size
|
| 34 |
+
|
| 35 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 36 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 37 |
+
|
| 38 |
+
self.itm_head = nn.Linear(text_width, 2)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def forward(self, image, caption, match_head='itm'):
|
| 42 |
+
|
| 43 |
+
image_embeds = self.visual_encoder(image)
|
| 44 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 45 |
+
|
| 46 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
| 47 |
+
return_tensors="pt").to(image.device)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if match_head=='itm':
|
| 51 |
+
output = self.text_encoder(text.input_ids,
|
| 52 |
+
attention_mask = text.attention_mask,
|
| 53 |
+
encoder_hidden_states = image_embeds,
|
| 54 |
+
encoder_attention_mask = image_atts,
|
| 55 |
+
return_dict = True,
|
| 56 |
+
)
|
| 57 |
+
itm_output = self.itm_head(output.last_hidden_state[:,0,:])
|
| 58 |
+
return itm_output
|
| 59 |
+
|
| 60 |
+
elif match_head=='itc':
|
| 61 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 62 |
+
return_dict = True, mode = 'text')
|
| 63 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 64 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 65 |
+
|
| 66 |
+
sim = image_feat @ text_feat.t()
|
| 67 |
+
return sim
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def blip_itm(pretrained='',**kwargs):
|
| 71 |
+
model = BLIP_ITM(**kwargs)
|
| 72 |
+
if pretrained:
|
| 73 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 74 |
+
assert(len(msg.missing_keys)==0)
|
| 75 |
+
return model
|
| 76 |
+
|
defake/blipmodels/blip_nlvr.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig
|
| 2 |
+
from models.nlvr_encoder import BertModel
|
| 3 |
+
from models.vit import interpolate_pos_embed
|
| 4 |
+
from models.blip import create_vit, init_tokenizer, is_url
|
| 5 |
+
|
| 6 |
+
from timm.models.hub import download_cached_file
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from transformers import BertTokenizer
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
class BLIP_NLVR(nn.Module):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
med_config = 'configs/med_config.json',
|
| 17 |
+
image_size = 480,
|
| 18 |
+
vit = 'base',
|
| 19 |
+
vit_grad_ckpt = False,
|
| 20 |
+
vit_ckpt_layer = 0,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 25 |
+
image_size (int): input image size
|
| 26 |
+
vit (str): model size of vision transformer
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
| 31 |
+
self.tokenizer = init_tokenizer()
|
| 32 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 33 |
+
med_config.encoder_width = vision_width
|
| 34 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 35 |
+
|
| 36 |
+
self.cls_head = nn.Sequential(
|
| 37 |
+
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
|
| 38 |
+
nn.ReLU(),
|
| 39 |
+
nn.Linear(self.text_encoder.config.hidden_size, 2)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, image, text, targets, train=True):
|
| 43 |
+
|
| 44 |
+
image_embeds = self.visual_encoder(image)
|
| 45 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 46 |
+
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
|
| 47 |
+
|
| 48 |
+
text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
|
| 49 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 50 |
+
|
| 51 |
+
output = self.text_encoder(text.input_ids,
|
| 52 |
+
attention_mask = text.attention_mask,
|
| 53 |
+
encoder_hidden_states = [image0_embeds,image1_embeds],
|
| 54 |
+
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
|
| 55 |
+
image_atts[image0_embeds.size(0):]],
|
| 56 |
+
return_dict = True,
|
| 57 |
+
)
|
| 58 |
+
hidden_state = output.last_hidden_state[:,0,:]
|
| 59 |
+
prediction = self.cls_head(hidden_state)
|
| 60 |
+
|
| 61 |
+
if train:
|
| 62 |
+
loss = F.cross_entropy(prediction, targets)
|
| 63 |
+
return loss
|
| 64 |
+
else:
|
| 65 |
+
return prediction
|
| 66 |
+
|
| 67 |
+
def blip_nlvr(pretrained='',**kwargs):
|
| 68 |
+
model = BLIP_NLVR(**kwargs)
|
| 69 |
+
if pretrained:
|
| 70 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 71 |
+
print("missing keys:")
|
| 72 |
+
print(msg.missing_keys)
|
| 73 |
+
return model
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_checkpoint(model,url_or_filename):
|
| 77 |
+
if is_url(url_or_filename):
|
| 78 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 79 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 80 |
+
elif os.path.isfile(url_or_filename):
|
| 81 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 82 |
+
else:
|
| 83 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 84 |
+
state_dict = checkpoint['model']
|
| 85 |
+
|
| 86 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 87 |
+
|
| 88 |
+
for key in list(state_dict.keys()):
|
| 89 |
+
if 'crossattention.self.' in key:
|
| 90 |
+
new_key0 = key.replace('self','self0')
|
| 91 |
+
new_key1 = key.replace('self','self1')
|
| 92 |
+
state_dict[new_key0] = state_dict[key]
|
| 93 |
+
state_dict[new_key1] = state_dict[key]
|
| 94 |
+
elif 'crossattention.output.dense.' in key:
|
| 95 |
+
new_key0 = key.replace('dense','dense0')
|
| 96 |
+
new_key1 = key.replace('dense','dense1')
|
| 97 |
+
state_dict[new_key0] = state_dict[key]
|
| 98 |
+
state_dict[new_key1] = state_dict[key]
|
| 99 |
+
|
| 100 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 101 |
+
print('load checkpoint from %s'%url_or_filename)
|
| 102 |
+
return model,msg
|
| 103 |
+
|
defake/blipmodels/blip_pretrain.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
'''
|
| 8 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
| 9 |
+
from transformers import BertTokenizer
|
| 10 |
+
import transformers
|
| 11 |
+
transformers.logging.set_verbosity_error()
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 18 |
+
|
| 19 |
+
class BLIP_Pretrain(nn.Module):
|
| 20 |
+
def __init__(self,
|
| 21 |
+
med_config = 'configs/bert_config.json',
|
| 22 |
+
image_size = 224,
|
| 23 |
+
vit = 'base',
|
| 24 |
+
vit_grad_ckpt = False,
|
| 25 |
+
vit_ckpt_layer = 0,
|
| 26 |
+
embed_dim = 256,
|
| 27 |
+
queue_size = 57600,
|
| 28 |
+
momentum = 0.995,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 33 |
+
image_size (int): input image size
|
| 34 |
+
vit (str): model size of vision transformer
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
| 39 |
+
|
| 40 |
+
if vit=='base':
|
| 41 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 42 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 43 |
+
map_location="cpu", check_hash=True)
|
| 44 |
+
state_dict = checkpoint["model"]
|
| 45 |
+
msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
|
| 46 |
+
elif vit=='large':
|
| 47 |
+
from timm.models.helpers import load_custom_pretrained
|
| 48 |
+
from timm.models.vision_transformer import default_cfgs
|
| 49 |
+
load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
|
| 50 |
+
|
| 51 |
+
self.tokenizer = init_tokenizer()
|
| 52 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
| 53 |
+
encoder_config.encoder_width = vision_width
|
| 54 |
+
self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
|
| 55 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
| 56 |
+
|
| 57 |
+
text_width = self.text_encoder.config.hidden_size
|
| 58 |
+
|
| 59 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 60 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 61 |
+
|
| 62 |
+
self.itm_head = nn.Linear(text_width, 2)
|
| 63 |
+
|
| 64 |
+
# create momentum encoders
|
| 65 |
+
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
| 66 |
+
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
| 67 |
+
self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 68 |
+
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
| 69 |
+
|
| 70 |
+
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
| 71 |
+
[self.vision_proj,self.vision_proj_m],
|
| 72 |
+
[self.text_encoder,self.text_encoder_m],
|
| 73 |
+
[self.text_proj,self.text_proj_m],
|
| 74 |
+
]
|
| 75 |
+
self.copy_params()
|
| 76 |
+
|
| 77 |
+
# create the queue
|
| 78 |
+
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
| 79 |
+
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
| 80 |
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
| 81 |
+
|
| 82 |
+
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
| 83 |
+
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
| 84 |
+
|
| 85 |
+
self.queue_size = queue_size
|
| 86 |
+
self.momentum = momentum
|
| 87 |
+
self.temp = nn.Parameter(0.07*torch.ones([]))
|
| 88 |
+
|
| 89 |
+
# create the decoder
|
| 90 |
+
decoder_config = BertConfig.from_json_file(med_config)
|
| 91 |
+
decoder_config.encoder_width = vision_width
|
| 92 |
+
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
|
| 93 |
+
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
|
| 94 |
+
tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def forward(self, image, caption, alpha):
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
self.temp.clamp_(0.001,0.5)
|
| 100 |
+
|
| 101 |
+
image_embeds = self.visual_encoder(image)
|
| 102 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 103 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 104 |
+
|
| 105 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
|
| 106 |
+
return_tensors="pt").to(image.device)
|
| 107 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 108 |
+
return_dict = True, mode = 'text')
|
| 109 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 110 |
+
|
| 111 |
+
# get momentum features
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
self._momentum_update()
|
| 114 |
+
image_embeds_m = self.visual_encoder_m(image)
|
| 115 |
+
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
| 116 |
+
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
| 117 |
+
|
| 118 |
+
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
| 119 |
+
return_dict = True, mode = 'text')
|
| 120 |
+
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
| 121 |
+
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
| 122 |
+
|
| 123 |
+
sim_i2t_m = image_feat_m @ text_feat_all / self.temp
|
| 124 |
+
sim_t2i_m = text_feat_m @ image_feat_all / self.temp
|
| 125 |
+
|
| 126 |
+
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
|
| 127 |
+
sim_targets.fill_diagonal_(1)
|
| 128 |
+
|
| 129 |
+
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
| 130 |
+
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
| 131 |
+
|
| 132 |
+
sim_i2t = image_feat @ text_feat_all / self.temp
|
| 133 |
+
sim_t2i = text_feat @ image_feat_all / self.temp
|
| 134 |
+
|
| 135 |
+
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
| 136 |
+
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
| 137 |
+
|
| 138 |
+
loss_ita = (loss_i2t+loss_t2i)/2
|
| 139 |
+
|
| 140 |
+
self._dequeue_and_enqueue(image_feat_m, text_feat_m)
|
| 141 |
+
|
| 142 |
+
###============== Image-text Matching ===================###
|
| 143 |
+
encoder_input_ids = text.input_ids.clone()
|
| 144 |
+
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
| 145 |
+
|
| 146 |
+
# forward the positve image-text pair
|
| 147 |
+
bs = image.size(0)
|
| 148 |
+
output_pos = self.text_encoder(encoder_input_ids,
|
| 149 |
+
attention_mask = text.attention_mask,
|
| 150 |
+
encoder_hidden_states = image_embeds,
|
| 151 |
+
encoder_attention_mask = image_atts,
|
| 152 |
+
return_dict = True,
|
| 153 |
+
)
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
|
| 156 |
+
weights_t2i.fill_diagonal_(0)
|
| 157 |
+
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
|
| 158 |
+
weights_i2t.fill_diagonal_(0)
|
| 159 |
+
|
| 160 |
+
# select a negative image for each text
|
| 161 |
+
image_embeds_neg = []
|
| 162 |
+
for b in range(bs):
|
| 163 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 164 |
+
image_embeds_neg.append(image_embeds[neg_idx])
|
| 165 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 166 |
+
|
| 167 |
+
# select a negative text for each image
|
| 168 |
+
text_ids_neg = []
|
| 169 |
+
text_atts_neg = []
|
| 170 |
+
for b in range(bs):
|
| 171 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 172 |
+
text_ids_neg.append(encoder_input_ids[neg_idx])
|
| 173 |
+
text_atts_neg.append(text.attention_mask[neg_idx])
|
| 174 |
+
|
| 175 |
+
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
| 176 |
+
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
| 177 |
+
|
| 178 |
+
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
| 179 |
+
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
| 180 |
+
|
| 181 |
+
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
| 182 |
+
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
| 183 |
+
|
| 184 |
+
output_neg = self.text_encoder(text_ids_all,
|
| 185 |
+
attention_mask = text_atts_all,
|
| 186 |
+
encoder_hidden_states = image_embeds_all,
|
| 187 |
+
encoder_attention_mask = image_atts_all,
|
| 188 |
+
return_dict = True,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
| 192 |
+
vl_output = self.itm_head(vl_embeddings)
|
| 193 |
+
|
| 194 |
+
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
| 195 |
+
dim=0).to(image.device)
|
| 196 |
+
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
| 197 |
+
|
| 198 |
+
##================= LM ========================##
|
| 199 |
+
decoder_input_ids = text.input_ids.clone()
|
| 200 |
+
decoder_input_ids[:,0] = self.tokenizer.bos_token_id
|
| 201 |
+
decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
|
| 202 |
+
|
| 203 |
+
decoder_output = self.text_decoder(decoder_input_ids,
|
| 204 |
+
attention_mask = text.attention_mask,
|
| 205 |
+
encoder_hidden_states = image_embeds,
|
| 206 |
+
encoder_attention_mask = image_atts,
|
| 207 |
+
labels = decoder_targets,
|
| 208 |
+
return_dict = True,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
loss_lm = decoder_output.loss
|
| 212 |
+
return loss_ita, loss_itm, loss_lm
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def copy_params(self):
|
| 218 |
+
for model_pair in self.model_pairs:
|
| 219 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 220 |
+
param_m.data.copy_(param.data) # initialize
|
| 221 |
+
param_m.requires_grad = False # not update by gradient
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def _momentum_update(self):
|
| 226 |
+
for model_pair in self.model_pairs:
|
| 227 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 228 |
+
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@torch.no_grad()
|
| 232 |
+
def _dequeue_and_enqueue(self, image_feat, text_feat):
|
| 233 |
+
# gather keys before updating queue
|
| 234 |
+
image_feats = concat_all_gather(image_feat)
|
| 235 |
+
text_feats = concat_all_gather(text_feat)
|
| 236 |
+
|
| 237 |
+
batch_size = image_feats.shape[0]
|
| 238 |
+
|
| 239 |
+
ptr = int(self.queue_ptr)
|
| 240 |
+
assert self.queue_size % batch_size == 0 # for simplicity
|
| 241 |
+
|
| 242 |
+
# replace the keys at ptr (dequeue and enqueue)
|
| 243 |
+
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
| 244 |
+
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
| 245 |
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
| 246 |
+
|
| 247 |
+
self.queue_ptr[0] = ptr
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def blip_pretrain(**kwargs):
|
| 251 |
+
model = BLIP_Pretrain(**kwargs)
|
| 252 |
+
return model
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@torch.no_grad()
|
| 256 |
+
def concat_all_gather(tensor):
|
| 257 |
+
"""
|
| 258 |
+
Performs all_gather operation on the provided tensors.
|
| 259 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 260 |
+
"""
|
| 261 |
+
tensors_gather = [torch.ones_like(tensor)
|
| 262 |
+
for _ in range(torch.distributed.get_world_size())]
|
| 263 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 264 |
+
|
| 265 |
+
output = torch.cat(tensors_gather, dim=0)
|
| 266 |
+
return output
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
from typing import List
|
| 270 |
+
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
|
| 271 |
+
uninitialized_encoder_weights: List[str] = []
|
| 272 |
+
if decoder.__class__ != encoder.__class__:
|
| 273 |
+
logger.info(
|
| 274 |
+
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def tie_encoder_to_decoder_recursively(
|
| 278 |
+
decoder_pointer: nn.Module,
|
| 279 |
+
encoder_pointer: nn.Module,
|
| 280 |
+
module_name: str,
|
| 281 |
+
uninitialized_encoder_weights: List[str],
|
| 282 |
+
skip_key: str,
|
| 283 |
+
depth=0,
|
| 284 |
+
):
|
| 285 |
+
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
| 286 |
+
encoder_pointer, nn.Module
|
| 287 |
+
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
|
| 288 |
+
if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
|
| 289 |
+
assert hasattr(encoder_pointer, "weight")
|
| 290 |
+
encoder_pointer.weight = decoder_pointer.weight
|
| 291 |
+
if hasattr(decoder_pointer, "bias"):
|
| 292 |
+
assert hasattr(encoder_pointer, "bias")
|
| 293 |
+
encoder_pointer.bias = decoder_pointer.bias
|
| 294 |
+
print(module_name+' is tied')
|
| 295 |
+
return
|
| 296 |
+
|
| 297 |
+
encoder_modules = encoder_pointer._modules
|
| 298 |
+
decoder_modules = decoder_pointer._modules
|
| 299 |
+
if len(decoder_modules) > 0:
|
| 300 |
+
assert (
|
| 301 |
+
len(encoder_modules) > 0
|
| 302 |
+
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
| 303 |
+
|
| 304 |
+
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
|
| 305 |
+
encoder_layer_pos = 0
|
| 306 |
+
for name, module in decoder_modules.items():
|
| 307 |
+
if name.isdigit():
|
| 308 |
+
encoder_name = str(int(name) + encoder_layer_pos)
|
| 309 |
+
decoder_name = name
|
| 310 |
+
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
|
| 311 |
+
encoder_modules
|
| 312 |
+
) != len(decoder_modules):
|
| 313 |
+
# this can happen if the name corresponds to the position in a list module list of layers
|
| 314 |
+
# in this case the decoder has added a cross-attention that the encoder does not have
|
| 315 |
+
# thus skip this step and subtract one layer pos from encoder
|
| 316 |
+
encoder_layer_pos -= 1
|
| 317 |
+
continue
|
| 318 |
+
elif name not in encoder_modules:
|
| 319 |
+
continue
|
| 320 |
+
elif depth > 500:
|
| 321 |
+
raise ValueError(
|
| 322 |
+
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
decoder_name = encoder_name = name
|
| 326 |
+
tie_encoder_to_decoder_recursively(
|
| 327 |
+
decoder_modules[decoder_name],
|
| 328 |
+
encoder_modules[encoder_name],
|
| 329 |
+
module_name + "/" + name,
|
| 330 |
+
uninitialized_encoder_weights,
|
| 331 |
+
skip_key,
|
| 332 |
+
depth=depth + 1,
|
| 333 |
+
)
|
| 334 |
+
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
| 335 |
+
|
| 336 |
+
uninitialized_encoder_weights += list(all_encoder_weights)
|
| 337 |
+
|
| 338 |
+
# tie weights recursively
|
| 339 |
+
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
|
defake/blipmodels/blip_retrieval.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig, BertModel
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 9 |
+
|
| 10 |
+
class BLIP_Retrieval(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
med_config = 'configs/med_config.json',
|
| 13 |
+
image_size = 384,
|
| 14 |
+
vit = 'base',
|
| 15 |
+
vit_grad_ckpt = False,
|
| 16 |
+
vit_ckpt_layer = 0,
|
| 17 |
+
embed_dim = 256,
|
| 18 |
+
queue_size = 57600,
|
| 19 |
+
momentum = 0.995,
|
| 20 |
+
negative_all_rank = False,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 25 |
+
image_size (int): input image size
|
| 26 |
+
vit (str): model size of vision transformer
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 31 |
+
self.tokenizer = init_tokenizer()
|
| 32 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 33 |
+
med_config.encoder_width = vision_width
|
| 34 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 35 |
+
|
| 36 |
+
text_width = self.text_encoder.config.hidden_size
|
| 37 |
+
|
| 38 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 39 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 40 |
+
|
| 41 |
+
self.itm_head = nn.Linear(text_width, 2)
|
| 42 |
+
|
| 43 |
+
# create momentum encoders
|
| 44 |
+
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
| 45 |
+
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
| 46 |
+
self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
|
| 47 |
+
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
| 48 |
+
|
| 49 |
+
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
| 50 |
+
[self.vision_proj,self.vision_proj_m],
|
| 51 |
+
[self.text_encoder,self.text_encoder_m],
|
| 52 |
+
[self.text_proj,self.text_proj_m],
|
| 53 |
+
]
|
| 54 |
+
self.copy_params()
|
| 55 |
+
|
| 56 |
+
# create the queue
|
| 57 |
+
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
| 58 |
+
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
| 59 |
+
self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
|
| 60 |
+
self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
|
| 61 |
+
|
| 62 |
+
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
| 63 |
+
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
| 64 |
+
|
| 65 |
+
self.queue_size = queue_size
|
| 66 |
+
self.momentum = momentum
|
| 67 |
+
self.temp = nn.Parameter(0.07*torch.ones([]))
|
| 68 |
+
|
| 69 |
+
self.negative_all_rank = negative_all_rank
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def forward(self, image, caption, alpha, idx):
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
self.temp.clamp_(0.001,0.5)
|
| 75 |
+
|
| 76 |
+
image_embeds = self.visual_encoder(image)
|
| 77 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 78 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 79 |
+
|
| 80 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
| 81 |
+
return_tensors="pt").to(image.device)
|
| 82 |
+
|
| 83 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 84 |
+
return_dict = True, mode = 'text')
|
| 85 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 86 |
+
|
| 87 |
+
###============== Image-text Contrastive Learning ===================###
|
| 88 |
+
idx = idx.view(-1,1)
|
| 89 |
+
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
|
| 90 |
+
pos_idx = torch.eq(idx, idx_all).float()
|
| 91 |
+
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
|
| 92 |
+
|
| 93 |
+
# get momentum features
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
self._momentum_update()
|
| 96 |
+
image_embeds_m = self.visual_encoder_m(image)
|
| 97 |
+
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
| 98 |
+
image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
| 99 |
+
|
| 100 |
+
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
| 101 |
+
return_dict = True, mode = 'text')
|
| 102 |
+
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
| 103 |
+
text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
| 104 |
+
|
| 105 |
+
sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
|
| 106 |
+
sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
|
| 107 |
+
|
| 108 |
+
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
| 109 |
+
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
| 110 |
+
|
| 111 |
+
sim_i2t = image_feat @ text_feat_m_all / self.temp
|
| 112 |
+
sim_t2i = text_feat @ image_feat_m_all / self.temp
|
| 113 |
+
|
| 114 |
+
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
| 115 |
+
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
| 116 |
+
|
| 117 |
+
loss_ita = (loss_i2t+loss_t2i)/2
|
| 118 |
+
|
| 119 |
+
idxs = concat_all_gather(idx)
|
| 120 |
+
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
|
| 121 |
+
|
| 122 |
+
###============== Image-text Matching ===================###
|
| 123 |
+
encoder_input_ids = text.input_ids.clone()
|
| 124 |
+
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
| 125 |
+
|
| 126 |
+
# forward the positve image-text pair
|
| 127 |
+
bs = image.size(0)
|
| 128 |
+
output_pos = self.text_encoder(encoder_input_ids,
|
| 129 |
+
attention_mask = text.attention_mask,
|
| 130 |
+
encoder_hidden_states = image_embeds,
|
| 131 |
+
encoder_attention_mask = image_atts,
|
| 132 |
+
return_dict = True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if self.negative_all_rank:
|
| 137 |
+
# compute sample similarity
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
mask = torch.eq(idx, idxs.t())
|
| 140 |
+
|
| 141 |
+
image_feat_world = concat_all_gather(image_feat)
|
| 142 |
+
text_feat_world = concat_all_gather(text_feat)
|
| 143 |
+
|
| 144 |
+
sim_i2t = image_feat @ text_feat_world.t() / self.temp
|
| 145 |
+
sim_t2i = text_feat @ image_feat_world.t() / self.temp
|
| 146 |
+
|
| 147 |
+
weights_i2t = F.softmax(sim_i2t,dim=1)
|
| 148 |
+
weights_i2t.masked_fill_(mask, 0)
|
| 149 |
+
|
| 150 |
+
weights_t2i = F.softmax(sim_t2i,dim=1)
|
| 151 |
+
weights_t2i.masked_fill_(mask, 0)
|
| 152 |
+
|
| 153 |
+
image_embeds_world = all_gather_with_grad(image_embeds)
|
| 154 |
+
|
| 155 |
+
# select a negative image (from all ranks) for each text
|
| 156 |
+
image_embeds_neg = []
|
| 157 |
+
for b in range(bs):
|
| 158 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 159 |
+
image_embeds_neg.append(image_embeds_world[neg_idx])
|
| 160 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 161 |
+
|
| 162 |
+
# select a negative text (from all ranks) for each image
|
| 163 |
+
input_ids_world = concat_all_gather(encoder_input_ids)
|
| 164 |
+
att_mask_world = concat_all_gather(text.attention_mask)
|
| 165 |
+
|
| 166 |
+
text_ids_neg = []
|
| 167 |
+
text_atts_neg = []
|
| 168 |
+
for b in range(bs):
|
| 169 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 170 |
+
text_ids_neg.append(input_ids_world[neg_idx])
|
| 171 |
+
text_atts_neg.append(att_mask_world[neg_idx])
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
mask = torch.eq(idx, idx.t())
|
| 176 |
+
|
| 177 |
+
sim_i2t = image_feat @ text_feat.t() / self.temp
|
| 178 |
+
sim_t2i = text_feat @ image_feat.t() / self.temp
|
| 179 |
+
|
| 180 |
+
weights_i2t = F.softmax(sim_i2t,dim=1)
|
| 181 |
+
weights_i2t.masked_fill_(mask, 0)
|
| 182 |
+
|
| 183 |
+
weights_t2i = F.softmax(sim_t2i,dim=1)
|
| 184 |
+
weights_t2i.masked_fill_(mask, 0)
|
| 185 |
+
|
| 186 |
+
# select a negative image (from same rank) for each text
|
| 187 |
+
image_embeds_neg = []
|
| 188 |
+
for b in range(bs):
|
| 189 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 190 |
+
image_embeds_neg.append(image_embeds[neg_idx])
|
| 191 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 192 |
+
|
| 193 |
+
# select a negative text (from same rank) for each image
|
| 194 |
+
text_ids_neg = []
|
| 195 |
+
text_atts_neg = []
|
| 196 |
+
for b in range(bs):
|
| 197 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 198 |
+
text_ids_neg.append(encoder_input_ids[neg_idx])
|
| 199 |
+
text_atts_neg.append(text.attention_mask[neg_idx])
|
| 200 |
+
|
| 201 |
+
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
| 202 |
+
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
| 203 |
+
|
| 204 |
+
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
| 205 |
+
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
| 206 |
+
|
| 207 |
+
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
| 208 |
+
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
| 209 |
+
|
| 210 |
+
output_neg = self.text_encoder(text_ids_all,
|
| 211 |
+
attention_mask = text_atts_all,
|
| 212 |
+
encoder_hidden_states = image_embeds_all,
|
| 213 |
+
encoder_attention_mask = image_atts_all,
|
| 214 |
+
return_dict = True,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
| 219 |
+
vl_output = self.itm_head(vl_embeddings)
|
| 220 |
+
|
| 221 |
+
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
| 222 |
+
dim=0).to(image.device)
|
| 223 |
+
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
| 224 |
+
|
| 225 |
+
return loss_ita, loss_itm
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@torch.no_grad()
|
| 229 |
+
def copy_params(self):
|
| 230 |
+
for model_pair in self.model_pairs:
|
| 231 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 232 |
+
param_m.data.copy_(param.data) # initialize
|
| 233 |
+
param_m.requires_grad = False # not update by gradient
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@torch.no_grad()
|
| 237 |
+
def _momentum_update(self):
|
| 238 |
+
for model_pair in self.model_pairs:
|
| 239 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 240 |
+
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@torch.no_grad()
|
| 244 |
+
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
|
| 245 |
+
# gather keys before updating queue
|
| 246 |
+
image_feats = concat_all_gather(image_feat)
|
| 247 |
+
text_feats = concat_all_gather(text_feat)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
batch_size = image_feats.shape[0]
|
| 251 |
+
|
| 252 |
+
ptr = int(self.ptr_queue)
|
| 253 |
+
assert self.queue_size % batch_size == 0 # for simplicity
|
| 254 |
+
|
| 255 |
+
# replace the keys at ptr (dequeue and enqueue)
|
| 256 |
+
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
| 257 |
+
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
| 258 |
+
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
|
| 259 |
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
| 260 |
+
|
| 261 |
+
self.ptr_queue[0] = ptr
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def blip_retrieval(pretrained='',**kwargs):
|
| 265 |
+
model = BLIP_Retrieval(**kwargs)
|
| 266 |
+
if pretrained:
|
| 267 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 268 |
+
print("missing keys:")
|
| 269 |
+
print(msg.missing_keys)
|
| 270 |
+
return model
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@torch.no_grad()
|
| 274 |
+
def concat_all_gather(tensor):
|
| 275 |
+
"""
|
| 276 |
+
Performs all_gather operation on the provided tensors.
|
| 277 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 278 |
+
"""
|
| 279 |
+
tensors_gather = [torch.ones_like(tensor)
|
| 280 |
+
for _ in range(torch.distributed.get_world_size())]
|
| 281 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 282 |
+
|
| 283 |
+
output = torch.cat(tensors_gather, dim=0)
|
| 284 |
+
return output
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class GatherLayer(torch.autograd.Function):
|
| 288 |
+
"""
|
| 289 |
+
Gather tensors from all workers with support for backward propagation:
|
| 290 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def forward(ctx, x):
|
| 295 |
+
output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
|
| 296 |
+
torch.distributed.all_gather(output, x)
|
| 297 |
+
return tuple(output)
|
| 298 |
+
|
| 299 |
+
@staticmethod
|
| 300 |
+
def backward(ctx, *grads):
|
| 301 |
+
all_gradients = torch.stack(grads)
|
| 302 |
+
torch.distributed.all_reduce(all_gradients)
|
| 303 |
+
return all_gradients[torch.distributed.get_rank()]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def all_gather_with_grad(tensors):
|
| 307 |
+
"""
|
| 308 |
+
Performs all_gather operation on the provided tensors.
|
| 309 |
+
Graph remains connected for backward grad computation.
|
| 310 |
+
"""
|
| 311 |
+
# Queue the gathered tensors
|
| 312 |
+
world_size = torch.distributed.get_world_size()
|
| 313 |
+
# There is no need for reduction in the single-proc case
|
| 314 |
+
if world_size == 1:
|
| 315 |
+
return tensors
|
| 316 |
+
|
| 317 |
+
tensor_all = GatherLayer.apply(tensors)
|
| 318 |
+
|
| 319 |
+
return torch.cat(tensor_all, dim=0)
|
defake/blipmodels/blip_vqa.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
| 2 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import BertTokenizer
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
class BLIP_VQA(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
med_config = 'configs/med_config.json',
|
| 13 |
+
image_size = 480,
|
| 14 |
+
vit = 'base',
|
| 15 |
+
vit_grad_ckpt = False,
|
| 16 |
+
vit_ckpt_layer = 0,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Args:
|
| 20 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 21 |
+
image_size (int): input image size
|
| 22 |
+
vit (str): model size of vision transformer
|
| 23 |
+
"""
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
| 27 |
+
self.tokenizer = init_tokenizer()
|
| 28 |
+
|
| 29 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
| 30 |
+
encoder_config.encoder_width = vision_width
|
| 31 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 32 |
+
|
| 33 |
+
decoder_config = BertConfig.from_json_file(med_config)
|
| 34 |
+
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
|
| 38 |
+
|
| 39 |
+
image_embeds = self.visual_encoder(image)
|
| 40 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 41 |
+
|
| 42 |
+
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
|
| 43 |
+
return_tensors="pt").to(image.device)
|
| 44 |
+
question.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 45 |
+
|
| 46 |
+
if train:
|
| 47 |
+
'''
|
| 48 |
+
n: number of answers for each question
|
| 49 |
+
weights: weight for each answer
|
| 50 |
+
'''
|
| 51 |
+
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
|
| 52 |
+
answer.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 53 |
+
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
|
| 54 |
+
|
| 55 |
+
question_output = self.text_encoder(question.input_ids,
|
| 56 |
+
attention_mask = question.attention_mask,
|
| 57 |
+
encoder_hidden_states = image_embeds,
|
| 58 |
+
encoder_attention_mask = image_atts,
|
| 59 |
+
return_dict = True)
|
| 60 |
+
|
| 61 |
+
question_states = []
|
| 62 |
+
question_atts = []
|
| 63 |
+
for b, n in enumerate(n):
|
| 64 |
+
question_states += [question_output.last_hidden_state[b]]*n
|
| 65 |
+
question_atts += [question.attention_mask[b]]*n
|
| 66 |
+
question_states = torch.stack(question_states,0)
|
| 67 |
+
question_atts = torch.stack(question_atts,0)
|
| 68 |
+
|
| 69 |
+
answer_output = self.text_decoder(answer.input_ids,
|
| 70 |
+
attention_mask = answer.attention_mask,
|
| 71 |
+
encoder_hidden_states = question_states,
|
| 72 |
+
encoder_attention_mask = question_atts,
|
| 73 |
+
labels = answer_targets,
|
| 74 |
+
return_dict = True,
|
| 75 |
+
reduction = 'none',
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
loss = weights * answer_output.loss
|
| 79 |
+
loss = loss.sum()/image.size(0)
|
| 80 |
+
|
| 81 |
+
return loss
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
question_output = self.text_encoder(question.input_ids,
|
| 86 |
+
attention_mask = question.attention_mask,
|
| 87 |
+
encoder_hidden_states = image_embeds,
|
| 88 |
+
encoder_attention_mask = image_atts,
|
| 89 |
+
return_dict = True)
|
| 90 |
+
|
| 91 |
+
if inference=='generate':
|
| 92 |
+
num_beams = 3
|
| 93 |
+
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
|
| 94 |
+
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
|
| 95 |
+
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
|
| 96 |
+
|
| 97 |
+
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
|
| 98 |
+
|
| 99 |
+
outputs = self.text_decoder.generate(input_ids=bos_ids,
|
| 100 |
+
max_length=10,
|
| 101 |
+
min_length=1,
|
| 102 |
+
num_beams=num_beams,
|
| 103 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 104 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 105 |
+
**model_kwargs)
|
| 106 |
+
|
| 107 |
+
answers = []
|
| 108 |
+
for output in outputs:
|
| 109 |
+
answer = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 110 |
+
answers.append(answer)
|
| 111 |
+
return answers
|
| 112 |
+
|
| 113 |
+
elif inference=='rank':
|
| 114 |
+
max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
|
| 115 |
+
answer.input_ids, answer.attention_mask, k_test)
|
| 116 |
+
return max_ids
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
|
| 121 |
+
|
| 122 |
+
num_ques = question_states.size(0)
|
| 123 |
+
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
|
| 124 |
+
|
| 125 |
+
start_output = self.text_decoder(start_ids,
|
| 126 |
+
encoder_hidden_states = question_states,
|
| 127 |
+
encoder_attention_mask = question_atts,
|
| 128 |
+
return_dict = True,
|
| 129 |
+
reduction = 'none')
|
| 130 |
+
logits = start_output.logits[:,0,:] # first token's logit
|
| 131 |
+
|
| 132 |
+
# topk_probs: top-k probability
|
| 133 |
+
# topk_ids: [num_question, k]
|
| 134 |
+
answer_first_token = answer_ids[:,1]
|
| 135 |
+
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
|
| 136 |
+
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
|
| 137 |
+
|
| 138 |
+
# answer input: [num_question*k, answer_len]
|
| 139 |
+
input_ids = []
|
| 140 |
+
input_atts = []
|
| 141 |
+
for b, topk_id in enumerate(topk_ids):
|
| 142 |
+
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
| 143 |
+
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
| 144 |
+
input_ids = torch.cat(input_ids,dim=0)
|
| 145 |
+
input_atts = torch.cat(input_atts,dim=0)
|
| 146 |
+
|
| 147 |
+
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
|
| 148 |
+
|
| 149 |
+
# repeat encoder's output for top-k answers
|
| 150 |
+
question_states = tile(question_states, 0, k)
|
| 151 |
+
question_atts = tile(question_atts, 0, k)
|
| 152 |
+
|
| 153 |
+
output = self.text_decoder(input_ids,
|
| 154 |
+
attention_mask = input_atts,
|
| 155 |
+
encoder_hidden_states = question_states,
|
| 156 |
+
encoder_attention_mask = question_atts,
|
| 157 |
+
labels = targets_ids,
|
| 158 |
+
return_dict = True,
|
| 159 |
+
reduction = 'none')
|
| 160 |
+
|
| 161 |
+
log_probs_sum = -output.loss
|
| 162 |
+
log_probs_sum = log_probs_sum.view(num_ques,k)
|
| 163 |
+
|
| 164 |
+
max_topk_ids = log_probs_sum.argmax(dim=1)
|
| 165 |
+
max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
|
| 166 |
+
|
| 167 |
+
return max_ids
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def blip_vqa(pretrained='',**kwargs):
|
| 171 |
+
model = BLIP_VQA(**kwargs)
|
| 172 |
+
if pretrained:
|
| 173 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 174 |
+
# assert(len(msg.missing_keys)==0)
|
| 175 |
+
return model
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def tile(x, dim, n_tile):
|
| 179 |
+
init_dim = x.size(dim)
|
| 180 |
+
repeat_idx = [1] * x.dim()
|
| 181 |
+
repeat_idx[dim] = n_tile
|
| 182 |
+
x = x.repeat(*(repeat_idx))
|
| 183 |
+
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
| 184 |
+
return torch.index_select(x, dim, order_index.to(x.device))
|
| 185 |
+
|
| 186 |
+
|
defake/blipmodels/blipconfig/bert_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"hidden_act": "gelu",
|
| 7 |
+
"hidden_dropout_prob": 0.1,
|
| 8 |
+
"hidden_size": 768,
|
| 9 |
+
"initializer_range": 0.02,
|
| 10 |
+
"intermediate_size": 3072,
|
| 11 |
+
"layer_norm_eps": 1e-12,
|
| 12 |
+
"max_position_embeddings": 512,
|
| 13 |
+
"model_type": "bert",
|
| 14 |
+
"num_attention_heads": 12,
|
| 15 |
+
"num_hidden_layers": 12,
|
| 16 |
+
"pad_token_id": 0,
|
| 17 |
+
"type_vocab_size": 2,
|
| 18 |
+
"vocab_size": 30522,
|
| 19 |
+
"encoder_width": 768,
|
| 20 |
+
"add_cross_attention": true
|
| 21 |
+
}
|
defake/blipmodels/blipconfig/caption_coco.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
| 2 |
+
ann_root: 'annotation'
|
| 3 |
+
coco_gt_root: 'annotation/coco_gt'
|
| 4 |
+
|
| 5 |
+
# set pretrained as a file path or an url
|
| 6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
| 7 |
+
|
| 8 |
+
# size of vit model; base or large
|
| 9 |
+
vit: 'base'
|
| 10 |
+
vit_grad_ckpt: False
|
| 11 |
+
vit_ckpt_layer: 0
|
| 12 |
+
batch_size: 32
|
| 13 |
+
init_lr: 1e-5
|
| 14 |
+
|
| 15 |
+
# vit: 'large'
|
| 16 |
+
# vit_grad_ckpt: True
|
| 17 |
+
# vit_ckpt_layer: 5
|
| 18 |
+
# batch_size: 16
|
| 19 |
+
# init_lr: 2e-6
|
| 20 |
+
|
| 21 |
+
image_size: 384
|
| 22 |
+
|
| 23 |
+
# generation configs
|
| 24 |
+
max_length: 20
|
| 25 |
+
min_length: 5
|
| 26 |
+
num_beams: 3
|
| 27 |
+
prompt: 'a picture of '
|
| 28 |
+
|
| 29 |
+
# optimizer
|
| 30 |
+
weight_decay: 0.05
|
| 31 |
+
min_lr: 0
|
| 32 |
+
max_epoch: 5
|
| 33 |
+
|
defake/blipmodels/blipconfig/med_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"hidden_act": "gelu",
|
| 7 |
+
"hidden_dropout_prob": 0.1,
|
| 8 |
+
"hidden_size": 768,
|
| 9 |
+
"initializer_range": 0.02,
|
| 10 |
+
"intermediate_size": 3072,
|
| 11 |
+
"layer_norm_eps": 1e-12,
|
| 12 |
+
"max_position_embeddings": 512,
|
| 13 |
+
"model_type": "bert",
|
| 14 |
+
"num_attention_heads": 12,
|
| 15 |
+
"num_hidden_layers": 12,
|
| 16 |
+
"pad_token_id": 0,
|
| 17 |
+
"type_vocab_size": 2,
|
| 18 |
+
"vocab_size": 30524,
|
| 19 |
+
"encoder_width": 768,
|
| 20 |
+
"add_cross_attention": true
|
| 21 |
+
}
|
defake/blipmodels/blipconfig/nlvr.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_root: '/export/share/datasets/vision/NLVR2/'
|
| 2 |
+
ann_root: 'annotation'
|
| 3 |
+
|
| 4 |
+
# set pretrained as a file path or an url
|
| 5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
|
| 6 |
+
|
| 7 |
+
#size of vit model; base or large
|
| 8 |
+
vit: 'base'
|
| 9 |
+
batch_size_train: 16
|
| 10 |
+
batch_size_test: 64
|
| 11 |
+
vit_grad_ckpt: False
|
| 12 |
+
vit_ckpt_layer: 0
|
| 13 |
+
max_epoch: 15
|
| 14 |
+
|
| 15 |
+
image_size: 384
|
| 16 |
+
|
| 17 |
+
# optimizer
|
| 18 |
+
weight_decay: 0.05
|
| 19 |
+
init_lr: 3e-5
|
| 20 |
+
min_lr: 0
|
| 21 |
+
|
defake/blipmodels/blipconfig/nocaps.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_root: '/export/share/datasets/vision/nocaps/'
|
| 2 |
+
ann_root: 'annotation'
|
| 3 |
+
|
| 4 |
+
# set pretrained as a file path or an url
|
| 5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
| 6 |
+
|
| 7 |
+
vit: 'base'
|
| 8 |
+
batch_size: 32
|
| 9 |
+
|
| 10 |
+
image_size: 384
|
| 11 |
+
|
| 12 |
+
max_length: 20
|
| 13 |
+
min_length: 5
|
| 14 |
+
num_beams: 3
|
| 15 |
+
prompt: 'a picture of '
|
defake/blipmodels/blipconfig/pretrain.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
|
| 2 |
+
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
|
| 3 |
+
]
|
| 4 |
+
laion_path: ''
|
| 5 |
+
|
| 6 |
+
# size of vit model; base or large
|
| 7 |
+
vit: 'base'
|
| 8 |
+
vit_grad_ckpt: False
|
| 9 |
+
vit_ckpt_layer: 0
|
| 10 |
+
|
| 11 |
+
image_size: 224
|
| 12 |
+
batch_size: 75
|
| 13 |
+
|
| 14 |
+
queue_size: 57600
|
| 15 |
+
alpha: 0.4
|
| 16 |
+
|
| 17 |
+
# optimizer
|
| 18 |
+
weight_decay: 0.05
|
| 19 |
+
init_lr: 3e-4
|
| 20 |
+
min_lr: 1e-6
|
| 21 |
+
warmup_lr: 1e-6
|
| 22 |
+
lr_decay_rate: 0.9
|
| 23 |
+
max_epoch: 20
|
| 24 |
+
warmup_steps: 3000
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
defake/blipmodels/blipconfig/retrieval_coco.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
| 2 |
+
ann_root: 'annotation'
|
| 3 |
+
dataset: 'coco'
|
| 4 |
+
|
| 5 |
+
# set pretrained as a file path or an url
|
| 6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
| 7 |
+
|
| 8 |
+
# size of vit model; base or large
|
| 9 |
+
|
| 10 |
+
vit: 'base'
|
| 11 |
+
batch_size_train: 32
|
| 12 |
+
batch_size_test: 64
|
| 13 |
+
vit_grad_ckpt: True
|
| 14 |
+
vit_ckpt_layer: 4
|
| 15 |
+
init_lr: 1e-5
|
| 16 |
+
|
| 17 |
+
# vit: 'large'
|
| 18 |
+
# batch_size_train: 16
|
| 19 |
+
# batch_size_test: 32
|
| 20 |
+
# vit_grad_ckpt: True
|
| 21 |
+
# vit_ckpt_layer: 12
|
| 22 |
+
# init_lr: 5e-6
|
| 23 |
+
|
| 24 |
+
image_size: 384
|
| 25 |
+
queue_size: 57600
|
| 26 |
+
alpha: 0.4
|
| 27 |
+
k_test: 256
|
| 28 |
+
negative_all_rank: True
|
| 29 |
+
|
| 30 |
+
# optimizer
|
| 31 |
+
weight_decay: 0.05
|
| 32 |
+
min_lr: 0
|
| 33 |
+
max_epoch: 6
|
| 34 |
+
|
defake/blipmodels/blipconfig/retrieval_flickr.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_root: '/export/share/datasets/vision/flickr30k/'
|
| 2 |
+
ann_root: 'annotation'
|
| 3 |
+
dataset: 'flickr'
|
| 4 |
+
|
| 5 |
+
# set pretrained as a file path or an url
|
| 6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
|
| 7 |
+
|
| 8 |
+
# size of vit model; base or large
|
| 9 |
+
|
| 10 |
+
vit: 'base'
|
| 11 |
+
batch_size_train: 32
|
| 12 |
+
batch_size_test: 64
|
| 13 |
+
vit_grad_ckpt: True
|
| 14 |
+
vit_ckpt_layer: 4
|
| 15 |
+
init_lr: 1e-5
|
| 16 |
+
|
| 17 |
+
# vit: 'large'
|
| 18 |
+
# batch_size_train: 16
|
| 19 |
+
# batch_size_test: 32
|
| 20 |
+
# vit_grad_ckpt: True
|
| 21 |
+
# vit_ckpt_layer: 10
|
| 22 |
+
# init_lr: 5e-6
|
| 23 |
+
|
| 24 |
+
image_size: 384
|
| 25 |
+
queue_size: 57600
|
| 26 |
+
alpha: 0.4
|
| 27 |
+
k_test: 128
|
| 28 |
+
negative_all_rank: False
|
| 29 |
+
|
| 30 |
+
# optimizer
|
| 31 |
+
weight_decay: 0.05
|
| 32 |
+
min_lr: 0
|
| 33 |
+
max_epoch: 6
|
| 34 |
+
|
defake/blipmodels/blipconfig/retrieval_msrvtt.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
|
| 2 |
+
ann_root: 'annotation'
|
| 3 |
+
|
| 4 |
+
# set pretrained as a file path or an url
|
| 5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
| 6 |
+
|
| 7 |
+
# size of vit model; base or large
|
| 8 |
+
vit: 'base'
|
| 9 |
+
batch_size: 64
|
| 10 |
+
k_test: 128
|
| 11 |
+
image_size: 384
|
| 12 |
+
num_frm_test: 8
|
defake/blipmodels/blipconfig/vqa.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
|
| 2 |
+
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
|
| 3 |
+
train_files: ['vqa_train','vqa_val','vg_qa']
|
| 4 |
+
ann_root: 'annotation'
|
| 5 |
+
|
| 6 |
+
# set pretrained as a file path or an url
|
| 7 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
|
| 8 |
+
|
| 9 |
+
# size of vit model; base or large
|
| 10 |
+
vit: 'base'
|
| 11 |
+
batch_size_train: 16
|
| 12 |
+
batch_size_test: 32
|
| 13 |
+
vit_grad_ckpt: False
|
| 14 |
+
vit_ckpt_layer: 0
|
| 15 |
+
init_lr: 2e-5
|
| 16 |
+
|
| 17 |
+
image_size: 480
|
| 18 |
+
|
| 19 |
+
k_test: 128
|
| 20 |
+
inference: 'rank'
|
| 21 |
+
|
| 22 |
+
# optimizer
|
| 23 |
+
weight_decay: 0.05
|
| 24 |
+
min_lr: 0
|
| 25 |
+
max_epoch: 10
|
defake/blipmodels/med.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
* Based on huggingface code base
|
| 8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import os
|
| 13 |
+
import warnings
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import Tensor, device, dtype, nn
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torch.nn import CrossEntropyLoss
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from transformers.activations import ACT2FN
|
| 25 |
+
from transformers.file_utils import (
|
| 26 |
+
ModelOutput,
|
| 27 |
+
)
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 31 |
+
CausalLMOutputWithCrossAttentions,
|
| 32 |
+
MaskedLMOutput,
|
| 33 |
+
MultipleChoiceModelOutput,
|
| 34 |
+
NextSentencePredictorOutput,
|
| 35 |
+
QuestionAnsweringModelOutput,
|
| 36 |
+
SequenceClassifierOutput,
|
| 37 |
+
TokenClassifierOutput,
|
| 38 |
+
)
|
| 39 |
+
from transformers.modeling_utils import (
|
| 40 |
+
PreTrainedModel,
|
| 41 |
+
apply_chunking_to_forward,
|
| 42 |
+
find_pruneable_heads_and_indices,
|
| 43 |
+
prune_linear_layer,
|
| 44 |
+
)
|
| 45 |
+
from transformers.utils import logging
|
| 46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class BertEmbeddings(nn.Module):
|
| 53 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 59 |
+
|
| 60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 61 |
+
# any TensorFlow checkpoint file
|
| 62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 64 |
+
|
| 65 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 66 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 67 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 73 |
+
):
|
| 74 |
+
if input_ids is not None:
|
| 75 |
+
input_shape = input_ids.size()
|
| 76 |
+
else:
|
| 77 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 78 |
+
|
| 79 |
+
seq_length = input_shape[1]
|
| 80 |
+
|
| 81 |
+
if position_ids is None:
|
| 82 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 83 |
+
|
| 84 |
+
if inputs_embeds is None:
|
| 85 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 86 |
+
|
| 87 |
+
embeddings = inputs_embeds
|
| 88 |
+
|
| 89 |
+
if self.position_embedding_type == "absolute":
|
| 90 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 91 |
+
embeddings += position_embeddings
|
| 92 |
+
embeddings = self.LayerNorm(embeddings)
|
| 93 |
+
embeddings = self.dropout(embeddings)
|
| 94 |
+
return embeddings
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class BertSelfAttention(nn.Module):
|
| 98 |
+
def __init__(self, config, is_cross_attention):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.config = config
|
| 101 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 104 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.num_attention_heads = config.num_attention_heads
|
| 108 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 109 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 110 |
+
|
| 111 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 112 |
+
if is_cross_attention:
|
| 113 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 114 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 115 |
+
else:
|
| 116 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 117 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 118 |
+
|
| 119 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 120 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 121 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 122 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 123 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 124 |
+
self.save_attention = False
|
| 125 |
+
|
| 126 |
+
def save_attn_gradients(self, attn_gradients):
|
| 127 |
+
self.attn_gradients = attn_gradients
|
| 128 |
+
|
| 129 |
+
def get_attn_gradients(self):
|
| 130 |
+
return self.attn_gradients
|
| 131 |
+
|
| 132 |
+
def save_attention_map(self, attention_map):
|
| 133 |
+
self.attention_map = attention_map
|
| 134 |
+
|
| 135 |
+
def get_attention_map(self):
|
| 136 |
+
return self.attention_map
|
| 137 |
+
|
| 138 |
+
def transpose_for_scores(self, x):
|
| 139 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 140 |
+
x = x.view(*new_x_shape)
|
| 141 |
+
return x.permute(0, 2, 1, 3)
|
| 142 |
+
|
| 143 |
+
def forward(
|
| 144 |
+
self,
|
| 145 |
+
hidden_states,
|
| 146 |
+
attention_mask=None,
|
| 147 |
+
head_mask=None,
|
| 148 |
+
encoder_hidden_states=None,
|
| 149 |
+
encoder_attention_mask=None,
|
| 150 |
+
past_key_value=None,
|
| 151 |
+
output_attentions=False,
|
| 152 |
+
):
|
| 153 |
+
mixed_query_layer = self.query(hidden_states)
|
| 154 |
+
|
| 155 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 156 |
+
# and values come from an encoder; the attention mask needs to be
|
| 157 |
+
# such that the encoder's padding tokens are not attended to.
|
| 158 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 159 |
+
|
| 160 |
+
if is_cross_attention:
|
| 161 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 162 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 163 |
+
attention_mask = encoder_attention_mask
|
| 164 |
+
elif past_key_value is not None:
|
| 165 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 166 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 167 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 168 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 169 |
+
else:
|
| 170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 172 |
+
|
| 173 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 174 |
+
|
| 175 |
+
past_key_value = (key_layer, value_layer)
|
| 176 |
+
|
| 177 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 178 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 179 |
+
|
| 180 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 181 |
+
seq_length = hidden_states.size()[1]
|
| 182 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 183 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 184 |
+
distance = position_ids_l - position_ids_r
|
| 185 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 186 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 187 |
+
|
| 188 |
+
if self.position_embedding_type == "relative_key":
|
| 189 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 190 |
+
attention_scores = attention_scores + relative_position_scores
|
| 191 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 192 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 193 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 194 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 195 |
+
|
| 196 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 197 |
+
if attention_mask is not None:
|
| 198 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 199 |
+
attention_scores = attention_scores + attention_mask
|
| 200 |
+
|
| 201 |
+
# Normalize the attention scores to probabilities.
|
| 202 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 203 |
+
|
| 204 |
+
if is_cross_attention and self.save_attention:
|
| 205 |
+
self.save_attention_map(attention_probs)
|
| 206 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 207 |
+
|
| 208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 210 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 211 |
+
|
| 212 |
+
# Mask heads if we want to
|
| 213 |
+
if head_mask is not None:
|
| 214 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 215 |
+
|
| 216 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 217 |
+
|
| 218 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 219 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 220 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 221 |
+
|
| 222 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 223 |
+
|
| 224 |
+
outputs = outputs + (past_key_value,)
|
| 225 |
+
return outputs
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class BertSelfOutput(nn.Module):
|
| 229 |
+
def __init__(self, config):
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 232 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 233 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 234 |
+
|
| 235 |
+
def forward(self, hidden_states, input_tensor):
|
| 236 |
+
hidden_states = self.dense(hidden_states)
|
| 237 |
+
hidden_states = self.dropout(hidden_states)
|
| 238 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 239 |
+
return hidden_states
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class BertAttention(nn.Module):
|
| 243 |
+
def __init__(self, config, is_cross_attention=False):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 246 |
+
self.output = BertSelfOutput(config)
|
| 247 |
+
self.pruned_heads = set()
|
| 248 |
+
|
| 249 |
+
def prune_heads(self, heads):
|
| 250 |
+
if len(heads) == 0:
|
| 251 |
+
return
|
| 252 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 253 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Prune linear layers
|
| 257 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 258 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 259 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 260 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 261 |
+
|
| 262 |
+
# Update hyper params and store pruned heads
|
| 263 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 264 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 265 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 266 |
+
|
| 267 |
+
def forward(
|
| 268 |
+
self,
|
| 269 |
+
hidden_states,
|
| 270 |
+
attention_mask=None,
|
| 271 |
+
head_mask=None,
|
| 272 |
+
encoder_hidden_states=None,
|
| 273 |
+
encoder_attention_mask=None,
|
| 274 |
+
past_key_value=None,
|
| 275 |
+
output_attentions=False,
|
| 276 |
+
):
|
| 277 |
+
self_outputs = self.self(
|
| 278 |
+
hidden_states,
|
| 279 |
+
attention_mask,
|
| 280 |
+
head_mask,
|
| 281 |
+
encoder_hidden_states,
|
| 282 |
+
encoder_attention_mask,
|
| 283 |
+
past_key_value,
|
| 284 |
+
output_attentions,
|
| 285 |
+
)
|
| 286 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 287 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 288 |
+
return outputs
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class BertIntermediate(nn.Module):
|
| 292 |
+
def __init__(self, config):
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 295 |
+
if isinstance(config.hidden_act, str):
|
| 296 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 297 |
+
else:
|
| 298 |
+
self.intermediate_act_fn = config.hidden_act
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states):
|
| 301 |
+
hidden_states = self.dense(hidden_states)
|
| 302 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 303 |
+
return hidden_states
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class BertOutput(nn.Module):
|
| 307 |
+
def __init__(self, config):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 312 |
+
|
| 313 |
+
def forward(self, hidden_states, input_tensor):
|
| 314 |
+
hidden_states = self.dense(hidden_states)
|
| 315 |
+
hidden_states = self.dropout(hidden_states)
|
| 316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 317 |
+
return hidden_states
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class BertLayer(nn.Module):
|
| 321 |
+
def __init__(self, config, layer_num):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.config = config
|
| 324 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 325 |
+
self.seq_len_dim = 1
|
| 326 |
+
self.attention = BertAttention(config)
|
| 327 |
+
self.layer_num = layer_num
|
| 328 |
+
if self.config.add_cross_attention:
|
| 329 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
| 330 |
+
self.intermediate = BertIntermediate(config)
|
| 331 |
+
self.output = BertOutput(config)
|
| 332 |
+
|
| 333 |
+
def forward(
|
| 334 |
+
self,
|
| 335 |
+
hidden_states,
|
| 336 |
+
attention_mask=None,
|
| 337 |
+
head_mask=None,
|
| 338 |
+
encoder_hidden_states=None,
|
| 339 |
+
encoder_attention_mask=None,
|
| 340 |
+
past_key_value=None,
|
| 341 |
+
output_attentions=False,
|
| 342 |
+
mode=None,
|
| 343 |
+
):
|
| 344 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 345 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 346 |
+
self_attention_outputs = self.attention(
|
| 347 |
+
hidden_states,
|
| 348 |
+
attention_mask,
|
| 349 |
+
head_mask,
|
| 350 |
+
output_attentions=output_attentions,
|
| 351 |
+
past_key_value=self_attn_past_key_value,
|
| 352 |
+
)
|
| 353 |
+
attention_output = self_attention_outputs[0]
|
| 354 |
+
|
| 355 |
+
outputs = self_attention_outputs[1:-1]
|
| 356 |
+
present_key_value = self_attention_outputs[-1]
|
| 357 |
+
|
| 358 |
+
if mode=='multimodal':
|
| 359 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
| 360 |
+
|
| 361 |
+
cross_attention_outputs = self.crossattention(
|
| 362 |
+
attention_output,
|
| 363 |
+
attention_mask,
|
| 364 |
+
head_mask,
|
| 365 |
+
encoder_hidden_states,
|
| 366 |
+
encoder_attention_mask,
|
| 367 |
+
output_attentions=output_attentions,
|
| 368 |
+
)
|
| 369 |
+
attention_output = cross_attention_outputs[0]
|
| 370 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 371 |
+
layer_output = apply_chunking_to_forward(
|
| 372 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 373 |
+
)
|
| 374 |
+
outputs = (layer_output,) + outputs
|
| 375 |
+
|
| 376 |
+
outputs = outputs + (present_key_value,)
|
| 377 |
+
|
| 378 |
+
return outputs
|
| 379 |
+
|
| 380 |
+
def feed_forward_chunk(self, attention_output):
|
| 381 |
+
intermediate_output = self.intermediate(attention_output)
|
| 382 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 383 |
+
return layer_output
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class BertEncoder(nn.Module):
|
| 387 |
+
def __init__(self, config):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.config = config
|
| 390 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
| 391 |
+
self.gradient_checkpointing = False
|
| 392 |
+
|
| 393 |
+
def forward(
|
| 394 |
+
self,
|
| 395 |
+
hidden_states,
|
| 396 |
+
attention_mask=None,
|
| 397 |
+
head_mask=None,
|
| 398 |
+
encoder_hidden_states=None,
|
| 399 |
+
encoder_attention_mask=None,
|
| 400 |
+
past_key_values=None,
|
| 401 |
+
use_cache=None,
|
| 402 |
+
output_attentions=False,
|
| 403 |
+
output_hidden_states=False,
|
| 404 |
+
return_dict=True,
|
| 405 |
+
mode='multimodal',
|
| 406 |
+
):
|
| 407 |
+
all_hidden_states = () if output_hidden_states else None
|
| 408 |
+
all_self_attentions = () if output_attentions else None
|
| 409 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 410 |
+
|
| 411 |
+
next_decoder_cache = () if use_cache else None
|
| 412 |
+
|
| 413 |
+
for i in range(self.config.num_hidden_layers):
|
| 414 |
+
layer_module = self.layer[i]
|
| 415 |
+
if output_hidden_states:
|
| 416 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 417 |
+
|
| 418 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 419 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 420 |
+
|
| 421 |
+
if self.gradient_checkpointing and self.training:
|
| 422 |
+
|
| 423 |
+
if use_cache:
|
| 424 |
+
logger.warn(
|
| 425 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 426 |
+
)
|
| 427 |
+
use_cache = False
|
| 428 |
+
|
| 429 |
+
def create_custom_forward(module):
|
| 430 |
+
def custom_forward(*inputs):
|
| 431 |
+
return module(*inputs, past_key_value, output_attentions)
|
| 432 |
+
|
| 433 |
+
return custom_forward
|
| 434 |
+
|
| 435 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 436 |
+
create_custom_forward(layer_module),
|
| 437 |
+
hidden_states,
|
| 438 |
+
attention_mask,
|
| 439 |
+
layer_head_mask,
|
| 440 |
+
encoder_hidden_states,
|
| 441 |
+
encoder_attention_mask,
|
| 442 |
+
mode=mode,
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
layer_outputs = layer_module(
|
| 446 |
+
hidden_states,
|
| 447 |
+
attention_mask,
|
| 448 |
+
layer_head_mask,
|
| 449 |
+
encoder_hidden_states,
|
| 450 |
+
encoder_attention_mask,
|
| 451 |
+
past_key_value,
|
| 452 |
+
output_attentions,
|
| 453 |
+
mode=mode,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
hidden_states = layer_outputs[0]
|
| 457 |
+
if use_cache:
|
| 458 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 459 |
+
if output_attentions:
|
| 460 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 461 |
+
|
| 462 |
+
if output_hidden_states:
|
| 463 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 464 |
+
|
| 465 |
+
if not return_dict:
|
| 466 |
+
return tuple(
|
| 467 |
+
v
|
| 468 |
+
for v in [
|
| 469 |
+
hidden_states,
|
| 470 |
+
next_decoder_cache,
|
| 471 |
+
all_hidden_states,
|
| 472 |
+
all_self_attentions,
|
| 473 |
+
all_cross_attentions,
|
| 474 |
+
]
|
| 475 |
+
if v is not None
|
| 476 |
+
)
|
| 477 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 478 |
+
last_hidden_state=hidden_states,
|
| 479 |
+
past_key_values=next_decoder_cache,
|
| 480 |
+
hidden_states=all_hidden_states,
|
| 481 |
+
attentions=all_self_attentions,
|
| 482 |
+
cross_attentions=all_cross_attentions,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class BertPooler(nn.Module):
|
| 487 |
+
def __init__(self, config):
|
| 488 |
+
super().__init__()
|
| 489 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 490 |
+
self.activation = nn.Tanh()
|
| 491 |
+
|
| 492 |
+
def forward(self, hidden_states):
|
| 493 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 494 |
+
# to the first token.
|
| 495 |
+
first_token_tensor = hidden_states[:, 0]
|
| 496 |
+
pooled_output = self.dense(first_token_tensor)
|
| 497 |
+
pooled_output = self.activation(pooled_output)
|
| 498 |
+
return pooled_output
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 502 |
+
def __init__(self, config):
|
| 503 |
+
super().__init__()
|
| 504 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 505 |
+
if isinstance(config.hidden_act, str):
|
| 506 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 507 |
+
else:
|
| 508 |
+
self.transform_act_fn = config.hidden_act
|
| 509 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 510 |
+
|
| 511 |
+
def forward(self, hidden_states):
|
| 512 |
+
hidden_states = self.dense(hidden_states)
|
| 513 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 514 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 515 |
+
return hidden_states
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class BertLMPredictionHead(nn.Module):
|
| 519 |
+
def __init__(self, config):
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 522 |
+
|
| 523 |
+
# The output weights are the same as the input embeddings, but there is
|
| 524 |
+
# an output-only bias for each token.
|
| 525 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 526 |
+
|
| 527 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 528 |
+
|
| 529 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 530 |
+
self.decoder.bias = self.bias
|
| 531 |
+
|
| 532 |
+
def forward(self, hidden_states):
|
| 533 |
+
hidden_states = self.transform(hidden_states)
|
| 534 |
+
hidden_states = self.decoder(hidden_states)
|
| 535 |
+
return hidden_states
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class BertOnlyMLMHead(nn.Module):
|
| 539 |
+
def __init__(self, config):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.predictions = BertLMPredictionHead(config)
|
| 542 |
+
|
| 543 |
+
def forward(self, sequence_output):
|
| 544 |
+
prediction_scores = self.predictions(sequence_output)
|
| 545 |
+
return prediction_scores
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 549 |
+
"""
|
| 550 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 551 |
+
models.
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
config_class = BertConfig
|
| 555 |
+
base_model_prefix = "bert"
|
| 556 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 557 |
+
|
| 558 |
+
def _init_weights(self, module):
|
| 559 |
+
""" Initialize the weights """
|
| 560 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 561 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 562 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 563 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 564 |
+
elif isinstance(module, nn.LayerNorm):
|
| 565 |
+
module.bias.data.zero_()
|
| 566 |
+
module.weight.data.fill_(1.0)
|
| 567 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 568 |
+
module.bias.data.zero_()
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
class BertModel(BertPreTrainedModel):
|
| 572 |
+
"""
|
| 573 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 574 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 575 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 576 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 577 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 578 |
+
input to the forward pass.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 582 |
+
super().__init__(config)
|
| 583 |
+
self.config = config
|
| 584 |
+
|
| 585 |
+
self.embeddings = BertEmbeddings(config)
|
| 586 |
+
|
| 587 |
+
self.encoder = BertEncoder(config)
|
| 588 |
+
|
| 589 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 590 |
+
|
| 591 |
+
self.init_weights()
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def get_input_embeddings(self):
|
| 595 |
+
return self.embeddings.word_embeddings
|
| 596 |
+
|
| 597 |
+
def set_input_embeddings(self, value):
|
| 598 |
+
self.embeddings.word_embeddings = value
|
| 599 |
+
|
| 600 |
+
def _prune_heads(self, heads_to_prune):
|
| 601 |
+
"""
|
| 602 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 603 |
+
class PreTrainedModel
|
| 604 |
+
"""
|
| 605 |
+
for layer, heads in heads_to_prune.items():
|
| 606 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
| 610 |
+
"""
|
| 611 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 612 |
+
|
| 613 |
+
Arguments:
|
| 614 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 615 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 616 |
+
input_shape (:obj:`Tuple[int]`):
|
| 617 |
+
The shape of the input to the model.
|
| 618 |
+
device: (:obj:`torch.device`):
|
| 619 |
+
The device of the input to the model.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 623 |
+
"""
|
| 624 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 625 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 626 |
+
if attention_mask.dim() == 3:
|
| 627 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 628 |
+
elif attention_mask.dim() == 2:
|
| 629 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 630 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 631 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 632 |
+
if is_decoder:
|
| 633 |
+
batch_size, seq_length = input_shape
|
| 634 |
+
|
| 635 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 636 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 637 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
| 638 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 639 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 640 |
+
|
| 641 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 642 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 643 |
+
causal_mask = torch.cat(
|
| 644 |
+
[
|
| 645 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
| 646 |
+
causal_mask,
|
| 647 |
+
],
|
| 648 |
+
axis=-1,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 652 |
+
else:
|
| 653 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 654 |
+
else:
|
| 655 |
+
raise ValueError(
|
| 656 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 657 |
+
input_shape, attention_mask.shape
|
| 658 |
+
)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 662 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 663 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 664 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 665 |
+
# effectively the same as removing these entirely.
|
| 666 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 667 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 668 |
+
return extended_attention_mask
|
| 669 |
+
|
| 670 |
+
def forward(
|
| 671 |
+
self,
|
| 672 |
+
input_ids=None,
|
| 673 |
+
attention_mask=None,
|
| 674 |
+
position_ids=None,
|
| 675 |
+
head_mask=None,
|
| 676 |
+
inputs_embeds=None,
|
| 677 |
+
encoder_embeds=None,
|
| 678 |
+
encoder_hidden_states=None,
|
| 679 |
+
encoder_attention_mask=None,
|
| 680 |
+
past_key_values=None,
|
| 681 |
+
use_cache=None,
|
| 682 |
+
output_attentions=None,
|
| 683 |
+
output_hidden_states=None,
|
| 684 |
+
return_dict=None,
|
| 685 |
+
is_decoder=False,
|
| 686 |
+
mode='multimodal',
|
| 687 |
+
):
|
| 688 |
+
r"""
|
| 689 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 690 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 691 |
+
the model is configured as a decoder.
|
| 692 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 693 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 694 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 695 |
+
- 1 for tokens that are **not masked**,
|
| 696 |
+
- 0 for tokens that are **masked**.
|
| 697 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 698 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 699 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 700 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 701 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 702 |
+
use_cache (:obj:`bool`, `optional`):
|
| 703 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 704 |
+
decoding (see :obj:`past_key_values`).
|
| 705 |
+
"""
|
| 706 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 707 |
+
output_hidden_states = (
|
| 708 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 709 |
+
)
|
| 710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 711 |
+
|
| 712 |
+
if is_decoder:
|
| 713 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 714 |
+
else:
|
| 715 |
+
use_cache = False
|
| 716 |
+
|
| 717 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 718 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 719 |
+
elif input_ids is not None:
|
| 720 |
+
input_shape = input_ids.size()
|
| 721 |
+
batch_size, seq_length = input_shape
|
| 722 |
+
device = input_ids.device
|
| 723 |
+
elif inputs_embeds is not None:
|
| 724 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 725 |
+
batch_size, seq_length = input_shape
|
| 726 |
+
device = inputs_embeds.device
|
| 727 |
+
elif encoder_embeds is not None:
|
| 728 |
+
input_shape = encoder_embeds.size()[:-1]
|
| 729 |
+
batch_size, seq_length = input_shape
|
| 730 |
+
device = encoder_embeds.device
|
| 731 |
+
else:
|
| 732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
| 733 |
+
|
| 734 |
+
# past_key_values_length
|
| 735 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 736 |
+
|
| 737 |
+
if attention_mask is None:
|
| 738 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 739 |
+
|
| 740 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 741 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 742 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
| 743 |
+
device, is_decoder)
|
| 744 |
+
|
| 745 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 746 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 747 |
+
if encoder_hidden_states is not None:
|
| 748 |
+
if type(encoder_hidden_states) == list:
|
| 749 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 750 |
+
else:
|
| 751 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 752 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 753 |
+
|
| 754 |
+
if type(encoder_attention_mask) == list:
|
| 755 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 756 |
+
elif encoder_attention_mask is None:
|
| 757 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 758 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 759 |
+
else:
|
| 760 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 761 |
+
else:
|
| 762 |
+
encoder_extended_attention_mask = None
|
| 763 |
+
|
| 764 |
+
# Prepare head mask if needed
|
| 765 |
+
# 1.0 in head_mask indicate we keep the head
|
| 766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 769 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 770 |
+
|
| 771 |
+
if encoder_embeds is None:
|
| 772 |
+
embedding_output = self.embeddings(
|
| 773 |
+
input_ids=input_ids,
|
| 774 |
+
position_ids=position_ids,
|
| 775 |
+
inputs_embeds=inputs_embeds,
|
| 776 |
+
past_key_values_length=past_key_values_length,
|
| 777 |
+
)
|
| 778 |
+
else:
|
| 779 |
+
embedding_output = encoder_embeds
|
| 780 |
+
|
| 781 |
+
encoder_outputs = self.encoder(
|
| 782 |
+
embedding_output,
|
| 783 |
+
attention_mask=extended_attention_mask,
|
| 784 |
+
head_mask=head_mask,
|
| 785 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 786 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 787 |
+
past_key_values=past_key_values,
|
| 788 |
+
use_cache=use_cache,
|
| 789 |
+
output_attentions=output_attentions,
|
| 790 |
+
output_hidden_states=output_hidden_states,
|
| 791 |
+
return_dict=return_dict,
|
| 792 |
+
mode=mode,
|
| 793 |
+
)
|
| 794 |
+
sequence_output = encoder_outputs[0]
|
| 795 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 796 |
+
|
| 797 |
+
if not return_dict:
|
| 798 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 799 |
+
|
| 800 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 801 |
+
last_hidden_state=sequence_output,
|
| 802 |
+
pooler_output=pooled_output,
|
| 803 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 804 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 805 |
+
attentions=encoder_outputs.attentions,
|
| 806 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 812 |
+
|
| 813 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 814 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 815 |
+
|
| 816 |
+
def __init__(self, config):
|
| 817 |
+
super().__init__(config)
|
| 818 |
+
|
| 819 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 820 |
+
self.cls = BertOnlyMLMHead(config)
|
| 821 |
+
|
| 822 |
+
self.init_weights()
|
| 823 |
+
|
| 824 |
+
def get_output_embeddings(self):
|
| 825 |
+
return self.cls.predictions.decoder
|
| 826 |
+
|
| 827 |
+
def set_output_embeddings(self, new_embeddings):
|
| 828 |
+
self.cls.predictions.decoder = new_embeddings
|
| 829 |
+
|
| 830 |
+
def forward(
|
| 831 |
+
self,
|
| 832 |
+
input_ids=None,
|
| 833 |
+
attention_mask=None,
|
| 834 |
+
position_ids=None,
|
| 835 |
+
head_mask=None,
|
| 836 |
+
inputs_embeds=None,
|
| 837 |
+
encoder_hidden_states=None,
|
| 838 |
+
encoder_attention_mask=None,
|
| 839 |
+
labels=None,
|
| 840 |
+
past_key_values=None,
|
| 841 |
+
use_cache=None,
|
| 842 |
+
output_attentions=None,
|
| 843 |
+
output_hidden_states=None,
|
| 844 |
+
return_dict=None,
|
| 845 |
+
return_logits=False,
|
| 846 |
+
is_decoder=True,
|
| 847 |
+
reduction='mean',
|
| 848 |
+
mode='multimodal',
|
| 849 |
+
):
|
| 850 |
+
r"""
|
| 851 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 852 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 853 |
+
the model is configured as a decoder.
|
| 854 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 855 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 856 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 857 |
+
- 1 for tokens that are **not masked**,
|
| 858 |
+
- 0 for tokens that are **masked**.
|
| 859 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 860 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 861 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
| 862 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
| 863 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 864 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 865 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 866 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 867 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 868 |
+
use_cache (:obj:`bool`, `optional`):
|
| 869 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 870 |
+
decoding (see :obj:`past_key_values`).
|
| 871 |
+
Returns:
|
| 872 |
+
Example::
|
| 873 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 874 |
+
>>> import torch
|
| 875 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 876 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 877 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 878 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 879 |
+
>>> outputs = model(**inputs)
|
| 880 |
+
>>> prediction_logits = outputs.logits
|
| 881 |
+
"""
|
| 882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 883 |
+
if labels is not None:
|
| 884 |
+
use_cache = False
|
| 885 |
+
|
| 886 |
+
outputs = self.bert(
|
| 887 |
+
input_ids,
|
| 888 |
+
attention_mask=attention_mask,
|
| 889 |
+
position_ids=position_ids,
|
| 890 |
+
head_mask=head_mask,
|
| 891 |
+
inputs_embeds=inputs_embeds,
|
| 892 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 893 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 894 |
+
past_key_values=past_key_values,
|
| 895 |
+
use_cache=use_cache,
|
| 896 |
+
output_attentions=output_attentions,
|
| 897 |
+
output_hidden_states=output_hidden_states,
|
| 898 |
+
return_dict=return_dict,
|
| 899 |
+
is_decoder=is_decoder,
|
| 900 |
+
mode=mode,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
sequence_output = outputs[0]
|
| 904 |
+
prediction_scores = self.cls(sequence_output)
|
| 905 |
+
|
| 906 |
+
if return_logits:
|
| 907 |
+
return prediction_scores[:, :-1, :].contiguous()
|
| 908 |
+
|
| 909 |
+
lm_loss = None
|
| 910 |
+
if labels is not None:
|
| 911 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 912 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 913 |
+
labels = labels[:, 1:].contiguous()
|
| 914 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
| 915 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 916 |
+
if reduction=='none':
|
| 917 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
| 918 |
+
|
| 919 |
+
if not return_dict:
|
| 920 |
+
output = (prediction_scores,) + outputs[2:]
|
| 921 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 922 |
+
|
| 923 |
+
return CausalLMOutputWithCrossAttentions(
|
| 924 |
+
loss=lm_loss,
|
| 925 |
+
logits=prediction_scores,
|
| 926 |
+
past_key_values=outputs.past_key_values,
|
| 927 |
+
hidden_states=outputs.hidden_states,
|
| 928 |
+
attentions=outputs.attentions,
|
| 929 |
+
cross_attentions=outputs.cross_attentions,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
| 933 |
+
input_shape = input_ids.shape
|
| 934 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 935 |
+
if attention_mask is None:
|
| 936 |
+
attention_mask = input_ids.new_ones(input_shape)
|
| 937 |
+
|
| 938 |
+
# cut decoder_input_ids if past is used
|
| 939 |
+
if past is not None:
|
| 940 |
+
input_ids = input_ids[:, -1:]
|
| 941 |
+
|
| 942 |
+
return {
|
| 943 |
+
"input_ids": input_ids,
|
| 944 |
+
"attention_mask": attention_mask,
|
| 945 |
+
"past_key_values": past,
|
| 946 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
| 947 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
| 948 |
+
"is_decoder": True,
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
def _reorder_cache(self, past, beam_idx):
|
| 952 |
+
reordered_past = ()
|
| 953 |
+
for layer_past in past:
|
| 954 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 955 |
+
return reordered_past
|
defake/blipmodels/nlvr_encoder.py
ADDED
|
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, device, dtype, nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import CrossEntropyLoss
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from transformers.activations import ACT2FN
|
| 15 |
+
from transformers.file_utils import (
|
| 16 |
+
ModelOutput,
|
| 17 |
+
)
|
| 18 |
+
from transformers.modeling_outputs import (
|
| 19 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 20 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 21 |
+
CausalLMOutputWithCrossAttentions,
|
| 22 |
+
MaskedLMOutput,
|
| 23 |
+
MultipleChoiceModelOutput,
|
| 24 |
+
NextSentencePredictorOutput,
|
| 25 |
+
QuestionAnsweringModelOutput,
|
| 26 |
+
SequenceClassifierOutput,
|
| 27 |
+
TokenClassifierOutput,
|
| 28 |
+
)
|
| 29 |
+
from transformers.modeling_utils import (
|
| 30 |
+
PreTrainedModel,
|
| 31 |
+
apply_chunking_to_forward,
|
| 32 |
+
find_pruneable_heads_and_indices,
|
| 33 |
+
prune_linear_layer,
|
| 34 |
+
)
|
| 35 |
+
from transformers.utils import logging
|
| 36 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BertEmbeddings(nn.Module):
|
| 43 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 48 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 49 |
+
|
| 50 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 51 |
+
# any TensorFlow checkpoint file
|
| 52 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 53 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 54 |
+
|
| 55 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 56 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 57 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 58 |
+
|
| 59 |
+
self.config = config
|
| 60 |
+
|
| 61 |
+
def forward(
|
| 62 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 63 |
+
):
|
| 64 |
+
if input_ids is not None:
|
| 65 |
+
input_shape = input_ids.size()
|
| 66 |
+
else:
|
| 67 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 68 |
+
|
| 69 |
+
seq_length = input_shape[1]
|
| 70 |
+
|
| 71 |
+
if position_ids is None:
|
| 72 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 73 |
+
|
| 74 |
+
if inputs_embeds is None:
|
| 75 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 76 |
+
|
| 77 |
+
embeddings = inputs_embeds
|
| 78 |
+
|
| 79 |
+
if self.position_embedding_type == "absolute":
|
| 80 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 81 |
+
embeddings += position_embeddings
|
| 82 |
+
embeddings = self.LayerNorm(embeddings)
|
| 83 |
+
embeddings = self.dropout(embeddings)
|
| 84 |
+
return embeddings
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BertSelfAttention(nn.Module):
|
| 88 |
+
def __init__(self, config, is_cross_attention):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.config = config
|
| 91 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 92 |
+
raise ValueError(
|
| 93 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 94 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.num_attention_heads = config.num_attention_heads
|
| 98 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 99 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 100 |
+
|
| 101 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 102 |
+
if is_cross_attention:
|
| 103 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 104 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 105 |
+
else:
|
| 106 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 107 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 108 |
+
|
| 109 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 110 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 111 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 112 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 113 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 114 |
+
self.save_attention = False
|
| 115 |
+
|
| 116 |
+
def save_attn_gradients(self, attn_gradients):
|
| 117 |
+
self.attn_gradients = attn_gradients
|
| 118 |
+
|
| 119 |
+
def get_attn_gradients(self):
|
| 120 |
+
return self.attn_gradients
|
| 121 |
+
|
| 122 |
+
def save_attention_map(self, attention_map):
|
| 123 |
+
self.attention_map = attention_map
|
| 124 |
+
|
| 125 |
+
def get_attention_map(self):
|
| 126 |
+
return self.attention_map
|
| 127 |
+
|
| 128 |
+
def transpose_for_scores(self, x):
|
| 129 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 130 |
+
x = x.view(*new_x_shape)
|
| 131 |
+
return x.permute(0, 2, 1, 3)
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
hidden_states,
|
| 136 |
+
attention_mask=None,
|
| 137 |
+
head_mask=None,
|
| 138 |
+
encoder_hidden_states=None,
|
| 139 |
+
encoder_attention_mask=None,
|
| 140 |
+
past_key_value=None,
|
| 141 |
+
output_attentions=False,
|
| 142 |
+
):
|
| 143 |
+
mixed_query_layer = self.query(hidden_states)
|
| 144 |
+
|
| 145 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 146 |
+
# and values come from an encoder; the attention mask needs to be
|
| 147 |
+
# such that the encoder's padding tokens are not attended to.
|
| 148 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 149 |
+
|
| 150 |
+
if is_cross_attention:
|
| 151 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 152 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 153 |
+
attention_mask = encoder_attention_mask
|
| 154 |
+
elif past_key_value is not None:
|
| 155 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 156 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 157 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 158 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 159 |
+
else:
|
| 160 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 161 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 162 |
+
|
| 163 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 164 |
+
|
| 165 |
+
past_key_value = (key_layer, value_layer)
|
| 166 |
+
|
| 167 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 168 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 169 |
+
|
| 170 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 171 |
+
seq_length = hidden_states.size()[1]
|
| 172 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 173 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 174 |
+
distance = position_ids_l - position_ids_r
|
| 175 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 176 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 177 |
+
|
| 178 |
+
if self.position_embedding_type == "relative_key":
|
| 179 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 180 |
+
attention_scores = attention_scores + relative_position_scores
|
| 181 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 182 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 183 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 184 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 185 |
+
|
| 186 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 187 |
+
if attention_mask is not None:
|
| 188 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 189 |
+
attention_scores = attention_scores + attention_mask
|
| 190 |
+
|
| 191 |
+
# Normalize the attention scores to probabilities.
|
| 192 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 193 |
+
|
| 194 |
+
if is_cross_attention and self.save_attention:
|
| 195 |
+
self.save_attention_map(attention_probs)
|
| 196 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 197 |
+
|
| 198 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 199 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 200 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 201 |
+
|
| 202 |
+
# Mask heads if we want to
|
| 203 |
+
if head_mask is not None:
|
| 204 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 205 |
+
|
| 206 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 207 |
+
|
| 208 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 209 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 210 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 211 |
+
|
| 212 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 213 |
+
|
| 214 |
+
outputs = outputs + (past_key_value,)
|
| 215 |
+
return outputs
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class BertSelfOutput(nn.Module):
|
| 219 |
+
def __init__(self, config, twin=False, merge=False):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 222 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 223 |
+
if twin:
|
| 224 |
+
self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
|
| 225 |
+
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
|
| 226 |
+
else:
|
| 227 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 228 |
+
if merge:
|
| 229 |
+
self.act = ACT2FN[config.hidden_act]
|
| 230 |
+
self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 231 |
+
self.merge = True
|
| 232 |
+
else:
|
| 233 |
+
self.merge = False
|
| 234 |
+
|
| 235 |
+
def forward(self, hidden_states, input_tensor):
|
| 236 |
+
if type(hidden_states) == list:
|
| 237 |
+
hidden_states0 = self.dense0(hidden_states[0])
|
| 238 |
+
hidden_states1 = self.dense1(hidden_states[1])
|
| 239 |
+
if self.merge:
|
| 240 |
+
#hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
|
| 241 |
+
hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
|
| 242 |
+
else:
|
| 243 |
+
hidden_states = (hidden_states0+hidden_states1)/2
|
| 244 |
+
else:
|
| 245 |
+
hidden_states = self.dense(hidden_states)
|
| 246 |
+
hidden_states = self.dropout(hidden_states)
|
| 247 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 248 |
+
return hidden_states
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class BertAttention(nn.Module):
|
| 252 |
+
def __init__(self, config, is_cross_attention=False, layer_num=-1):
|
| 253 |
+
super().__init__()
|
| 254 |
+
if is_cross_attention:
|
| 255 |
+
self.self0 = BertSelfAttention(config, is_cross_attention)
|
| 256 |
+
self.self1 = BertSelfAttention(config, is_cross_attention)
|
| 257 |
+
else:
|
| 258 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 259 |
+
self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
|
| 260 |
+
self.pruned_heads = set()
|
| 261 |
+
|
| 262 |
+
def prune_heads(self, heads):
|
| 263 |
+
if len(heads) == 0:
|
| 264 |
+
return
|
| 265 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 266 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Prune linear layers
|
| 270 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 271 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 272 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 273 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 274 |
+
|
| 275 |
+
# Update hyper params and store pruned heads
|
| 276 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 277 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 278 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 279 |
+
|
| 280 |
+
def forward(
|
| 281 |
+
self,
|
| 282 |
+
hidden_states,
|
| 283 |
+
attention_mask=None,
|
| 284 |
+
head_mask=None,
|
| 285 |
+
encoder_hidden_states=None,
|
| 286 |
+
encoder_attention_mask=None,
|
| 287 |
+
past_key_value=None,
|
| 288 |
+
output_attentions=False,
|
| 289 |
+
):
|
| 290 |
+
if type(encoder_hidden_states)==list:
|
| 291 |
+
self_outputs0 = self.self0(
|
| 292 |
+
hidden_states,
|
| 293 |
+
attention_mask,
|
| 294 |
+
head_mask,
|
| 295 |
+
encoder_hidden_states[0],
|
| 296 |
+
encoder_attention_mask[0],
|
| 297 |
+
past_key_value,
|
| 298 |
+
output_attentions,
|
| 299 |
+
)
|
| 300 |
+
self_outputs1 = self.self1(
|
| 301 |
+
hidden_states,
|
| 302 |
+
attention_mask,
|
| 303 |
+
head_mask,
|
| 304 |
+
encoder_hidden_states[1],
|
| 305 |
+
encoder_attention_mask[1],
|
| 306 |
+
past_key_value,
|
| 307 |
+
output_attentions,
|
| 308 |
+
)
|
| 309 |
+
attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
|
| 310 |
+
|
| 311 |
+
outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
|
| 312 |
+
else:
|
| 313 |
+
self_outputs = self.self(
|
| 314 |
+
hidden_states,
|
| 315 |
+
attention_mask,
|
| 316 |
+
head_mask,
|
| 317 |
+
encoder_hidden_states,
|
| 318 |
+
encoder_attention_mask,
|
| 319 |
+
past_key_value,
|
| 320 |
+
output_attentions,
|
| 321 |
+
)
|
| 322 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 323 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 324 |
+
return outputs
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class BertIntermediate(nn.Module):
|
| 328 |
+
def __init__(self, config):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 331 |
+
if isinstance(config.hidden_act, str):
|
| 332 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 333 |
+
else:
|
| 334 |
+
self.intermediate_act_fn = config.hidden_act
|
| 335 |
+
|
| 336 |
+
def forward(self, hidden_states):
|
| 337 |
+
hidden_states = self.dense(hidden_states)
|
| 338 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 339 |
+
return hidden_states
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class BertOutput(nn.Module):
|
| 343 |
+
def __init__(self, config):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 346 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 347 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 348 |
+
|
| 349 |
+
def forward(self, hidden_states, input_tensor):
|
| 350 |
+
hidden_states = self.dense(hidden_states)
|
| 351 |
+
hidden_states = self.dropout(hidden_states)
|
| 352 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 353 |
+
return hidden_states
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class BertLayer(nn.Module):
|
| 357 |
+
def __init__(self, config, layer_num):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.config = config
|
| 360 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 361 |
+
self.seq_len_dim = 1
|
| 362 |
+
self.attention = BertAttention(config)
|
| 363 |
+
self.layer_num = layer_num
|
| 364 |
+
if self.config.add_cross_attention:
|
| 365 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
|
| 366 |
+
self.intermediate = BertIntermediate(config)
|
| 367 |
+
self.output = BertOutput(config)
|
| 368 |
+
|
| 369 |
+
def forward(
|
| 370 |
+
self,
|
| 371 |
+
hidden_states,
|
| 372 |
+
attention_mask=None,
|
| 373 |
+
head_mask=None,
|
| 374 |
+
encoder_hidden_states=None,
|
| 375 |
+
encoder_attention_mask=None,
|
| 376 |
+
past_key_value=None,
|
| 377 |
+
output_attentions=False,
|
| 378 |
+
mode=None,
|
| 379 |
+
):
|
| 380 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 381 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 382 |
+
self_attention_outputs = self.attention(
|
| 383 |
+
hidden_states,
|
| 384 |
+
attention_mask,
|
| 385 |
+
head_mask,
|
| 386 |
+
output_attentions=output_attentions,
|
| 387 |
+
past_key_value=self_attn_past_key_value,
|
| 388 |
+
)
|
| 389 |
+
attention_output = self_attention_outputs[0]
|
| 390 |
+
|
| 391 |
+
outputs = self_attention_outputs[1:-1]
|
| 392 |
+
present_key_value = self_attention_outputs[-1]
|
| 393 |
+
|
| 394 |
+
if mode=='multimodal':
|
| 395 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
| 396 |
+
cross_attention_outputs = self.crossattention(
|
| 397 |
+
attention_output,
|
| 398 |
+
attention_mask,
|
| 399 |
+
head_mask,
|
| 400 |
+
encoder_hidden_states,
|
| 401 |
+
encoder_attention_mask,
|
| 402 |
+
output_attentions=output_attentions,
|
| 403 |
+
)
|
| 404 |
+
attention_output = cross_attention_outputs[0]
|
| 405 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 406 |
+
layer_output = apply_chunking_to_forward(
|
| 407 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 408 |
+
)
|
| 409 |
+
outputs = (layer_output,) + outputs
|
| 410 |
+
|
| 411 |
+
outputs = outputs + (present_key_value,)
|
| 412 |
+
|
| 413 |
+
return outputs
|
| 414 |
+
|
| 415 |
+
def feed_forward_chunk(self, attention_output):
|
| 416 |
+
intermediate_output = self.intermediate(attention_output)
|
| 417 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 418 |
+
return layer_output
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class BertEncoder(nn.Module):
|
| 422 |
+
def __init__(self, config):
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.config = config
|
| 425 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
| 426 |
+
self.gradient_checkpointing = False
|
| 427 |
+
|
| 428 |
+
def forward(
|
| 429 |
+
self,
|
| 430 |
+
hidden_states,
|
| 431 |
+
attention_mask=None,
|
| 432 |
+
head_mask=None,
|
| 433 |
+
encoder_hidden_states=None,
|
| 434 |
+
encoder_attention_mask=None,
|
| 435 |
+
past_key_values=None,
|
| 436 |
+
use_cache=None,
|
| 437 |
+
output_attentions=False,
|
| 438 |
+
output_hidden_states=False,
|
| 439 |
+
return_dict=True,
|
| 440 |
+
mode='multimodal',
|
| 441 |
+
):
|
| 442 |
+
all_hidden_states = () if output_hidden_states else None
|
| 443 |
+
all_self_attentions = () if output_attentions else None
|
| 444 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 445 |
+
|
| 446 |
+
next_decoder_cache = () if use_cache else None
|
| 447 |
+
|
| 448 |
+
for i in range(self.config.num_hidden_layers):
|
| 449 |
+
layer_module = self.layer[i]
|
| 450 |
+
if output_hidden_states:
|
| 451 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 452 |
+
|
| 453 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 454 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 455 |
+
|
| 456 |
+
if self.gradient_checkpointing and self.training:
|
| 457 |
+
|
| 458 |
+
if use_cache:
|
| 459 |
+
logger.warn(
|
| 460 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 461 |
+
)
|
| 462 |
+
use_cache = False
|
| 463 |
+
|
| 464 |
+
def create_custom_forward(module):
|
| 465 |
+
def custom_forward(*inputs):
|
| 466 |
+
return module(*inputs, past_key_value, output_attentions)
|
| 467 |
+
|
| 468 |
+
return custom_forward
|
| 469 |
+
|
| 470 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
create_custom_forward(layer_module),
|
| 472 |
+
hidden_states,
|
| 473 |
+
attention_mask,
|
| 474 |
+
layer_head_mask,
|
| 475 |
+
encoder_hidden_states,
|
| 476 |
+
encoder_attention_mask,
|
| 477 |
+
mode=mode,
|
| 478 |
+
)
|
| 479 |
+
else:
|
| 480 |
+
layer_outputs = layer_module(
|
| 481 |
+
hidden_states,
|
| 482 |
+
attention_mask,
|
| 483 |
+
layer_head_mask,
|
| 484 |
+
encoder_hidden_states,
|
| 485 |
+
encoder_attention_mask,
|
| 486 |
+
past_key_value,
|
| 487 |
+
output_attentions,
|
| 488 |
+
mode=mode,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
hidden_states = layer_outputs[0]
|
| 492 |
+
if use_cache:
|
| 493 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 494 |
+
if output_attentions:
|
| 495 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 496 |
+
|
| 497 |
+
if output_hidden_states:
|
| 498 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 499 |
+
|
| 500 |
+
if not return_dict:
|
| 501 |
+
return tuple(
|
| 502 |
+
v
|
| 503 |
+
for v in [
|
| 504 |
+
hidden_states,
|
| 505 |
+
next_decoder_cache,
|
| 506 |
+
all_hidden_states,
|
| 507 |
+
all_self_attentions,
|
| 508 |
+
all_cross_attentions,
|
| 509 |
+
]
|
| 510 |
+
if v is not None
|
| 511 |
+
)
|
| 512 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 513 |
+
last_hidden_state=hidden_states,
|
| 514 |
+
past_key_values=next_decoder_cache,
|
| 515 |
+
hidden_states=all_hidden_states,
|
| 516 |
+
attentions=all_self_attentions,
|
| 517 |
+
cross_attentions=all_cross_attentions,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class BertPooler(nn.Module):
|
| 522 |
+
def __init__(self, config):
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 525 |
+
self.activation = nn.Tanh()
|
| 526 |
+
|
| 527 |
+
def forward(self, hidden_states):
|
| 528 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 529 |
+
# to the first token.
|
| 530 |
+
first_token_tensor = hidden_states[:, 0]
|
| 531 |
+
pooled_output = self.dense(first_token_tensor)
|
| 532 |
+
pooled_output = self.activation(pooled_output)
|
| 533 |
+
return pooled_output
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 537 |
+
def __init__(self, config):
|
| 538 |
+
super().__init__()
|
| 539 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 540 |
+
if isinstance(config.hidden_act, str):
|
| 541 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 542 |
+
else:
|
| 543 |
+
self.transform_act_fn = config.hidden_act
|
| 544 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 545 |
+
|
| 546 |
+
def forward(self, hidden_states):
|
| 547 |
+
hidden_states = self.dense(hidden_states)
|
| 548 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 549 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 550 |
+
return hidden_states
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
class BertLMPredictionHead(nn.Module):
|
| 554 |
+
def __init__(self, config):
|
| 555 |
+
super().__init__()
|
| 556 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 557 |
+
|
| 558 |
+
# The output weights are the same as the input embeddings, but there is
|
| 559 |
+
# an output-only bias for each token.
|
| 560 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 561 |
+
|
| 562 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 563 |
+
|
| 564 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 565 |
+
self.decoder.bias = self.bias
|
| 566 |
+
|
| 567 |
+
def forward(self, hidden_states):
|
| 568 |
+
hidden_states = self.transform(hidden_states)
|
| 569 |
+
hidden_states = self.decoder(hidden_states)
|
| 570 |
+
return hidden_states
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class BertOnlyMLMHead(nn.Module):
|
| 574 |
+
def __init__(self, config):
|
| 575 |
+
super().__init__()
|
| 576 |
+
self.predictions = BertLMPredictionHead(config)
|
| 577 |
+
|
| 578 |
+
def forward(self, sequence_output):
|
| 579 |
+
prediction_scores = self.predictions(sequence_output)
|
| 580 |
+
return prediction_scores
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 584 |
+
"""
|
| 585 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 586 |
+
models.
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
config_class = BertConfig
|
| 590 |
+
base_model_prefix = "bert"
|
| 591 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 592 |
+
|
| 593 |
+
def _init_weights(self, module):
|
| 594 |
+
""" Initialize the weights """
|
| 595 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 596 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 597 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 598 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 599 |
+
elif isinstance(module, nn.LayerNorm):
|
| 600 |
+
module.bias.data.zero_()
|
| 601 |
+
module.weight.data.fill_(1.0)
|
| 602 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 603 |
+
module.bias.data.zero_()
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class BertModel(BertPreTrainedModel):
|
| 607 |
+
"""
|
| 608 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 609 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 610 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 611 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 612 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 613 |
+
input to the forward pass.
|
| 614 |
+
"""
|
| 615 |
+
|
| 616 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 617 |
+
super().__init__(config)
|
| 618 |
+
self.config = config
|
| 619 |
+
|
| 620 |
+
self.embeddings = BertEmbeddings(config)
|
| 621 |
+
|
| 622 |
+
self.encoder = BertEncoder(config)
|
| 623 |
+
|
| 624 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 625 |
+
|
| 626 |
+
self.init_weights()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def get_input_embeddings(self):
|
| 630 |
+
return self.embeddings.word_embeddings
|
| 631 |
+
|
| 632 |
+
def set_input_embeddings(self, value):
|
| 633 |
+
self.embeddings.word_embeddings = value
|
| 634 |
+
|
| 635 |
+
def _prune_heads(self, heads_to_prune):
|
| 636 |
+
"""
|
| 637 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 638 |
+
class PreTrainedModel
|
| 639 |
+
"""
|
| 640 |
+
for layer, heads in heads_to_prune.items():
|
| 641 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
| 645 |
+
"""
|
| 646 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 647 |
+
|
| 648 |
+
Arguments:
|
| 649 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 650 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 651 |
+
input_shape (:obj:`Tuple[int]`):
|
| 652 |
+
The shape of the input to the model.
|
| 653 |
+
device: (:obj:`torch.device`):
|
| 654 |
+
The device of the input to the model.
|
| 655 |
+
|
| 656 |
+
Returns:
|
| 657 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 658 |
+
"""
|
| 659 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 660 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 661 |
+
if attention_mask.dim() == 3:
|
| 662 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 663 |
+
elif attention_mask.dim() == 2:
|
| 664 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 665 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 666 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 667 |
+
if is_decoder:
|
| 668 |
+
batch_size, seq_length = input_shape
|
| 669 |
+
|
| 670 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 671 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 672 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
| 673 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 674 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 675 |
+
|
| 676 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 677 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 678 |
+
causal_mask = torch.cat(
|
| 679 |
+
[
|
| 680 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
| 681 |
+
causal_mask,
|
| 682 |
+
],
|
| 683 |
+
axis=-1,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 687 |
+
else:
|
| 688 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 689 |
+
else:
|
| 690 |
+
raise ValueError(
|
| 691 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 692 |
+
input_shape, attention_mask.shape
|
| 693 |
+
)
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 697 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 698 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 699 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 700 |
+
# effectively the same as removing these entirely.
|
| 701 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 702 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 703 |
+
return extended_attention_mask
|
| 704 |
+
|
| 705 |
+
def forward(
|
| 706 |
+
self,
|
| 707 |
+
input_ids=None,
|
| 708 |
+
attention_mask=None,
|
| 709 |
+
position_ids=None,
|
| 710 |
+
head_mask=None,
|
| 711 |
+
inputs_embeds=None,
|
| 712 |
+
encoder_embeds=None,
|
| 713 |
+
encoder_hidden_states=None,
|
| 714 |
+
encoder_attention_mask=None,
|
| 715 |
+
past_key_values=None,
|
| 716 |
+
use_cache=None,
|
| 717 |
+
output_attentions=None,
|
| 718 |
+
output_hidden_states=None,
|
| 719 |
+
return_dict=None,
|
| 720 |
+
is_decoder=False,
|
| 721 |
+
mode='multimodal',
|
| 722 |
+
):
|
| 723 |
+
r"""
|
| 724 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 725 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 726 |
+
the model is configured as a decoder.
|
| 727 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 728 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 729 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 730 |
+
- 1 for tokens that are **not masked**,
|
| 731 |
+
- 0 for tokens that are **masked**.
|
| 732 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 733 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 734 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 735 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 736 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 737 |
+
use_cache (:obj:`bool`, `optional`):
|
| 738 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 739 |
+
decoding (see :obj:`past_key_values`).
|
| 740 |
+
"""
|
| 741 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 742 |
+
output_hidden_states = (
|
| 743 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 744 |
+
)
|
| 745 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 746 |
+
|
| 747 |
+
if is_decoder:
|
| 748 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 749 |
+
else:
|
| 750 |
+
use_cache = False
|
| 751 |
+
|
| 752 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 753 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 754 |
+
elif input_ids is not None:
|
| 755 |
+
input_shape = input_ids.size()
|
| 756 |
+
batch_size, seq_length = input_shape
|
| 757 |
+
device = input_ids.device
|
| 758 |
+
elif inputs_embeds is not None:
|
| 759 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 760 |
+
batch_size, seq_length = input_shape
|
| 761 |
+
device = inputs_embeds.device
|
| 762 |
+
elif encoder_embeds is not None:
|
| 763 |
+
input_shape = encoder_embeds.size()[:-1]
|
| 764 |
+
batch_size, seq_length = input_shape
|
| 765 |
+
device = encoder_embeds.device
|
| 766 |
+
else:
|
| 767 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
| 768 |
+
|
| 769 |
+
# past_key_values_length
|
| 770 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 771 |
+
|
| 772 |
+
if attention_mask is None:
|
| 773 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 774 |
+
|
| 775 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 776 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 777 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
| 778 |
+
device, is_decoder)
|
| 779 |
+
|
| 780 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 781 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 782 |
+
if encoder_hidden_states is not None:
|
| 783 |
+
if type(encoder_hidden_states) == list:
|
| 784 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 785 |
+
else:
|
| 786 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 787 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 788 |
+
|
| 789 |
+
if type(encoder_attention_mask) == list:
|
| 790 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 791 |
+
elif encoder_attention_mask is None:
|
| 792 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 793 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 794 |
+
else:
|
| 795 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 796 |
+
else:
|
| 797 |
+
encoder_extended_attention_mask = None
|
| 798 |
+
|
| 799 |
+
# Prepare head mask if needed
|
| 800 |
+
# 1.0 in head_mask indicate we keep the head
|
| 801 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 802 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 803 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 804 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 805 |
+
|
| 806 |
+
if encoder_embeds is None:
|
| 807 |
+
embedding_output = self.embeddings(
|
| 808 |
+
input_ids=input_ids,
|
| 809 |
+
position_ids=position_ids,
|
| 810 |
+
inputs_embeds=inputs_embeds,
|
| 811 |
+
past_key_values_length=past_key_values_length,
|
| 812 |
+
)
|
| 813 |
+
else:
|
| 814 |
+
embedding_output = encoder_embeds
|
| 815 |
+
|
| 816 |
+
encoder_outputs = self.encoder(
|
| 817 |
+
embedding_output,
|
| 818 |
+
attention_mask=extended_attention_mask,
|
| 819 |
+
head_mask=head_mask,
|
| 820 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 821 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 822 |
+
past_key_values=past_key_values,
|
| 823 |
+
use_cache=use_cache,
|
| 824 |
+
output_attentions=output_attentions,
|
| 825 |
+
output_hidden_states=output_hidden_states,
|
| 826 |
+
return_dict=return_dict,
|
| 827 |
+
mode=mode,
|
| 828 |
+
)
|
| 829 |
+
sequence_output = encoder_outputs[0]
|
| 830 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 831 |
+
|
| 832 |
+
if not return_dict:
|
| 833 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 834 |
+
|
| 835 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 836 |
+
last_hidden_state=sequence_output,
|
| 837 |
+
pooler_output=pooled_output,
|
| 838 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 839 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 840 |
+
attentions=encoder_outputs.attentions,
|
| 841 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 842 |
+
)
|
| 843 |
+
|
defake/blipmodels/vit.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
* Based on timm code base
|
| 8 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from functools import partial
|
| 15 |
+
|
| 16 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
| 17 |
+
from timm.models.registry import register_model
|
| 18 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 19 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
| 20 |
+
|
| 21 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 22 |
+
|
| 23 |
+
class Mlp(nn.Module):
|
| 24 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Attention(nn.Module):
|
| 45 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 50 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
self.attn_gradients = None
|
| 56 |
+
self.attention_map = None
|
| 57 |
+
|
| 58 |
+
def save_attn_gradients(self, attn_gradients):
|
| 59 |
+
self.attn_gradients = attn_gradients
|
| 60 |
+
|
| 61 |
+
def get_attn_gradients(self):
|
| 62 |
+
return self.attn_gradients
|
| 63 |
+
|
| 64 |
+
def save_attention_map(self, attention_map):
|
| 65 |
+
self.attention_map = attention_map
|
| 66 |
+
|
| 67 |
+
def get_attention_map(self):
|
| 68 |
+
return self.attention_map
|
| 69 |
+
|
| 70 |
+
def forward(self, x, register_hook=False):
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 73 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 74 |
+
|
| 75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 76 |
+
attn = attn.softmax(dim=-1)
|
| 77 |
+
attn = self.attn_drop(attn)
|
| 78 |
+
|
| 79 |
+
if register_hook:
|
| 80 |
+
self.save_attention_map(attn)
|
| 81 |
+
attn.register_hook(self.save_attn_gradients)
|
| 82 |
+
|
| 83 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 84 |
+
x = self.proj(x)
|
| 85 |
+
x = self.proj_drop(x)
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Block(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 92 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.norm1 = norm_layer(dim)
|
| 95 |
+
self.attn = Attention(
|
| 96 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 97 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 98 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 99 |
+
self.norm2 = norm_layer(dim)
|
| 100 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 101 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 102 |
+
|
| 103 |
+
if use_grad_checkpointing:
|
| 104 |
+
self.attn = checkpoint_wrapper(self.attn)
|
| 105 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, register_hook=False):
|
| 108 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
| 109 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class VisionTransformer(nn.Module):
|
| 114 |
+
""" Vision Transformer
|
| 115 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
| 116 |
+
https://arxiv.org/abs/2010.11929
|
| 117 |
+
"""
|
| 118 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 119 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
| 120 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
| 121 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
| 122 |
+
"""
|
| 123 |
+
Args:
|
| 124 |
+
img_size (int, tuple): input image size
|
| 125 |
+
patch_size (int, tuple): patch size
|
| 126 |
+
in_chans (int): number of input channels
|
| 127 |
+
num_classes (int): number of classes for classification head
|
| 128 |
+
embed_dim (int): embedding dimension
|
| 129 |
+
depth (int): depth of transformer
|
| 130 |
+
num_heads (int): number of attention heads
|
| 131 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 132 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 133 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 134 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 135 |
+
drop_rate (float): dropout rate
|
| 136 |
+
attn_drop_rate (float): attention dropout rate
|
| 137 |
+
drop_path_rate (float): stochastic depth rate
|
| 138 |
+
norm_layer: (nn.Module): normalization layer
|
| 139 |
+
"""
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 142 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 143 |
+
|
| 144 |
+
self.patch_embed = PatchEmbed(
|
| 145 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 146 |
+
|
| 147 |
+
num_patches = self.patch_embed.num_patches
|
| 148 |
+
|
| 149 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 150 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 151 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 152 |
+
|
| 153 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 154 |
+
self.blocks = nn.ModuleList([
|
| 155 |
+
Block(
|
| 156 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 157 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 158 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
| 159 |
+
)
|
| 160 |
+
for i in range(depth)])
|
| 161 |
+
self.norm = norm_layer(embed_dim)
|
| 162 |
+
|
| 163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 164 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 165 |
+
self.apply(self._init_weights)
|
| 166 |
+
|
| 167 |
+
def _init_weights(self, m):
|
| 168 |
+
if isinstance(m, nn.Linear):
|
| 169 |
+
trunc_normal_(m.weight, std=.02)
|
| 170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 171 |
+
nn.init.constant_(m.bias, 0)
|
| 172 |
+
elif isinstance(m, nn.LayerNorm):
|
| 173 |
+
nn.init.constant_(m.bias, 0)
|
| 174 |
+
nn.init.constant_(m.weight, 1.0)
|
| 175 |
+
|
| 176 |
+
@torch.jit.ignore
|
| 177 |
+
def no_weight_decay(self):
|
| 178 |
+
return {'pos_embed', 'cls_token'}
|
| 179 |
+
|
| 180 |
+
def forward(self, x, register_blk=-1):
|
| 181 |
+
B = x.shape[0]
|
| 182 |
+
x = self.patch_embed(x)
|
| 183 |
+
|
| 184 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 185 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 186 |
+
|
| 187 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
| 188 |
+
x = self.pos_drop(x)
|
| 189 |
+
|
| 190 |
+
for i,blk in enumerate(self.blocks):
|
| 191 |
+
x = blk(x, register_blk==i)
|
| 192 |
+
x = self.norm(x)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
@torch.jit.ignore()
|
| 197 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
| 198 |
+
_load_weights(self, checkpoint_path, prefix)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
| 203 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
| 204 |
+
"""
|
| 205 |
+
import numpy as np
|
| 206 |
+
|
| 207 |
+
def _n2p(w, t=True):
|
| 208 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
| 209 |
+
w = w.flatten()
|
| 210 |
+
if t:
|
| 211 |
+
if w.ndim == 4:
|
| 212 |
+
w = w.transpose([3, 2, 0, 1])
|
| 213 |
+
elif w.ndim == 3:
|
| 214 |
+
w = w.transpose([2, 0, 1])
|
| 215 |
+
elif w.ndim == 2:
|
| 216 |
+
w = w.transpose([1, 0])
|
| 217 |
+
return torch.from_numpy(w)
|
| 218 |
+
|
| 219 |
+
w = np.load(checkpoint_path)
|
| 220 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
| 221 |
+
prefix = 'opt/target/'
|
| 222 |
+
|
| 223 |
+
if hasattr(model.patch_embed, 'backbone'):
|
| 224 |
+
# hybrid
|
| 225 |
+
backbone = model.patch_embed.backbone
|
| 226 |
+
stem_only = not hasattr(backbone, 'stem')
|
| 227 |
+
stem = backbone if stem_only else backbone.stem
|
| 228 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
| 229 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
| 230 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
| 231 |
+
if not stem_only:
|
| 232 |
+
for i, stage in enumerate(backbone.stages):
|
| 233 |
+
for j, block in enumerate(stage.blocks):
|
| 234 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
| 235 |
+
for r in range(3):
|
| 236 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
| 237 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
| 238 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
| 239 |
+
if block.downsample is not None:
|
| 240 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
| 241 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
| 242 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
| 243 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
| 244 |
+
else:
|
| 245 |
+
embed_conv_w = adapt_input_conv(
|
| 246 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
| 247 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
| 248 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
| 249 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
| 250 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
| 251 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
| 252 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
| 253 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
| 254 |
+
model.pos_embed.copy_(pos_embed_w)
|
| 255 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
| 256 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
| 257 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
| 258 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
| 259 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
| 260 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
| 261 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
| 262 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
| 263 |
+
for i, block in enumerate(model.blocks.children()):
|
| 264 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
| 265 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
| 266 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
| 267 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
| 268 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
| 269 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
| 270 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
| 271 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
| 272 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
| 273 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
| 274 |
+
for r in range(2):
|
| 275 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
| 276 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
| 277 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
| 278 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
| 282 |
+
# interpolate position embedding
|
| 283 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 284 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
| 285 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
| 286 |
+
# height (== width) for the checkpoint position embedding
|
| 287 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 288 |
+
# height (== width) for the new position embedding
|
| 289 |
+
new_size = int(num_patches ** 0.5)
|
| 290 |
+
|
| 291 |
+
if orig_size!=new_size:
|
| 292 |
+
# class_token and dist_token are kept unchanged
|
| 293 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 294 |
+
# only the position tokens are interpolated
|
| 295 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 296 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 297 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 298 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 299 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 300 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 301 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
| 302 |
+
|
| 303 |
+
return new_pos_embed
|
| 304 |
+
else:
|
| 305 |
+
return pos_embed_checkpoint
|
defake/clipdatasets.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torchvision.transforms as transforms
|
| 3 |
+
import torchvision
|
| 4 |
+
import PIL.Image as Image
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import PIL.Image as Image
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pycocotools.coco import COCO
|
| 12 |
+
import skimage.io as io
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
class real(torch.utils.data.Dataset):
|
| 16 |
+
def __init__(self,realroot,size,transform=None):
|
| 17 |
+
self.transform = transforms.Compose([
|
| 18 |
+
transforms.Resize((size,size)),
|
| 19 |
+
#RandAugment(2, 14),
|
| 20 |
+
#transforms.CenterCrop((size,size)),
|
| 21 |
+
transforms.ToTensor()
|
| 22 |
+
])
|
| 23 |
+
dataDir='your dir'
|
| 24 |
+
dataType='val2014'
|
| 25 |
+
self.annFile = '{}/annotations/captions_{}.json'.format(dataDir,dataType)
|
| 26 |
+
self.coco=COCO(self.annFile)
|
| 27 |
+
self.imgIds_list=sorted(self.coco.getImgIds())
|
| 28 |
+
|
| 29 |
+
def __getitem__(self,item):
|
| 30 |
+
imgIds = self.coco.getImgIds(imgIds = [self.imgIds_list[item]])
|
| 31 |
+
img = self.coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]
|
| 32 |
+
I = io.imread(img['coco_url'])
|
| 33 |
+
real_image = Image.fromarray(I).convert('RGB')
|
| 34 |
+
real_image = self.transform(real_image)
|
| 35 |
+
annIds = self.coco.getAnnIds(imgIds=img['id'])
|
| 36 |
+
anns = self.coco.loadAnns(annIds)
|
| 37 |
+
|
| 38 |
+
label = 0
|
| 39 |
+
return real_image,label,anns[0]['caption']
|
| 40 |
+
|
| 41 |
+
def __len__(self):
|
| 42 |
+
return len(self.imgIds_list)
|
| 43 |
+
|
| 44 |
+
class realflickr(torch.utils.data.Dataset):
|
| 45 |
+
def __init__(self,realroot,size,transform=None):
|
| 46 |
+
self.transform = transforms.Compose([
|
| 47 |
+
transforms.Resize((size,size)),
|
| 48 |
+
#RandAugment(2, 14),
|
| 49 |
+
#transforms.CenterCrop((size,size)),
|
| 50 |
+
transforms.ToTensor()
|
| 51 |
+
])
|
| 52 |
+
annotations = pd.read_table('your dir', sep='\t', header=None,
|
| 53 |
+
names=['image', 'caption'])
|
| 54 |
+
self.prompt_list = np.array(annotations['caption'][::5])
|
| 55 |
+
self.image_list = np.array(annotations['image'][::5])
|
| 56 |
+
|
| 57 |
+
def __getitem__(self,item):
|
| 58 |
+
real_image = Image.open('your dir')
|
| 59 |
+
prompts = self.prompt_list[item]
|
| 60 |
+
label = 0
|
| 61 |
+
real_image = self.transform(real_image)
|
| 62 |
+
return real_image,label,prompts
|
| 63 |
+
|
| 64 |
+
def __len__(self):
|
| 65 |
+
return len(self.image_list)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class fakereal(torch.utils.data.Dataset):
|
| 69 |
+
def __init__(self,fakeroot,size,transform=None):
|
| 70 |
+
self.transform = transforms.Compose([
|
| 71 |
+
transforms.Resize((size,size)),
|
| 72 |
+
#RandAugment(2, 14),
|
| 73 |
+
#transforms.CenterCrop((size,size)),
|
| 74 |
+
transforms.ToTensor()
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
fake_images_path = Path(fakeroot)
|
| 78 |
+
fake_images_list = list(fake_images_path.glob('*.png'))
|
| 79 |
+
fake_images_list_str = [ str(x) for x in fake_images_list ]
|
| 80 |
+
self.fake_images = fake_images_list_str
|
| 81 |
+
|
| 82 |
+
def __getitem__(self,item):
|
| 83 |
+
fake_image_path = self.fake_images[item]
|
| 84 |
+
fake_image = Image.open(fake_image_path).convert('RGB')
|
| 85 |
+
fake_image = self.transform(fake_image)
|
| 86 |
+
label = 1
|
| 87 |
+
prompts = fake_image_path.split('/')[-1].replace('-',' ').split('.png')[0]
|
| 88 |
+
|
| 89 |
+
return fake_image,label,prompts
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return len(self.fake_images)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class fakeclip(torch.utils.data.Dataset):
|
| 98 |
+
def __init__(self,fakeroot,size,transforms=None):
|
| 99 |
+
self.transform = transforms.Compose([
|
| 100 |
+
transforms.Resize((size,size)),
|
| 101 |
+
#RandAugment(2, 14),
|
| 102 |
+
#transforms.CenterCrop((size,size)),
|
| 103 |
+
transforms.ToTensor()
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
fake_images_path = Path(fakeroot)
|
| 107 |
+
fake_images_list = list(fake_images_path.glob('*.png'))
|
| 108 |
+
fake_images_list_str = [ str(x) for x in fake_images_list ]
|
| 109 |
+
self.fake_images = fake_images_list_str
|
| 110 |
+
|
defake/environment.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: defake
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.8.5
|
| 7 |
+
- pip=20.3
|
| 8 |
+
- cudatoolkit=11.3
|
| 9 |
+
- pytorch=1.12.1
|
| 10 |
+
- torchvision=0.13.1
|
| 11 |
+
- numpy=1.23.1
|
| 12 |
+
- pip:
|
| 13 |
+
- tqdm
|
| 14 |
+
- Pillow
|
| 15 |
+
- scikit-learn
|
| 16 |
+
- ftfy
|
| 17 |
+
- regex
|
| 18 |
+
- git+https://github.com/openai/CLIP.git
|
| 19 |
+
- fairscale==0.4.4
|
| 20 |
+
- pycocoevalcap
|
| 21 |
+
- natsort
|
| 22 |
+
- timm
|
| 23 |
+
- transformers==4.15.0
|
defake/models/__init__.py
ADDED
|
File without changes
|
defake/models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
defake/models/__pycache__/blip.cpython-38.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
defake/models/__pycache__/med.cpython-38.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
defake/models/__pycache__/vit.cpython-38.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
defake/models/blip.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
'''
|
| 8 |
+
import warnings
|
| 9 |
+
warnings.filterwarnings("ignore")
|
| 10 |
+
|
| 11 |
+
from models.vit import VisionTransformer, interpolate_pos_embed
|
| 12 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
| 13 |
+
from transformers import BertTokenizer
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
from urllib.parse import urlparse
|
| 21 |
+
from timm.models.hub import download_cached_file
|
| 22 |
+
|
| 23 |
+
class BLIP_Base(nn.Module):
|
| 24 |
+
def __init__(self,
|
| 25 |
+
med_config = 'configs/med_config.json',
|
| 26 |
+
image_size = 224,
|
| 27 |
+
vit = 'base',
|
| 28 |
+
vit_grad_ckpt = False,
|
| 29 |
+
vit_ckpt_layer = 0,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 34 |
+
image_size (int): input image size
|
| 35 |
+
vit (str): model size of vision transformer
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 40 |
+
self.tokenizer = init_tokenizer()
|
| 41 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 42 |
+
med_config.encoder_width = vision_width
|
| 43 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def forward(self, image, caption, mode):
|
| 47 |
+
|
| 48 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
| 49 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
| 50 |
+
|
| 51 |
+
if mode=='image':
|
| 52 |
+
# return image features
|
| 53 |
+
image_embeds = self.visual_encoder(image)
|
| 54 |
+
return image_embeds
|
| 55 |
+
|
| 56 |
+
elif mode=='text':
|
| 57 |
+
# return text features
|
| 58 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 59 |
+
return_dict = True, mode = 'text')
|
| 60 |
+
return text_output.last_hidden_state
|
| 61 |
+
|
| 62 |
+
elif mode=='multimodal':
|
| 63 |
+
# return multimodel features
|
| 64 |
+
image_embeds = self.visual_encoder(image)
|
| 65 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 66 |
+
|
| 67 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 68 |
+
output = self.text_encoder(text.input_ids,
|
| 69 |
+
attention_mask = text.attention_mask,
|
| 70 |
+
encoder_hidden_states = image_embeds,
|
| 71 |
+
encoder_attention_mask = image_atts,
|
| 72 |
+
return_dict = True,
|
| 73 |
+
)
|
| 74 |
+
return output.last_hidden_state
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class BLIP_Decoder(nn.Module):
|
| 79 |
+
def __init__(self,
|
| 80 |
+
med_config = '/home/c01zesh/CISPA-projects/fake_artist-2022/BLIP/configs/med_config.json',
|
| 81 |
+
image_size = 384,
|
| 82 |
+
vit = 'base',
|
| 83 |
+
vit_grad_ckpt = False,
|
| 84 |
+
vit_ckpt_layer = 0,
|
| 85 |
+
prompt = 'a picture of ',
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 90 |
+
image_size (int): input image size
|
| 91 |
+
vit (str): model size of vision transformer
|
| 92 |
+
"""
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 96 |
+
self.tokenizer = init_tokenizer()
|
| 97 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 98 |
+
med_config.encoder_width = vision_width
|
| 99 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
| 100 |
+
|
| 101 |
+
self.prompt = prompt
|
| 102 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def forward(self, image, caption):
|
| 106 |
+
|
| 107 |
+
image_embeds = self.visual_encoder(image)
|
| 108 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 109 |
+
|
| 110 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
| 111 |
+
|
| 112 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 113 |
+
|
| 114 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
| 115 |
+
decoder_targets[:,:self.prompt_length] = -100
|
| 116 |
+
|
| 117 |
+
decoder_output = self.text_decoder(text.input_ids,
|
| 118 |
+
attention_mask = text.attention_mask,
|
| 119 |
+
encoder_hidden_states = image_embeds,
|
| 120 |
+
encoder_attention_mask = image_atts,
|
| 121 |
+
labels = decoder_targets,
|
| 122 |
+
return_dict = True,
|
| 123 |
+
)
|
| 124 |
+
loss_lm = decoder_output.loss
|
| 125 |
+
|
| 126 |
+
return loss_lm
|
| 127 |
+
|
| 128 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
| 129 |
+
image_embeds = self.visual_encoder(image)
|
| 130 |
+
|
| 131 |
+
if not sample:
|
| 132 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
| 133 |
+
|
| 134 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 135 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
| 136 |
+
|
| 137 |
+
prompt = [self.prompt] * image.size(0)
|
| 138 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
| 139 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
| 140 |
+
input_ids = input_ids[:, :-1]
|
| 141 |
+
|
| 142 |
+
if sample:
|
| 143 |
+
#nucleus sampling
|
| 144 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 145 |
+
max_length=max_length,
|
| 146 |
+
min_length=min_length,
|
| 147 |
+
do_sample=True,
|
| 148 |
+
top_p=top_p,
|
| 149 |
+
num_return_sequences=1,
|
| 150 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 151 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 152 |
+
repetition_penalty=1.1,
|
| 153 |
+
**model_kwargs)
|
| 154 |
+
else:
|
| 155 |
+
#beam search
|
| 156 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 157 |
+
max_length=max_length,
|
| 158 |
+
min_length=min_length,
|
| 159 |
+
num_beams=num_beams,
|
| 160 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 161 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 162 |
+
repetition_penalty=repetition_penalty,
|
| 163 |
+
**model_kwargs)
|
| 164 |
+
|
| 165 |
+
captions = []
|
| 166 |
+
for output in outputs:
|
| 167 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 168 |
+
captions.append(caption[len(self.prompt):])
|
| 169 |
+
return captions
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def blip_decoder(pretrained='',**kwargs):
|
| 173 |
+
model = BLIP_Decoder(**kwargs)
|
| 174 |
+
if pretrained:
|
| 175 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 176 |
+
assert(len(msg.missing_keys)==0)
|
| 177 |
+
return model
|
| 178 |
+
|
| 179 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
| 180 |
+
model = BLIP_Base(**kwargs)
|
| 181 |
+
if pretrained:
|
| 182 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 183 |
+
assert(len(msg.missing_keys)==0)
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
def init_tokenizer():
|
| 187 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 188 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
| 189 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
| 190 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
| 191 |
+
return tokenizer
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
| 195 |
+
|
| 196 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
| 197 |
+
if vit=='base':
|
| 198 |
+
vision_width = 768
|
| 199 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
| 200 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 201 |
+
drop_path_rate=0 or drop_path_rate
|
| 202 |
+
)
|
| 203 |
+
elif vit=='large':
|
| 204 |
+
vision_width = 1024
|
| 205 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
| 206 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 207 |
+
drop_path_rate=0.1 or drop_path_rate
|
| 208 |
+
)
|
| 209 |
+
return visual_encoder, vision_width
|
| 210 |
+
|
| 211 |
+
def is_url(url_or_filename):
|
| 212 |
+
parsed = urlparse(url_or_filename)
|
| 213 |
+
return parsed.scheme in ("http", "https")
|
| 214 |
+
|
| 215 |
+
def load_checkpoint(model,url_or_filename):
|
| 216 |
+
if is_url(url_or_filename):
|
| 217 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 218 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 219 |
+
elif os.path.isfile(url_or_filename):
|
| 220 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 221 |
+
else:
|
| 222 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 223 |
+
|
| 224 |
+
state_dict = checkpoint['model']
|
| 225 |
+
|
| 226 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 227 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
| 228 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
| 229 |
+
model.visual_encoder_m)
|
| 230 |
+
for key in model.state_dict().keys():
|
| 231 |
+
if key in state_dict.keys():
|
| 232 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
| 233 |
+
del state_dict[key]
|
| 234 |
+
|
| 235 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 236 |
+
print('load checkpoint from %s'%url_or_filename)
|
| 237 |
+
return model,msg
|
| 238 |
+
|
defake/models/blip_itm.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig, BertModel
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 9 |
+
|
| 10 |
+
class BLIP_ITM(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
med_config = 'configs/med_config.json',
|
| 13 |
+
image_size = 384,
|
| 14 |
+
vit = 'base',
|
| 15 |
+
vit_grad_ckpt = False,
|
| 16 |
+
vit_ckpt_layer = 0,
|
| 17 |
+
embed_dim = 256,
|
| 18 |
+
):
|
| 19 |
+
"""
|
| 20 |
+
Args:
|
| 21 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 22 |
+
image_size (int): input image size
|
| 23 |
+
vit (str): model size of vision transformer
|
| 24 |
+
"""
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 28 |
+
self.tokenizer = init_tokenizer()
|
| 29 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 30 |
+
med_config.encoder_width = vision_width
|
| 31 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 32 |
+
|
| 33 |
+
text_width = self.text_encoder.config.hidden_size
|
| 34 |
+
|
| 35 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 36 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 37 |
+
|
| 38 |
+
self.itm_head = nn.Linear(text_width, 2)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def forward(self, image, caption, match_head='itm'):
|
| 42 |
+
|
| 43 |
+
image_embeds = self.visual_encoder(image)
|
| 44 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 45 |
+
|
| 46 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
| 47 |
+
return_tensors="pt").to(image.device)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if match_head=='itm':
|
| 51 |
+
output = self.text_encoder(text.input_ids,
|
| 52 |
+
attention_mask = text.attention_mask,
|
| 53 |
+
encoder_hidden_states = image_embeds,
|
| 54 |
+
encoder_attention_mask = image_atts,
|
| 55 |
+
return_dict = True,
|
| 56 |
+
)
|
| 57 |
+
itm_output = self.itm_head(output.last_hidden_state[:,0,:])
|
| 58 |
+
return itm_output
|
| 59 |
+
|
| 60 |
+
elif match_head=='itc':
|
| 61 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 62 |
+
return_dict = True, mode = 'text')
|
| 63 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 64 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 65 |
+
|
| 66 |
+
sim = image_feat @ text_feat.t()
|
| 67 |
+
return sim
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def blip_itm(pretrained='',**kwargs):
|
| 71 |
+
model = BLIP_ITM(**kwargs)
|
| 72 |
+
if pretrained:
|
| 73 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 74 |
+
assert(len(msg.missing_keys)==0)
|
| 75 |
+
return model
|
| 76 |
+
|
defake/models/blip_nlvr.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig
|
| 2 |
+
from models.nlvr_encoder import BertModel
|
| 3 |
+
from models.vit import interpolate_pos_embed
|
| 4 |
+
from models.blip import create_vit, init_tokenizer, is_url
|
| 5 |
+
|
| 6 |
+
from timm.models.hub import download_cached_file
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from transformers import BertTokenizer
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
class BLIP_NLVR(nn.Module):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
med_config = 'configs/med_config.json',
|
| 17 |
+
image_size = 480,
|
| 18 |
+
vit = 'base',
|
| 19 |
+
vit_grad_ckpt = False,
|
| 20 |
+
vit_ckpt_layer = 0,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 25 |
+
image_size (int): input image size
|
| 26 |
+
vit (str): model size of vision transformer
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
| 31 |
+
self.tokenizer = init_tokenizer()
|
| 32 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 33 |
+
med_config.encoder_width = vision_width
|
| 34 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 35 |
+
|
| 36 |
+
self.cls_head = nn.Sequential(
|
| 37 |
+
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
|
| 38 |
+
nn.ReLU(),
|
| 39 |
+
nn.Linear(self.text_encoder.config.hidden_size, 2)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, image, text, targets, train=True):
|
| 43 |
+
|
| 44 |
+
image_embeds = self.visual_encoder(image)
|
| 45 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 46 |
+
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
|
| 47 |
+
|
| 48 |
+
text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
|
| 49 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 50 |
+
|
| 51 |
+
output = self.text_encoder(text.input_ids,
|
| 52 |
+
attention_mask = text.attention_mask,
|
| 53 |
+
encoder_hidden_states = [image0_embeds,image1_embeds],
|
| 54 |
+
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
|
| 55 |
+
image_atts[image0_embeds.size(0):]],
|
| 56 |
+
return_dict = True,
|
| 57 |
+
)
|
| 58 |
+
hidden_state = output.last_hidden_state[:,0,:]
|
| 59 |
+
prediction = self.cls_head(hidden_state)
|
| 60 |
+
|
| 61 |
+
if train:
|
| 62 |
+
loss = F.cross_entropy(prediction, targets)
|
| 63 |
+
return loss
|
| 64 |
+
else:
|
| 65 |
+
return prediction
|
| 66 |
+
|
| 67 |
+
def blip_nlvr(pretrained='',**kwargs):
|
| 68 |
+
model = BLIP_NLVR(**kwargs)
|
| 69 |
+
if pretrained:
|
| 70 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 71 |
+
print("missing keys:")
|
| 72 |
+
print(msg.missing_keys)
|
| 73 |
+
return model
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_checkpoint(model,url_or_filename):
|
| 77 |
+
if is_url(url_or_filename):
|
| 78 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 79 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 80 |
+
elif os.path.isfile(url_or_filename):
|
| 81 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 82 |
+
else:
|
| 83 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 84 |
+
state_dict = checkpoint['model']
|
| 85 |
+
|
| 86 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 87 |
+
|
| 88 |
+
for key in list(state_dict.keys()):
|
| 89 |
+
if 'crossattention.self.' in key:
|
| 90 |
+
new_key0 = key.replace('self','self0')
|
| 91 |
+
new_key1 = key.replace('self','self1')
|
| 92 |
+
state_dict[new_key0] = state_dict[key]
|
| 93 |
+
state_dict[new_key1] = state_dict[key]
|
| 94 |
+
elif 'crossattention.output.dense.' in key:
|
| 95 |
+
new_key0 = key.replace('dense','dense0')
|
| 96 |
+
new_key1 = key.replace('dense','dense1')
|
| 97 |
+
state_dict[new_key0] = state_dict[key]
|
| 98 |
+
state_dict[new_key1] = state_dict[key]
|
| 99 |
+
|
| 100 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 101 |
+
print('load checkpoint from %s'%url_or_filename)
|
| 102 |
+
return model,msg
|
| 103 |
+
|
defake/models/blip_pretrain.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
'''
|
| 8 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
| 9 |
+
from transformers import BertTokenizer
|
| 10 |
+
import transformers
|
| 11 |
+
transformers.logging.set_verbosity_error()
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 18 |
+
|
| 19 |
+
class BLIP_Pretrain(nn.Module):
|
| 20 |
+
def __init__(self,
|
| 21 |
+
med_config = 'configs/bert_config.json',
|
| 22 |
+
image_size = 224,
|
| 23 |
+
vit = 'base',
|
| 24 |
+
vit_grad_ckpt = False,
|
| 25 |
+
vit_ckpt_layer = 0,
|
| 26 |
+
embed_dim = 256,
|
| 27 |
+
queue_size = 57600,
|
| 28 |
+
momentum = 0.995,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 33 |
+
image_size (int): input image size
|
| 34 |
+
vit (str): model size of vision transformer
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
| 39 |
+
|
| 40 |
+
if vit=='base':
|
| 41 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 42 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 43 |
+
map_location="cpu", check_hash=True)
|
| 44 |
+
state_dict = checkpoint["model"]
|
| 45 |
+
msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
|
| 46 |
+
elif vit=='large':
|
| 47 |
+
from timm.models.helpers import load_custom_pretrained
|
| 48 |
+
from timm.models.vision_transformer import default_cfgs
|
| 49 |
+
load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
|
| 50 |
+
|
| 51 |
+
self.tokenizer = init_tokenizer()
|
| 52 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
| 53 |
+
encoder_config.encoder_width = vision_width
|
| 54 |
+
self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
|
| 55 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
| 56 |
+
|
| 57 |
+
text_width = self.text_encoder.config.hidden_size
|
| 58 |
+
|
| 59 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 60 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 61 |
+
|
| 62 |
+
self.itm_head = nn.Linear(text_width, 2)
|
| 63 |
+
|
| 64 |
+
# create momentum encoders
|
| 65 |
+
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
| 66 |
+
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
| 67 |
+
self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 68 |
+
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
| 69 |
+
|
| 70 |
+
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
| 71 |
+
[self.vision_proj,self.vision_proj_m],
|
| 72 |
+
[self.text_encoder,self.text_encoder_m],
|
| 73 |
+
[self.text_proj,self.text_proj_m],
|
| 74 |
+
]
|
| 75 |
+
self.copy_params()
|
| 76 |
+
|
| 77 |
+
# create the queue
|
| 78 |
+
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
| 79 |
+
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
| 80 |
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
| 81 |
+
|
| 82 |
+
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
| 83 |
+
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
| 84 |
+
|
| 85 |
+
self.queue_size = queue_size
|
| 86 |
+
self.momentum = momentum
|
| 87 |
+
self.temp = nn.Parameter(0.07*torch.ones([]))
|
| 88 |
+
|
| 89 |
+
# create the decoder
|
| 90 |
+
decoder_config = BertConfig.from_json_file(med_config)
|
| 91 |
+
decoder_config.encoder_width = vision_width
|
| 92 |
+
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
|
| 93 |
+
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
|
| 94 |
+
tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def forward(self, image, caption, alpha):
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
self.temp.clamp_(0.001,0.5)
|
| 100 |
+
|
| 101 |
+
image_embeds = self.visual_encoder(image)
|
| 102 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 103 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 104 |
+
|
| 105 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
|
| 106 |
+
return_tensors="pt").to(image.device)
|
| 107 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 108 |
+
return_dict = True, mode = 'text')
|
| 109 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 110 |
+
|
| 111 |
+
# get momentum features
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
self._momentum_update()
|
| 114 |
+
image_embeds_m = self.visual_encoder_m(image)
|
| 115 |
+
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
| 116 |
+
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
| 117 |
+
|
| 118 |
+
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
| 119 |
+
return_dict = True, mode = 'text')
|
| 120 |
+
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
| 121 |
+
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
| 122 |
+
|
| 123 |
+
sim_i2t_m = image_feat_m @ text_feat_all / self.temp
|
| 124 |
+
sim_t2i_m = text_feat_m @ image_feat_all / self.temp
|
| 125 |
+
|
| 126 |
+
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
|
| 127 |
+
sim_targets.fill_diagonal_(1)
|
| 128 |
+
|
| 129 |
+
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
| 130 |
+
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
| 131 |
+
|
| 132 |
+
sim_i2t = image_feat @ text_feat_all / self.temp
|
| 133 |
+
sim_t2i = text_feat @ image_feat_all / self.temp
|
| 134 |
+
|
| 135 |
+
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
| 136 |
+
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
| 137 |
+
|
| 138 |
+
loss_ita = (loss_i2t+loss_t2i)/2
|
| 139 |
+
|
| 140 |
+
self._dequeue_and_enqueue(image_feat_m, text_feat_m)
|
| 141 |
+
|
| 142 |
+
###============== Image-text Matching ===================###
|
| 143 |
+
encoder_input_ids = text.input_ids.clone()
|
| 144 |
+
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
| 145 |
+
|
| 146 |
+
# forward the positve image-text pair
|
| 147 |
+
bs = image.size(0)
|
| 148 |
+
output_pos = self.text_encoder(encoder_input_ids,
|
| 149 |
+
attention_mask = text.attention_mask,
|
| 150 |
+
encoder_hidden_states = image_embeds,
|
| 151 |
+
encoder_attention_mask = image_atts,
|
| 152 |
+
return_dict = True,
|
| 153 |
+
)
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
|
| 156 |
+
weights_t2i.fill_diagonal_(0)
|
| 157 |
+
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
|
| 158 |
+
weights_i2t.fill_diagonal_(0)
|
| 159 |
+
|
| 160 |
+
# select a negative image for each text
|
| 161 |
+
image_embeds_neg = []
|
| 162 |
+
for b in range(bs):
|
| 163 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 164 |
+
image_embeds_neg.append(image_embeds[neg_idx])
|
| 165 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 166 |
+
|
| 167 |
+
# select a negative text for each image
|
| 168 |
+
text_ids_neg = []
|
| 169 |
+
text_atts_neg = []
|
| 170 |
+
for b in range(bs):
|
| 171 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 172 |
+
text_ids_neg.append(encoder_input_ids[neg_idx])
|
| 173 |
+
text_atts_neg.append(text.attention_mask[neg_idx])
|
| 174 |
+
|
| 175 |
+
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
| 176 |
+
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
| 177 |
+
|
| 178 |
+
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
| 179 |
+
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
| 180 |
+
|
| 181 |
+
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
| 182 |
+
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
| 183 |
+
|
| 184 |
+
output_neg = self.text_encoder(text_ids_all,
|
| 185 |
+
attention_mask = text_atts_all,
|
| 186 |
+
encoder_hidden_states = image_embeds_all,
|
| 187 |
+
encoder_attention_mask = image_atts_all,
|
| 188 |
+
return_dict = True,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
| 192 |
+
vl_output = self.itm_head(vl_embeddings)
|
| 193 |
+
|
| 194 |
+
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
| 195 |
+
dim=0).to(image.device)
|
| 196 |
+
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
| 197 |
+
|
| 198 |
+
##================= LM ========================##
|
| 199 |
+
decoder_input_ids = text.input_ids.clone()
|
| 200 |
+
decoder_input_ids[:,0] = self.tokenizer.bos_token_id
|
| 201 |
+
decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
|
| 202 |
+
|
| 203 |
+
decoder_output = self.text_decoder(decoder_input_ids,
|
| 204 |
+
attention_mask = text.attention_mask,
|
| 205 |
+
encoder_hidden_states = image_embeds,
|
| 206 |
+
encoder_attention_mask = image_atts,
|
| 207 |
+
labels = decoder_targets,
|
| 208 |
+
return_dict = True,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
loss_lm = decoder_output.loss
|
| 212 |
+
return loss_ita, loss_itm, loss_lm
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def copy_params(self):
|
| 218 |
+
for model_pair in self.model_pairs:
|
| 219 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 220 |
+
param_m.data.copy_(param.data) # initialize
|
| 221 |
+
param_m.requires_grad = False # not update by gradient
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def _momentum_update(self):
|
| 226 |
+
for model_pair in self.model_pairs:
|
| 227 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 228 |
+
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@torch.no_grad()
|
| 232 |
+
def _dequeue_and_enqueue(self, image_feat, text_feat):
|
| 233 |
+
# gather keys before updating queue
|
| 234 |
+
image_feats = concat_all_gather(image_feat)
|
| 235 |
+
text_feats = concat_all_gather(text_feat)
|
| 236 |
+
|
| 237 |
+
batch_size = image_feats.shape[0]
|
| 238 |
+
|
| 239 |
+
ptr = int(self.queue_ptr)
|
| 240 |
+
assert self.queue_size % batch_size == 0 # for simplicity
|
| 241 |
+
|
| 242 |
+
# replace the keys at ptr (dequeue and enqueue)
|
| 243 |
+
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
| 244 |
+
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
| 245 |
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
| 246 |
+
|
| 247 |
+
self.queue_ptr[0] = ptr
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def blip_pretrain(**kwargs):
|
| 251 |
+
model = BLIP_Pretrain(**kwargs)
|
| 252 |
+
return model
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@torch.no_grad()
|
| 256 |
+
def concat_all_gather(tensor):
|
| 257 |
+
"""
|
| 258 |
+
Performs all_gather operation on the provided tensors.
|
| 259 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 260 |
+
"""
|
| 261 |
+
tensors_gather = [torch.ones_like(tensor)
|
| 262 |
+
for _ in range(torch.distributed.get_world_size())]
|
| 263 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 264 |
+
|
| 265 |
+
output = torch.cat(tensors_gather, dim=0)
|
| 266 |
+
return output
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
from typing import List
|
| 270 |
+
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
|
| 271 |
+
uninitialized_encoder_weights: List[str] = []
|
| 272 |
+
if decoder.__class__ != encoder.__class__:
|
| 273 |
+
logger.info(
|
| 274 |
+
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def tie_encoder_to_decoder_recursively(
|
| 278 |
+
decoder_pointer: nn.Module,
|
| 279 |
+
encoder_pointer: nn.Module,
|
| 280 |
+
module_name: str,
|
| 281 |
+
uninitialized_encoder_weights: List[str],
|
| 282 |
+
skip_key: str,
|
| 283 |
+
depth=0,
|
| 284 |
+
):
|
| 285 |
+
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
| 286 |
+
encoder_pointer, nn.Module
|
| 287 |
+
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
|
| 288 |
+
if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
|
| 289 |
+
assert hasattr(encoder_pointer, "weight")
|
| 290 |
+
encoder_pointer.weight = decoder_pointer.weight
|
| 291 |
+
if hasattr(decoder_pointer, "bias"):
|
| 292 |
+
assert hasattr(encoder_pointer, "bias")
|
| 293 |
+
encoder_pointer.bias = decoder_pointer.bias
|
| 294 |
+
print(module_name+' is tied')
|
| 295 |
+
return
|
| 296 |
+
|
| 297 |
+
encoder_modules = encoder_pointer._modules
|
| 298 |
+
decoder_modules = decoder_pointer._modules
|
| 299 |
+
if len(decoder_modules) > 0:
|
| 300 |
+
assert (
|
| 301 |
+
len(encoder_modules) > 0
|
| 302 |
+
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
| 303 |
+
|
| 304 |
+
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
|
| 305 |
+
encoder_layer_pos = 0
|
| 306 |
+
for name, module in decoder_modules.items():
|
| 307 |
+
if name.isdigit():
|
| 308 |
+
encoder_name = str(int(name) + encoder_layer_pos)
|
| 309 |
+
decoder_name = name
|
| 310 |
+
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
|
| 311 |
+
encoder_modules
|
| 312 |
+
) != len(decoder_modules):
|
| 313 |
+
# this can happen if the name corresponds to the position in a list module list of layers
|
| 314 |
+
# in this case the decoder has added a cross-attention that the encoder does not have
|
| 315 |
+
# thus skip this step and subtract one layer pos from encoder
|
| 316 |
+
encoder_layer_pos -= 1
|
| 317 |
+
continue
|
| 318 |
+
elif name not in encoder_modules:
|
| 319 |
+
continue
|
| 320 |
+
elif depth > 500:
|
| 321 |
+
raise ValueError(
|
| 322 |
+
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
decoder_name = encoder_name = name
|
| 326 |
+
tie_encoder_to_decoder_recursively(
|
| 327 |
+
decoder_modules[decoder_name],
|
| 328 |
+
encoder_modules[encoder_name],
|
| 329 |
+
module_name + "/" + name,
|
| 330 |
+
uninitialized_encoder_weights,
|
| 331 |
+
skip_key,
|
| 332 |
+
depth=depth + 1,
|
| 333 |
+
)
|
| 334 |
+
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
| 335 |
+
|
| 336 |
+
uninitialized_encoder_weights += list(all_encoder_weights)
|
| 337 |
+
|
| 338 |
+
# tie weights recursively
|
| 339 |
+
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
|
defake/models/blip_retrieval.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig, BertModel
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 9 |
+
|
| 10 |
+
class BLIP_Retrieval(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
med_config = 'configs/med_config.json',
|
| 13 |
+
image_size = 384,
|
| 14 |
+
vit = 'base',
|
| 15 |
+
vit_grad_ckpt = False,
|
| 16 |
+
vit_ckpt_layer = 0,
|
| 17 |
+
embed_dim = 256,
|
| 18 |
+
queue_size = 57600,
|
| 19 |
+
momentum = 0.995,
|
| 20 |
+
negative_all_rank = False,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 25 |
+
image_size (int): input image size
|
| 26 |
+
vit (str): model size of vision transformer
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 31 |
+
self.tokenizer = init_tokenizer()
|
| 32 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 33 |
+
med_config.encoder_width = vision_width
|
| 34 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 35 |
+
|
| 36 |
+
text_width = self.text_encoder.config.hidden_size
|
| 37 |
+
|
| 38 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 39 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 40 |
+
|
| 41 |
+
self.itm_head = nn.Linear(text_width, 2)
|
| 42 |
+
|
| 43 |
+
# create momentum encoders
|
| 44 |
+
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
| 45 |
+
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
| 46 |
+
self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
|
| 47 |
+
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
| 48 |
+
|
| 49 |
+
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
| 50 |
+
[self.vision_proj,self.vision_proj_m],
|
| 51 |
+
[self.text_encoder,self.text_encoder_m],
|
| 52 |
+
[self.text_proj,self.text_proj_m],
|
| 53 |
+
]
|
| 54 |
+
self.copy_params()
|
| 55 |
+
|
| 56 |
+
# create the queue
|
| 57 |
+
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
| 58 |
+
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
| 59 |
+
self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
|
| 60 |
+
self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
|
| 61 |
+
|
| 62 |
+
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
| 63 |
+
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
| 64 |
+
|
| 65 |
+
self.queue_size = queue_size
|
| 66 |
+
self.momentum = momentum
|
| 67 |
+
self.temp = nn.Parameter(0.07*torch.ones([]))
|
| 68 |
+
|
| 69 |
+
self.negative_all_rank = negative_all_rank
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def forward(self, image, caption, alpha, idx):
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
self.temp.clamp_(0.001,0.5)
|
| 75 |
+
|
| 76 |
+
image_embeds = self.visual_encoder(image)
|
| 77 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 78 |
+
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 79 |
+
|
| 80 |
+
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
| 81 |
+
return_tensors="pt").to(image.device)
|
| 82 |
+
|
| 83 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 84 |
+
return_dict = True, mode = 'text')
|
| 85 |
+
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 86 |
+
|
| 87 |
+
###============== Image-text Contrastive Learning ===================###
|
| 88 |
+
idx = idx.view(-1,1)
|
| 89 |
+
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
|
| 90 |
+
pos_idx = torch.eq(idx, idx_all).float()
|
| 91 |
+
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
|
| 92 |
+
|
| 93 |
+
# get momentum features
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
self._momentum_update()
|
| 96 |
+
image_embeds_m = self.visual_encoder_m(image)
|
| 97 |
+
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
| 98 |
+
image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
| 99 |
+
|
| 100 |
+
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
| 101 |
+
return_dict = True, mode = 'text')
|
| 102 |
+
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
| 103 |
+
text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
| 104 |
+
|
| 105 |
+
sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
|
| 106 |
+
sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
|
| 107 |
+
|
| 108 |
+
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
| 109 |
+
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
| 110 |
+
|
| 111 |
+
sim_i2t = image_feat @ text_feat_m_all / self.temp
|
| 112 |
+
sim_t2i = text_feat @ image_feat_m_all / self.temp
|
| 113 |
+
|
| 114 |
+
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
| 115 |
+
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
| 116 |
+
|
| 117 |
+
loss_ita = (loss_i2t+loss_t2i)/2
|
| 118 |
+
|
| 119 |
+
idxs = concat_all_gather(idx)
|
| 120 |
+
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
|
| 121 |
+
|
| 122 |
+
###============== Image-text Matching ===================###
|
| 123 |
+
encoder_input_ids = text.input_ids.clone()
|
| 124 |
+
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
| 125 |
+
|
| 126 |
+
# forward the positve image-text pair
|
| 127 |
+
bs = image.size(0)
|
| 128 |
+
output_pos = self.text_encoder(encoder_input_ids,
|
| 129 |
+
attention_mask = text.attention_mask,
|
| 130 |
+
encoder_hidden_states = image_embeds,
|
| 131 |
+
encoder_attention_mask = image_atts,
|
| 132 |
+
return_dict = True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if self.negative_all_rank:
|
| 137 |
+
# compute sample similarity
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
mask = torch.eq(idx, idxs.t())
|
| 140 |
+
|
| 141 |
+
image_feat_world = concat_all_gather(image_feat)
|
| 142 |
+
text_feat_world = concat_all_gather(text_feat)
|
| 143 |
+
|
| 144 |
+
sim_i2t = image_feat @ text_feat_world.t() / self.temp
|
| 145 |
+
sim_t2i = text_feat @ image_feat_world.t() / self.temp
|
| 146 |
+
|
| 147 |
+
weights_i2t = F.softmax(sim_i2t,dim=1)
|
| 148 |
+
weights_i2t.masked_fill_(mask, 0)
|
| 149 |
+
|
| 150 |
+
weights_t2i = F.softmax(sim_t2i,dim=1)
|
| 151 |
+
weights_t2i.masked_fill_(mask, 0)
|
| 152 |
+
|
| 153 |
+
image_embeds_world = all_gather_with_grad(image_embeds)
|
| 154 |
+
|
| 155 |
+
# select a negative image (from all ranks) for each text
|
| 156 |
+
image_embeds_neg = []
|
| 157 |
+
for b in range(bs):
|
| 158 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 159 |
+
image_embeds_neg.append(image_embeds_world[neg_idx])
|
| 160 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 161 |
+
|
| 162 |
+
# select a negative text (from all ranks) for each image
|
| 163 |
+
input_ids_world = concat_all_gather(encoder_input_ids)
|
| 164 |
+
att_mask_world = concat_all_gather(text.attention_mask)
|
| 165 |
+
|
| 166 |
+
text_ids_neg = []
|
| 167 |
+
text_atts_neg = []
|
| 168 |
+
for b in range(bs):
|
| 169 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 170 |
+
text_ids_neg.append(input_ids_world[neg_idx])
|
| 171 |
+
text_atts_neg.append(att_mask_world[neg_idx])
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
mask = torch.eq(idx, idx.t())
|
| 176 |
+
|
| 177 |
+
sim_i2t = image_feat @ text_feat.t() / self.temp
|
| 178 |
+
sim_t2i = text_feat @ image_feat.t() / self.temp
|
| 179 |
+
|
| 180 |
+
weights_i2t = F.softmax(sim_i2t,dim=1)
|
| 181 |
+
weights_i2t.masked_fill_(mask, 0)
|
| 182 |
+
|
| 183 |
+
weights_t2i = F.softmax(sim_t2i,dim=1)
|
| 184 |
+
weights_t2i.masked_fill_(mask, 0)
|
| 185 |
+
|
| 186 |
+
# select a negative image (from same rank) for each text
|
| 187 |
+
image_embeds_neg = []
|
| 188 |
+
for b in range(bs):
|
| 189 |
+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 190 |
+
image_embeds_neg.append(image_embeds[neg_idx])
|
| 191 |
+
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 192 |
+
|
| 193 |
+
# select a negative text (from same rank) for each image
|
| 194 |
+
text_ids_neg = []
|
| 195 |
+
text_atts_neg = []
|
| 196 |
+
for b in range(bs):
|
| 197 |
+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 198 |
+
text_ids_neg.append(encoder_input_ids[neg_idx])
|
| 199 |
+
text_atts_neg.append(text.attention_mask[neg_idx])
|
| 200 |
+
|
| 201 |
+
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
| 202 |
+
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
| 203 |
+
|
| 204 |
+
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
| 205 |
+
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
| 206 |
+
|
| 207 |
+
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
| 208 |
+
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
| 209 |
+
|
| 210 |
+
output_neg = self.text_encoder(text_ids_all,
|
| 211 |
+
attention_mask = text_atts_all,
|
| 212 |
+
encoder_hidden_states = image_embeds_all,
|
| 213 |
+
encoder_attention_mask = image_atts_all,
|
| 214 |
+
return_dict = True,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
| 219 |
+
vl_output = self.itm_head(vl_embeddings)
|
| 220 |
+
|
| 221 |
+
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
| 222 |
+
dim=0).to(image.device)
|
| 223 |
+
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
| 224 |
+
|
| 225 |
+
return loss_ita, loss_itm
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@torch.no_grad()
|
| 229 |
+
def copy_params(self):
|
| 230 |
+
for model_pair in self.model_pairs:
|
| 231 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 232 |
+
param_m.data.copy_(param.data) # initialize
|
| 233 |
+
param_m.requires_grad = False # not update by gradient
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@torch.no_grad()
|
| 237 |
+
def _momentum_update(self):
|
| 238 |
+
for model_pair in self.model_pairs:
|
| 239 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 240 |
+
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@torch.no_grad()
|
| 244 |
+
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
|
| 245 |
+
# gather keys before updating queue
|
| 246 |
+
image_feats = concat_all_gather(image_feat)
|
| 247 |
+
text_feats = concat_all_gather(text_feat)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
batch_size = image_feats.shape[0]
|
| 251 |
+
|
| 252 |
+
ptr = int(self.ptr_queue)
|
| 253 |
+
assert self.queue_size % batch_size == 0 # for simplicity
|
| 254 |
+
|
| 255 |
+
# replace the keys at ptr (dequeue and enqueue)
|
| 256 |
+
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
| 257 |
+
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
| 258 |
+
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
|
| 259 |
+
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
| 260 |
+
|
| 261 |
+
self.ptr_queue[0] = ptr
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def blip_retrieval(pretrained='',**kwargs):
|
| 265 |
+
model = BLIP_Retrieval(**kwargs)
|
| 266 |
+
if pretrained:
|
| 267 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 268 |
+
print("missing keys:")
|
| 269 |
+
print(msg.missing_keys)
|
| 270 |
+
return model
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@torch.no_grad()
|
| 274 |
+
def concat_all_gather(tensor):
|
| 275 |
+
"""
|
| 276 |
+
Performs all_gather operation on the provided tensors.
|
| 277 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 278 |
+
"""
|
| 279 |
+
tensors_gather = [torch.ones_like(tensor)
|
| 280 |
+
for _ in range(torch.distributed.get_world_size())]
|
| 281 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 282 |
+
|
| 283 |
+
output = torch.cat(tensors_gather, dim=0)
|
| 284 |
+
return output
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class GatherLayer(torch.autograd.Function):
|
| 288 |
+
"""
|
| 289 |
+
Gather tensors from all workers with support for backward propagation:
|
| 290 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def forward(ctx, x):
|
| 295 |
+
output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
|
| 296 |
+
torch.distributed.all_gather(output, x)
|
| 297 |
+
return tuple(output)
|
| 298 |
+
|
| 299 |
+
@staticmethod
|
| 300 |
+
def backward(ctx, *grads):
|
| 301 |
+
all_gradients = torch.stack(grads)
|
| 302 |
+
torch.distributed.all_reduce(all_gradients)
|
| 303 |
+
return all_gradients[torch.distributed.get_rank()]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def all_gather_with_grad(tensors):
|
| 307 |
+
"""
|
| 308 |
+
Performs all_gather operation on the provided tensors.
|
| 309 |
+
Graph remains connected for backward grad computation.
|
| 310 |
+
"""
|
| 311 |
+
# Queue the gathered tensors
|
| 312 |
+
world_size = torch.distributed.get_world_size()
|
| 313 |
+
# There is no need for reduction in the single-proc case
|
| 314 |
+
if world_size == 1:
|
| 315 |
+
return tensors
|
| 316 |
+
|
| 317 |
+
tensor_all = GatherLayer.apply(tensors)
|
| 318 |
+
|
| 319 |
+
return torch.cat(tensor_all, dim=0)
|
defake/models/blip_vqa.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
| 2 |
+
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import BertTokenizer
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
class BLIP_VQA(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
med_config = 'configs/med_config.json',
|
| 13 |
+
image_size = 480,
|
| 14 |
+
vit = 'base',
|
| 15 |
+
vit_grad_ckpt = False,
|
| 16 |
+
vit_ckpt_layer = 0,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Args:
|
| 20 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 21 |
+
image_size (int): input image size
|
| 22 |
+
vit (str): model size of vision transformer
|
| 23 |
+
"""
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
| 27 |
+
self.tokenizer = init_tokenizer()
|
| 28 |
+
|
| 29 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
| 30 |
+
encoder_config.encoder_width = vision_width
|
| 31 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 32 |
+
|
| 33 |
+
decoder_config = BertConfig.from_json_file(med_config)
|
| 34 |
+
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
|
| 38 |
+
|
| 39 |
+
image_embeds = self.visual_encoder(image)
|
| 40 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 41 |
+
|
| 42 |
+
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
|
| 43 |
+
return_tensors="pt").to(image.device)
|
| 44 |
+
question.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 45 |
+
|
| 46 |
+
if train:
|
| 47 |
+
'''
|
| 48 |
+
n: number of answers for each question
|
| 49 |
+
weights: weight for each answer
|
| 50 |
+
'''
|
| 51 |
+
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
|
| 52 |
+
answer.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 53 |
+
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
|
| 54 |
+
|
| 55 |
+
question_output = self.text_encoder(question.input_ids,
|
| 56 |
+
attention_mask = question.attention_mask,
|
| 57 |
+
encoder_hidden_states = image_embeds,
|
| 58 |
+
encoder_attention_mask = image_atts,
|
| 59 |
+
return_dict = True)
|
| 60 |
+
|
| 61 |
+
question_states = []
|
| 62 |
+
question_atts = []
|
| 63 |
+
for b, n in enumerate(n):
|
| 64 |
+
question_states += [question_output.last_hidden_state[b]]*n
|
| 65 |
+
question_atts += [question.attention_mask[b]]*n
|
| 66 |
+
question_states = torch.stack(question_states,0)
|
| 67 |
+
question_atts = torch.stack(question_atts,0)
|
| 68 |
+
|
| 69 |
+
answer_output = self.text_decoder(answer.input_ids,
|
| 70 |
+
attention_mask = answer.attention_mask,
|
| 71 |
+
encoder_hidden_states = question_states,
|
| 72 |
+
encoder_attention_mask = question_atts,
|
| 73 |
+
labels = answer_targets,
|
| 74 |
+
return_dict = True,
|
| 75 |
+
reduction = 'none',
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
loss = weights * answer_output.loss
|
| 79 |
+
loss = loss.sum()/image.size(0)
|
| 80 |
+
|
| 81 |
+
return loss
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
question_output = self.text_encoder(question.input_ids,
|
| 86 |
+
attention_mask = question.attention_mask,
|
| 87 |
+
encoder_hidden_states = image_embeds,
|
| 88 |
+
encoder_attention_mask = image_atts,
|
| 89 |
+
return_dict = True)
|
| 90 |
+
|
| 91 |
+
if inference=='generate':
|
| 92 |
+
num_beams = 3
|
| 93 |
+
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
|
| 94 |
+
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
|
| 95 |
+
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
|
| 96 |
+
|
| 97 |
+
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
|
| 98 |
+
|
| 99 |
+
outputs = self.text_decoder.generate(input_ids=bos_ids,
|
| 100 |
+
max_length=10,
|
| 101 |
+
min_length=1,
|
| 102 |
+
num_beams=num_beams,
|
| 103 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 104 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 105 |
+
**model_kwargs)
|
| 106 |
+
|
| 107 |
+
answers = []
|
| 108 |
+
for output in outputs:
|
| 109 |
+
answer = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 110 |
+
answers.append(answer)
|
| 111 |
+
return answers
|
| 112 |
+
|
| 113 |
+
elif inference=='rank':
|
| 114 |
+
max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
|
| 115 |
+
answer.input_ids, answer.attention_mask, k_test)
|
| 116 |
+
return max_ids
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
|
| 121 |
+
|
| 122 |
+
num_ques = question_states.size(0)
|
| 123 |
+
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
|
| 124 |
+
|
| 125 |
+
start_output = self.text_decoder(start_ids,
|
| 126 |
+
encoder_hidden_states = question_states,
|
| 127 |
+
encoder_attention_mask = question_atts,
|
| 128 |
+
return_dict = True,
|
| 129 |
+
reduction = 'none')
|
| 130 |
+
logits = start_output.logits[:,0,:] # first token's logit
|
| 131 |
+
|
| 132 |
+
# topk_probs: top-k probability
|
| 133 |
+
# topk_ids: [num_question, k]
|
| 134 |
+
answer_first_token = answer_ids[:,1]
|
| 135 |
+
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
|
| 136 |
+
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
|
| 137 |
+
|
| 138 |
+
# answer input: [num_question*k, answer_len]
|
| 139 |
+
input_ids = []
|
| 140 |
+
input_atts = []
|
| 141 |
+
for b, topk_id in enumerate(topk_ids):
|
| 142 |
+
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
| 143 |
+
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
| 144 |
+
input_ids = torch.cat(input_ids,dim=0)
|
| 145 |
+
input_atts = torch.cat(input_atts,dim=0)
|
| 146 |
+
|
| 147 |
+
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
|
| 148 |
+
|
| 149 |
+
# repeat encoder's output for top-k answers
|
| 150 |
+
question_states = tile(question_states, 0, k)
|
| 151 |
+
question_atts = tile(question_atts, 0, k)
|
| 152 |
+
|
| 153 |
+
output = self.text_decoder(input_ids,
|
| 154 |
+
attention_mask = input_atts,
|
| 155 |
+
encoder_hidden_states = question_states,
|
| 156 |
+
encoder_attention_mask = question_atts,
|
| 157 |
+
labels = targets_ids,
|
| 158 |
+
return_dict = True,
|
| 159 |
+
reduction = 'none')
|
| 160 |
+
|
| 161 |
+
log_probs_sum = -output.loss
|
| 162 |
+
log_probs_sum = log_probs_sum.view(num_ques,k)
|
| 163 |
+
|
| 164 |
+
max_topk_ids = log_probs_sum.argmax(dim=1)
|
| 165 |
+
max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
|
| 166 |
+
|
| 167 |
+
return max_ids
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def blip_vqa(pretrained='',**kwargs):
|
| 171 |
+
model = BLIP_VQA(**kwargs)
|
| 172 |
+
if pretrained:
|
| 173 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 174 |
+
# assert(len(msg.missing_keys)==0)
|
| 175 |
+
return model
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def tile(x, dim, n_tile):
|
| 179 |
+
init_dim = x.size(dim)
|
| 180 |
+
repeat_idx = [1] * x.dim()
|
| 181 |
+
repeat_idx[dim] = n_tile
|
| 182 |
+
x = x.repeat(*(repeat_idx))
|
| 183 |
+
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
| 184 |
+
return torch.index_select(x, dim, order_index.to(x.device))
|
| 185 |
+
|
| 186 |
+
|
defake/models/med.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
* Based on huggingface code base
|
| 8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import os
|
| 13 |
+
import warnings
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import Tensor, device, dtype, nn
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torch.nn import CrossEntropyLoss
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from transformers.activations import ACT2FN
|
| 25 |
+
from transformers.file_utils import (
|
| 26 |
+
ModelOutput,
|
| 27 |
+
)
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 31 |
+
CausalLMOutputWithCrossAttentions,
|
| 32 |
+
MaskedLMOutput,
|
| 33 |
+
MultipleChoiceModelOutput,
|
| 34 |
+
NextSentencePredictorOutput,
|
| 35 |
+
QuestionAnsweringModelOutput,
|
| 36 |
+
SequenceClassifierOutput,
|
| 37 |
+
TokenClassifierOutput,
|
| 38 |
+
)
|
| 39 |
+
from transformers.modeling_utils import (
|
| 40 |
+
PreTrainedModel,
|
| 41 |
+
apply_chunking_to_forward,
|
| 42 |
+
find_pruneable_heads_and_indices,
|
| 43 |
+
prune_linear_layer,
|
| 44 |
+
)
|
| 45 |
+
from transformers.utils import logging
|
| 46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class BertEmbeddings(nn.Module):
|
| 53 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 59 |
+
|
| 60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 61 |
+
# any TensorFlow checkpoint file
|
| 62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 64 |
+
|
| 65 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 66 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 67 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 73 |
+
):
|
| 74 |
+
if input_ids is not None:
|
| 75 |
+
input_shape = input_ids.size()
|
| 76 |
+
else:
|
| 77 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 78 |
+
|
| 79 |
+
seq_length = input_shape[1]
|
| 80 |
+
|
| 81 |
+
if position_ids is None:
|
| 82 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 83 |
+
|
| 84 |
+
if inputs_embeds is None:
|
| 85 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 86 |
+
|
| 87 |
+
embeddings = inputs_embeds
|
| 88 |
+
|
| 89 |
+
if self.position_embedding_type == "absolute":
|
| 90 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 91 |
+
embeddings += position_embeddings
|
| 92 |
+
embeddings = self.LayerNorm(embeddings)
|
| 93 |
+
embeddings = self.dropout(embeddings)
|
| 94 |
+
return embeddings
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class BertSelfAttention(nn.Module):
|
| 98 |
+
def __init__(self, config, is_cross_attention):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.config = config
|
| 101 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 104 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.num_attention_heads = config.num_attention_heads
|
| 108 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 109 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 110 |
+
|
| 111 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 112 |
+
if is_cross_attention:
|
| 113 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 114 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 115 |
+
else:
|
| 116 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 117 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 118 |
+
|
| 119 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 120 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 121 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 122 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 123 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 124 |
+
self.save_attention = False
|
| 125 |
+
|
| 126 |
+
def save_attn_gradients(self, attn_gradients):
|
| 127 |
+
self.attn_gradients = attn_gradients
|
| 128 |
+
|
| 129 |
+
def get_attn_gradients(self):
|
| 130 |
+
return self.attn_gradients
|
| 131 |
+
|
| 132 |
+
def save_attention_map(self, attention_map):
|
| 133 |
+
self.attention_map = attention_map
|
| 134 |
+
|
| 135 |
+
def get_attention_map(self):
|
| 136 |
+
return self.attention_map
|
| 137 |
+
|
| 138 |
+
def transpose_for_scores(self, x):
|
| 139 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 140 |
+
x = x.view(*new_x_shape)
|
| 141 |
+
return x.permute(0, 2, 1, 3)
|
| 142 |
+
|
| 143 |
+
def forward(
|
| 144 |
+
self,
|
| 145 |
+
hidden_states,
|
| 146 |
+
attention_mask=None,
|
| 147 |
+
head_mask=None,
|
| 148 |
+
encoder_hidden_states=None,
|
| 149 |
+
encoder_attention_mask=None,
|
| 150 |
+
past_key_value=None,
|
| 151 |
+
output_attentions=False,
|
| 152 |
+
):
|
| 153 |
+
mixed_query_layer = self.query(hidden_states)
|
| 154 |
+
|
| 155 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 156 |
+
# and values come from an encoder; the attention mask needs to be
|
| 157 |
+
# such that the encoder's padding tokens are not attended to.
|
| 158 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 159 |
+
|
| 160 |
+
if is_cross_attention:
|
| 161 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 162 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 163 |
+
attention_mask = encoder_attention_mask
|
| 164 |
+
elif past_key_value is not None:
|
| 165 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 166 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 167 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 168 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 169 |
+
else:
|
| 170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 172 |
+
|
| 173 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 174 |
+
|
| 175 |
+
past_key_value = (key_layer, value_layer)
|
| 176 |
+
|
| 177 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 178 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 179 |
+
|
| 180 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 181 |
+
seq_length = hidden_states.size()[1]
|
| 182 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 183 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 184 |
+
distance = position_ids_l - position_ids_r
|
| 185 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 186 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 187 |
+
|
| 188 |
+
if self.position_embedding_type == "relative_key":
|
| 189 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 190 |
+
attention_scores = attention_scores + relative_position_scores
|
| 191 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 192 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 193 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 194 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 195 |
+
|
| 196 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 197 |
+
if attention_mask is not None:
|
| 198 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 199 |
+
attention_scores = attention_scores + attention_mask
|
| 200 |
+
|
| 201 |
+
# Normalize the attention scores to probabilities.
|
| 202 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 203 |
+
|
| 204 |
+
if is_cross_attention and self.save_attention:
|
| 205 |
+
self.save_attention_map(attention_probs)
|
| 206 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 207 |
+
|
| 208 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 209 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 210 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 211 |
+
|
| 212 |
+
# Mask heads if we want to
|
| 213 |
+
if head_mask is not None:
|
| 214 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 215 |
+
|
| 216 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 217 |
+
|
| 218 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 219 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 220 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 221 |
+
|
| 222 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 223 |
+
|
| 224 |
+
outputs = outputs + (past_key_value,)
|
| 225 |
+
return outputs
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class BertSelfOutput(nn.Module):
|
| 229 |
+
def __init__(self, config):
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 232 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 233 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 234 |
+
|
| 235 |
+
def forward(self, hidden_states, input_tensor):
|
| 236 |
+
hidden_states = self.dense(hidden_states)
|
| 237 |
+
hidden_states = self.dropout(hidden_states)
|
| 238 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 239 |
+
return hidden_states
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class BertAttention(nn.Module):
|
| 243 |
+
def __init__(self, config, is_cross_attention=False):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 246 |
+
self.output = BertSelfOutput(config)
|
| 247 |
+
self.pruned_heads = set()
|
| 248 |
+
|
| 249 |
+
def prune_heads(self, heads):
|
| 250 |
+
if len(heads) == 0:
|
| 251 |
+
return
|
| 252 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 253 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Prune linear layers
|
| 257 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 258 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 259 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 260 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 261 |
+
|
| 262 |
+
# Update hyper params and store pruned heads
|
| 263 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 264 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 265 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 266 |
+
|
| 267 |
+
def forward(
|
| 268 |
+
self,
|
| 269 |
+
hidden_states,
|
| 270 |
+
attention_mask=None,
|
| 271 |
+
head_mask=None,
|
| 272 |
+
encoder_hidden_states=None,
|
| 273 |
+
encoder_attention_mask=None,
|
| 274 |
+
past_key_value=None,
|
| 275 |
+
output_attentions=False,
|
| 276 |
+
):
|
| 277 |
+
self_outputs = self.self(
|
| 278 |
+
hidden_states,
|
| 279 |
+
attention_mask,
|
| 280 |
+
head_mask,
|
| 281 |
+
encoder_hidden_states,
|
| 282 |
+
encoder_attention_mask,
|
| 283 |
+
past_key_value,
|
| 284 |
+
output_attentions,
|
| 285 |
+
)
|
| 286 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 287 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 288 |
+
return outputs
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class BertIntermediate(nn.Module):
|
| 292 |
+
def __init__(self, config):
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 295 |
+
if isinstance(config.hidden_act, str):
|
| 296 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 297 |
+
else:
|
| 298 |
+
self.intermediate_act_fn = config.hidden_act
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states):
|
| 301 |
+
hidden_states = self.dense(hidden_states)
|
| 302 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 303 |
+
return hidden_states
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class BertOutput(nn.Module):
|
| 307 |
+
def __init__(self, config):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 312 |
+
|
| 313 |
+
def forward(self, hidden_states, input_tensor):
|
| 314 |
+
hidden_states = self.dense(hidden_states)
|
| 315 |
+
hidden_states = self.dropout(hidden_states)
|
| 316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 317 |
+
return hidden_states
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class BertLayer(nn.Module):
|
| 321 |
+
def __init__(self, config, layer_num):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.config = config
|
| 324 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 325 |
+
self.seq_len_dim = 1
|
| 326 |
+
self.attention = BertAttention(config)
|
| 327 |
+
self.layer_num = layer_num
|
| 328 |
+
if self.config.add_cross_attention:
|
| 329 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
| 330 |
+
self.intermediate = BertIntermediate(config)
|
| 331 |
+
self.output = BertOutput(config)
|
| 332 |
+
|
| 333 |
+
def forward(
|
| 334 |
+
self,
|
| 335 |
+
hidden_states,
|
| 336 |
+
attention_mask=None,
|
| 337 |
+
head_mask=None,
|
| 338 |
+
encoder_hidden_states=None,
|
| 339 |
+
encoder_attention_mask=None,
|
| 340 |
+
past_key_value=None,
|
| 341 |
+
output_attentions=False,
|
| 342 |
+
mode=None,
|
| 343 |
+
):
|
| 344 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 345 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 346 |
+
self_attention_outputs = self.attention(
|
| 347 |
+
hidden_states,
|
| 348 |
+
attention_mask,
|
| 349 |
+
head_mask,
|
| 350 |
+
output_attentions=output_attentions,
|
| 351 |
+
past_key_value=self_attn_past_key_value,
|
| 352 |
+
)
|
| 353 |
+
attention_output = self_attention_outputs[0]
|
| 354 |
+
|
| 355 |
+
outputs = self_attention_outputs[1:-1]
|
| 356 |
+
present_key_value = self_attention_outputs[-1]
|
| 357 |
+
|
| 358 |
+
if mode=='multimodal':
|
| 359 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
| 360 |
+
|
| 361 |
+
cross_attention_outputs = self.crossattention(
|
| 362 |
+
attention_output,
|
| 363 |
+
attention_mask,
|
| 364 |
+
head_mask,
|
| 365 |
+
encoder_hidden_states,
|
| 366 |
+
encoder_attention_mask,
|
| 367 |
+
output_attentions=output_attentions,
|
| 368 |
+
)
|
| 369 |
+
attention_output = cross_attention_outputs[0]
|
| 370 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 371 |
+
layer_output = apply_chunking_to_forward(
|
| 372 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 373 |
+
)
|
| 374 |
+
outputs = (layer_output,) + outputs
|
| 375 |
+
|
| 376 |
+
outputs = outputs + (present_key_value,)
|
| 377 |
+
|
| 378 |
+
return outputs
|
| 379 |
+
|
| 380 |
+
def feed_forward_chunk(self, attention_output):
|
| 381 |
+
intermediate_output = self.intermediate(attention_output)
|
| 382 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 383 |
+
return layer_output
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class BertEncoder(nn.Module):
|
| 387 |
+
def __init__(self, config):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.config = config
|
| 390 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
| 391 |
+
self.gradient_checkpointing = False
|
| 392 |
+
|
| 393 |
+
def forward(
|
| 394 |
+
self,
|
| 395 |
+
hidden_states,
|
| 396 |
+
attention_mask=None,
|
| 397 |
+
head_mask=None,
|
| 398 |
+
encoder_hidden_states=None,
|
| 399 |
+
encoder_attention_mask=None,
|
| 400 |
+
past_key_values=None,
|
| 401 |
+
use_cache=None,
|
| 402 |
+
output_attentions=False,
|
| 403 |
+
output_hidden_states=False,
|
| 404 |
+
return_dict=True,
|
| 405 |
+
mode='multimodal',
|
| 406 |
+
):
|
| 407 |
+
all_hidden_states = () if output_hidden_states else None
|
| 408 |
+
all_self_attentions = () if output_attentions else None
|
| 409 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 410 |
+
|
| 411 |
+
next_decoder_cache = () if use_cache else None
|
| 412 |
+
|
| 413 |
+
for i in range(self.config.num_hidden_layers):
|
| 414 |
+
layer_module = self.layer[i]
|
| 415 |
+
if output_hidden_states:
|
| 416 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 417 |
+
|
| 418 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 419 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 420 |
+
|
| 421 |
+
if self.gradient_checkpointing and self.training:
|
| 422 |
+
|
| 423 |
+
if use_cache:
|
| 424 |
+
logger.warn(
|
| 425 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 426 |
+
)
|
| 427 |
+
use_cache = False
|
| 428 |
+
|
| 429 |
+
def create_custom_forward(module):
|
| 430 |
+
def custom_forward(*inputs):
|
| 431 |
+
return module(*inputs, past_key_value, output_attentions)
|
| 432 |
+
|
| 433 |
+
return custom_forward
|
| 434 |
+
|
| 435 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 436 |
+
create_custom_forward(layer_module),
|
| 437 |
+
hidden_states,
|
| 438 |
+
attention_mask,
|
| 439 |
+
layer_head_mask,
|
| 440 |
+
encoder_hidden_states,
|
| 441 |
+
encoder_attention_mask,
|
| 442 |
+
mode=mode,
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
layer_outputs = layer_module(
|
| 446 |
+
hidden_states,
|
| 447 |
+
attention_mask,
|
| 448 |
+
layer_head_mask,
|
| 449 |
+
encoder_hidden_states,
|
| 450 |
+
encoder_attention_mask,
|
| 451 |
+
past_key_value,
|
| 452 |
+
output_attentions,
|
| 453 |
+
mode=mode,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
hidden_states = layer_outputs[0]
|
| 457 |
+
if use_cache:
|
| 458 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 459 |
+
if output_attentions:
|
| 460 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 461 |
+
|
| 462 |
+
if output_hidden_states:
|
| 463 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 464 |
+
|
| 465 |
+
if not return_dict:
|
| 466 |
+
return tuple(
|
| 467 |
+
v
|
| 468 |
+
for v in [
|
| 469 |
+
hidden_states,
|
| 470 |
+
next_decoder_cache,
|
| 471 |
+
all_hidden_states,
|
| 472 |
+
all_self_attentions,
|
| 473 |
+
all_cross_attentions,
|
| 474 |
+
]
|
| 475 |
+
if v is not None
|
| 476 |
+
)
|
| 477 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 478 |
+
last_hidden_state=hidden_states,
|
| 479 |
+
past_key_values=next_decoder_cache,
|
| 480 |
+
hidden_states=all_hidden_states,
|
| 481 |
+
attentions=all_self_attentions,
|
| 482 |
+
cross_attentions=all_cross_attentions,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class BertPooler(nn.Module):
|
| 487 |
+
def __init__(self, config):
|
| 488 |
+
super().__init__()
|
| 489 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 490 |
+
self.activation = nn.Tanh()
|
| 491 |
+
|
| 492 |
+
def forward(self, hidden_states):
|
| 493 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 494 |
+
# to the first token.
|
| 495 |
+
first_token_tensor = hidden_states[:, 0]
|
| 496 |
+
pooled_output = self.dense(first_token_tensor)
|
| 497 |
+
pooled_output = self.activation(pooled_output)
|
| 498 |
+
return pooled_output
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 502 |
+
def __init__(self, config):
|
| 503 |
+
super().__init__()
|
| 504 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 505 |
+
if isinstance(config.hidden_act, str):
|
| 506 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 507 |
+
else:
|
| 508 |
+
self.transform_act_fn = config.hidden_act
|
| 509 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 510 |
+
|
| 511 |
+
def forward(self, hidden_states):
|
| 512 |
+
hidden_states = self.dense(hidden_states)
|
| 513 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 514 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 515 |
+
return hidden_states
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class BertLMPredictionHead(nn.Module):
|
| 519 |
+
def __init__(self, config):
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 522 |
+
|
| 523 |
+
# The output weights are the same as the input embeddings, but there is
|
| 524 |
+
# an output-only bias for each token.
|
| 525 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 526 |
+
|
| 527 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 528 |
+
|
| 529 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 530 |
+
self.decoder.bias = self.bias
|
| 531 |
+
|
| 532 |
+
def forward(self, hidden_states):
|
| 533 |
+
hidden_states = self.transform(hidden_states)
|
| 534 |
+
hidden_states = self.decoder(hidden_states)
|
| 535 |
+
return hidden_states
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class BertOnlyMLMHead(nn.Module):
|
| 539 |
+
def __init__(self, config):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.predictions = BertLMPredictionHead(config)
|
| 542 |
+
|
| 543 |
+
def forward(self, sequence_output):
|
| 544 |
+
prediction_scores = self.predictions(sequence_output)
|
| 545 |
+
return prediction_scores
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 549 |
+
"""
|
| 550 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 551 |
+
models.
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
config_class = BertConfig
|
| 555 |
+
base_model_prefix = "bert"
|
| 556 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 557 |
+
|
| 558 |
+
def _init_weights(self, module):
|
| 559 |
+
""" Initialize the weights """
|
| 560 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 561 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 562 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 563 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 564 |
+
elif isinstance(module, nn.LayerNorm):
|
| 565 |
+
module.bias.data.zero_()
|
| 566 |
+
module.weight.data.fill_(1.0)
|
| 567 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 568 |
+
module.bias.data.zero_()
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
class BertModel(BertPreTrainedModel):
|
| 572 |
+
"""
|
| 573 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 574 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 575 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 576 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 577 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 578 |
+
input to the forward pass.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 582 |
+
super().__init__(config)
|
| 583 |
+
self.config = config
|
| 584 |
+
|
| 585 |
+
self.embeddings = BertEmbeddings(config)
|
| 586 |
+
|
| 587 |
+
self.encoder = BertEncoder(config)
|
| 588 |
+
|
| 589 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 590 |
+
|
| 591 |
+
self.init_weights()
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def get_input_embeddings(self):
|
| 595 |
+
return self.embeddings.word_embeddings
|
| 596 |
+
|
| 597 |
+
def set_input_embeddings(self, value):
|
| 598 |
+
self.embeddings.word_embeddings = value
|
| 599 |
+
|
| 600 |
+
def _prune_heads(self, heads_to_prune):
|
| 601 |
+
"""
|
| 602 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 603 |
+
class PreTrainedModel
|
| 604 |
+
"""
|
| 605 |
+
for layer, heads in heads_to_prune.items():
|
| 606 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
| 610 |
+
"""
|
| 611 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 612 |
+
|
| 613 |
+
Arguments:
|
| 614 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 615 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 616 |
+
input_shape (:obj:`Tuple[int]`):
|
| 617 |
+
The shape of the input to the model.
|
| 618 |
+
device: (:obj:`torch.device`):
|
| 619 |
+
The device of the input to the model.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 623 |
+
"""
|
| 624 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 625 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 626 |
+
if attention_mask.dim() == 3:
|
| 627 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 628 |
+
elif attention_mask.dim() == 2:
|
| 629 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 630 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 631 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 632 |
+
if is_decoder:
|
| 633 |
+
batch_size, seq_length = input_shape
|
| 634 |
+
|
| 635 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 636 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 637 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
| 638 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 639 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 640 |
+
|
| 641 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 642 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 643 |
+
causal_mask = torch.cat(
|
| 644 |
+
[
|
| 645 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
| 646 |
+
causal_mask,
|
| 647 |
+
],
|
| 648 |
+
axis=-1,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 652 |
+
else:
|
| 653 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 654 |
+
else:
|
| 655 |
+
raise ValueError(
|
| 656 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 657 |
+
input_shape, attention_mask.shape
|
| 658 |
+
)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 662 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 663 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 664 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 665 |
+
# effectively the same as removing these entirely.
|
| 666 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 667 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 668 |
+
return extended_attention_mask
|
| 669 |
+
|
| 670 |
+
def forward(
|
| 671 |
+
self,
|
| 672 |
+
input_ids=None,
|
| 673 |
+
attention_mask=None,
|
| 674 |
+
position_ids=None,
|
| 675 |
+
head_mask=None,
|
| 676 |
+
inputs_embeds=None,
|
| 677 |
+
encoder_embeds=None,
|
| 678 |
+
encoder_hidden_states=None,
|
| 679 |
+
encoder_attention_mask=None,
|
| 680 |
+
past_key_values=None,
|
| 681 |
+
use_cache=None,
|
| 682 |
+
output_attentions=None,
|
| 683 |
+
output_hidden_states=None,
|
| 684 |
+
return_dict=None,
|
| 685 |
+
is_decoder=False,
|
| 686 |
+
mode='multimodal',
|
| 687 |
+
):
|
| 688 |
+
r"""
|
| 689 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 690 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 691 |
+
the model is configured as a decoder.
|
| 692 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 693 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 694 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 695 |
+
- 1 for tokens that are **not masked**,
|
| 696 |
+
- 0 for tokens that are **masked**.
|
| 697 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 698 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 699 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 700 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 701 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 702 |
+
use_cache (:obj:`bool`, `optional`):
|
| 703 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 704 |
+
decoding (see :obj:`past_key_values`).
|
| 705 |
+
"""
|
| 706 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 707 |
+
output_hidden_states = (
|
| 708 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 709 |
+
)
|
| 710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 711 |
+
|
| 712 |
+
if is_decoder:
|
| 713 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 714 |
+
else:
|
| 715 |
+
use_cache = False
|
| 716 |
+
|
| 717 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 718 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 719 |
+
elif input_ids is not None:
|
| 720 |
+
input_shape = input_ids.size()
|
| 721 |
+
batch_size, seq_length = input_shape
|
| 722 |
+
device = input_ids.device
|
| 723 |
+
elif inputs_embeds is not None:
|
| 724 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 725 |
+
batch_size, seq_length = input_shape
|
| 726 |
+
device = inputs_embeds.device
|
| 727 |
+
elif encoder_embeds is not None:
|
| 728 |
+
input_shape = encoder_embeds.size()[:-1]
|
| 729 |
+
batch_size, seq_length = input_shape
|
| 730 |
+
device = encoder_embeds.device
|
| 731 |
+
else:
|
| 732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
| 733 |
+
|
| 734 |
+
# past_key_values_length
|
| 735 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 736 |
+
|
| 737 |
+
if attention_mask is None:
|
| 738 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 739 |
+
|
| 740 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 741 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 742 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
| 743 |
+
device, is_decoder)
|
| 744 |
+
|
| 745 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 746 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 747 |
+
if encoder_hidden_states is not None:
|
| 748 |
+
if type(encoder_hidden_states) == list:
|
| 749 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 750 |
+
else:
|
| 751 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 752 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 753 |
+
|
| 754 |
+
if type(encoder_attention_mask) == list:
|
| 755 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 756 |
+
elif encoder_attention_mask is None:
|
| 757 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 758 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 759 |
+
else:
|
| 760 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 761 |
+
else:
|
| 762 |
+
encoder_extended_attention_mask = None
|
| 763 |
+
|
| 764 |
+
# Prepare head mask if needed
|
| 765 |
+
# 1.0 in head_mask indicate we keep the head
|
| 766 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 767 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 768 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 769 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 770 |
+
|
| 771 |
+
if encoder_embeds is None:
|
| 772 |
+
embedding_output = self.embeddings(
|
| 773 |
+
input_ids=input_ids,
|
| 774 |
+
position_ids=position_ids,
|
| 775 |
+
inputs_embeds=inputs_embeds,
|
| 776 |
+
past_key_values_length=past_key_values_length,
|
| 777 |
+
)
|
| 778 |
+
else:
|
| 779 |
+
embedding_output = encoder_embeds
|
| 780 |
+
|
| 781 |
+
encoder_outputs = self.encoder(
|
| 782 |
+
embedding_output,
|
| 783 |
+
attention_mask=extended_attention_mask,
|
| 784 |
+
head_mask=head_mask,
|
| 785 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 786 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 787 |
+
past_key_values=past_key_values,
|
| 788 |
+
use_cache=use_cache,
|
| 789 |
+
output_attentions=output_attentions,
|
| 790 |
+
output_hidden_states=output_hidden_states,
|
| 791 |
+
return_dict=return_dict,
|
| 792 |
+
mode=mode,
|
| 793 |
+
)
|
| 794 |
+
sequence_output = encoder_outputs[0]
|
| 795 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 796 |
+
|
| 797 |
+
if not return_dict:
|
| 798 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 799 |
+
|
| 800 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 801 |
+
last_hidden_state=sequence_output,
|
| 802 |
+
pooler_output=pooled_output,
|
| 803 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 804 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 805 |
+
attentions=encoder_outputs.attentions,
|
| 806 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 812 |
+
|
| 813 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 814 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 815 |
+
|
| 816 |
+
def __init__(self, config):
|
| 817 |
+
super().__init__(config)
|
| 818 |
+
|
| 819 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 820 |
+
self.cls = BertOnlyMLMHead(config)
|
| 821 |
+
|
| 822 |
+
self.init_weights()
|
| 823 |
+
|
| 824 |
+
def get_output_embeddings(self):
|
| 825 |
+
return self.cls.predictions.decoder
|
| 826 |
+
|
| 827 |
+
def set_output_embeddings(self, new_embeddings):
|
| 828 |
+
self.cls.predictions.decoder = new_embeddings
|
| 829 |
+
|
| 830 |
+
def forward(
|
| 831 |
+
self,
|
| 832 |
+
input_ids=None,
|
| 833 |
+
attention_mask=None,
|
| 834 |
+
position_ids=None,
|
| 835 |
+
head_mask=None,
|
| 836 |
+
inputs_embeds=None,
|
| 837 |
+
encoder_hidden_states=None,
|
| 838 |
+
encoder_attention_mask=None,
|
| 839 |
+
labels=None,
|
| 840 |
+
past_key_values=None,
|
| 841 |
+
use_cache=None,
|
| 842 |
+
output_attentions=None,
|
| 843 |
+
output_hidden_states=None,
|
| 844 |
+
return_dict=None,
|
| 845 |
+
return_logits=False,
|
| 846 |
+
is_decoder=True,
|
| 847 |
+
reduction='mean',
|
| 848 |
+
mode='multimodal',
|
| 849 |
+
):
|
| 850 |
+
r"""
|
| 851 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 852 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 853 |
+
the model is configured as a decoder.
|
| 854 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 855 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 856 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 857 |
+
- 1 for tokens that are **not masked**,
|
| 858 |
+
- 0 for tokens that are **masked**.
|
| 859 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 860 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 861 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
| 862 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
| 863 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 864 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 865 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 866 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 867 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 868 |
+
use_cache (:obj:`bool`, `optional`):
|
| 869 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 870 |
+
decoding (see :obj:`past_key_values`).
|
| 871 |
+
Returns:
|
| 872 |
+
Example::
|
| 873 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 874 |
+
>>> import torch
|
| 875 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 876 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 877 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 878 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 879 |
+
>>> outputs = model(**inputs)
|
| 880 |
+
>>> prediction_logits = outputs.logits
|
| 881 |
+
"""
|
| 882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 883 |
+
if labels is not None:
|
| 884 |
+
use_cache = False
|
| 885 |
+
|
| 886 |
+
outputs = self.bert(
|
| 887 |
+
input_ids,
|
| 888 |
+
attention_mask=attention_mask,
|
| 889 |
+
position_ids=position_ids,
|
| 890 |
+
head_mask=head_mask,
|
| 891 |
+
inputs_embeds=inputs_embeds,
|
| 892 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 893 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 894 |
+
past_key_values=past_key_values,
|
| 895 |
+
use_cache=use_cache,
|
| 896 |
+
output_attentions=output_attentions,
|
| 897 |
+
output_hidden_states=output_hidden_states,
|
| 898 |
+
return_dict=return_dict,
|
| 899 |
+
is_decoder=is_decoder,
|
| 900 |
+
mode=mode,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
sequence_output = outputs[0]
|
| 904 |
+
prediction_scores = self.cls(sequence_output)
|
| 905 |
+
|
| 906 |
+
if return_logits:
|
| 907 |
+
return prediction_scores[:, :-1, :].contiguous()
|
| 908 |
+
|
| 909 |
+
lm_loss = None
|
| 910 |
+
if labels is not None:
|
| 911 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 912 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 913 |
+
labels = labels[:, 1:].contiguous()
|
| 914 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
| 915 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 916 |
+
if reduction=='none':
|
| 917 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
| 918 |
+
|
| 919 |
+
if not return_dict:
|
| 920 |
+
output = (prediction_scores,) + outputs[2:]
|
| 921 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 922 |
+
|
| 923 |
+
return CausalLMOutputWithCrossAttentions(
|
| 924 |
+
loss=lm_loss,
|
| 925 |
+
logits=prediction_scores,
|
| 926 |
+
past_key_values=outputs.past_key_values,
|
| 927 |
+
hidden_states=outputs.hidden_states,
|
| 928 |
+
attentions=outputs.attentions,
|
| 929 |
+
cross_attentions=outputs.cross_attentions,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
| 933 |
+
input_shape = input_ids.shape
|
| 934 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 935 |
+
if attention_mask is None:
|
| 936 |
+
attention_mask = input_ids.new_ones(input_shape)
|
| 937 |
+
|
| 938 |
+
# cut decoder_input_ids if past is used
|
| 939 |
+
if past is not None:
|
| 940 |
+
input_ids = input_ids[:, -1:]
|
| 941 |
+
|
| 942 |
+
return {
|
| 943 |
+
"input_ids": input_ids,
|
| 944 |
+
"attention_mask": attention_mask,
|
| 945 |
+
"past_key_values": past,
|
| 946 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
| 947 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
| 948 |
+
"is_decoder": True,
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
def _reorder_cache(self, past, beam_idx):
|
| 952 |
+
reordered_past = ()
|
| 953 |
+
for layer_past in past:
|
| 954 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 955 |
+
return reordered_past
|
defake/models/nlvr_encoder.py
ADDED
|
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, device, dtype, nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import CrossEntropyLoss
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from transformers.activations import ACT2FN
|
| 15 |
+
from transformers.file_utils import (
|
| 16 |
+
ModelOutput,
|
| 17 |
+
)
|
| 18 |
+
from transformers.modeling_outputs import (
|
| 19 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 20 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 21 |
+
CausalLMOutputWithCrossAttentions,
|
| 22 |
+
MaskedLMOutput,
|
| 23 |
+
MultipleChoiceModelOutput,
|
| 24 |
+
NextSentencePredictorOutput,
|
| 25 |
+
QuestionAnsweringModelOutput,
|
| 26 |
+
SequenceClassifierOutput,
|
| 27 |
+
TokenClassifierOutput,
|
| 28 |
+
)
|
| 29 |
+
from transformers.modeling_utils import (
|
| 30 |
+
PreTrainedModel,
|
| 31 |
+
apply_chunking_to_forward,
|
| 32 |
+
find_pruneable_heads_and_indices,
|
| 33 |
+
prune_linear_layer,
|
| 34 |
+
)
|
| 35 |
+
from transformers.utils import logging
|
| 36 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BertEmbeddings(nn.Module):
|
| 43 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 48 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 49 |
+
|
| 50 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 51 |
+
# any TensorFlow checkpoint file
|
| 52 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 53 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 54 |
+
|
| 55 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 56 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 57 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 58 |
+
|
| 59 |
+
self.config = config
|
| 60 |
+
|
| 61 |
+
def forward(
|
| 62 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 63 |
+
):
|
| 64 |
+
if input_ids is not None:
|
| 65 |
+
input_shape = input_ids.size()
|
| 66 |
+
else:
|
| 67 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 68 |
+
|
| 69 |
+
seq_length = input_shape[1]
|
| 70 |
+
|
| 71 |
+
if position_ids is None:
|
| 72 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 73 |
+
|
| 74 |
+
if inputs_embeds is None:
|
| 75 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 76 |
+
|
| 77 |
+
embeddings = inputs_embeds
|
| 78 |
+
|
| 79 |
+
if self.position_embedding_type == "absolute":
|
| 80 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 81 |
+
embeddings += position_embeddings
|
| 82 |
+
embeddings = self.LayerNorm(embeddings)
|
| 83 |
+
embeddings = self.dropout(embeddings)
|
| 84 |
+
return embeddings
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BertSelfAttention(nn.Module):
|
| 88 |
+
def __init__(self, config, is_cross_attention):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.config = config
|
| 91 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 92 |
+
raise ValueError(
|
| 93 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 94 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.num_attention_heads = config.num_attention_heads
|
| 98 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 99 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 100 |
+
|
| 101 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 102 |
+
if is_cross_attention:
|
| 103 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 104 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 105 |
+
else:
|
| 106 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 107 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 108 |
+
|
| 109 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 110 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 111 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 112 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 113 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 114 |
+
self.save_attention = False
|
| 115 |
+
|
| 116 |
+
def save_attn_gradients(self, attn_gradients):
|
| 117 |
+
self.attn_gradients = attn_gradients
|
| 118 |
+
|
| 119 |
+
def get_attn_gradients(self):
|
| 120 |
+
return self.attn_gradients
|
| 121 |
+
|
| 122 |
+
def save_attention_map(self, attention_map):
|
| 123 |
+
self.attention_map = attention_map
|
| 124 |
+
|
| 125 |
+
def get_attention_map(self):
|
| 126 |
+
return self.attention_map
|
| 127 |
+
|
| 128 |
+
def transpose_for_scores(self, x):
|
| 129 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 130 |
+
x = x.view(*new_x_shape)
|
| 131 |
+
return x.permute(0, 2, 1, 3)
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
hidden_states,
|
| 136 |
+
attention_mask=None,
|
| 137 |
+
head_mask=None,
|
| 138 |
+
encoder_hidden_states=None,
|
| 139 |
+
encoder_attention_mask=None,
|
| 140 |
+
past_key_value=None,
|
| 141 |
+
output_attentions=False,
|
| 142 |
+
):
|
| 143 |
+
mixed_query_layer = self.query(hidden_states)
|
| 144 |
+
|
| 145 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 146 |
+
# and values come from an encoder; the attention mask needs to be
|
| 147 |
+
# such that the encoder's padding tokens are not attended to.
|
| 148 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 149 |
+
|
| 150 |
+
if is_cross_attention:
|
| 151 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 152 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 153 |
+
attention_mask = encoder_attention_mask
|
| 154 |
+
elif past_key_value is not None:
|
| 155 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 156 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 157 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 158 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 159 |
+
else:
|
| 160 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 161 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 162 |
+
|
| 163 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 164 |
+
|
| 165 |
+
past_key_value = (key_layer, value_layer)
|
| 166 |
+
|
| 167 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 168 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 169 |
+
|
| 170 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 171 |
+
seq_length = hidden_states.size()[1]
|
| 172 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 173 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 174 |
+
distance = position_ids_l - position_ids_r
|
| 175 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 176 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 177 |
+
|
| 178 |
+
if self.position_embedding_type == "relative_key":
|
| 179 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 180 |
+
attention_scores = attention_scores + relative_position_scores
|
| 181 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 182 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 183 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 184 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 185 |
+
|
| 186 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 187 |
+
if attention_mask is not None:
|
| 188 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 189 |
+
attention_scores = attention_scores + attention_mask
|
| 190 |
+
|
| 191 |
+
# Normalize the attention scores to probabilities.
|
| 192 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 193 |
+
|
| 194 |
+
if is_cross_attention and self.save_attention:
|
| 195 |
+
self.save_attention_map(attention_probs)
|
| 196 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 197 |
+
|
| 198 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 199 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 200 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 201 |
+
|
| 202 |
+
# Mask heads if we want to
|
| 203 |
+
if head_mask is not None:
|
| 204 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 205 |
+
|
| 206 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 207 |
+
|
| 208 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 209 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 210 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 211 |
+
|
| 212 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 213 |
+
|
| 214 |
+
outputs = outputs + (past_key_value,)
|
| 215 |
+
return outputs
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class BertSelfOutput(nn.Module):
|
| 219 |
+
def __init__(self, config, twin=False, merge=False):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 222 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 223 |
+
if twin:
|
| 224 |
+
self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
|
| 225 |
+
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
|
| 226 |
+
else:
|
| 227 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 228 |
+
if merge:
|
| 229 |
+
self.act = ACT2FN[config.hidden_act]
|
| 230 |
+
self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 231 |
+
self.merge = True
|
| 232 |
+
else:
|
| 233 |
+
self.merge = False
|
| 234 |
+
|
| 235 |
+
def forward(self, hidden_states, input_tensor):
|
| 236 |
+
if type(hidden_states) == list:
|
| 237 |
+
hidden_states0 = self.dense0(hidden_states[0])
|
| 238 |
+
hidden_states1 = self.dense1(hidden_states[1])
|
| 239 |
+
if self.merge:
|
| 240 |
+
#hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
|
| 241 |
+
hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
|
| 242 |
+
else:
|
| 243 |
+
hidden_states = (hidden_states0+hidden_states1)/2
|
| 244 |
+
else:
|
| 245 |
+
hidden_states = self.dense(hidden_states)
|
| 246 |
+
hidden_states = self.dropout(hidden_states)
|
| 247 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 248 |
+
return hidden_states
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class BertAttention(nn.Module):
|
| 252 |
+
def __init__(self, config, is_cross_attention=False, layer_num=-1):
|
| 253 |
+
super().__init__()
|
| 254 |
+
if is_cross_attention:
|
| 255 |
+
self.self0 = BertSelfAttention(config, is_cross_attention)
|
| 256 |
+
self.self1 = BertSelfAttention(config, is_cross_attention)
|
| 257 |
+
else:
|
| 258 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 259 |
+
self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
|
| 260 |
+
self.pruned_heads = set()
|
| 261 |
+
|
| 262 |
+
def prune_heads(self, heads):
|
| 263 |
+
if len(heads) == 0:
|
| 264 |
+
return
|
| 265 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 266 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Prune linear layers
|
| 270 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 271 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 272 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 273 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 274 |
+
|
| 275 |
+
# Update hyper params and store pruned heads
|
| 276 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 277 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 278 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 279 |
+
|
| 280 |
+
def forward(
|
| 281 |
+
self,
|
| 282 |
+
hidden_states,
|
| 283 |
+
attention_mask=None,
|
| 284 |
+
head_mask=None,
|
| 285 |
+
encoder_hidden_states=None,
|
| 286 |
+
encoder_attention_mask=None,
|
| 287 |
+
past_key_value=None,
|
| 288 |
+
output_attentions=False,
|
| 289 |
+
):
|
| 290 |
+
if type(encoder_hidden_states)==list:
|
| 291 |
+
self_outputs0 = self.self0(
|
| 292 |
+
hidden_states,
|
| 293 |
+
attention_mask,
|
| 294 |
+
head_mask,
|
| 295 |
+
encoder_hidden_states[0],
|
| 296 |
+
encoder_attention_mask[0],
|
| 297 |
+
past_key_value,
|
| 298 |
+
output_attentions,
|
| 299 |
+
)
|
| 300 |
+
self_outputs1 = self.self1(
|
| 301 |
+
hidden_states,
|
| 302 |
+
attention_mask,
|
| 303 |
+
head_mask,
|
| 304 |
+
encoder_hidden_states[1],
|
| 305 |
+
encoder_attention_mask[1],
|
| 306 |
+
past_key_value,
|
| 307 |
+
output_attentions,
|
| 308 |
+
)
|
| 309 |
+
attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
|
| 310 |
+
|
| 311 |
+
outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
|
| 312 |
+
else:
|
| 313 |
+
self_outputs = self.self(
|
| 314 |
+
hidden_states,
|
| 315 |
+
attention_mask,
|
| 316 |
+
head_mask,
|
| 317 |
+
encoder_hidden_states,
|
| 318 |
+
encoder_attention_mask,
|
| 319 |
+
past_key_value,
|
| 320 |
+
output_attentions,
|
| 321 |
+
)
|
| 322 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 323 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 324 |
+
return outputs
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class BertIntermediate(nn.Module):
|
| 328 |
+
def __init__(self, config):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 331 |
+
if isinstance(config.hidden_act, str):
|
| 332 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 333 |
+
else:
|
| 334 |
+
self.intermediate_act_fn = config.hidden_act
|
| 335 |
+
|
| 336 |
+
def forward(self, hidden_states):
|
| 337 |
+
hidden_states = self.dense(hidden_states)
|
| 338 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 339 |
+
return hidden_states
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class BertOutput(nn.Module):
|
| 343 |
+
def __init__(self, config):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 346 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 347 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 348 |
+
|
| 349 |
+
def forward(self, hidden_states, input_tensor):
|
| 350 |
+
hidden_states = self.dense(hidden_states)
|
| 351 |
+
hidden_states = self.dropout(hidden_states)
|
| 352 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 353 |
+
return hidden_states
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class BertLayer(nn.Module):
|
| 357 |
+
def __init__(self, config, layer_num):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.config = config
|
| 360 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 361 |
+
self.seq_len_dim = 1
|
| 362 |
+
self.attention = BertAttention(config)
|
| 363 |
+
self.layer_num = layer_num
|
| 364 |
+
if self.config.add_cross_attention:
|
| 365 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
|
| 366 |
+
self.intermediate = BertIntermediate(config)
|
| 367 |
+
self.output = BertOutput(config)
|
| 368 |
+
|
| 369 |
+
def forward(
|
| 370 |
+
self,
|
| 371 |
+
hidden_states,
|
| 372 |
+
attention_mask=None,
|
| 373 |
+
head_mask=None,
|
| 374 |
+
encoder_hidden_states=None,
|
| 375 |
+
encoder_attention_mask=None,
|
| 376 |
+
past_key_value=None,
|
| 377 |
+
output_attentions=False,
|
| 378 |
+
mode=None,
|
| 379 |
+
):
|
| 380 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 381 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 382 |
+
self_attention_outputs = self.attention(
|
| 383 |
+
hidden_states,
|
| 384 |
+
attention_mask,
|
| 385 |
+
head_mask,
|
| 386 |
+
output_attentions=output_attentions,
|
| 387 |
+
past_key_value=self_attn_past_key_value,
|
| 388 |
+
)
|
| 389 |
+
attention_output = self_attention_outputs[0]
|
| 390 |
+
|
| 391 |
+
outputs = self_attention_outputs[1:-1]
|
| 392 |
+
present_key_value = self_attention_outputs[-1]
|
| 393 |
+
|
| 394 |
+
if mode=='multimodal':
|
| 395 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
| 396 |
+
cross_attention_outputs = self.crossattention(
|
| 397 |
+
attention_output,
|
| 398 |
+
attention_mask,
|
| 399 |
+
head_mask,
|
| 400 |
+
encoder_hidden_states,
|
| 401 |
+
encoder_attention_mask,
|
| 402 |
+
output_attentions=output_attentions,
|
| 403 |
+
)
|
| 404 |
+
attention_output = cross_attention_outputs[0]
|
| 405 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 406 |
+
layer_output = apply_chunking_to_forward(
|
| 407 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 408 |
+
)
|
| 409 |
+
outputs = (layer_output,) + outputs
|
| 410 |
+
|
| 411 |
+
outputs = outputs + (present_key_value,)
|
| 412 |
+
|
| 413 |
+
return outputs
|
| 414 |
+
|
| 415 |
+
def feed_forward_chunk(self, attention_output):
|
| 416 |
+
intermediate_output = self.intermediate(attention_output)
|
| 417 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 418 |
+
return layer_output
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class BertEncoder(nn.Module):
|
| 422 |
+
def __init__(self, config):
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.config = config
|
| 425 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
| 426 |
+
self.gradient_checkpointing = False
|
| 427 |
+
|
| 428 |
+
def forward(
|
| 429 |
+
self,
|
| 430 |
+
hidden_states,
|
| 431 |
+
attention_mask=None,
|
| 432 |
+
head_mask=None,
|
| 433 |
+
encoder_hidden_states=None,
|
| 434 |
+
encoder_attention_mask=None,
|
| 435 |
+
past_key_values=None,
|
| 436 |
+
use_cache=None,
|
| 437 |
+
output_attentions=False,
|
| 438 |
+
output_hidden_states=False,
|
| 439 |
+
return_dict=True,
|
| 440 |
+
mode='multimodal',
|
| 441 |
+
):
|
| 442 |
+
all_hidden_states = () if output_hidden_states else None
|
| 443 |
+
all_self_attentions = () if output_attentions else None
|
| 444 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 445 |
+
|
| 446 |
+
next_decoder_cache = () if use_cache else None
|
| 447 |
+
|
| 448 |
+
for i in range(self.config.num_hidden_layers):
|
| 449 |
+
layer_module = self.layer[i]
|
| 450 |
+
if output_hidden_states:
|
| 451 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 452 |
+
|
| 453 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 454 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 455 |
+
|
| 456 |
+
if self.gradient_checkpointing and self.training:
|
| 457 |
+
|
| 458 |
+
if use_cache:
|
| 459 |
+
logger.warn(
|
| 460 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 461 |
+
)
|
| 462 |
+
use_cache = False
|
| 463 |
+
|
| 464 |
+
def create_custom_forward(module):
|
| 465 |
+
def custom_forward(*inputs):
|
| 466 |
+
return module(*inputs, past_key_value, output_attentions)
|
| 467 |
+
|
| 468 |
+
return custom_forward
|
| 469 |
+
|
| 470 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
create_custom_forward(layer_module),
|
| 472 |
+
hidden_states,
|
| 473 |
+
attention_mask,
|
| 474 |
+
layer_head_mask,
|
| 475 |
+
encoder_hidden_states,
|
| 476 |
+
encoder_attention_mask,
|
| 477 |
+
mode=mode,
|
| 478 |
+
)
|
| 479 |
+
else:
|
| 480 |
+
layer_outputs = layer_module(
|
| 481 |
+
hidden_states,
|
| 482 |
+
attention_mask,
|
| 483 |
+
layer_head_mask,
|
| 484 |
+
encoder_hidden_states,
|
| 485 |
+
encoder_attention_mask,
|
| 486 |
+
past_key_value,
|
| 487 |
+
output_attentions,
|
| 488 |
+
mode=mode,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
hidden_states = layer_outputs[0]
|
| 492 |
+
if use_cache:
|
| 493 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 494 |
+
if output_attentions:
|
| 495 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 496 |
+
|
| 497 |
+
if output_hidden_states:
|
| 498 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 499 |
+
|
| 500 |
+
if not return_dict:
|
| 501 |
+
return tuple(
|
| 502 |
+
v
|
| 503 |
+
for v in [
|
| 504 |
+
hidden_states,
|
| 505 |
+
next_decoder_cache,
|
| 506 |
+
all_hidden_states,
|
| 507 |
+
all_self_attentions,
|
| 508 |
+
all_cross_attentions,
|
| 509 |
+
]
|
| 510 |
+
if v is not None
|
| 511 |
+
)
|
| 512 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 513 |
+
last_hidden_state=hidden_states,
|
| 514 |
+
past_key_values=next_decoder_cache,
|
| 515 |
+
hidden_states=all_hidden_states,
|
| 516 |
+
attentions=all_self_attentions,
|
| 517 |
+
cross_attentions=all_cross_attentions,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class BertPooler(nn.Module):
|
| 522 |
+
def __init__(self, config):
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 525 |
+
self.activation = nn.Tanh()
|
| 526 |
+
|
| 527 |
+
def forward(self, hidden_states):
|
| 528 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 529 |
+
# to the first token.
|
| 530 |
+
first_token_tensor = hidden_states[:, 0]
|
| 531 |
+
pooled_output = self.dense(first_token_tensor)
|
| 532 |
+
pooled_output = self.activation(pooled_output)
|
| 533 |
+
return pooled_output
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 537 |
+
def __init__(self, config):
|
| 538 |
+
super().__init__()
|
| 539 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 540 |
+
if isinstance(config.hidden_act, str):
|
| 541 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 542 |
+
else:
|
| 543 |
+
self.transform_act_fn = config.hidden_act
|
| 544 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 545 |
+
|
| 546 |
+
def forward(self, hidden_states):
|
| 547 |
+
hidden_states = self.dense(hidden_states)
|
| 548 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 549 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 550 |
+
return hidden_states
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
class BertLMPredictionHead(nn.Module):
|
| 554 |
+
def __init__(self, config):
|
| 555 |
+
super().__init__()
|
| 556 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 557 |
+
|
| 558 |
+
# The output weights are the same as the input embeddings, but there is
|
| 559 |
+
# an output-only bias for each token.
|
| 560 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 561 |
+
|
| 562 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 563 |
+
|
| 564 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 565 |
+
self.decoder.bias = self.bias
|
| 566 |
+
|
| 567 |
+
def forward(self, hidden_states):
|
| 568 |
+
hidden_states = self.transform(hidden_states)
|
| 569 |
+
hidden_states = self.decoder(hidden_states)
|
| 570 |
+
return hidden_states
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class BertOnlyMLMHead(nn.Module):
|
| 574 |
+
def __init__(self, config):
|
| 575 |
+
super().__init__()
|
| 576 |
+
self.predictions = BertLMPredictionHead(config)
|
| 577 |
+
|
| 578 |
+
def forward(self, sequence_output):
|
| 579 |
+
prediction_scores = self.predictions(sequence_output)
|
| 580 |
+
return prediction_scores
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 584 |
+
"""
|
| 585 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 586 |
+
models.
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
config_class = BertConfig
|
| 590 |
+
base_model_prefix = "bert"
|
| 591 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 592 |
+
|
| 593 |
+
def _init_weights(self, module):
|
| 594 |
+
""" Initialize the weights """
|
| 595 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 596 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 597 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 598 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 599 |
+
elif isinstance(module, nn.LayerNorm):
|
| 600 |
+
module.bias.data.zero_()
|
| 601 |
+
module.weight.data.fill_(1.0)
|
| 602 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 603 |
+
module.bias.data.zero_()
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class BertModel(BertPreTrainedModel):
|
| 607 |
+
"""
|
| 608 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 609 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 610 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 611 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 612 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 613 |
+
input to the forward pass.
|
| 614 |
+
"""
|
| 615 |
+
|
| 616 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 617 |
+
super().__init__(config)
|
| 618 |
+
self.config = config
|
| 619 |
+
|
| 620 |
+
self.embeddings = BertEmbeddings(config)
|
| 621 |
+
|
| 622 |
+
self.encoder = BertEncoder(config)
|
| 623 |
+
|
| 624 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 625 |
+
|
| 626 |
+
self.init_weights()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def get_input_embeddings(self):
|
| 630 |
+
return self.embeddings.word_embeddings
|
| 631 |
+
|
| 632 |
+
def set_input_embeddings(self, value):
|
| 633 |
+
self.embeddings.word_embeddings = value
|
| 634 |
+
|
| 635 |
+
def _prune_heads(self, heads_to_prune):
|
| 636 |
+
"""
|
| 637 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 638 |
+
class PreTrainedModel
|
| 639 |
+
"""
|
| 640 |
+
for layer, heads in heads_to_prune.items():
|
| 641 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
| 645 |
+
"""
|
| 646 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 647 |
+
|
| 648 |
+
Arguments:
|
| 649 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 650 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 651 |
+
input_shape (:obj:`Tuple[int]`):
|
| 652 |
+
The shape of the input to the model.
|
| 653 |
+
device: (:obj:`torch.device`):
|
| 654 |
+
The device of the input to the model.
|
| 655 |
+
|
| 656 |
+
Returns:
|
| 657 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 658 |
+
"""
|
| 659 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 660 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 661 |
+
if attention_mask.dim() == 3:
|
| 662 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 663 |
+
elif attention_mask.dim() == 2:
|
| 664 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 665 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 666 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 667 |
+
if is_decoder:
|
| 668 |
+
batch_size, seq_length = input_shape
|
| 669 |
+
|
| 670 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 671 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 672 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
| 673 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 674 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 675 |
+
|
| 676 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 677 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 678 |
+
causal_mask = torch.cat(
|
| 679 |
+
[
|
| 680 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
| 681 |
+
causal_mask,
|
| 682 |
+
],
|
| 683 |
+
axis=-1,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 687 |
+
else:
|
| 688 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 689 |
+
else:
|
| 690 |
+
raise ValueError(
|
| 691 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 692 |
+
input_shape, attention_mask.shape
|
| 693 |
+
)
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 697 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 698 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 699 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 700 |
+
# effectively the same as removing these entirely.
|
| 701 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 702 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 703 |
+
return extended_attention_mask
|
| 704 |
+
|
| 705 |
+
def forward(
|
| 706 |
+
self,
|
| 707 |
+
input_ids=None,
|
| 708 |
+
attention_mask=None,
|
| 709 |
+
position_ids=None,
|
| 710 |
+
head_mask=None,
|
| 711 |
+
inputs_embeds=None,
|
| 712 |
+
encoder_embeds=None,
|
| 713 |
+
encoder_hidden_states=None,
|
| 714 |
+
encoder_attention_mask=None,
|
| 715 |
+
past_key_values=None,
|
| 716 |
+
use_cache=None,
|
| 717 |
+
output_attentions=None,
|
| 718 |
+
output_hidden_states=None,
|
| 719 |
+
return_dict=None,
|
| 720 |
+
is_decoder=False,
|
| 721 |
+
mode='multimodal',
|
| 722 |
+
):
|
| 723 |
+
r"""
|
| 724 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 725 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 726 |
+
the model is configured as a decoder.
|
| 727 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 728 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 729 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 730 |
+
- 1 for tokens that are **not masked**,
|
| 731 |
+
- 0 for tokens that are **masked**.
|
| 732 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 733 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 734 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 735 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 736 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 737 |
+
use_cache (:obj:`bool`, `optional`):
|
| 738 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 739 |
+
decoding (see :obj:`past_key_values`).
|
| 740 |
+
"""
|
| 741 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 742 |
+
output_hidden_states = (
|
| 743 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 744 |
+
)
|
| 745 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 746 |
+
|
| 747 |
+
if is_decoder:
|
| 748 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 749 |
+
else:
|
| 750 |
+
use_cache = False
|
| 751 |
+
|
| 752 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 753 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 754 |
+
elif input_ids is not None:
|
| 755 |
+
input_shape = input_ids.size()
|
| 756 |
+
batch_size, seq_length = input_shape
|
| 757 |
+
device = input_ids.device
|
| 758 |
+
elif inputs_embeds is not None:
|
| 759 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 760 |
+
batch_size, seq_length = input_shape
|
| 761 |
+
device = inputs_embeds.device
|
| 762 |
+
elif encoder_embeds is not None:
|
| 763 |
+
input_shape = encoder_embeds.size()[:-1]
|
| 764 |
+
batch_size, seq_length = input_shape
|
| 765 |
+
device = encoder_embeds.device
|
| 766 |
+
else:
|
| 767 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
| 768 |
+
|
| 769 |
+
# past_key_values_length
|
| 770 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 771 |
+
|
| 772 |
+
if attention_mask is None:
|
| 773 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 774 |
+
|
| 775 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 776 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 777 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
| 778 |
+
device, is_decoder)
|
| 779 |
+
|
| 780 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 781 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 782 |
+
if encoder_hidden_states is not None:
|
| 783 |
+
if type(encoder_hidden_states) == list:
|
| 784 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 785 |
+
else:
|
| 786 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 787 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 788 |
+
|
| 789 |
+
if type(encoder_attention_mask) == list:
|
| 790 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 791 |
+
elif encoder_attention_mask is None:
|
| 792 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 793 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 794 |
+
else:
|
| 795 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 796 |
+
else:
|
| 797 |
+
encoder_extended_attention_mask = None
|
| 798 |
+
|
| 799 |
+
# Prepare head mask if needed
|
| 800 |
+
# 1.0 in head_mask indicate we keep the head
|
| 801 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 802 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 803 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 804 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 805 |
+
|
| 806 |
+
if encoder_embeds is None:
|
| 807 |
+
embedding_output = self.embeddings(
|
| 808 |
+
input_ids=input_ids,
|
| 809 |
+
position_ids=position_ids,
|
| 810 |
+
inputs_embeds=inputs_embeds,
|
| 811 |
+
past_key_values_length=past_key_values_length,
|
| 812 |
+
)
|
| 813 |
+
else:
|
| 814 |
+
embedding_output = encoder_embeds
|
| 815 |
+
|
| 816 |
+
encoder_outputs = self.encoder(
|
| 817 |
+
embedding_output,
|
| 818 |
+
attention_mask=extended_attention_mask,
|
| 819 |
+
head_mask=head_mask,
|
| 820 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 821 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 822 |
+
past_key_values=past_key_values,
|
| 823 |
+
use_cache=use_cache,
|
| 824 |
+
output_attentions=output_attentions,
|
| 825 |
+
output_hidden_states=output_hidden_states,
|
| 826 |
+
return_dict=return_dict,
|
| 827 |
+
mode=mode,
|
| 828 |
+
)
|
| 829 |
+
sequence_output = encoder_outputs[0]
|
| 830 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 831 |
+
|
| 832 |
+
if not return_dict:
|
| 833 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 834 |
+
|
| 835 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 836 |
+
last_hidden_state=sequence_output,
|
| 837 |
+
pooler_output=pooled_output,
|
| 838 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 839 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 840 |
+
attentions=encoder_outputs.attentions,
|
| 841 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 842 |
+
)
|
| 843 |
+
|
defake/models/vit.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
* Based on timm code base
|
| 8 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from functools import partial
|
| 15 |
+
|
| 16 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
| 17 |
+
from timm.models.registry import register_model
|
| 18 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 19 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
| 20 |
+
|
| 21 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 22 |
+
|
| 23 |
+
class Mlp(nn.Module):
|
| 24 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Attention(nn.Module):
|
| 45 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 50 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
self.attn_gradients = None
|
| 56 |
+
self.attention_map = None
|
| 57 |
+
|
| 58 |
+
def save_attn_gradients(self, attn_gradients):
|
| 59 |
+
self.attn_gradients = attn_gradients
|
| 60 |
+
|
| 61 |
+
def get_attn_gradients(self):
|
| 62 |
+
return self.attn_gradients
|
| 63 |
+
|
| 64 |
+
def save_attention_map(self, attention_map):
|
| 65 |
+
self.attention_map = attention_map
|
| 66 |
+
|
| 67 |
+
def get_attention_map(self):
|
| 68 |
+
return self.attention_map
|
| 69 |
+
|
| 70 |
+
def forward(self, x, register_hook=False):
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 73 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 74 |
+
|
| 75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 76 |
+
attn = attn.softmax(dim=-1)
|
| 77 |
+
attn = self.attn_drop(attn)
|
| 78 |
+
|
| 79 |
+
if register_hook:
|
| 80 |
+
self.save_attention_map(attn)
|
| 81 |
+
attn.register_hook(self.save_attn_gradients)
|
| 82 |
+
|
| 83 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 84 |
+
x = self.proj(x)
|
| 85 |
+
x = self.proj_drop(x)
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Block(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 92 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.norm1 = norm_layer(dim)
|
| 95 |
+
self.attn = Attention(
|
| 96 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 97 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 98 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 99 |
+
self.norm2 = norm_layer(dim)
|
| 100 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 101 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 102 |
+
|
| 103 |
+
if use_grad_checkpointing:
|
| 104 |
+
self.attn = checkpoint_wrapper(self.attn)
|
| 105 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, register_hook=False):
|
| 108 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
| 109 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class VisionTransformer(nn.Module):
|
| 114 |
+
""" Vision Transformer
|
| 115 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
| 116 |
+
https://arxiv.org/abs/2010.11929
|
| 117 |
+
"""
|
| 118 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 119 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
| 120 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
| 121 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
| 122 |
+
"""
|
| 123 |
+
Args:
|
| 124 |
+
img_size (int, tuple): input image size
|
| 125 |
+
patch_size (int, tuple): patch size
|
| 126 |
+
in_chans (int): number of input channels
|
| 127 |
+
num_classes (int): number of classes for classification head
|
| 128 |
+
embed_dim (int): embedding dimension
|
| 129 |
+
depth (int): depth of transformer
|
| 130 |
+
num_heads (int): number of attention heads
|
| 131 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 132 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 133 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 134 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 135 |
+
drop_rate (float): dropout rate
|
| 136 |
+
attn_drop_rate (float): attention dropout rate
|
| 137 |
+
drop_path_rate (float): stochastic depth rate
|
| 138 |
+
norm_layer: (nn.Module): normalization layer
|
| 139 |
+
"""
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 142 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 143 |
+
|
| 144 |
+
self.patch_embed = PatchEmbed(
|
| 145 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 146 |
+
|
| 147 |
+
num_patches = self.patch_embed.num_patches
|
| 148 |
+
|
| 149 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 150 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 151 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 152 |
+
|
| 153 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 154 |
+
self.blocks = nn.ModuleList([
|
| 155 |
+
Block(
|
| 156 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 157 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 158 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
| 159 |
+
)
|
| 160 |
+
for i in range(depth)])
|
| 161 |
+
self.norm = norm_layer(embed_dim)
|
| 162 |
+
|
| 163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 164 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 165 |
+
self.apply(self._init_weights)
|
| 166 |
+
|
| 167 |
+
def _init_weights(self, m):
|
| 168 |
+
if isinstance(m, nn.Linear):
|
| 169 |
+
trunc_normal_(m.weight, std=.02)
|
| 170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 171 |
+
nn.init.constant_(m.bias, 0)
|
| 172 |
+
elif isinstance(m, nn.LayerNorm):
|
| 173 |
+
nn.init.constant_(m.bias, 0)
|
| 174 |
+
nn.init.constant_(m.weight, 1.0)
|
| 175 |
+
|
| 176 |
+
@torch.jit.ignore
|
| 177 |
+
def no_weight_decay(self):
|
| 178 |
+
return {'pos_embed', 'cls_token'}
|
| 179 |
+
|
| 180 |
+
def forward(self, x, register_blk=-1):
|
| 181 |
+
B = x.shape[0]
|
| 182 |
+
x = self.patch_embed(x)
|
| 183 |
+
|
| 184 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 185 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 186 |
+
|
| 187 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
| 188 |
+
x = self.pos_drop(x)
|
| 189 |
+
|
| 190 |
+
for i,blk in enumerate(self.blocks):
|
| 191 |
+
x = blk(x, register_blk==i)
|
| 192 |
+
x = self.norm(x)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
@torch.jit.ignore()
|
| 197 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
| 198 |
+
_load_weights(self, checkpoint_path, prefix)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
| 203 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
| 204 |
+
"""
|
| 205 |
+
import numpy as np
|
| 206 |
+
|
| 207 |
+
def _n2p(w, t=True):
|
| 208 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
| 209 |
+
w = w.flatten()
|
| 210 |
+
if t:
|
| 211 |
+
if w.ndim == 4:
|
| 212 |
+
w = w.transpose([3, 2, 0, 1])
|
| 213 |
+
elif w.ndim == 3:
|
| 214 |
+
w = w.transpose([2, 0, 1])
|
| 215 |
+
elif w.ndim == 2:
|
| 216 |
+
w = w.transpose([1, 0])
|
| 217 |
+
return torch.from_numpy(w)
|
| 218 |
+
|
| 219 |
+
w = np.load(checkpoint_path)
|
| 220 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
| 221 |
+
prefix = 'opt/target/'
|
| 222 |
+
|
| 223 |
+
if hasattr(model.patch_embed, 'backbone'):
|
| 224 |
+
# hybrid
|
| 225 |
+
backbone = model.patch_embed.backbone
|
| 226 |
+
stem_only = not hasattr(backbone, 'stem')
|
| 227 |
+
stem = backbone if stem_only else backbone.stem
|
| 228 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
| 229 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
| 230 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
| 231 |
+
if not stem_only:
|
| 232 |
+
for i, stage in enumerate(backbone.stages):
|
| 233 |
+
for j, block in enumerate(stage.blocks):
|
| 234 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
| 235 |
+
for r in range(3):
|
| 236 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
| 237 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
| 238 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
| 239 |
+
if block.downsample is not None:
|
| 240 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
| 241 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
| 242 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
| 243 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
| 244 |
+
else:
|
| 245 |
+
embed_conv_w = adapt_input_conv(
|
| 246 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
| 247 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
| 248 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
| 249 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
| 250 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
| 251 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
| 252 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
| 253 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
| 254 |
+
model.pos_embed.copy_(pos_embed_w)
|
| 255 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
| 256 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
| 257 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
| 258 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
| 259 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
| 260 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
| 261 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
| 262 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
| 263 |
+
for i, block in enumerate(model.blocks.children()):
|
| 264 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
| 265 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
| 266 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
| 267 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
| 268 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
| 269 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
| 270 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
| 271 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
| 272 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
| 273 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
| 274 |
+
for r in range(2):
|
| 275 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
| 276 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
| 277 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
| 278 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
| 282 |
+
# interpolate position embedding
|
| 283 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 284 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
| 285 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
| 286 |
+
# height (== width) for the checkpoint position embedding
|
| 287 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 288 |
+
# height (== width) for the new position embedding
|
| 289 |
+
new_size = int(num_patches ** 0.5)
|
| 290 |
+
|
| 291 |
+
if orig_size!=new_size:
|
| 292 |
+
# class_token and dist_token are kept unchanged
|
| 293 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 294 |
+
# only the position tokens are interpolated
|
| 295 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 296 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 297 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 298 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 299 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 300 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 301 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
| 302 |
+
|
| 303 |
+
return new_pos_embed
|
| 304 |
+
else:
|
| 305 |
+
return pos_embed_checkpoint
|
defake/test.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from time import process_time_ns
|
| 2 |
+
import torch
|
| 3 |
+
import clip
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
from sklearn.metrics import confusion_matrix
|
| 12 |
+
import itertools
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from torch.utils.data import random_split
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
import sys
|
| 20 |
+
import argparse
|
| 21 |
+
import time
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from sklearn import metrics
|
| 24 |
+
from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_curve
|
| 25 |
+
from blipmodels import blip_decoder
|
| 26 |
+
|
| 27 |
+
class NeuralNet(nn.Module):
|
| 28 |
+
def __init__(self, input_size, hidden_size_list, num_classes):
|
| 29 |
+
super(NeuralNet, self).__init__()
|
| 30 |
+
self.dropout2 = nn.Dropout(0.5)
|
| 31 |
+
self.fc1 = nn.Linear(input_size, hidden_size_list[0])
|
| 32 |
+
self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1])
|
| 33 |
+
self.fc3 = nn.Linear(hidden_size_list[1], num_classes)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
out = self.fc1(x)
|
| 37 |
+
out = F.relu(out)
|
| 38 |
+
out = self.dropout2(out)
|
| 39 |
+
out = self.fc2(out)
|
| 40 |
+
out = F.relu(out)
|
| 41 |
+
out = self.fc3(out)
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
def preprocess_image(img_path, image_size=224):
|
| 45 |
+
img = Image.open(img_path)
|
| 46 |
+
img = img.resize((image_size, image_size))
|
| 47 |
+
return preprocess(img)
|
| 48 |
+
|
| 49 |
+
parser = argparse.ArgumentParser(description='Finetune the classifier to wash the backdoor')
|
| 50 |
+
parser.add_argument('--image_path',default='CLIP.png',type=str)
|
| 51 |
+
parser.add_argument('--gpu', default='0', type=str)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 55 |
+
model2, preprocess = clip.load("ViT-B/32")
|
| 56 |
+
|
| 57 |
+
image_size = 224
|
| 58 |
+
|
| 59 |
+
blip_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
|
| 60 |
+
|
| 61 |
+
blip = blip_decoder(pretrained=blip_url, image_size=image_size, vit='base')
|
| 62 |
+
blip.eval()
|
| 63 |
+
blip = blip.to(device)
|
| 64 |
+
|
| 65 |
+
img = Image.open(args.image_path).convert('RGB')
|
| 66 |
+
tform = transforms.Compose(
|
| 67 |
+
[
|
| 68 |
+
transforms.Resize(224),
|
| 69 |
+
transforms.CenterCrop(224),
|
| 70 |
+
transforms.ToTensor(),
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
img = tform(img)
|
| 74 |
+
img = img.unsqueeze(0).to("cuda")
|
| 75 |
+
|
| 76 |
+
caption = blip.generate(img, sample=False, num_beams=3, max_length=60, min_length=5)
|
| 77 |
+
text = clip.tokenize(list(caption)).to(device)
|
| 78 |
+
|
| 79 |
+
model = torch.load("finetune_clip.pt").to(device)
|
| 80 |
+
|
| 81 |
+
linear = NeuralNet(1024,[512,256],2).to(device)
|
| 82 |
+
linear = torch.load('clip_linear.pt')
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
image = preprocess_image(args.image_path,image_size).unsqueeze(0).to(device)
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
image_features = model.encode_image(image)
|
| 89 |
+
text_features = model.encode_text(text)
|
| 90 |
+
|
| 91 |
+
emb = torch.cat((image_features, text_features),1)
|
| 92 |
+
output = linear(emb.float())
|
| 93 |
+
predict = output.argmax(1)
|
| 94 |
+
predict = predict.cpu().numpy()
|
| 95 |
+
print(predict)
|
defake/test_api.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import clip
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
from blipmodels import blip_decoder
|
| 10 |
+
|
| 11 |
+
class NeuralNet(nn.Module):
|
| 12 |
+
def __init__(self, input_size, hidden_size_list, num_classes):
|
| 13 |
+
super(NeuralNet, self).__init__()
|
| 14 |
+
self.dropout2 = nn.Dropout(0.5)
|
| 15 |
+
self.fc1 = nn.Linear(input_size, hidden_size_list[0])
|
| 16 |
+
self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1])
|
| 17 |
+
self.fc3 = nn.Linear(hidden_size_list[1], num_classes)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
out = self.fc1(x)
|
| 21 |
+
out = F.relu(out)
|
| 22 |
+
out = self.dropout2(out)
|
| 23 |
+
out = self.fc2(out)
|
| 24 |
+
out = F.relu(out)
|
| 25 |
+
out = self.fc3(out)
|
| 26 |
+
return out
|
| 27 |
+
|
| 28 |
+
def load_models(device=None):
|
| 29 |
+
"""
|
| 30 |
+
加载 CLIP、BLIP 和线性分类器,只加载一次。
|
| 31 |
+
"""
|
| 32 |
+
import os
|
| 33 |
+
print("Current working folder:", os.getcwd()) # ← 加这行
|
| 34 |
+
|
| 35 |
+
if device is None:
|
| 36 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
+
|
| 38 |
+
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
| 39 |
+
|
| 40 |
+
image_size = 224
|
| 41 |
+
blip_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
|
| 42 |
+
blip = blip_decoder(pretrained=blip_url, image_size=image_size, vit='base')
|
| 43 |
+
blip.eval()
|
| 44 |
+
blip = blip.to(device)
|
| 45 |
+
|
| 46 |
+
# 加载 finetuned CLIP
|
| 47 |
+
clip_finetuned = torch.load("finetune_clip.pt", map_location=device).to(device)
|
| 48 |
+
|
| 49 |
+
# 加载线性分类器
|
| 50 |
+
linear = NeuralNet(1024, [512, 256], 2).to(device)
|
| 51 |
+
linear = torch.load("clip_linear.pt", map_location=device).to(device)
|
| 52 |
+
linear.eval()
|
| 53 |
+
|
| 54 |
+
return {
|
| 55 |
+
"device": device,
|
| 56 |
+
"clip_model": clip_model,
|
| 57 |
+
"clip_preprocess": clip_preprocess,
|
| 58 |
+
"blip": blip,
|
| 59 |
+
"linear": linear,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def predict_image(image_path, models=None):
|
| 63 |
+
"""
|
| 64 |
+
传入图片路径,返回预测结果和概率。
|
| 65 |
+
"""
|
| 66 |
+
if models is None:
|
| 67 |
+
models = load_models()
|
| 68 |
+
|
| 69 |
+
device = models["device"]
|
| 70 |
+
clip_model = models["clip_model"]
|
| 71 |
+
clip_preprocess = models["clip_preprocess"]
|
| 72 |
+
blip = models["blip"]
|
| 73 |
+
linear = models["linear"]
|
| 74 |
+
|
| 75 |
+
# 1. 用 BLIP 生成 caption
|
| 76 |
+
img = Image.open(image_path).convert('RGB')
|
| 77 |
+
tform = transforms.Compose([
|
| 78 |
+
transforms.Resize(224),
|
| 79 |
+
transforms.CenterCrop(224),
|
| 80 |
+
transforms.ToTensor(),
|
| 81 |
+
])
|
| 82 |
+
img_tensor = tform(img).unsqueeze(0).to(device)
|
| 83 |
+
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
caption = blip.generate(img_tensor, sample=False, num_beams=3, max_length=60, min_length=5)
|
| 86 |
+
text = clip.tokenize(list(caption)).to(device)
|
| 87 |
+
|
| 88 |
+
# 2. 用 CLIP preprocess 处理图像
|
| 89 |
+
image = clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device)
|
| 90 |
+
|
| 91 |
+
# 3. 提取特征并分类
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
image_features = clip_model.encode_image(image)
|
| 94 |
+
text_features = clip_model.encode_text(text)
|
| 95 |
+
emb = torch.cat((image_features, text_features), 1)
|
| 96 |
+
output = linear(emb.float())
|
| 97 |
+
probs = torch.softmax(output, dim=1)
|
| 98 |
+
pred = probs.argmax(1).item()
|
| 99 |
+
probs_list = probs[0].cpu().numpy().tolist()
|
| 100 |
+
|
| 101 |
+
return pred, probs_list
|
| 102 |
+
|
| 103 |
+
def main():
|
| 104 |
+
parser = argparse.ArgumentParser(description='De-Fake single image test')
|
| 105 |
+
parser.add_argument('--image_path', default='CLIP.png', type=str)
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
|
| 108 |
+
models = load_models()
|
| 109 |
+
pred, probs = predict_image(args.image_path, models)
|
| 110 |
+
print("Prediction:", pred)
|
| 111 |
+
print("Probabilities:", probs)
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
defake/train.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import argparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from time import process_time_ns
|
| 12 |
+
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.optim as optim
|
| 17 |
+
import torchvision
|
| 18 |
+
import torchvision.transforms as transforms
|
| 19 |
+
from torchvision import datasets
|
| 20 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 21 |
+
from natsort import natsorted
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import clip
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NeuralNet(nn.Module):
|
| 27 |
+
def __init__(self, input_size, hidden_size_list, num_classes):
|
| 28 |
+
super(NeuralNet, self).__init__()
|
| 29 |
+
self.dropout2 = nn.Dropout(0.5)
|
| 30 |
+
self.fc1 = nn.Linear(input_size, hidden_size_list[0])
|
| 31 |
+
self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1])
|
| 32 |
+
self.fc3 = nn.Linear(hidden_size_list[1], num_classes)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
out = self.fc1(x)
|
| 36 |
+
out = F.relu(out)
|
| 37 |
+
out = self.dropout2(out)
|
| 38 |
+
out = self.fc2(out)
|
| 39 |
+
out = F.relu(out)
|
| 40 |
+
out = self.fc3(out)
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
# 定义数据集类
|
| 44 |
+
class real(Dataset):
|
| 45 |
+
def __init__(self, root_dir1, prompts_file, transform=None):
|
| 46 |
+
self.root_dir1 = root_dir1
|
| 47 |
+
self.transform = transform
|
| 48 |
+
self.image_filenames1 = natsorted(os.listdir(root_dir1))
|
| 49 |
+
self.prompts = self.load_prompts(prompts_file)
|
| 50 |
+
|
| 51 |
+
def load_prompts(self, prompts_file):
|
| 52 |
+
with open(prompts_file, 'r') as file:
|
| 53 |
+
prompts = file.readlines()
|
| 54 |
+
prompts = [prompt.strip() for prompt in prompts] # Remove any extra whitespace
|
| 55 |
+
return prompts
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.image_filenames1)
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx):
|
| 61 |
+
class_name1 = self.image_filenames1[idx]
|
| 62 |
+
image_path1 = os.path.join(self.root_dir1, class_name1)
|
| 63 |
+
image1 = Image.open(image_path1).convert("RGB")
|
| 64 |
+
|
| 65 |
+
if self.transform:
|
| 66 |
+
image1 = self.transform(image1)
|
| 67 |
+
|
| 68 |
+
label = 0
|
| 69 |
+
prompt = self.prompts[idx] if idx < len(self.prompts) else ""
|
| 70 |
+
|
| 71 |
+
return image1, prompt, label
|
| 72 |
+
|
| 73 |
+
class fake(Dataset):
|
| 74 |
+
def __init__(self, root_dir1, prompts_file, transform=None):
|
| 75 |
+
self.root_dir1 = root_dir1
|
| 76 |
+
self.transform = transform
|
| 77 |
+
self.image_filenames1 = natsorted(os.listdir(root_dir1))
|
| 78 |
+
self.prompts = self.load_prompts(prompts_file)
|
| 79 |
+
|
| 80 |
+
def load_prompts(self, prompts_file):
|
| 81 |
+
with open(prompts_file, 'r') as file:
|
| 82 |
+
prompts = file.readlines()
|
| 83 |
+
prompts = [prompt.strip() for prompt in prompts]
|
| 84 |
+
return prompts
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return len(self.image_filenames1)
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx):
|
| 90 |
+
class_name1 = self.image_filenames1[idx]
|
| 91 |
+
image_path1 = os.path.join(self.root_dir1, class_name1)
|
| 92 |
+
image1 = Image.open(image_path1).convert("RGB")
|
| 93 |
+
|
| 94 |
+
if self.transform:
|
| 95 |
+
image1 = self.transform(image1)
|
| 96 |
+
|
| 97 |
+
label = 1
|
| 98 |
+
prompt = self.prompts[idx] if idx < len(self.prompts) else ""
|
| 99 |
+
|
| 100 |
+
return image1, prompt, label
|
| 101 |
+
|
| 102 |
+
transform = transforms.Compose([
|
| 103 |
+
transforms.Resize((224, 224)),
|
| 104 |
+
transforms.ToTensor(),
|
| 105 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
# realdata = real(root_dir1="real-data",prompts_file="prompt.txt", transform=transform)
|
| 109 |
+
# fakedata = fake(root_dir1="fake-data",prompts_file="prompt.txt", transform=transform)
|
| 110 |
+
|
| 111 |
+
realdata = real(root_dir1="/home/sha/stable-diffusion/real-data",prompts_file="/home/sha/stable-diffusion/prompt_for_coco.txt", transform=transform)
|
| 112 |
+
fakedata = fake(root_dir1="/home/sha/stable-diffusion/output4/samples",prompts_file="/home/sha/stable-diffusion/prompt_for_coco.txt", transform=transform)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
dataset = torch.utils.data.ConcatDataset([realdata,fakedata])
|
| 116 |
+
|
| 117 |
+
newsize = 800
|
| 118 |
+
|
| 119 |
+
size = len(dataset)
|
| 120 |
+
train_dataset,test_dataset = random_split(dataset=dataset,lengths=[int(newsize),int(size-newsize)],generator=torch.Generator().manual_seed(0))
|
| 121 |
+
|
| 122 |
+
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
| 123 |
+
|
| 124 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
| 125 |
+
|
| 126 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 127 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
| 128 |
+
linear = NeuralNet(1024,[512,256],2).to(device)
|
| 129 |
+
model.to(device)
|
| 130 |
+
optimizer = torch.optim.Adam(list(linear.parameters())+list(model.parameters()), lr=3e-4)
|
| 131 |
+
criterion = nn.CrossEntropyLoss()
|
| 132 |
+
linear.to(device)
|
| 133 |
+
|
| 134 |
+
for epoch in range(50):
|
| 135 |
+
model.train()
|
| 136 |
+
linear.train()
|
| 137 |
+
for batch_idx, (data1, prompt, target) in enumerate(train_loader):
|
| 138 |
+
data1, target = data1.to(device), target.to(device)
|
| 139 |
+
text = clip.tokenize(list(prompt)).to(device)
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
imga_embedding = model.encode_image(data1)
|
| 142 |
+
text_emb = model.encode_text(text)
|
| 143 |
+
emb = torch.cat((imga_embedding,text_emb),1)
|
| 144 |
+
output = linear(emb.float())
|
| 145 |
+
optimizer.zero_grad()
|
| 146 |
+
loss = criterion(output, target)
|
| 147 |
+
loss.backward()
|
| 148 |
+
optimizer.step()
|
| 149 |
+
|
| 150 |
+
if batch_idx % 10 == 0:
|
| 151 |
+
print('Epoch {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
| 152 |
+
epoch, batch_idx * len(data1), len(train_loader.dataset),
|
| 153 |
+
100. * batch_idx / len(train_loader), loss.item()))
|
| 154 |
+
|
| 155 |
+
test_loss = 0
|
| 156 |
+
correct = 0
|
| 157 |
+
model.eval()
|
| 158 |
+
linear.eval()
|
| 159 |
+
all_preds = []
|
| 160 |
+
all_targets = []
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
for images, prompts, targets in test_loader:
|
| 163 |
+
images, targets = images.to(device), targets.to(device)
|
| 164 |
+
text_tokens = clip.tokenize(prompts).to(device)
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
image_embeddings = model.encode_image(images)
|
| 167 |
+
text_embeddings = model.encode_text(text_tokens)
|
| 168 |
+
|
| 169 |
+
embeddings = torch.cat((image_embeddings, text_embeddings), 1)
|
| 170 |
+
outputs = linear(embeddings.float())
|
| 171 |
+
_, preds = torch.max(outputs, 1)
|
| 172 |
+
|
| 173 |
+
all_preds.extend(preds.cpu().numpy())
|
| 174 |
+
all_targets.extend(targets.cpu().numpy())
|
| 175 |
+
|
| 176 |
+
all_preds = np.array(all_preds)
|
| 177 |
+
all_targets = np.array(all_targets)
|
| 178 |
+
|
| 179 |
+
accuracy = accuracy_score(all_targets, all_preds)
|
| 180 |
+
recall = recall_score(all_targets, all_preds, average='weighted')
|
| 181 |
+
precision = precision_score(all_targets, all_preds, average='weighted')
|
| 182 |
+
f1 = f1_score(all_targets, all_preds, average='weighted')
|
| 183 |
+
|
| 184 |
+
print(f'Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, F1 Score: {f1:.4f}')
|
| 185 |
+
|
| 186 |
+
torch.save(model.state_dict(), 'train_clip_model.pth')
|
| 187 |
+
torch.save(linear.state_dict(), 'train_linear_model.pth')
|