Spaces:
Runtime error
Runtime error
Upload 25 files
Browse files- .gitattributes +5 -0
- app.py +235 -0
- config/__init__.py +0 -0
- config/examples.py +16 -0
- config/experiment.yaml +41 -0
- config/with_decoder.yaml +60 -0
- create_result_script.py +531 -0
- data/chatgpt_similarity_score_test_direct.json +0 -0
- data/chatgpt_similarity_score_test_indirect.json +0 -0
- data/eval_test_image.json +0 -0
- data/eval_test_text.json +0 -0
- data/images/.DS_Store +0 -0
- data/key_pair.json +0 -0
- data/preview_image.jpeg +0 -0
- data/random_sample_test_direct_ids.json +3 -0
- data/random_sample_test_indirect_ids.json +3 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/fused_model.cpython-38.pyc +0 -0
- models/fused_model.py +544 -0
- models/model.py +117 -0
- results/config.yaml +42 -0
- results/model_epoch3.pth +3 -0
- results/results_pair_dict.json +3 -0
- results_pair_dict1.json +3 -0
- results_pair_dict2.json +3 -0
.gitattributes
CHANGED
|
@@ -32,3 +32,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
data/random_sample_test_direct_ids.json filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/random_sample_test_indirect_ids.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
results_pair_dict1.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
results_pair_dict2.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
results/results_pair_dict.json filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 7 |
+
import json
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import random
|
| 12 |
+
from os import listdir
|
| 13 |
+
from os.path import isfile, join
|
| 14 |
+
from torchvision.io import read_image
|
| 15 |
+
from torchvision.utils import draw_bounding_boxes
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import os
|
| 18 |
+
from scipy.stats import rankdata
|
| 19 |
+
import tqdm
|
| 20 |
+
import streamlit as st
|
| 21 |
+
import pandas as pd
|
| 22 |
+
|
| 23 |
+
# %%
|
| 24 |
+
def load_json(PATH):
|
| 25 |
+
if os.path.isfile(PATH) and os.access(PATH, os.R_OK):
|
| 26 |
+
with open(PATH) as json_file:
|
| 27 |
+
dict_data = json.load(json_file)
|
| 28 |
+
else:
|
| 29 |
+
print("The Path of", PATH,"is not exist")
|
| 30 |
+
dict_data = {}
|
| 31 |
+
return dict_data
|
| 32 |
+
|
| 33 |
+
def get_list_folder(PATH):
|
| 34 |
+
return [name for name in os.listdir(PATH) if os.path.isdir(os.path.join(PATH, name))]
|
| 35 |
+
|
| 36 |
+
def get_file_only(PATH):
|
| 37 |
+
return [f for f in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, f))]
|
| 38 |
+
|
| 39 |
+
# %%
|
| 40 |
+
def compute_ndcg(ranks, scores, k=3):
|
| 41 |
+
"""
|
| 42 |
+
ranks = [5, 1, 4, 2, 3]
|
| 43 |
+
scores = [0.1, 0.5, 0.3, 0.95, 1.0]
|
| 44 |
+
"""
|
| 45 |
+
rank_score_tuple = list(zip(ranks, scores))
|
| 46 |
+
|
| 47 |
+
top_k = sorted(rank_score_tuple, key=lambda x: x[1], reverse=True)[:k]
|
| 48 |
+
|
| 49 |
+
dcg = sum([score / np.log2(rank + 1) for rank, score in top_k])
|
| 50 |
+
|
| 51 |
+
ideal_dcg = sum([score / np.log2(idx + 2) for idx, (_, score) in enumerate(top_k)])
|
| 52 |
+
|
| 53 |
+
ndcg = dcg / ideal_dcg
|
| 54 |
+
return ndcg
|
| 55 |
+
|
| 56 |
+
def compute_ndcg_score_per_mode(pred_rank_dict, gpt_rel_scores, random_sample_dict, mode='indrect', split='test', k=200):
|
| 57 |
+
ndcg_scores = []
|
| 58 |
+
|
| 59 |
+
for key in tqdm.tqdm(pred_rank_dict.keys(), total=len(pred_rank_dict.keys())):
|
| 60 |
+
gpt_scores_for_key = [gpt_rel_scores[key][cand_key] if cand_key in gpt_rel_scores[key] else 0.0 for cand_key in random_sample_dict[key]]
|
| 61 |
+
|
| 62 |
+
pred_rank_for_key = pred_rank_dict[key]
|
| 63 |
+
|
| 64 |
+
ndcg_score = compute_ndcg(pred_rank_for_key, gpt_scores_for_key, k=k)
|
| 65 |
+
ndcg_scores.append(ndcg_score)
|
| 66 |
+
|
| 67 |
+
avg_ndcg_score = sum(ndcg_scores) / len(ndcg_scores)
|
| 68 |
+
print(f"Random split, mode={mode} ndcg score: ", avg_ndcg_score)
|
| 69 |
+
return avg_ndcg_score
|
| 70 |
+
# %%
|
| 71 |
+
|
| 72 |
+
def get_score_direct(random_sample_pair_test_direct, predictions, key_pair, similarity_score_test_direct, k = 200):
|
| 73 |
+
mode = 'direct'
|
| 74 |
+
i2t_ranks = []
|
| 75 |
+
t2i_ranks = []
|
| 76 |
+
i2t_rank_dict = {}
|
| 77 |
+
results_dict = {}
|
| 78 |
+
key_pair_reversed = {v: k for k, v in key_pair.items()}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
for file_key in tqdm.tqdm(random_sample_pair_test_direct.keys(), total=len(random_sample_pair_test_direct.keys())):
|
| 82 |
+
i2t_rank = rankdata([predictions[str(file_key)+':'+str(key_pair[k])] for k in random_sample_pair_test_direct[file_key]])
|
| 83 |
+
t2i_rank = rankdata([predictions[str(key_pair_reversed[key_pair[k]])+':'+str(key_pair[file_key])] for k in random_sample_pair_test_direct[file_key]])
|
| 84 |
+
|
| 85 |
+
i2t_ranks.append(i2t_rank[-1])
|
| 86 |
+
t2i_ranks.append(t2i_rank[-1])
|
| 87 |
+
i2t_rank_dict[file_key] = i2t_rank
|
| 88 |
+
|
| 89 |
+
assert len(i2t_ranks) == len(t2i_ranks) == 1000
|
| 90 |
+
|
| 91 |
+
ndcg_score = compute_ndcg_score_per_mode(i2t_rank_dict, similarity_score_test_direct, random_sample_pair_test_direct, mode='indirect', split='test', k=200)
|
| 92 |
+
|
| 93 |
+
results_dict['direct'] = {}
|
| 94 |
+
results_dict['direct']['i2t rank'] = float(sum(i2t_ranks) / len(i2t_ranks))
|
| 95 |
+
results_dict['direct']['t2i rank'] = float(sum(t2i_ranks) / len(t2i_ranks))
|
| 96 |
+
results_dict['direct']['ndcg score'] = float(ndcg_score)
|
| 97 |
+
print(f"Random split, mode={mode} i2t rank: ", sum(i2t_ranks) / len(i2t_ranks))
|
| 98 |
+
print(f"Random split, mode={mode} t2i rank: ", sum(t2i_ranks) / len(t2i_ranks))
|
| 99 |
+
return results_dict
|
| 100 |
+
|
| 101 |
+
# %%
|
| 102 |
+
|
| 103 |
+
def get_score_indirect(random_sample_pair_test_indirect, predictions, key_pair, similarity_score_test_indirect, k = 200):
|
| 104 |
+
mode = 'indirect'
|
| 105 |
+
i2t_ranks = []
|
| 106 |
+
t2i_ranks = []
|
| 107 |
+
i2t_rank_dict = {}
|
| 108 |
+
results_dict = {}
|
| 109 |
+
key_pair_reversed = {v: k for k, v in key_pair.items()}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
for file_key in tqdm.tqdm(random_sample_pair_test_indirect.keys(), total=len(random_sample_pair_test_indirect.keys())):
|
| 113 |
+
i2t_rank = rankdata([predictions[str(file_key)+':'+str(key_pair[k])] for k in random_sample_pair_test_indirect[file_key]])
|
| 114 |
+
t2i_rank = rankdata([predictions[str(key_pair_reversed[key_pair[k]])+':'+str(key_pair[file_key])] for k in random_sample_pair_test_indirect[file_key]])
|
| 115 |
+
|
| 116 |
+
i2t_ranks.append(i2t_rank[-1])
|
| 117 |
+
t2i_ranks.append(t2i_rank[-1])
|
| 118 |
+
i2t_rank_dict[file_key] = i2t_rank
|
| 119 |
+
|
| 120 |
+
assert len(i2t_ranks) == len(t2i_ranks) == 1000
|
| 121 |
+
|
| 122 |
+
ndcg_score = compute_ndcg_score_per_mode(i2t_rank_dict, similarity_score_test_indirect, random_sample_pair_test_indirect, mode='indrect', split='test', k=200)
|
| 123 |
+
|
| 124 |
+
results_dict['indirect'] = {}
|
| 125 |
+
results_dict['indirect']['i2t rank'] = float(sum(i2t_ranks) / len(i2t_ranks))
|
| 126 |
+
results_dict['indirect']['t2i rank'] = float(sum(t2i_ranks) / len(t2i_ranks))
|
| 127 |
+
results_dict['indirect']['ndcg score'] = float(ndcg_score)
|
| 128 |
+
print(f"Random split, mode={mode} i2t rank: ", sum(i2t_ranks) / len(i2t_ranks))
|
| 129 |
+
print(f"Random split, mode={mode} t2i rank: ", sum(t2i_ranks) / len(t2i_ranks))
|
| 130 |
+
return results_dict
|
| 131 |
+
|
| 132 |
+
# %%
|
| 133 |
+
def main(json_file):
|
| 134 |
+
### Setup
|
| 135 |
+
# os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))
|
| 136 |
+
|
| 137 |
+
# %%
|
| 138 |
+
|
| 139 |
+
### Load data
|
| 140 |
+
# if os.path.isfile(os.path.join(os.environ['ROOT'], json_file)): #'results_pair_dict.json')):
|
| 141 |
+
# predictions_file_path = os.path.join(os.environ['ROOT'], json_file) #'results_pair_dict.json')
|
| 142 |
+
# else:
|
| 143 |
+
# predictions_file_path = os.path.join(os.environ['ROOT'], json_file) #'data/results_pair_dict.json')
|
| 144 |
+
|
| 145 |
+
# with open(predictions_file_path) as f:
|
| 146 |
+
# predictions = json.load(f)
|
| 147 |
+
predictions = json_file
|
| 148 |
+
|
| 149 |
+
with open(os.path.join(os.environ['ROOT'], 'data/key_pair.json')) as f:
|
| 150 |
+
key_pair = json.load(f)
|
| 151 |
+
|
| 152 |
+
# key_pair_reversed = {v: k for k, v in key_pair.items()}
|
| 153 |
+
|
| 154 |
+
with open(os.path.join(os.environ['ROOT'], 'data/random_sample_test_direct_ids.json')) as f:
|
| 155 |
+
random_sample_pair_test_direct = json.load(f)
|
| 156 |
+
|
| 157 |
+
with open(os.path.join(os.environ['ROOT'], 'data/random_sample_test_indirect_ids.json')) as f:
|
| 158 |
+
random_sample_pair_test_indirect = json.load(f)
|
| 159 |
+
|
| 160 |
+
with open(os.path.join(os.environ['ROOT'], 'data/chatgpt_similarity_score_test_direct.json')) as f:
|
| 161 |
+
similarity_score_test_direct = json.load(f)
|
| 162 |
+
|
| 163 |
+
with open(os.path.join(os.environ['ROOT'], 'data/chatgpt_similarity_score_test_indirect.json')) as f:
|
| 164 |
+
similarity_score_test_indirect = json.load(f)
|
| 165 |
+
|
| 166 |
+
# %%
|
| 167 |
+
### Compute scores
|
| 168 |
+
|
| 169 |
+
result_direct = get_score_direct(random_sample_pair_test_direct, predictions, key_pair, similarity_score_test_direct, k = 200)
|
| 170 |
+
result_indirect = get_score_indirect(random_sample_pair_test_indirect, predictions, key_pair, similarity_score_test_indirect, k = 200)
|
| 171 |
+
result_dict = {**result_direct, **result_indirect}
|
| 172 |
+
return result_dict
|
| 173 |
+
# %%
|
| 174 |
+
if __name__ == '__main__':
|
| 175 |
+
os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))
|
| 176 |
+
st.title("Evaluation Server for Driving Hazard Prediction and Reasoning ")
|
| 177 |
+
st.image(os.path.join(os.environ['ROOT'],'data/preview_image.jpeg'))
|
| 178 |
+
st.divider()
|
| 179 |
+
result_text = ''
|
| 180 |
+
result_dict = {}
|
| 181 |
+
uploaded_files = None
|
| 182 |
+
json_file = None
|
| 183 |
+
|
| 184 |
+
uploaded_files = st.file_uploader("Upload All Result Files Here (results_pair_dict1.json, results_pair_dict2.json)", type=["csv"], accept_multiple_files=True)
|
| 185 |
+
if st.button('Run Evaluation with no upload files (using demo files)'):
|
| 186 |
+
json_file1 = load_json(os.path.join(os.environ['ROOT'], 'results_pair_dict1.json'))
|
| 187 |
+
json_file2 = load_json(os.path.join(os.environ['ROOT'], 'results_pair_dict2.json'))
|
| 188 |
+
json_file = {**json_file1, **json_file2}
|
| 189 |
+
dataframe = pd.DataFrame([])
|
| 190 |
+
if uploaded_files is not None:
|
| 191 |
+
for i in range(len(uploaded_files)):
|
| 192 |
+
dataframe = pd.concat([dataframe, pd.read_csv(uploaded_files[i])])
|
| 193 |
+
|
| 194 |
+
result = dataframe.to_dict('tight')['data']
|
| 195 |
+
json_file = {}
|
| 196 |
+
for i in range(len(result)):
|
| 197 |
+
json_file[str(result[i][1])] = float(result[i][2])
|
| 198 |
+
|
| 199 |
+
if len(json_file) >= 1:
|
| 200 |
+
result_dict = main(json_file)
|
| 201 |
+
result_text = json.dumps(result_dict)
|
| 202 |
+
|
| 203 |
+
st.download_button('Download Results', result_text)
|
| 204 |
+
st.json(result_dict)
|
| 205 |
+
|
| 206 |
+
# !streamlit run app.py --server.fileWatcherType none
|
| 207 |
+
|
| 208 |
+
# if st.button('Load Results File1 from local instead'):
|
| 209 |
+
# json_file1_path = os.path.join(os.environ['ROOT'], 'results_pair_dict1.json')
|
| 210 |
+
# json_file1 = load_json(json_file1_path)
|
| 211 |
+
# st.write(json_file1)
|
| 212 |
+
|
| 213 |
+
# if uploaded_files is not None:
|
| 214 |
+
# with open(uploaded_file1) as jf:
|
| 215 |
+
# json_file1 = json.load(jf)
|
| 216 |
+
# json_file1 = load_json(uploaded_file1)
|
| 217 |
+
# uploaded1 = True
|
| 218 |
+
|
| 219 |
+
# uploaded_file2 = st.file_uploader("Upload Results File2")
|
| 220 |
+
# if st.button('Load Results File2 from local instead'):
|
| 221 |
+
# json_file2_path = os.path.join(os.environ['ROOT'], 'results_pair_dict2.json')
|
| 222 |
+
# json_file2 = load_json(json_file2_path)
|
| 223 |
+
# st.write(json_file2)
|
| 224 |
+
|
| 225 |
+
# if uploaded_file2 is not None:
|
| 226 |
+
# with open(uploaded_file2) as jf:
|
| 227 |
+
# json_file2 = json.load(jf)
|
| 228 |
+
# uploaded2 = True
|
| 229 |
+
|
| 230 |
+
# # if uploaded1 and uploaded2:
|
| 231 |
+
# json_file = {**json_file1, **json_file2}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# %%
|
config/__init__.py
ADDED
|
File without changes
|
config/examples.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 7 |
+
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
|
| 8 |
+
|
| 9 |
+
os.environ['ROOT'] = "/home/quang/workspace/traffic_var"
|
| 10 |
+
os.environ['DATA_ROOT'] = "/home/quang/datasets/traffic_var"
|
| 11 |
+
|
| 12 |
+
# initialize hydra config
|
| 13 |
+
GlobalHydra.instance().clear()
|
| 14 |
+
initialize(config_path="./")
|
| 15 |
+
with_decoder_config = compose(config_name='with_decoder.yaml', overrides=["clip_model=ViT-B/16", "rationale_type=0", "val_rationale_type=0"])
|
| 16 |
+
original_config = compose(config_name='experiment.yaml', overrides=["clip_model=ViT-B/16", "rationale_type=0", "val_rationale_type=0"])
|
config/experiment.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.yaml
|
| 2 |
+
|
| 3 |
+
exp_name: exp1-2
|
| 4 |
+
wandb: 90788c79e1500570b08e5acf283e17df7e0c54b2
|
| 5 |
+
root: "${oc.env:DATA_ROOT}"
|
| 6 |
+
overfit: False
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
img_size: 224
|
| 10 |
+
rationale_type: 0 # 0 - only rationale, 1 - randomly add entity desc, 2 - add entity desc
|
| 11 |
+
val_rationale_type: 0
|
| 12 |
+
hide_true_bbox: 8 # clues and inferences selected randomly
|
| 13 |
+
widescreen_processing: 1 # 0 - no widescreen, 1 - widescreen
|
| 14 |
+
h_flip: False
|
| 15 |
+
ema_decay: 0.9999
|
| 16 |
+
|
| 17 |
+
clip_model: 'ViT-B/16' # 'RN101' 'RN50x4''RN50x16' 'RN50x64' 'ViT-L/14@336px' 'ViT-B/32'
|
| 18 |
+
num_layers: 3
|
| 19 |
+
dim_hidden: 512
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
warmup: 1000
|
| 23 |
+
init_from: ''
|
| 24 |
+
lr: .00001
|
| 25 |
+
n_epochs: 15
|
| 26 |
+
save_every: 0
|
| 27 |
+
early_stop: 5
|
| 28 |
+
val_stat: 'loss'
|
| 29 |
+
device: 'cuda'
|
| 30 |
+
use_multi: False
|
| 31 |
+
local_rank: 0
|
| 32 |
+
|
| 33 |
+
hydra:
|
| 34 |
+
run:
|
| 35 |
+
dir: ./results/${exp_name}
|
| 36 |
+
output_subdir: ./ # directory for saving the yaml configs
|
| 37 |
+
job:
|
| 38 |
+
config:
|
| 39 |
+
override_dirname:
|
| 40 |
+
exclude_keys:
|
| 41 |
+
- exp.name
|
config/with_decoder.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.yaml
|
| 2 |
+
|
| 3 |
+
exp_name: exp1-2
|
| 4 |
+
wandb: 90788c79e1500570b08e5acf283e17df7e0c54b2
|
| 5 |
+
root: "${oc.env:DATA_ROOT}"
|
| 6 |
+
overfit: False
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
img_size: 224
|
| 10 |
+
rationale_type: 0 # 0 - only rationale, 1 - randomly add entity desc, 2 - add entity desc
|
| 11 |
+
val_rationale_type: 0
|
| 12 |
+
hide_true_bbox: 8 # clues and inferences selected randomly
|
| 13 |
+
widescreen_processing: 1 # 0 - no widescreen, 1 - widescreen
|
| 14 |
+
h_flip: False
|
| 15 |
+
ema_decay: 0.9999
|
| 16 |
+
aux_weight: 0.2
|
| 17 |
+
no_hard_negative_itm: False
|
| 18 |
+
|
| 19 |
+
clip_model: 'ViT-B/16' # 'RN101' 'RN50x4''RN50x16' 'RN50x64' 'ViT-L/14@336px' 'ViT-B/32'
|
| 20 |
+
has_extra_txt_decoder: False
|
| 21 |
+
has_extra_img_decoder: False
|
| 22 |
+
has_extra_mix_decoder: False
|
| 23 |
+
has_extra_gen_decoder: False
|
| 24 |
+
|
| 25 |
+
extra_decoder:
|
| 26 |
+
is_decoder: True
|
| 27 |
+
vocab_size: 1000
|
| 28 |
+
d_ff: 512
|
| 29 |
+
d_kv: 64
|
| 30 |
+
d_model: 512
|
| 31 |
+
dropout_rate: 0.1
|
| 32 |
+
num_heads: 8
|
| 33 |
+
num_layers: 2
|
| 34 |
+
# eos_token_id: 1
|
| 35 |
+
# pad_token_id: 0
|
| 36 |
+
# decoder_start_token_id: 0
|
| 37 |
+
# n_positions: 512
|
| 38 |
+
relative_attention_max_distance: 128
|
| 39 |
+
relative_attention_num_buckets: 32
|
| 40 |
+
|
| 41 |
+
warmup: 1000
|
| 42 |
+
init_from: ''
|
| 43 |
+
lr: .00001
|
| 44 |
+
n_epochs: 15
|
| 45 |
+
save_every: 0
|
| 46 |
+
early_stop: 5
|
| 47 |
+
val_stat: 'loss'
|
| 48 |
+
device: 'cuda'
|
| 49 |
+
use_multi: False
|
| 50 |
+
local_rank: 0
|
| 51 |
+
|
| 52 |
+
hydra:
|
| 53 |
+
run:
|
| 54 |
+
dir: ./results/${exp_name}
|
| 55 |
+
output_subdir: ./ # directory for saving the yaml configs
|
| 56 |
+
job:
|
| 57 |
+
config:
|
| 58 |
+
override_dirname:
|
| 59 |
+
exclude_keys:
|
| 60 |
+
- exp.name
|
create_result_script.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
import pickle
|
| 6 |
+
sys.path.append("../")
|
| 7 |
+
import collections
|
| 8 |
+
from models.fused_model import Model
|
| 9 |
+
import os
|
| 10 |
+
import tqdm
|
| 11 |
+
import time
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
from PIL import ImageFile
|
| 15 |
+
from PIL import Image, ImageDraw
|
| 16 |
+
import clip
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torchvision.transforms as T
|
| 20 |
+
import torchvision.transforms.functional as F
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
import pandas as pd
|
| 23 |
+
|
| 24 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 25 |
+
|
| 26 |
+
# %%
|
| 27 |
+
from types import SimpleNamespace
|
| 28 |
+
# get config
|
| 29 |
+
import os
|
| 30 |
+
from omegaconf import OmegaConf
|
| 31 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 32 |
+
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
|
| 33 |
+
|
| 34 |
+
os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))
|
| 35 |
+
os.environ['DATA_ROOT'] = os.path.join(os.environ['ROOT'], 'data')
|
| 36 |
+
|
| 37 |
+
# initialize hydra config
|
| 38 |
+
GlobalHydra.instance().clear()
|
| 39 |
+
initialize(config_path="./config")
|
| 40 |
+
|
| 41 |
+
config = compose(config_name='with_decoder.yaml',
|
| 42 |
+
overrides=["clip_model=ViT-L/14@336px",
|
| 43 |
+
"rationale_type=0", "val_rationale_type=0"])
|
| 44 |
+
|
| 45 |
+
class SquarePad:
|
| 46 |
+
|
| 47 |
+
def __call__(self, image):
|
| 48 |
+
max_wh = max(image.size)
|
| 49 |
+
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
|
| 50 |
+
p_right, p_bottom = [max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top])]
|
| 51 |
+
padding = (p_left, p_top, p_right, p_bottom)
|
| 52 |
+
return F.pad(image, padding, 0, 'constant')
|
| 53 |
+
|
| 54 |
+
class VarDatasetForAuxEncoders:
|
| 55 |
+
def __init__(self, config, file_path, split="train", mode="combined", do_swap=False, tensorize=True, do_crop=True):
|
| 56 |
+
self.config = config
|
| 57 |
+
self.mode = mode
|
| 58 |
+
self.split = split
|
| 59 |
+
self.do_swap = do_swap
|
| 60 |
+
self.rationale_type = config.rationale_type if split == "train" else config.val_rationale_type
|
| 61 |
+
self.root_path = Path(config.root)
|
| 62 |
+
self.anno_path = file_path #self.root_path / f'annotations/13_05/anno_{split}_{mode}.json'
|
| 63 |
+
if split == "test" and mode == "combined" and config.overfit:
|
| 64 |
+
self.anno_path = self.root_path / f'annotations/13_05/anno_{split}_{mode}_overfit.json'
|
| 65 |
+
|
| 66 |
+
self.data = json.load(open(self.anno_path))
|
| 67 |
+
self.idx2name = list(self.data.keys())
|
| 68 |
+
|
| 69 |
+
if 'bounding_box' in self.data[list(self.data.keys())[0]]['details'][-1]:
|
| 70 |
+
self.one_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 1]
|
| 71 |
+
self.two_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 2]
|
| 72 |
+
self.three_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 3]
|
| 73 |
+
self.all_ent_keys = self.one_ent_keys + self.two_ent_keys + self.three_ent_keys
|
| 74 |
+
|
| 75 |
+
self.keys = {1: self.one_ent_keys, 2: self.two_ent_keys, 3: self.three_ent_keys}
|
| 76 |
+
|
| 77 |
+
if self.config.widescreen_processing in [0, 1]:
|
| 78 |
+
self.resize_crop = self.get_transform(config.img_size, split == "train", padding=False)
|
| 79 |
+
else:
|
| 80 |
+
self.resize_crop = self.get_transform(config.img_size, split == "train", padding=True)
|
| 81 |
+
|
| 82 |
+
self.tensorize = tensorize
|
| 83 |
+
self.jitter_transform = T.ColorJitter(brightness=.5, hue=.3, saturation=.3) if split == "train" else lambda x: x
|
| 84 |
+
|
| 85 |
+
self.final_transform = T.Compose([
|
| 86 |
+
lambda image: image.convert("RGB"),
|
| 87 |
+
T.ToTensor() if tensorize else lambda x: x,
|
| 88 |
+
T.Normalize(
|
| 89 |
+
(0.48145466, 0.4578275, 0.40821073),
|
| 90 |
+
(0.26862954, 0.26130258, 0.27577711),
|
| 91 |
+
) if tensorize else lambda x: x
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
def get_transform(self, n_px, training, padding=False):
|
| 95 |
+
resize = T.Resize((n_px + 16, n_px + 16), interpolation=Image.BICUBIC)
|
| 96 |
+
|
| 97 |
+
# for traning split
|
| 98 |
+
if training and not padding: # train
|
| 99 |
+
return T.Compose([resize, T.RandomCrop(n_px)])
|
| 100 |
+
|
| 101 |
+
if training and padding: # train_pad
|
| 102 |
+
return T.Compose([SquarePad(), resize, T.RandomCrop(n_px)])
|
| 103 |
+
|
| 104 |
+
# for test and val split
|
| 105 |
+
if not training and not padding: # test
|
| 106 |
+
return T.Compose([resize, T.CenterCrop(n_px)])
|
| 107 |
+
|
| 108 |
+
if not training and padding: # test_pad
|
| 109 |
+
return T.Compose([SquarePad(), resize, T.CenterCrop(n_px)])
|
| 110 |
+
|
| 111 |
+
def key2img_path(self, key):
|
| 112 |
+
file_paths = [
|
| 113 |
+
self.root_path / f"var_images/{key}.jpg",
|
| 114 |
+
self.root_path / f"var_images/{key}.png",
|
| 115 |
+
self.root_path / f"images/{key}.jpg",
|
| 116 |
+
self.root_path / f"img/train/{key.split('_')[0]}/{key}.png",
|
| 117 |
+
self.root_path / f"img/val/{key.split('_')[0]}/{key}.png",
|
| 118 |
+
self.root_path / f"img/test/{key.split('_')[0]}/{key}.png",
|
| 119 |
+
self.root_path / f"img/{key}.png",
|
| 120 |
+
self.root_path / f"img/{key}.jpg",
|
| 121 |
+
self.root_path / f"images/{key}.png",
|
| 122 |
+
self.root_path / f"images/{key}.jpg",
|
| 123 |
+
]
|
| 124 |
+
for file_path in file_paths:
|
| 125 |
+
if file_path.exists():
|
| 126 |
+
return file_path
|
| 127 |
+
|
| 128 |
+
def key2img(self, key):
|
| 129 |
+
file_path = self.key2img_path(key)
|
| 130 |
+
return Image.open(file_path)
|
| 131 |
+
|
| 132 |
+
def hide_region(self, image, bboxes):
|
| 133 |
+
image = image.convert('RGBA')
|
| 134 |
+
|
| 135 |
+
if self.config.hide_true_bbox == 1: # hide mode
|
| 136 |
+
draw = ImageDraw.Draw(image, 'RGBA')
|
| 137 |
+
|
| 138 |
+
if self.config.hide_true_bbox in [2, 5, 7, 8, 9]: #highlight mode
|
| 139 |
+
overlay = Image.new('RGBA', image.size, '#00000000')
|
| 140 |
+
draw = ImageDraw.Draw(overlay, 'RGBA')
|
| 141 |
+
|
| 142 |
+
if self.config.hide_true_bbox == 3 or self.config.hide_true_bbox == 6: #blackout mode or position only mode
|
| 143 |
+
overlay = Image.new('RGBA', image.size, '#7B7575ff')
|
| 144 |
+
draw = ImageDraw.Draw(overlay, 'RGBA')
|
| 145 |
+
|
| 146 |
+
color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] # Green, Blue, Yellow?
|
| 147 |
+
|
| 148 |
+
for idx, bbox in enumerate(bboxes):
|
| 149 |
+
if bbox == None:
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
color_fill = color_fill_list[idx]
|
| 153 |
+
x, y = bbox['left'], bbox['top']
|
| 154 |
+
|
| 155 |
+
if self.config.hide_true_bbox == 1: # hide mode
|
| 156 |
+
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575')
|
| 157 |
+
elif self.config.hide_true_bbox in [2, 5, 7, 8, 9]: # highlight mode
|
| 158 |
+
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff',
|
| 159 |
+
width=3) # Fill with Pink 60% ##00F1E8
|
| 160 |
+
elif self.config.hide_true_bbox == 3: # blackout mode
|
| 161 |
+
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000')
|
| 162 |
+
elif self.config.hide_true_bbox == 6: # position only mode
|
| 163 |
+
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill)
|
| 164 |
+
|
| 165 |
+
if self.config.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]:
|
| 166 |
+
image = Image.alpha_composite(image, overlay)
|
| 167 |
+
|
| 168 |
+
return image
|
| 169 |
+
|
| 170 |
+
def get_entity_codes(self):
|
| 171 |
+
entity_codes = [0, 1, 2]
|
| 172 |
+
if self.do_swap:
|
| 173 |
+
random.shuffle(entity_codes)
|
| 174 |
+
return entity_codes
|
| 175 |
+
|
| 176 |
+
def swap_entities(self, bboxes, text, entity_codes):
|
| 177 |
+
# text
|
| 178 |
+
for entity_idx, entity_code in enumerate(entity_codes):
|
| 179 |
+
text = text.replace(f"Entity #{entity_idx + 1}", f"Entity #{entity_code + 1}")
|
| 180 |
+
|
| 181 |
+
# bboxes: [1, 0, 2] -> [b[1], b[0], b[2]]
|
| 182 |
+
new_boxes = [bboxes[entity_code] for entity_code in entity_codes]
|
| 183 |
+
return new_boxes, text
|
| 184 |
+
|
| 185 |
+
def get_text_from_meta(self, meta):
|
| 186 |
+
n_boxes = len(meta['bounding_box']) # key ['1', '2', '3']
|
| 187 |
+
|
| 188 |
+
# for rationale
|
| 189 |
+
text = 'Rationale: ' + str(meta['rationale'])
|
| 190 |
+
|
| 191 |
+
if self.rationale_type == 1 or self.rationale_type == 2:
|
| 192 |
+
for box_idx in range(n_boxes):
|
| 193 |
+
ent_name = f'Entity #{box_idx + 1}'
|
| 194 |
+
ent_desc = f'{ent_name}, {meta[ent_name]}'
|
| 195 |
+
# todo: replace randomly
|
| 196 |
+
text = text.replace(ent_name, ent_desc, 1)
|
| 197 |
+
return text
|
| 198 |
+
|
| 199 |
+
def get_itm_text(self, ori_file_key):
|
| 200 |
+
file_key = ori_file_key
|
| 201 |
+
|
| 202 |
+
if random.random() < 0.5:
|
| 203 |
+
n_boxes = len(self.data[file_key]['details'][-1]['bounding_box'])
|
| 204 |
+
|
| 205 |
+
file_key = random.choice(self.keys[n_boxes])
|
| 206 |
+
|
| 207 |
+
if self.config.get('no_hard_negative_itm', False):
|
| 208 |
+
file_key = random.choice(self.all_ent_keys)
|
| 209 |
+
|
| 210 |
+
itm_label = 1 if file_key == ori_file_key else 0
|
| 211 |
+
meta = self.data[file_key]['details'][-1]
|
| 212 |
+
|
| 213 |
+
itm_text = self.get_text_from_meta(meta)
|
| 214 |
+
return itm_text, itm_label
|
| 215 |
+
|
| 216 |
+
def get_bboxes_and_text(self, file_key, meta):
|
| 217 |
+
text = self.get_text_from_meta(meta)
|
| 218 |
+
bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)]
|
| 219 |
+
|
| 220 |
+
entity_codes = self.get_entity_codes()
|
| 221 |
+
bboxes, text = self.swap_entities(bboxes, text, entity_codes)
|
| 222 |
+
|
| 223 |
+
itm_text, itm_label = self.get_itm_text(file_key)
|
| 224 |
+
_, itm_text = self.swap_entities([None, None, None], itm_text, entity_codes)
|
| 225 |
+
return {'bboxes': bboxes, 'text': text, 'itm_text': itm_text, 'itm_label': itm_label}
|
| 226 |
+
|
| 227 |
+
def get_image(self, file_key, bboxes):
|
| 228 |
+
image = self.key2img(file_key)
|
| 229 |
+
image = self.jitter_transform(image)
|
| 230 |
+
image = self.hide_region(image, bboxes)
|
| 231 |
+
image = self.final_transform(self.resize_crop(image))
|
| 232 |
+
return image
|
| 233 |
+
|
| 234 |
+
def __getitem__(self, idx):
|
| 235 |
+
file_key = self.idx2name[idx]
|
| 236 |
+
|
| 237 |
+
# Select the last version of label of the sample
|
| 238 |
+
meta = self.data[file_key]['details'][-1]
|
| 239 |
+
|
| 240 |
+
# read bboxes and rationale
|
| 241 |
+
outputs = self.get_bboxes_and_text(file_key, meta)
|
| 242 |
+
text = clip.tokenize(outputs['text'], truncate=True).squeeze()
|
| 243 |
+
itm_text = clip.tokenize(outputs['itm_text'], truncate=True).squeeze()
|
| 244 |
+
itm_label = torch.tensor(outputs['itm_label'])
|
| 245 |
+
|
| 246 |
+
image = self.get_image(file_key, outputs['bboxes'])
|
| 247 |
+
|
| 248 |
+
return {'image': image, 'caption': text, 'raw_text': text, 'file_key': file_key, 'itm_text': itm_text, 'itm_label': itm_label}
|
| 249 |
+
|
| 250 |
+
def __len__(self):
|
| 251 |
+
if self.config.overfit and not (self.split == 'test' and self.mode == 'combined'):
|
| 252 |
+
return 16
|
| 253 |
+
return len(self.data)
|
| 254 |
+
|
| 255 |
+
# %%
|
| 256 |
+
class VarDatasetImageOnly(VarDatasetForAuxEncoders):
|
| 257 |
+
def __init__(self, args, file_path, split="val", mode="combined", do_swap= False):
|
| 258 |
+
super().__init__(args, file_path, split=split, mode=mode, do_swap=do_swap)
|
| 259 |
+
|
| 260 |
+
def __getitem__(self, idx):
|
| 261 |
+
file_key = self.idx2name[idx]
|
| 262 |
+
meta = self.data[file_key]['details'][-1]
|
| 263 |
+
bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)]
|
| 264 |
+
entity_codes = self.get_entity_codes()
|
| 265 |
+
bboxes = [bboxes[entity_code] for entity_code in entity_codes]
|
| 266 |
+
image = self.get_image(file_key, bboxes)
|
| 267 |
+
return {'image': image, 'file_key': file_key}
|
| 268 |
+
|
| 269 |
+
# %%
|
| 270 |
+
class VarDatasetTextOnly(VarDatasetForAuxEncoders):
|
| 271 |
+
def __init__(self, args, file_path, split="val", mode="combined", do_swap= False):
|
| 272 |
+
super().__init__(args, file_path, split=split, mode=mode, do_swap=do_swap)
|
| 273 |
+
|
| 274 |
+
def __getitem__(self, idx):
|
| 275 |
+
file_key = self.idx2name[idx]
|
| 276 |
+
meta = self.data[file_key]['details'][-1]
|
| 277 |
+
# text = self.get_text_from_meta(meta)
|
| 278 |
+
if 'Entity #3' in meta['hazard']:
|
| 279 |
+
n_boxes = 3
|
| 280 |
+
elif 'Entity #2' in meta['hazard']:
|
| 281 |
+
n_boxes = 2
|
| 282 |
+
else:
|
| 283 |
+
n_boxes = 1
|
| 284 |
+
|
| 285 |
+
# for rationale
|
| 286 |
+
text = 'Rationale: ' + str(meta['hazard'])
|
| 287 |
+
|
| 288 |
+
if self.rationale_type == 1 or self.rationale_type == 2:
|
| 289 |
+
for box_idx in range(n_boxes):
|
| 290 |
+
ent_name = f'Entity #{box_idx + 1}'
|
| 291 |
+
ent_desc = f'{ent_name}, {meta[ent_name]}'
|
| 292 |
+
# todo: replace randomly
|
| 293 |
+
text = text.replace(ent_name, ent_desc, 1)
|
| 294 |
+
|
| 295 |
+
entity_codes = self.get_entity_codes()
|
| 296 |
+
for entity_idx, entity_code in enumerate(entity_codes):
|
| 297 |
+
text = text.replace(f"Entity #{entity_idx + 1}", f"Entity #{entity_code + 1}")
|
| 298 |
+
text = clip.tokenize(text, truncate=True).squeeze()
|
| 299 |
+
return {'caption': text,'file_key': file_key}
|
| 300 |
+
|
| 301 |
+
# %%
|
| 302 |
+
import os
|
| 303 |
+
import sys
|
| 304 |
+
|
| 305 |
+
sys.path.append('..')
|
| 306 |
+
import json
|
| 307 |
+
import fire
|
| 308 |
+
import tqdm
|
| 309 |
+
|
| 310 |
+
import clip
|
| 311 |
+
import torch
|
| 312 |
+
import sklearn
|
| 313 |
+
import numpy as np
|
| 314 |
+
|
| 315 |
+
from omegaconf import OmegaConf
|
| 316 |
+
from models.fused_model import Model
|
| 317 |
+
from torch.utils.data import DataLoader
|
| 318 |
+
# from datasets import VarDatasetForAuxEncoders
|
| 319 |
+
|
| 320 |
+
from scipy.stats import rankdata
|
| 321 |
+
from sklearn.metrics import ndcg_score
|
| 322 |
+
from sklearn.metrics import pairwise_distances
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# def get_data_loader(config, split="test", mode="combined", do_swap=False):
|
| 326 |
+
# dataset = VarDatasetForAuxEncoders(config, split=split, mode=mode, do_swap=do_swap)
|
| 327 |
+
# return DataLoader(dataset, batch_size=4, shuffle=False)
|
| 328 |
+
|
| 329 |
+
def get_image_data_loader(config, file_path, split="test", mode="combined", do_swap=False):
|
| 330 |
+
dataset = VarDatasetImageOnly(config, file_path, split=split, mode=mode, do_swap=do_swap)
|
| 331 |
+
return DataLoader(dataset, batch_size=4, shuffle=False)
|
| 332 |
+
|
| 333 |
+
def get_text_data_loader(config, file_path, split="test", mode="combined", do_swap=False):
|
| 334 |
+
dataset = VarDatasetTextOnly(config, file_path, split=split, mode=mode, do_swap=do_swap)
|
| 335 |
+
return DataLoader(dataset, batch_size=4, shuffle=False)
|
| 336 |
+
|
| 337 |
+
# def get_data_loader(config, split="test", mode="combined", do_swap=False):
|
| 338 |
+
# dataset = VarDatasetForAuxEncoders(config, split=split, mode=mode, do_swap=do_swap)
|
| 339 |
+
# return DataLoader(dataset, batch_size=4, shuffle=False)
|
| 340 |
+
|
| 341 |
+
def compute_rand_rank(split='test', mode='spec', img_token_dict={}, txt_token_dict={}): # the dicts contain all 2000 test samples
|
| 342 |
+
data = json.load(open( os.path.join(os.environ['ROOT'], f"data/annotations/13_05/anno_random_{split}_{mode}_ids.json")))
|
| 343 |
+
|
| 344 |
+
i2t_ranks = []
|
| 345 |
+
t2i_ranks = []
|
| 346 |
+
i2t_rank_dict = {}
|
| 347 |
+
t2i_rank_dict = {}
|
| 348 |
+
|
| 349 |
+
for file_key in data.keys():
|
| 350 |
+
img_emb = (img_token_dict[file_key]).unsqueeze(0)
|
| 351 |
+
txt_emb = (txt_token_dict[file_key]).unsqueeze(0)
|
| 352 |
+
|
| 353 |
+
txt_embs = torch.stack([txt_token_dict[k] for k in data[file_key]])
|
| 354 |
+
img_embs = torch.stack([img_token_dict[k] for k in data[file_key]])
|
| 355 |
+
assert txt_embs.shape[0] == img_embs.shape[0] == 1000
|
| 356 |
+
|
| 357 |
+
i2t_rank = rankdata(pairwise_distances(img_emb, txt_embs, metric='cosine', n_jobs=8), axis=1)[0]
|
| 358 |
+
t2i_rank = rankdata(pairwise_distances(txt_emb, img_embs, metric='cosine', n_jobs=8), axis=1)[0]
|
| 359 |
+
|
| 360 |
+
i2t_ranks.append(i2t_rank[-1])
|
| 361 |
+
t2i_ranks.append(t2i_rank[-1])
|
| 362 |
+
|
| 363 |
+
i2t_rank_dict[file_key] = i2t_rank
|
| 364 |
+
t2i_rank_dict[file_key] = t2i_rank
|
| 365 |
+
|
| 366 |
+
assert len(i2t_ranks) == len(t2i_ranks) == 1000
|
| 367 |
+
print(f"Random split, mode={mode} i2t rank: ", sum(i2t_ranks) / len(i2t_ranks))
|
| 368 |
+
print(f"Random split, mode={mode} t2i rank: ", sum(t2i_ranks) / len(t2i_ranks))
|
| 369 |
+
# for k in i2t_rank_dict.keys():
|
| 370 |
+
# print(k, i2t_rank_dict[k])
|
| 371 |
+
# print('------------------')
|
| 372 |
+
# break
|
| 373 |
+
return i2t_rank_dict # for computing the NDCG scores
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def read_relevance_scores(anno_path="anno_random_test_obvi_ids.json", gpt_path="chatgpt_similarity_score_test_direct_combined.json"):
|
| 377 |
+
gpt_scores = json.load(open(gpt_path))
|
| 378 |
+
data = json.load(open(anno_path))
|
| 379 |
+
|
| 380 |
+
# add_missing_relevance_scores
|
| 381 |
+
for k in tqdm.tqdm(data, total=len(data)):
|
| 382 |
+
cand_keys = data[k]
|
| 383 |
+
for cand_key in cand_keys:
|
| 384 |
+
if cand_key not in gpt_scores[k]:
|
| 385 |
+
gpt_scores[k][cand_key] = 0.0
|
| 386 |
+
if cand_key == k:
|
| 387 |
+
gpt_scores[k][cand_key] = 1.0
|
| 388 |
+
|
| 389 |
+
return gpt_scores
|
| 390 |
+
# %%
|
| 391 |
+
|
| 392 |
+
def compute_ndcg(ranks, scores, k=3):
|
| 393 |
+
"""
|
| 394 |
+
ranks = [5, 1, 4, 2, 3]
|
| 395 |
+
scores = [0.1, 0.5, 0.3, 0.95, 1.0]
|
| 396 |
+
"""
|
| 397 |
+
rank_score_tuple = list(zip(ranks, scores))
|
| 398 |
+
|
| 399 |
+
top_k = sorted(rank_score_tuple, key=lambda x: x[1], reverse=True)[:k]
|
| 400 |
+
|
| 401 |
+
dcg = sum([score / np.log2(rank + 1) for rank, score in top_k])
|
| 402 |
+
|
| 403 |
+
ideal_dcg = sum([score / np.log2(idx + 2) for idx, (_, score) in enumerate(top_k)])
|
| 404 |
+
|
| 405 |
+
ndcg = dcg / ideal_dcg
|
| 406 |
+
return ndcg
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def compute_ndcg_score_per_mode(pred_rank_dict, gpt_rel_scores, mode='spec', split='test', k=200):
|
| 410 |
+
data = json.load(open(os.path.join(os.environ['ROOT'],f"data/annotations/13_05/anno_random_{split}_{mode}_ids.json")))
|
| 411 |
+
|
| 412 |
+
ndcg_scores = []
|
| 413 |
+
|
| 414 |
+
for key in tqdm.tqdm(pred_rank_dict.keys(), total=len(pred_rank_dict.keys())):
|
| 415 |
+
gpt_scores_for_key = [gpt_rel_scores[key][cand_key] for cand_key in data[key]]
|
| 416 |
+
pred_rank_for_key = pred_rank_dict[key]
|
| 417 |
+
|
| 418 |
+
ndcg_score = compute_ndcg(pred_rank_for_key, gpt_scores_for_key, k=k)
|
| 419 |
+
ndcg_scores.append(ndcg_score)
|
| 420 |
+
|
| 421 |
+
avg_ndcg_score = sum(ndcg_scores) / len(ndcg_scores)
|
| 422 |
+
print(f"Random split, mode={mode} ndcg score: ", avg_ndcg_score)
|
| 423 |
+
return avg_ndcg_score
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# %%
|
| 427 |
+
def main():
|
| 428 |
+
# %%
|
| 429 |
+
## Load Model
|
| 430 |
+
config_path= os.path.join(os.environ['ROOT'],"results/config.yaml")
|
| 431 |
+
model_path= os.path.join(os.environ['ROOT'],"results/model_epoch3.pth")
|
| 432 |
+
# %%
|
| 433 |
+
print("Loading config from:", config_path)
|
| 434 |
+
config = OmegaConf.load(config_path)
|
| 435 |
+
#print(OmegaConf.to_yaml(config))
|
| 436 |
+
# %%
|
| 437 |
+
|
| 438 |
+
# load checkpoint
|
| 439 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
| 440 |
+
print("Loaded model from:", model_path)
|
| 441 |
+
|
| 442 |
+
clip_model, _ = clip.load(config.clip_model, jit=False)
|
| 443 |
+
model = Model(clip_model, config)
|
| 444 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 445 |
+
|
| 446 |
+
model = model.to(config.device)
|
| 447 |
+
model = model.eval()
|
| 448 |
+
model = model.float()
|
| 449 |
+
logit_scale = model.clip_model.logit_scale.exp()
|
| 450 |
+
|
| 451 |
+
image_path = os.path.join(os.environ['ROOT'], "data/eval_test_image.json")
|
| 452 |
+
text_path = os.path.join(os.environ['ROOT'], "data/eval_test_text.json")
|
| 453 |
+
|
| 454 |
+
data_loader_image = get_image_data_loader(config, image_path, split='test', mode='combined' )
|
| 455 |
+
data_loader_text = get_text_data_loader(config, text_path, split='test', mode='combined' )
|
| 456 |
+
|
| 457 |
+
# %%
|
| 458 |
+
key_text_dict = {}
|
| 459 |
+
text_tensor_embedding = None
|
| 460 |
+
with torch.no_grad():
|
| 461 |
+
for i, d in tqdm.tqdm(enumerate(data_loader_text), total=len(data_loader_text)):
|
| 462 |
+
# print("d", d['file_key'])
|
| 463 |
+
|
| 464 |
+
# with torch.amp.autocast(device_type=config.device, dtype=torch.float16):
|
| 465 |
+
text_tensor_out, text_cls_out = model.var_txt_forward(d['caption'].to(config.device))
|
| 466 |
+
#print("text_tensor_out", text_tensor_out[0].shape)
|
| 467 |
+
|
| 468 |
+
if text_tensor_embedding == None:
|
| 469 |
+
text_tensor_embedding = text_cls_out
|
| 470 |
+
else:
|
| 471 |
+
text_tensor_embedding = torch.cat((text_tensor_embedding, text_cls_out), 0)
|
| 472 |
+
|
| 473 |
+
for j,key in enumerate(d['file_key']):
|
| 474 |
+
key_text_dict[key] = int(i*len(d['file_key']) +j)
|
| 475 |
+
|
| 476 |
+
# %%
|
| 477 |
+
key_image_dict = {}
|
| 478 |
+
image_tensor_embedding = None
|
| 479 |
+
with torch.no_grad():
|
| 480 |
+
for i, d in tqdm.tqdm(enumerate(data_loader_image), total=len(data_loader_image)):
|
| 481 |
+
image_tensor_out, img_cls_out = model.var_img_forward(d['image'].to(config.device))
|
| 482 |
+
|
| 483 |
+
if image_tensor_embedding == None:
|
| 484 |
+
image_tensor_embedding = img_cls_out
|
| 485 |
+
else:
|
| 486 |
+
image_tensor_embedding = torch.cat((image_tensor_embedding, img_cls_out), 0)
|
| 487 |
+
|
| 488 |
+
for j,key in enumerate(d['file_key']):
|
| 489 |
+
key_image_dict[key] = int(i*len(d['file_key']) +j)
|
| 490 |
+
|
| 491 |
+
idx2img = {idx: k for idx, k in enumerate(key_image_dict)}
|
| 492 |
+
idx2text = {idx: k for idx, k in enumerate(key_text_dict)}
|
| 493 |
+
# %%
|
| 494 |
+
image_tensor_embedding = image_tensor_embedding.to('cpu')
|
| 495 |
+
text_tensor_embedding = text_tensor_embedding.to('cpu')
|
| 496 |
+
|
| 497 |
+
# %%
|
| 498 |
+
similarity_matrix = pairwise_distances(image_tensor_embedding, text_tensor_embedding, metric='cosine', n_jobs=8)
|
| 499 |
+
|
| 500 |
+
# %%
|
| 501 |
+
results_pair_dict = {}
|
| 502 |
+
## put into matrix
|
| 503 |
+
for i in range (2000):
|
| 504 |
+
for j in range (2000):
|
| 505 |
+
results_pair_dict[str(idx2img[i])+':'+str(idx2text[j])] = float(similarity_matrix[i][j])
|
| 506 |
+
|
| 507 |
+
# %%
|
| 508 |
+
results_pair_dict1 = {}
|
| 509 |
+
results_pair_dict2 = {}
|
| 510 |
+
len_ = int(len(results_pair_dict)/2)
|
| 511 |
+
for j, key in enumerate(results_pair_dict):
|
| 512 |
+
if j <= len_:
|
| 513 |
+
results_pair_dict1[key] = results_pair_dict[key]
|
| 514 |
+
else:
|
| 515 |
+
results_pair_dict2[key] = results_pair_dict[key]
|
| 516 |
+
|
| 517 |
+
# %%
|
| 518 |
+
# with open(os.path.join(os.environ['ROOT'],'results_pair_dict1.json'), 'w', encoding='utf-8') as f:
|
| 519 |
+
# json.dump(results_pair_dict1, f, ensure_ascii=False, indent=4)
|
| 520 |
+
# with open(os.path.join(os.environ['ROOT'],'results_pair_dict2.json'), 'w', encoding='utf-8') as f:
|
| 521 |
+
# json.dump(results_pair_dict2, f, ensure_ascii=False, indent=4)
|
| 522 |
+
df = pd.DataFrame(results_pair_dict1.items(), columns=['key_pair','score'])
|
| 523 |
+
df.to_csv(os.path.join(os.environ['ROOT'],'results_pair_dict1.csv'))
|
| 524 |
+
df = pd.DataFrame(results_pair_dict2.items(), columns=['key_pair','score'])
|
| 525 |
+
df.to_csv(os.path.join(os.environ['ROOT'],'results_pair_dict2.csv'))
|
| 526 |
+
|
| 527 |
+
# %%
|
| 528 |
+
if __name__ == "__main__":
|
| 529 |
+
main()
|
| 530 |
+
|
| 531 |
+
# %%
|
data/chatgpt_similarity_score_test_direct.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/chatgpt_similarity_score_test_indirect.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/eval_test_image.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/eval_test_text.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/images/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
data/key_pair.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/preview_image.jpeg
ADDED
|
data/random_sample_test_direct_ids.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:86c25ef7d12166f6c27fe725004a8857ddfcd4dbc8cfafd14722c189c13efbf5
|
| 3 |
+
size 27110039
|
data/random_sample_test_indirect_ids.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74851fc1e326e6c838b868062949f2c3daa6a1c3c35f6dae87549caf86acd39d
|
| 3 |
+
size 27109999
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
models/__pycache__/fused_model.cpython-38.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
models/fused_model.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.nn import CrossEntropyLoss
|
| 10 |
+
|
| 11 |
+
from transformers.activations import ACT2FN
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from transformers.models.t5.configuration_t5 import T5Config
|
| 14 |
+
from transformers.modeling_utils import ModuleUtilsMixin
|
| 15 |
+
|
| 16 |
+
from einops import rearrange, reduce
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FeedForward(nn.Module):
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: T5Config):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
|
| 24 |
+
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
| 25 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 26 |
+
self.act = ACT2FN["gelu"]
|
| 27 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
x_hidden = self.wo(self.dropout(self.act(self.wi(self.layer_norm(x)))))
|
| 31 |
+
return x + self.dropout(x_hidden)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Attention(nn.Module):
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: T5Config, has_relative_attention_bias=False):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.is_decoder = config.is_decoder
|
| 39 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 40 |
+
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
| 41 |
+
self.relative_attention_max_distance = config.relative_attention_max_distance
|
| 42 |
+
self.d_model = config.d_model
|
| 43 |
+
self.key_value_proj_dim = config.d_kv
|
| 44 |
+
self.n_heads = config.num_heads
|
| 45 |
+
self.dropout = config.dropout_rate
|
| 46 |
+
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
| 47 |
+
|
| 48 |
+
# Mesh TensorFlow initialization to avoid scaling before softmax
|
| 49 |
+
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
| 50 |
+
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
| 51 |
+
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
| 52 |
+
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
| 53 |
+
|
| 54 |
+
if self.has_relative_attention_bias:
|
| 55 |
+
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
| 59 |
+
"""
|
| 60 |
+
memory_position - query_position -> bucket_idx.
|
| 61 |
+
If bidirectional=False, then positive relative positions are invalid.
|
| 62 |
+
We use smaller buckets for small absolute relative_position
|
| 63 |
+
and larger buckets for larger absolute relative_positions.
|
| 64 |
+
* All relative positions >=max_distance map to the same bucket.
|
| 65 |
+
* All relative positions <=-max_distance map to the same bucket.
|
| 66 |
+
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
| 67 |
+
Args:
|
| 68 |
+
relative_position: an int32 Tensor
|
| 69 |
+
bidirectional: a boolean - whether the attention is bidirectional
|
| 70 |
+
num_buckets: an integer
|
| 71 |
+
max_distance: an integer
|
| 72 |
+
Returns:
|
| 73 |
+
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
| 74 |
+
"""
|
| 75 |
+
relative_buckets = 0
|
| 76 |
+
if bidirectional:
|
| 77 |
+
num_buckets //= 2
|
| 78 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
| 79 |
+
relative_position = torch.abs(relative_position)
|
| 80 |
+
else:
|
| 81 |
+
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
| 82 |
+
# now relative_position is in the range [0, inf)
|
| 83 |
+
|
| 84 |
+
# half of the buckets are for exact increments in positions
|
| 85 |
+
max_exact = num_buckets // 2
|
| 86 |
+
is_small = relative_position < max_exact
|
| 87 |
+
|
| 88 |
+
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
| 89 |
+
relative_position_if_large = max_exact + (torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) *
|
| 90 |
+
(num_buckets - max_exact)).to(torch.long)
|
| 91 |
+
relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
|
| 92 |
+
|
| 93 |
+
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
| 94 |
+
return relative_buckets
|
| 95 |
+
|
| 96 |
+
def compute_bias(self, query_length, key_length, device=None):
|
| 97 |
+
"""Compute binned relative position bias"""
|
| 98 |
+
if device is None:
|
| 99 |
+
device = self.relative_attention_bias.weight.device
|
| 100 |
+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
| 101 |
+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
| 102 |
+
relative_position = memory_position - context_position # shape (query_length, key_length)
|
| 103 |
+
relative_position_bucket = self._relative_position_bucket(
|
| 104 |
+
relative_position, # shape (query_length, key_length)
|
| 105 |
+
bidirectional=(not self.is_decoder),
|
| 106 |
+
num_buckets=self.relative_attention_num_buckets,
|
| 107 |
+
max_distance=self.relative_attention_max_distance,
|
| 108 |
+
)
|
| 109 |
+
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
| 110 |
+
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
| 111 |
+
return values
|
| 112 |
+
|
| 113 |
+
def forward(self, x, mask=None, x_kv=None, pos_bias=None):
|
| 114 |
+
"""
|
| 115 |
+
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
| 116 |
+
"""
|
| 117 |
+
# Input is (batch_size, seq_length, dim)
|
| 118 |
+
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
| 119 |
+
batch_size, seq_length = x.shape[:2]
|
| 120 |
+
|
| 121 |
+
real_seq_length = seq_length
|
| 122 |
+
key_length = real_seq_length if x_kv is None else x_kv.shape[1]
|
| 123 |
+
|
| 124 |
+
reshape = lambda states: rearrange(states, 'b s (h d) -> b h s d', h=self.n_heads)
|
| 125 |
+
unshape = lambda states: rearrange(states, 'b h s d -> b s (h d)')
|
| 126 |
+
|
| 127 |
+
q = reshape(self.q(x)) # (batch_size, n_heads, seq_length, dim_per_head)
|
| 128 |
+
k = reshape(self.k(x if x_kv is None else x_kv))
|
| 129 |
+
v = reshape(self.v(x if x_kv is None else x_kv))
|
| 130 |
+
|
| 131 |
+
# compute scores
|
| 132 |
+
scores = torch.matmul(q, k.transpose(3, 2))
|
| 133 |
+
|
| 134 |
+
if pos_bias is None:
|
| 135 |
+
if not self.has_relative_attention_bias:
|
| 136 |
+
pos_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
|
| 137 |
+
else:
|
| 138 |
+
pos_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
| 139 |
+
|
| 140 |
+
if mask is not None:
|
| 141 |
+
pos_bias = pos_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
| 142 |
+
|
| 143 |
+
position_bias_masked = pos_bias
|
| 144 |
+
scores += position_bias_masked
|
| 145 |
+
attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (B, H, seq_length, key_length)
|
| 146 |
+
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) # (B, H, seq_length, key_length)
|
| 147 |
+
|
| 148 |
+
attn_output = unshape(torch.matmul(attn_weights, v)) # (batch_size, seq_length, dim)
|
| 149 |
+
attn_output = self.o(attn_output)
|
| 150 |
+
return (attn_output, pos_bias)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LayerSelfAttention(nn.Module):
|
| 154 |
+
|
| 155 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.SelfAttention = Attention(config, has_relative_attention_bias=has_relative_attention_bias)
|
| 158 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 159 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 160 |
+
|
| 161 |
+
def forward(self, x, mask=None, pos_bias=None): # x + drop(attn(ln(x)))
|
| 162 |
+
h = self.layer_norm(x)
|
| 163 |
+
outputs = self.SelfAttention(h, mask=mask, pos_bias=pos_bias)
|
| 164 |
+
x = x + self.dropout(outputs[0])
|
| 165 |
+
return (x, outputs[1]) # outputs[1] is pos_bias
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class LayerCrossAttention(nn.Module):
|
| 169 |
+
|
| 170 |
+
def __init__(self, config):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.EncDecAttention = Attention(config, has_relative_attention_bias=False)
|
| 173 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 174 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 175 |
+
|
| 176 |
+
def forward(self, x, x_kv, mask=None, pos_bias=None): # x + drop(attn(ln(x), x_kv))
|
| 177 |
+
h = self.layer_norm(x)
|
| 178 |
+
|
| 179 |
+
outputs = self.EncDecAttention(h, mask=mask, x_kv=x_kv, pos_bias=pos_bias)
|
| 180 |
+
x = x + self.dropout(outputs[0])
|
| 181 |
+
return (x, outputs[1]) # outputs[1] is pos_bias
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class Block(nn.Module):
|
| 185 |
+
|
| 186 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.is_decoder = config.is_decoder
|
| 189 |
+
self.layer = nn.ModuleList()
|
| 190 |
+
self.layer.append(LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
| 191 |
+
if self.is_decoder:
|
| 192 |
+
self.layer.append(LayerCrossAttention(config))
|
| 193 |
+
|
| 194 |
+
self.layer.append(FeedForward(config))
|
| 195 |
+
|
| 196 |
+
def forward(self, x, mask=None, pos_bias=None, context=None, context_mask=None, context_pos_bias=None):
|
| 197 |
+
|
| 198 |
+
self_attention_outputs = self.layer[0](x, mask=mask, pos_bias=pos_bias)
|
| 199 |
+
hidden_states = self_attention_outputs[0]
|
| 200 |
+
|
| 201 |
+
do_cross_attention = self.is_decoder and context is not None
|
| 202 |
+
if do_cross_attention:
|
| 203 |
+
|
| 204 |
+
cross_attention_outputs = self.layer[1](
|
| 205 |
+
hidden_states,
|
| 206 |
+
x_kv=context,
|
| 207 |
+
mask=context_mask,
|
| 208 |
+
pos_bias=context_pos_bias,
|
| 209 |
+
)
|
| 210 |
+
hidden_states = cross_attention_outputs[0]
|
| 211 |
+
|
| 212 |
+
# Apply Feed Forward layer
|
| 213 |
+
hidden_states = self.layer[-1](hidden_states)
|
| 214 |
+
|
| 215 |
+
pos_bias = self_attention_outputs[1]
|
| 216 |
+
context_pos_bias = cross_attention_outputs[1] if do_cross_attention else None
|
| 217 |
+
|
| 218 |
+
return (hidden_states, pos_bias, context_pos_bias)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class Stack(nn.Module):
|
| 222 |
+
|
| 223 |
+
def __init__(self, config, is_decoder=True, has_embedding=False, generate_causal_mask=False):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.config = config
|
| 226 |
+
if has_embedding:
|
| 227 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
| 228 |
+
|
| 229 |
+
self.is_decoder = is_decoder
|
| 230 |
+
self.dtype = torch.float32
|
| 231 |
+
|
| 232 |
+
self.generate_causal_mask = generate_causal_mask
|
| 233 |
+
|
| 234 |
+
self.block = nn.ModuleList([Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)])
|
| 235 |
+
self.final_layer_norm = nn.LayerNorm(config.d_model)
|
| 236 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
input_ids=None,
|
| 241 |
+
dec_hidden_states=None,
|
| 242 |
+
enc_hidden_states=None,
|
| 243 |
+
dec_attention_mask=None,
|
| 244 |
+
enc_attention_mask=None,
|
| 245 |
+
):
|
| 246 |
+
input_shape = input_ids.size() if input_ids is not None else dec_hidden_states.shape[:-1]
|
| 247 |
+
batch_size, seq_length = input_shape
|
| 248 |
+
|
| 249 |
+
if input_ids is not None:
|
| 250 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 251 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 252 |
+
else:
|
| 253 |
+
inputs_embeds = dec_hidden_states
|
| 254 |
+
|
| 255 |
+
# required mask seq length can be calculated via length of past
|
| 256 |
+
mask_seq_length = seq_length
|
| 257 |
+
|
| 258 |
+
if dec_attention_mask is None:
|
| 259 |
+
dec_attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
| 260 |
+
if self.is_decoder and enc_attention_mask is None and enc_hidden_states is not None:
|
| 261 |
+
encoder_seq_length = enc_hidden_states.shape[1]
|
| 262 |
+
enc_attention_mask = torch.ones(batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long)
|
| 263 |
+
|
| 264 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 265 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 266 |
+
extended_attention_mask = self.get_extended_attention_mask(dec_attention_mask, input_shape)
|
| 267 |
+
|
| 268 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 269 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 270 |
+
if self.is_decoder and enc_hidden_states is not None:
|
| 271 |
+
encoder_batch_size, encoder_sequence_length, _ = enc_hidden_states.size()
|
| 272 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 273 |
+
if enc_attention_mask is None:
|
| 274 |
+
enc_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
|
| 275 |
+
encoder_extended_attention_mask = self.invert_attention_mask(enc_attention_mask)
|
| 276 |
+
else:
|
| 277 |
+
encoder_extended_attention_mask = None
|
| 278 |
+
|
| 279 |
+
pos_bias = None
|
| 280 |
+
context_pos_bias = None
|
| 281 |
+
|
| 282 |
+
hidden_states = self.dropout(inputs_embeds)
|
| 283 |
+
|
| 284 |
+
for i, layer_module in enumerate(self.block):
|
| 285 |
+
|
| 286 |
+
layer_outputs = layer_module(
|
| 287 |
+
hidden_states,
|
| 288 |
+
mask=extended_attention_mask, # [1, 1, 1, 1 ] [B, L]
|
| 289 |
+
pos_bias=pos_bias,
|
| 290 |
+
context=enc_hidden_states,
|
| 291 |
+
context_mask=encoder_extended_attention_mask,
|
| 292 |
+
context_pos_bias=context_pos_bias,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# layer_outputs is a tuple with:
|
| 296 |
+
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
| 297 |
+
|
| 298 |
+
hidden_states, present_key_value_state = layer_outputs[:2] # [B, L, D], None
|
| 299 |
+
|
| 300 |
+
# We share the position biases between the layers - the first layer store them
|
| 301 |
+
pos_bias = layer_outputs[2] # [B, H, L, L]
|
| 302 |
+
if self.is_decoder and enc_hidden_states is not None:
|
| 303 |
+
context_pos_bias = layer_outputs[3]
|
| 304 |
+
|
| 305 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 306 |
+
hidden_states = self.dropout(hidden_states)
|
| 307 |
+
|
| 308 |
+
return (hidden_states,)
|
| 309 |
+
|
| 310 |
+
def invert_attention_mask(self, attention_mask):
|
| 311 |
+
"""
|
| 312 |
+
Input: 1 for attend, 0 for masked/ignored
|
| 313 |
+
Output: 0 for attend, -1e30 for masked/ignored.
|
| 314 |
+
Then we can add it to the attention logits.
|
| 315 |
+
[B, L] -> [B, 1, 1, L]
|
| 316 |
+
[B, L, L] -> [B, 1, L, L]
|
| 317 |
+
"""
|
| 318 |
+
if attention_mask.dim() == 3:
|
| 319 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 320 |
+
if attention_mask.dim() == 2:
|
| 321 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 322 |
+
|
| 323 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 324 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
|
| 325 |
+
|
| 326 |
+
return extended_attention_mask
|
| 327 |
+
|
| 328 |
+
def get_extended_attention_mask(self, attention_mask, input_shape, device=None, dtype=None):
|
| 329 |
+
"""
|
| 330 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 331 |
+
attention_mask: 1 for attend, 0 for masked/ignored
|
| 332 |
+
Return: The extended attention mask: 0 for attend, -1e30 for masked/ignored
|
| 333 |
+
[B, L] -> [B, 1, 1, L]
|
| 334 |
+
[B, L, L] -> [B, 1, L, L]
|
| 335 |
+
"""
|
| 336 |
+
dtype = dtype if dtype else attention_mask.dtype
|
| 337 |
+
|
| 338 |
+
# If input [B, query_length, key_length] -> [B, 1, query_length, key_length]
|
| 339 |
+
if attention_mask.dim() == 3:
|
| 340 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 341 |
+
elif attention_mask.dim() == 2:
|
| 342 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 343 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 344 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 345 |
+
if self.config.is_decoder and self.generate_causal_mask:
|
| 346 |
+
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(input_shape, attention_mask, device)
|
| 347 |
+
else:
|
| 348 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 349 |
+
else:
|
| 350 |
+
raise ValueError(f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})")
|
| 351 |
+
|
| 352 |
+
# Input: valid = 1, padding = 0
|
| 353 |
+
# Output: valid = 0, padding = -1e30
|
| 354 |
+
# => then we can add it to the attention logits
|
| 355 |
+
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
| 356 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
|
| 357 |
+
return extended_attention_mask
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class Model(torch.nn.Module):
|
| 361 |
+
|
| 362 |
+
def __init__(self, clip_model, config):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.clip_model = clip_model
|
| 365 |
+
self.config = config
|
| 366 |
+
|
| 367 |
+
if self.config.has_extra_txt_decoder:
|
| 368 |
+
self.txt_decoder = Stack(config.extra_decoder)
|
| 369 |
+
self.itm_txt_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
|
| 370 |
+
|
| 371 |
+
if self.config.has_extra_img_decoder:
|
| 372 |
+
self.img_decoder = Stack(config.extra_decoder)
|
| 373 |
+
self.itm_img_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
|
| 374 |
+
|
| 375 |
+
if self.config.has_extra_mix_decoder:
|
| 376 |
+
self.mix_decoder = Stack(config.extra_decoder)
|
| 377 |
+
self.mix_itm_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
|
| 378 |
+
|
| 379 |
+
if self.config.has_extra_gen_decoder:
|
| 380 |
+
self.gen_decoder = Stack(config.extra_decoder, has_embedding=True, generate_causal_mask=True)
|
| 381 |
+
self.gen_head = torch.nn.Linear(config.extra_decoder.d_model, config.vocab_size)
|
| 382 |
+
|
| 383 |
+
self.config = config
|
| 384 |
+
|
| 385 |
+
def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224]
|
| 386 |
+
x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid]
|
| 387 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 388 |
+
x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width]
|
| 389 |
+
x = torch.cat(
|
| 390 |
+
[self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
|
| 391 |
+
dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 392 |
+
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
|
| 393 |
+
x = self.clip_model.visual.ln_pre(x)
|
| 394 |
+
|
| 395 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 396 |
+
x = self.clip_model.visual.transformer(x)
|
| 397 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 398 |
+
x = self.clip_model.visual.ln_post(x) # [NLD]
|
| 399 |
+
|
| 400 |
+
if self.clip_model.visual.proj is not None:
|
| 401 |
+
proj = self.clip_model.visual.proj[None, :, :]
|
| 402 |
+
x = (x @ proj)
|
| 403 |
+
|
| 404 |
+
cls_token = x[:, 0, :]
|
| 405 |
+
return x, cls_token
|
| 406 |
+
|
| 407 |
+
def txt_forward(self, text):
|
| 408 |
+
dtype = self.clip_model.dtype
|
| 409 |
+
x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model]
|
| 410 |
+
|
| 411 |
+
x = x + self.clip_model.positional_embedding.type(dtype)
|
| 412 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 413 |
+
x = self.clip_model.transformer(x)
|
| 414 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 415 |
+
x = self.clip_model.ln_final(x).type(dtype)
|
| 416 |
+
|
| 417 |
+
proj = self.clip_model.text_projection[None, :, :]
|
| 418 |
+
x = (x @ proj)
|
| 419 |
+
|
| 420 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 421 |
+
eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
|
| 422 |
+
return x, eot # [NLD]
|
| 423 |
+
|
| 424 |
+
def var_img_forward(self, image):
|
| 425 |
+
if len(image.shape) == 5:
|
| 426 |
+
img_features1, img_token1 = self.img_forward(image[:, 0, ...])
|
| 427 |
+
img_features2, img_token2 = self.img_forward(image[:, 1, ...])
|
| 428 |
+
img_token = (img_token1 + img_token2) / 2
|
| 429 |
+
img_features = (img_features1 + img_features2) / 2
|
| 430 |
+
else:
|
| 431 |
+
img_features, img_token = self.img_forward(image)
|
| 432 |
+
img_token = img_token / img_token.norm(dim=-1, keepdim=True)
|
| 433 |
+
return img_features, img_token
|
| 434 |
+
|
| 435 |
+
def var_txt_forward(self, text):
|
| 436 |
+
txt_features, txt_token = self.txt_forward(text)
|
| 437 |
+
txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True)
|
| 438 |
+
return txt_features, txt_token
|
| 439 |
+
|
| 440 |
+
def get_device(self):
|
| 441 |
+
return next(self.parameters()).device
|
| 442 |
+
|
| 443 |
+
def get_features(self, image=None, text_ids=None):
|
| 444 |
+
outputs = {}
|
| 445 |
+
if image is not None:
|
| 446 |
+
img_features, img_token = self.var_img_forward(image)
|
| 447 |
+
outputs['img_features'] = img_features
|
| 448 |
+
outputs['img_token'] = img_token
|
| 449 |
+
outputs['img_mask'] = torch.ones_like(img_features[:, :, 0])
|
| 450 |
+
if text_ids is not None:
|
| 451 |
+
txt_features, txt_token = self.var_txt_forward(text_ids)
|
| 452 |
+
outputs['txt_features'] = txt_features
|
| 453 |
+
outputs['txt_token'] = txt_token
|
| 454 |
+
outputs['txt_mask'] = (text_ids != 0).to(txt_features.dtype)
|
| 455 |
+
return outputs
|
| 456 |
+
|
| 457 |
+
def get_prediction(self, img_features, txt_features, img_mask=None, txt_mask=None, decoder="txt_decoder", **kwargs):
|
| 458 |
+
outputs = {}
|
| 459 |
+
if decoder == 'txt_decoder':
|
| 460 |
+
hidden_states = self.txt_decoder(
|
| 461 |
+
dec_hidden_states=txt_features,
|
| 462 |
+
enc_hidden_states=img_features,
|
| 463 |
+
enc_attention_mask=img_mask,
|
| 464 |
+
dec_attention_mask=txt_mask,
|
| 465 |
+
)
|
| 466 |
+
outputs['itm_txt_logits'] = self.itm_txt_head(hidden_states[0][:, 0, :])
|
| 467 |
+
outputs['itm_txt_probs'] = torch.softmax(outputs['itm_txt_logits'], dim=-1)
|
| 468 |
+
|
| 469 |
+
if decoder == 'img_decoder':
|
| 470 |
+
hidden_states = self.img_decoder(
|
| 471 |
+
dec_hidden_states=img_features,
|
| 472 |
+
enc_hidden_states=txt_features,
|
| 473 |
+
enc_attention_mask=txt_mask,
|
| 474 |
+
dec_attention_mask=img_mask,
|
| 475 |
+
)
|
| 476 |
+
outputs['itm_img_logits'] = self.itm_img_head(hidden_states[0][:, 0, :])
|
| 477 |
+
outputs['itm_img_probs'] = torch.softmax(outputs['itm_img_logits'], dim=-1)
|
| 478 |
+
return outputs
|
| 479 |
+
|
| 480 |
+
def forward(self, image, text, itm_text=None, itm_labels=None, gen_inputs=None, gen_labels=None): # , gen_inputs, gen_labels, **kwargs):
|
| 481 |
+
img_features, img_token = self.var_img_forward(image)
|
| 482 |
+
txt_features, txt_token = self.var_txt_forward(text)
|
| 483 |
+
|
| 484 |
+
itm_txt_features, _ = self.var_txt_forward(itm_text)
|
| 485 |
+
itm_txt_mask = (itm_text != 0).to(itm_txt_features.dtype)
|
| 486 |
+
|
| 487 |
+
outputs = dict(
|
| 488 |
+
img_token=img_token,
|
| 489 |
+
txt_token=txt_token,
|
| 490 |
+
img_features=img_features,
|
| 491 |
+
txt_features=txt_features,
|
| 492 |
+
)
|
| 493 |
+
if self.config.has_extra_txt_decoder and itm_text is not None:
|
| 494 |
+
itm_img_features = img_features
|
| 495 |
+
itm_txt_states = self.txt_decoder(
|
| 496 |
+
dec_hidden_states=itm_txt_features,
|
| 497 |
+
enc_hidden_states=itm_img_features,
|
| 498 |
+
enc_attention_mask=None,
|
| 499 |
+
dec_attention_mask=itm_txt_mask,
|
| 500 |
+
)
|
| 501 |
+
outputs['itm_txt_logits'] = self.itm_txt_head(itm_txt_states[0][:, 0])
|
| 502 |
+
|
| 503 |
+
if self.config.has_extra_img_decoder and itm_text is not None:
|
| 504 |
+
itm_img_features = img_features
|
| 505 |
+
itm_img_states = self.img_decoder(
|
| 506 |
+
dec_hidden_states=itm_img_features,
|
| 507 |
+
enc_hidden_states=itm_txt_features,
|
| 508 |
+
enc_attention_mask=itm_txt_mask,
|
| 509 |
+
dec_attention_mask=None,
|
| 510 |
+
)
|
| 511 |
+
outputs['itm_img_logits'] = self.itm_img_head(itm_img_states[0][:, 0])
|
| 512 |
+
|
| 513 |
+
if self.config.has_extra_mix_decoder:
|
| 514 |
+
pass
|
| 515 |
+
|
| 516 |
+
if self.config.has_extra_gen_decoder:
|
| 517 |
+
gen_features = self.gen_decoder(
|
| 518 |
+
input_ids=gen_inputs,
|
| 519 |
+
enc_hidden_states=img_features,
|
| 520 |
+
enc_attention_mask=None,
|
| 521 |
+
dec_attention_mask=None,
|
| 522 |
+
labels=gen_labels,
|
| 523 |
+
)
|
| 524 |
+
outputs['gen_logits'] = self.gen_head(gen_features[0])
|
| 525 |
+
|
| 526 |
+
return outputs
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
if __name__ == "__main__":
|
| 530 |
+
import sys
|
| 531 |
+
from omegaconf import OmegaConf
|
| 532 |
+
sys.path.append("/home/quang/workspace/traffic_var")
|
| 533 |
+
from config.examples import with_decoder_config as config
|
| 534 |
+
config.has_extra_txt_decoder = True
|
| 535 |
+
print(OmegaConf.to_yaml(config))
|
| 536 |
+
|
| 537 |
+
import clip
|
| 538 |
+
|
| 539 |
+
def get_resolution(model):
|
| 540 |
+
return model.visual.input_resolution if hasattr(model, 'visual') else model.input_resolution
|
| 541 |
+
|
| 542 |
+
model, _ = clip.load(config.clip_model, jit=False, device="cpu")
|
| 543 |
+
config.img_size = get_resolution(model)
|
| 544 |
+
model = Model(model, config)
|
models/model.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from easydict import EasyDict as edict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Block(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, config):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.config = config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Model(torch.nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(self, clip_model, config):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.clip_model = clip_model
|
| 18 |
+
# if config.i2t_encoder_layers > 0:
|
| 19 |
+
# self.i2t_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])
|
| 20 |
+
|
| 21 |
+
# if config.t2i_encoder_layers > 0:
|
| 22 |
+
# self.t2i_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])
|
| 23 |
+
|
| 24 |
+
self.config = config
|
| 25 |
+
|
| 26 |
+
def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224]
|
| 27 |
+
x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid]
|
| 28 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 29 |
+
x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width]
|
| 30 |
+
x = torch.cat(
|
| 31 |
+
[self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
|
| 32 |
+
dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 33 |
+
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
|
| 34 |
+
x = self.clip_model.visual.ln_pre(x)
|
| 35 |
+
|
| 36 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 37 |
+
x = self.clip_model.visual.transformer(x)
|
| 38 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 39 |
+
x = self.clip_model.visual.ln_post(x) # [NLD]
|
| 40 |
+
cls_token = self.clip_model.visual.ln_post(x[:, 0, :])
|
| 41 |
+
|
| 42 |
+
if self.clip_model.visual.proj is not None:
|
| 43 |
+
cls_token = cls_token @ self.clip_model.visual.proj
|
| 44 |
+
return x, cls_token
|
| 45 |
+
|
| 46 |
+
def txt_forward(self, text):
|
| 47 |
+
dtype = self.clip_model.dtype
|
| 48 |
+
x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model]
|
| 49 |
+
|
| 50 |
+
x = x + self.clip_model.positional_embedding.type(dtype)
|
| 51 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 52 |
+
x = self.clip_model.transformer(x)
|
| 53 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 54 |
+
x = self.clip_model.ln_final(x).type(dtype)
|
| 55 |
+
|
| 56 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 57 |
+
eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.clip_model.text_projection
|
| 58 |
+
return x, eot # [NLD]
|
| 59 |
+
|
| 60 |
+
def var_img_forward(self, image):
|
| 61 |
+
if len(image.shape) == 5:
|
| 62 |
+
img_features1, img_token1 = self.img_forward(image[:, 0, ...])
|
| 63 |
+
img_features2, img_token2 = self.img_forward(image[:, 1, ...])
|
| 64 |
+
img_token = (img_token1 + img_token2) / 2
|
| 65 |
+
img_features = (img_features1 + img_features2) / 2
|
| 66 |
+
else:
|
| 67 |
+
img_features, img_token = self.img_forward(image)
|
| 68 |
+
img_token = img_token / img_token.norm(dim=-1, keepdim=True)
|
| 69 |
+
return img_features, img_token
|
| 70 |
+
|
| 71 |
+
def var_txt_forward(self, text):
|
| 72 |
+
txt_features, txt_token = self.txt_forward(text)
|
| 73 |
+
txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True)
|
| 74 |
+
return txt_features, txt_token
|
| 75 |
+
|
| 76 |
+
def forward(self, image, text, past_img_tokens=None, past_txt_tokens=None):
|
| 77 |
+
# TODO: aggregate past img and txt tokens
|
| 78 |
+
img_features, img_token = self.var_img_forward(image)
|
| 79 |
+
txt_features, txt_token = self.var_txt_forward(text)
|
| 80 |
+
logit_scale = self.clip_model.logit_scale.exp()
|
| 81 |
+
|
| 82 |
+
if past_img_tokens is not None:
|
| 83 |
+
past_img_tokens = torch.cat([past_img_tokens, img_token], dim=0)
|
| 84 |
+
past_txt_tokens = torch.cat([past_txt_tokens, txt_token], dim=0)
|
| 85 |
+
|
| 86 |
+
batch_size = past_img_tokens.shape[0]
|
| 87 |
+
ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)
|
| 88 |
+
|
| 89 |
+
logits_for_imgs = logit_scale * past_img_tokens @ past_txt_tokens.t()
|
| 90 |
+
logits_for_txts = logits_for_imgs.t()
|
| 91 |
+
# print(f"past_img_tokens: {past_img_tokens.shape}, past_txt_tokens: {past_txt_tokens.shape}")
|
| 92 |
+
|
| 93 |
+
# CLIP Contrastive Learning Loss Function
|
| 94 |
+
loss_img = torch.nn.CrossEntropyLoss()
|
| 95 |
+
loss_txt = torch.nn.CrossEntropyLoss()
|
| 96 |
+
loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2
|
| 97 |
+
else:
|
| 98 |
+
batch_size = img_token.shape[0]
|
| 99 |
+
ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)
|
| 100 |
+
|
| 101 |
+
logits_for_imgs = logit_scale * img_token @ txt_token.t()
|
| 102 |
+
logits_for_txts = logits_for_imgs.t()
|
| 103 |
+
|
| 104 |
+
# CLIP Contrastive Learning Loss Function
|
| 105 |
+
loss_img = torch.nn.CrossEntropyLoss()
|
| 106 |
+
loss_txt = torch.nn.CrossEntropyLoss()
|
| 107 |
+
loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2
|
| 108 |
+
|
| 109 |
+
return dict(
|
| 110 |
+
img_token=img_token,
|
| 111 |
+
txt_token=txt_token,
|
| 112 |
+
img_features=img_features,
|
| 113 |
+
txt_features=txt_features,
|
| 114 |
+
loss=loss,
|
| 115 |
+
past_img_tokens=past_img_tokens,
|
| 116 |
+
past_txt_tokens=past_txt_tokens,
|
| 117 |
+
)
|
results/config.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_name: expl5-3a-bs16-imgtxtdec-ema-aux1
|
| 2 |
+
wandb: 90788c79e1500570b08e5acf283e17df7e0c54b2
|
| 3 |
+
root: ${oc.env:DATA_ROOT}
|
| 4 |
+
overfit: false
|
| 5 |
+
batch_size: 16
|
| 6 |
+
num_workers: 4
|
| 7 |
+
img_size: 336
|
| 8 |
+
rationale_type: 0
|
| 9 |
+
val_rationale_type: 0
|
| 10 |
+
hide_true_bbox: 8
|
| 11 |
+
widescreen_processing: 1
|
| 12 |
+
h_flip: false
|
| 13 |
+
ema_decay: 0.9999
|
| 14 |
+
aux_weight: 1.0
|
| 15 |
+
no_hard_negative_itm: false
|
| 16 |
+
clip_model: ViT-L/14@336px
|
| 17 |
+
has_extra_txt_decoder: true
|
| 18 |
+
has_extra_img_decoder: true
|
| 19 |
+
has_extra_mix_decoder: false
|
| 20 |
+
has_extra_gen_decoder: false
|
| 21 |
+
extra_decoder:
|
| 22 |
+
is_decoder: true
|
| 23 |
+
vocab_size: 1000
|
| 24 |
+
d_ff: 768
|
| 25 |
+
d_kv: 64
|
| 26 |
+
d_model: 768
|
| 27 |
+
dropout_rate: 0.1
|
| 28 |
+
num_heads: 12
|
| 29 |
+
num_layers: 2
|
| 30 |
+
relative_attention_max_distance: 128
|
| 31 |
+
relative_attention_num_buckets: 32
|
| 32 |
+
warmup: 1000
|
| 33 |
+
init_from: ''
|
| 34 |
+
lr: 1.0e-05
|
| 35 |
+
n_epochs: 15
|
| 36 |
+
save_every: 0
|
| 37 |
+
early_stop: 5
|
| 38 |
+
val_stat: loss
|
| 39 |
+
device: cuda
|
| 40 |
+
use_multi: false
|
| 41 |
+
local_rank: 0
|
| 42 |
+
run_file: /home/acd13872jh/workspace/traffic_var/results/.runinfo/.sh
|
results/model_epoch3.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a45011f16c569075fcfab80c3f0e6df273a825dbad2ebeca65a438a57a95977f
|
| 3 |
+
size 251753353
|
results/results_pair_dict.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d3fa8165cf47a2562a3aa534880df625caf3705a59b5f5b221ba40f37a524fb
|
| 3 |
+
size 303815155
|
results_pair_dict1.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a469b8b91ebec66d55cc3389731e3e9f7e11111f112f13a06ea2e668f573309
|
| 3 |
+
size 151976036
|
results_pair_dict2.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59e06a5e308a71fd97b3f2f87439fde8fbac2c2834896a3ac5c54868ec9c480f
|
| 3 |
+
size 151839121
|