DHPR commited on
Commit
f638d9c
·
1 Parent(s): df3ffa8

Upload 25 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ data/random_sample_test_direct_ids.json filter=lfs diff=lfs merge=lfs -text
36
+ data/random_sample_test_indirect_ids.json filter=lfs diff=lfs merge=lfs -text
37
+ results_pair_dict1.json filter=lfs diff=lfs merge=lfs -text
38
+ results_pair_dict2.json filter=lfs diff=lfs merge=lfs -text
39
+ results/results_pair_dict.json filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision
5
+ from torchvision import transforms
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ import json
8
+ from collections import defaultdict
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import random
12
+ from os import listdir
13
+ from os.path import isfile, join
14
+ from torchvision.io import read_image
15
+ from torchvision.utils import draw_bounding_boxes
16
+ from PIL import Image
17
+ import os
18
+ from scipy.stats import rankdata
19
+ import tqdm
20
+ import streamlit as st
21
+ import pandas as pd
22
+
23
+ # %%
24
+ def load_json(PATH):
25
+ if os.path.isfile(PATH) and os.access(PATH, os.R_OK):
26
+ with open(PATH) as json_file:
27
+ dict_data = json.load(json_file)
28
+ else:
29
+ print("The Path of", PATH,"is not exist")
30
+ dict_data = {}
31
+ return dict_data
32
+
33
+ def get_list_folder(PATH):
34
+ return [name for name in os.listdir(PATH) if os.path.isdir(os.path.join(PATH, name))]
35
+
36
+ def get_file_only(PATH):
37
+ return [f for f in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, f))]
38
+
39
+ # %%
40
+ def compute_ndcg(ranks, scores, k=3):
41
+ """
42
+ ranks = [5, 1, 4, 2, 3]
43
+ scores = [0.1, 0.5, 0.3, 0.95, 1.0]
44
+ """
45
+ rank_score_tuple = list(zip(ranks, scores))
46
+
47
+ top_k = sorted(rank_score_tuple, key=lambda x: x[1], reverse=True)[:k]
48
+
49
+ dcg = sum([score / np.log2(rank + 1) for rank, score in top_k])
50
+
51
+ ideal_dcg = sum([score / np.log2(idx + 2) for idx, (_, score) in enumerate(top_k)])
52
+
53
+ ndcg = dcg / ideal_dcg
54
+ return ndcg
55
+
56
+ def compute_ndcg_score_per_mode(pred_rank_dict, gpt_rel_scores, random_sample_dict, mode='indrect', split='test', k=200):
57
+ ndcg_scores = []
58
+
59
+ for key in tqdm.tqdm(pred_rank_dict.keys(), total=len(pred_rank_dict.keys())):
60
+ gpt_scores_for_key = [gpt_rel_scores[key][cand_key] if cand_key in gpt_rel_scores[key] else 0.0 for cand_key in random_sample_dict[key]]
61
+
62
+ pred_rank_for_key = pred_rank_dict[key]
63
+
64
+ ndcg_score = compute_ndcg(pred_rank_for_key, gpt_scores_for_key, k=k)
65
+ ndcg_scores.append(ndcg_score)
66
+
67
+ avg_ndcg_score = sum(ndcg_scores) / len(ndcg_scores)
68
+ print(f"Random split, mode={mode} ndcg score: ", avg_ndcg_score)
69
+ return avg_ndcg_score
70
+ # %%
71
+
72
+ def get_score_direct(random_sample_pair_test_direct, predictions, key_pair, similarity_score_test_direct, k = 200):
73
+ mode = 'direct'
74
+ i2t_ranks = []
75
+ t2i_ranks = []
76
+ i2t_rank_dict = {}
77
+ results_dict = {}
78
+ key_pair_reversed = {v: k for k, v in key_pair.items()}
79
+
80
+
81
+ for file_key in tqdm.tqdm(random_sample_pair_test_direct.keys(), total=len(random_sample_pair_test_direct.keys())):
82
+ i2t_rank = rankdata([predictions[str(file_key)+':'+str(key_pair[k])] for k in random_sample_pair_test_direct[file_key]])
83
+ t2i_rank = rankdata([predictions[str(key_pair_reversed[key_pair[k]])+':'+str(key_pair[file_key])] for k in random_sample_pair_test_direct[file_key]])
84
+
85
+ i2t_ranks.append(i2t_rank[-1])
86
+ t2i_ranks.append(t2i_rank[-1])
87
+ i2t_rank_dict[file_key] = i2t_rank
88
+
89
+ assert len(i2t_ranks) == len(t2i_ranks) == 1000
90
+
91
+ ndcg_score = compute_ndcg_score_per_mode(i2t_rank_dict, similarity_score_test_direct, random_sample_pair_test_direct, mode='indirect', split='test', k=200)
92
+
93
+ results_dict['direct'] = {}
94
+ results_dict['direct']['i2t rank'] = float(sum(i2t_ranks) / len(i2t_ranks))
95
+ results_dict['direct']['t2i rank'] = float(sum(t2i_ranks) / len(t2i_ranks))
96
+ results_dict['direct']['ndcg score'] = float(ndcg_score)
97
+ print(f"Random split, mode={mode} i2t rank: ", sum(i2t_ranks) / len(i2t_ranks))
98
+ print(f"Random split, mode={mode} t2i rank: ", sum(t2i_ranks) / len(t2i_ranks))
99
+ return results_dict
100
+
101
+ # %%
102
+
103
+ def get_score_indirect(random_sample_pair_test_indirect, predictions, key_pair, similarity_score_test_indirect, k = 200):
104
+ mode = 'indirect'
105
+ i2t_ranks = []
106
+ t2i_ranks = []
107
+ i2t_rank_dict = {}
108
+ results_dict = {}
109
+ key_pair_reversed = {v: k for k, v in key_pair.items()}
110
+
111
+
112
+ for file_key in tqdm.tqdm(random_sample_pair_test_indirect.keys(), total=len(random_sample_pair_test_indirect.keys())):
113
+ i2t_rank = rankdata([predictions[str(file_key)+':'+str(key_pair[k])] for k in random_sample_pair_test_indirect[file_key]])
114
+ t2i_rank = rankdata([predictions[str(key_pair_reversed[key_pair[k]])+':'+str(key_pair[file_key])] for k in random_sample_pair_test_indirect[file_key]])
115
+
116
+ i2t_ranks.append(i2t_rank[-1])
117
+ t2i_ranks.append(t2i_rank[-1])
118
+ i2t_rank_dict[file_key] = i2t_rank
119
+
120
+ assert len(i2t_ranks) == len(t2i_ranks) == 1000
121
+
122
+ ndcg_score = compute_ndcg_score_per_mode(i2t_rank_dict, similarity_score_test_indirect, random_sample_pair_test_indirect, mode='indrect', split='test', k=200)
123
+
124
+ results_dict['indirect'] = {}
125
+ results_dict['indirect']['i2t rank'] = float(sum(i2t_ranks) / len(i2t_ranks))
126
+ results_dict['indirect']['t2i rank'] = float(sum(t2i_ranks) / len(t2i_ranks))
127
+ results_dict['indirect']['ndcg score'] = float(ndcg_score)
128
+ print(f"Random split, mode={mode} i2t rank: ", sum(i2t_ranks) / len(i2t_ranks))
129
+ print(f"Random split, mode={mode} t2i rank: ", sum(t2i_ranks) / len(t2i_ranks))
130
+ return results_dict
131
+
132
+ # %%
133
+ def main(json_file):
134
+ ### Setup
135
+ # os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))
136
+
137
+ # %%
138
+
139
+ ### Load data
140
+ # if os.path.isfile(os.path.join(os.environ['ROOT'], json_file)): #'results_pair_dict.json')):
141
+ # predictions_file_path = os.path.join(os.environ['ROOT'], json_file) #'results_pair_dict.json')
142
+ # else:
143
+ # predictions_file_path = os.path.join(os.environ['ROOT'], json_file) #'data/results_pair_dict.json')
144
+
145
+ # with open(predictions_file_path) as f:
146
+ # predictions = json.load(f)
147
+ predictions = json_file
148
+
149
+ with open(os.path.join(os.environ['ROOT'], 'data/key_pair.json')) as f:
150
+ key_pair = json.load(f)
151
+
152
+ # key_pair_reversed = {v: k for k, v in key_pair.items()}
153
+
154
+ with open(os.path.join(os.environ['ROOT'], 'data/random_sample_test_direct_ids.json')) as f:
155
+ random_sample_pair_test_direct = json.load(f)
156
+
157
+ with open(os.path.join(os.environ['ROOT'], 'data/random_sample_test_indirect_ids.json')) as f:
158
+ random_sample_pair_test_indirect = json.load(f)
159
+
160
+ with open(os.path.join(os.environ['ROOT'], 'data/chatgpt_similarity_score_test_direct.json')) as f:
161
+ similarity_score_test_direct = json.load(f)
162
+
163
+ with open(os.path.join(os.environ['ROOT'], 'data/chatgpt_similarity_score_test_indirect.json')) as f:
164
+ similarity_score_test_indirect = json.load(f)
165
+
166
+ # %%
167
+ ### Compute scores
168
+
169
+ result_direct = get_score_direct(random_sample_pair_test_direct, predictions, key_pair, similarity_score_test_direct, k = 200)
170
+ result_indirect = get_score_indirect(random_sample_pair_test_indirect, predictions, key_pair, similarity_score_test_indirect, k = 200)
171
+ result_dict = {**result_direct, **result_indirect}
172
+ return result_dict
173
+ # %%
174
+ if __name__ == '__main__':
175
+ os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))
176
+ st.title("Evaluation Server for Driving Hazard Prediction and Reasoning ")
177
+ st.image(os.path.join(os.environ['ROOT'],'data/preview_image.jpeg'))
178
+ st.divider()
179
+ result_text = ''
180
+ result_dict = {}
181
+ uploaded_files = None
182
+ json_file = None
183
+
184
+ uploaded_files = st.file_uploader("Upload All Result Files Here (results_pair_dict1.json, results_pair_dict2.json)", type=["csv"], accept_multiple_files=True)
185
+ if st.button('Run Evaluation with no upload files (using demo files)'):
186
+ json_file1 = load_json(os.path.join(os.environ['ROOT'], 'results_pair_dict1.json'))
187
+ json_file2 = load_json(os.path.join(os.environ['ROOT'], 'results_pair_dict2.json'))
188
+ json_file = {**json_file1, **json_file2}
189
+ dataframe = pd.DataFrame([])
190
+ if uploaded_files is not None:
191
+ for i in range(len(uploaded_files)):
192
+ dataframe = pd.concat([dataframe, pd.read_csv(uploaded_files[i])])
193
+
194
+ result = dataframe.to_dict('tight')['data']
195
+ json_file = {}
196
+ for i in range(len(result)):
197
+ json_file[str(result[i][1])] = float(result[i][2])
198
+
199
+ if len(json_file) >= 1:
200
+ result_dict = main(json_file)
201
+ result_text = json.dumps(result_dict)
202
+
203
+ st.download_button('Download Results', result_text)
204
+ st.json(result_dict)
205
+
206
+ # !streamlit run app.py --server.fileWatcherType none
207
+
208
+ # if st.button('Load Results File1 from local instead'):
209
+ # json_file1_path = os.path.join(os.environ['ROOT'], 'results_pair_dict1.json')
210
+ # json_file1 = load_json(json_file1_path)
211
+ # st.write(json_file1)
212
+
213
+ # if uploaded_files is not None:
214
+ # with open(uploaded_file1) as jf:
215
+ # json_file1 = json.load(jf)
216
+ # json_file1 = load_json(uploaded_file1)
217
+ # uploaded1 = True
218
+
219
+ # uploaded_file2 = st.file_uploader("Upload Results File2")
220
+ # if st.button('Load Results File2 from local instead'):
221
+ # json_file2_path = os.path.join(os.environ['ROOT'], 'results_pair_dict2.json')
222
+ # json_file2 = load_json(json_file2_path)
223
+ # st.write(json_file2)
224
+
225
+ # if uploaded_file2 is not None:
226
+ # with open(uploaded_file2) as jf:
227
+ # json_file2 = json.load(jf)
228
+ # uploaded2 = True
229
+
230
+ # # if uploaded1 and uploaded2:
231
+ # json_file = {**json_file1, **json_file2}
232
+
233
+
234
+
235
+ # %%
config/__init__.py ADDED
File without changes
config/examples.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+
5
+ from omegaconf import OmegaConf
6
+ from hydra.core.global_hydra import GlobalHydra
7
+ from hydra import initialize, initialize_config_module, initialize_config_dir, compose
8
+
9
+ os.environ['ROOT'] = "/home/quang/workspace/traffic_var"
10
+ os.environ['DATA_ROOT'] = "/home/quang/datasets/traffic_var"
11
+
12
+ # initialize hydra config
13
+ GlobalHydra.instance().clear()
14
+ initialize(config_path="./")
15
+ with_decoder_config = compose(config_name='with_decoder.yaml', overrides=["clip_model=ViT-B/16", "rationale_type=0", "val_rationale_type=0"])
16
+ original_config = compose(config_name='experiment.yaml', overrides=["clip_model=ViT-B/16", "rationale_type=0", "val_rationale_type=0"])
config/experiment.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.yaml
2
+
3
+ exp_name: exp1-2
4
+ wandb: 90788c79e1500570b08e5acf283e17df7e0c54b2
5
+ root: "${oc.env:DATA_ROOT}"
6
+ overfit: False
7
+ batch_size: 4
8
+ num_workers: 4
9
+ img_size: 224
10
+ rationale_type: 0 # 0 - only rationale, 1 - randomly add entity desc, 2 - add entity desc
11
+ val_rationale_type: 0
12
+ hide_true_bbox: 8 # clues and inferences selected randomly
13
+ widescreen_processing: 1 # 0 - no widescreen, 1 - widescreen
14
+ h_flip: False
15
+ ema_decay: 0.9999
16
+
17
+ clip_model: 'ViT-B/16' # 'RN101' 'RN50x4''RN50x16' 'RN50x64' 'ViT-L/14@336px' 'ViT-B/32'
18
+ num_layers: 3
19
+ dim_hidden: 512
20
+
21
+
22
+ warmup: 1000
23
+ init_from: ''
24
+ lr: .00001
25
+ n_epochs: 15
26
+ save_every: 0
27
+ early_stop: 5
28
+ val_stat: 'loss'
29
+ device: 'cuda'
30
+ use_multi: False
31
+ local_rank: 0
32
+
33
+ hydra:
34
+ run:
35
+ dir: ./results/${exp_name}
36
+ output_subdir: ./ # directory for saving the yaml configs
37
+ job:
38
+ config:
39
+ override_dirname:
40
+ exclude_keys:
41
+ - exp.name
config/with_decoder.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.yaml
2
+
3
+ exp_name: exp1-2
4
+ wandb: 90788c79e1500570b08e5acf283e17df7e0c54b2
5
+ root: "${oc.env:DATA_ROOT}"
6
+ overfit: False
7
+ batch_size: 4
8
+ num_workers: 4
9
+ img_size: 224
10
+ rationale_type: 0 # 0 - only rationale, 1 - randomly add entity desc, 2 - add entity desc
11
+ val_rationale_type: 0
12
+ hide_true_bbox: 8 # clues and inferences selected randomly
13
+ widescreen_processing: 1 # 0 - no widescreen, 1 - widescreen
14
+ h_flip: False
15
+ ema_decay: 0.9999
16
+ aux_weight: 0.2
17
+ no_hard_negative_itm: False
18
+
19
+ clip_model: 'ViT-B/16' # 'RN101' 'RN50x4''RN50x16' 'RN50x64' 'ViT-L/14@336px' 'ViT-B/32'
20
+ has_extra_txt_decoder: False
21
+ has_extra_img_decoder: False
22
+ has_extra_mix_decoder: False
23
+ has_extra_gen_decoder: False
24
+
25
+ extra_decoder:
26
+ is_decoder: True
27
+ vocab_size: 1000
28
+ d_ff: 512
29
+ d_kv: 64
30
+ d_model: 512
31
+ dropout_rate: 0.1
32
+ num_heads: 8
33
+ num_layers: 2
34
+ # eos_token_id: 1
35
+ # pad_token_id: 0
36
+ # decoder_start_token_id: 0
37
+ # n_positions: 512
38
+ relative_attention_max_distance: 128
39
+ relative_attention_num_buckets: 32
40
+
41
+ warmup: 1000
42
+ init_from: ''
43
+ lr: .00001
44
+ n_epochs: 15
45
+ save_every: 0
46
+ early_stop: 5
47
+ val_stat: 'loss'
48
+ device: 'cuda'
49
+ use_multi: False
50
+ local_rank: 0
51
+
52
+ hydra:
53
+ run:
54
+ dir: ./results/${exp_name}
55
+ output_subdir: ./ # directory for saving the yaml configs
56
+ job:
57
+ config:
58
+ override_dirname:
59
+ exclude_keys:
60
+ - exp.name
create_result_script.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+
3
+ import json
4
+ import sys
5
+ import pickle
6
+ sys.path.append("../")
7
+ import collections
8
+ from models.fused_model import Model
9
+ import os
10
+ import tqdm
11
+ import time
12
+ import json
13
+ import random
14
+ from PIL import ImageFile
15
+ from PIL import Image, ImageDraw
16
+ import clip
17
+ import torch
18
+ import numpy as np
19
+ import torchvision.transforms as T
20
+ import torchvision.transforms.functional as F
21
+ from pathlib import Path
22
+ import pandas as pd
23
+
24
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
25
+
26
+ # %%
27
+ from types import SimpleNamespace
28
+ # get config
29
+ import os
30
+ from omegaconf import OmegaConf
31
+ from hydra.core.global_hydra import GlobalHydra
32
+ from hydra import initialize, initialize_config_module, initialize_config_dir, compose
33
+
34
+ os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))
35
+ os.environ['DATA_ROOT'] = os.path.join(os.environ['ROOT'], 'data')
36
+
37
+ # initialize hydra config
38
+ GlobalHydra.instance().clear()
39
+ initialize(config_path="./config")
40
+
41
+ config = compose(config_name='with_decoder.yaml',
42
+ overrides=["clip_model=ViT-L/14@336px",
43
+ "rationale_type=0", "val_rationale_type=0"])
44
+
45
+ class SquarePad:
46
+
47
+ def __call__(self, image):
48
+ max_wh = max(image.size)
49
+ p_left, p_top = [(max_wh - s) // 2 for s in image.size]
50
+ p_right, p_bottom = [max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top])]
51
+ padding = (p_left, p_top, p_right, p_bottom)
52
+ return F.pad(image, padding, 0, 'constant')
53
+
54
+ class VarDatasetForAuxEncoders:
55
+ def __init__(self, config, file_path, split="train", mode="combined", do_swap=False, tensorize=True, do_crop=True):
56
+ self.config = config
57
+ self.mode = mode
58
+ self.split = split
59
+ self.do_swap = do_swap
60
+ self.rationale_type = config.rationale_type if split == "train" else config.val_rationale_type
61
+ self.root_path = Path(config.root)
62
+ self.anno_path = file_path #self.root_path / f'annotations/13_05/anno_{split}_{mode}.json'
63
+ if split == "test" and mode == "combined" and config.overfit:
64
+ self.anno_path = self.root_path / f'annotations/13_05/anno_{split}_{mode}_overfit.json'
65
+
66
+ self.data = json.load(open(self.anno_path))
67
+ self.idx2name = list(self.data.keys())
68
+
69
+ if 'bounding_box' in self.data[list(self.data.keys())[0]]['details'][-1]:
70
+ self.one_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 1]
71
+ self.two_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 2]
72
+ self.three_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 3]
73
+ self.all_ent_keys = self.one_ent_keys + self.two_ent_keys + self.three_ent_keys
74
+
75
+ self.keys = {1: self.one_ent_keys, 2: self.two_ent_keys, 3: self.three_ent_keys}
76
+
77
+ if self.config.widescreen_processing in [0, 1]:
78
+ self.resize_crop = self.get_transform(config.img_size, split == "train", padding=False)
79
+ else:
80
+ self.resize_crop = self.get_transform(config.img_size, split == "train", padding=True)
81
+
82
+ self.tensorize = tensorize
83
+ self.jitter_transform = T.ColorJitter(brightness=.5, hue=.3, saturation=.3) if split == "train" else lambda x: x
84
+
85
+ self.final_transform = T.Compose([
86
+ lambda image: image.convert("RGB"),
87
+ T.ToTensor() if tensorize else lambda x: x,
88
+ T.Normalize(
89
+ (0.48145466, 0.4578275, 0.40821073),
90
+ (0.26862954, 0.26130258, 0.27577711),
91
+ ) if tensorize else lambda x: x
92
+ ])
93
+
94
+ def get_transform(self, n_px, training, padding=False):
95
+ resize = T.Resize((n_px + 16, n_px + 16), interpolation=Image.BICUBIC)
96
+
97
+ # for traning split
98
+ if training and not padding: # train
99
+ return T.Compose([resize, T.RandomCrop(n_px)])
100
+
101
+ if training and padding: # train_pad
102
+ return T.Compose([SquarePad(), resize, T.RandomCrop(n_px)])
103
+
104
+ # for test and val split
105
+ if not training and not padding: # test
106
+ return T.Compose([resize, T.CenterCrop(n_px)])
107
+
108
+ if not training and padding: # test_pad
109
+ return T.Compose([SquarePad(), resize, T.CenterCrop(n_px)])
110
+
111
+ def key2img_path(self, key):
112
+ file_paths = [
113
+ self.root_path / f"var_images/{key}.jpg",
114
+ self.root_path / f"var_images/{key}.png",
115
+ self.root_path / f"images/{key}.jpg",
116
+ self.root_path / f"img/train/{key.split('_')[0]}/{key}.png",
117
+ self.root_path / f"img/val/{key.split('_')[0]}/{key}.png",
118
+ self.root_path / f"img/test/{key.split('_')[0]}/{key}.png",
119
+ self.root_path / f"img/{key}.png",
120
+ self.root_path / f"img/{key}.jpg",
121
+ self.root_path / f"images/{key}.png",
122
+ self.root_path / f"images/{key}.jpg",
123
+ ]
124
+ for file_path in file_paths:
125
+ if file_path.exists():
126
+ return file_path
127
+
128
+ def key2img(self, key):
129
+ file_path = self.key2img_path(key)
130
+ return Image.open(file_path)
131
+
132
+ def hide_region(self, image, bboxes):
133
+ image = image.convert('RGBA')
134
+
135
+ if self.config.hide_true_bbox == 1: # hide mode
136
+ draw = ImageDraw.Draw(image, 'RGBA')
137
+
138
+ if self.config.hide_true_bbox in [2, 5, 7, 8, 9]: #highlight mode
139
+ overlay = Image.new('RGBA', image.size, '#00000000')
140
+ draw = ImageDraw.Draw(overlay, 'RGBA')
141
+
142
+ if self.config.hide_true_bbox == 3 or self.config.hide_true_bbox == 6: #blackout mode or position only mode
143
+ overlay = Image.new('RGBA', image.size, '#7B7575ff')
144
+ draw = ImageDraw.Draw(overlay, 'RGBA')
145
+
146
+ color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] # Green, Blue, Yellow?
147
+
148
+ for idx, bbox in enumerate(bboxes):
149
+ if bbox == None:
150
+ continue
151
+
152
+ color_fill = color_fill_list[idx]
153
+ x, y = bbox['left'], bbox['top']
154
+
155
+ if self.config.hide_true_bbox == 1: # hide mode
156
+ draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575')
157
+ elif self.config.hide_true_bbox in [2, 5, 7, 8, 9]: # highlight mode
158
+ draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff',
159
+ width=3) # Fill with Pink 60% ##00F1E8
160
+ elif self.config.hide_true_bbox == 3: # blackout mode
161
+ draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000')
162
+ elif self.config.hide_true_bbox == 6: # position only mode
163
+ draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill)
164
+
165
+ if self.config.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]:
166
+ image = Image.alpha_composite(image, overlay)
167
+
168
+ return image
169
+
170
+ def get_entity_codes(self):
171
+ entity_codes = [0, 1, 2]
172
+ if self.do_swap:
173
+ random.shuffle(entity_codes)
174
+ return entity_codes
175
+
176
+ def swap_entities(self, bboxes, text, entity_codes):
177
+ # text
178
+ for entity_idx, entity_code in enumerate(entity_codes):
179
+ text = text.replace(f"Entity #{entity_idx + 1}", f"Entity #{entity_code + 1}")
180
+
181
+ # bboxes: [1, 0, 2] -> [b[1], b[0], b[2]]
182
+ new_boxes = [bboxes[entity_code] for entity_code in entity_codes]
183
+ return new_boxes, text
184
+
185
+ def get_text_from_meta(self, meta):
186
+ n_boxes = len(meta['bounding_box']) # key ['1', '2', '3']
187
+
188
+ # for rationale
189
+ text = 'Rationale: ' + str(meta['rationale'])
190
+
191
+ if self.rationale_type == 1 or self.rationale_type == 2:
192
+ for box_idx in range(n_boxes):
193
+ ent_name = f'Entity #{box_idx + 1}'
194
+ ent_desc = f'{ent_name}, {meta[ent_name]}'
195
+ # todo: replace randomly
196
+ text = text.replace(ent_name, ent_desc, 1)
197
+ return text
198
+
199
+ def get_itm_text(self, ori_file_key):
200
+ file_key = ori_file_key
201
+
202
+ if random.random() < 0.5:
203
+ n_boxes = len(self.data[file_key]['details'][-1]['bounding_box'])
204
+
205
+ file_key = random.choice(self.keys[n_boxes])
206
+
207
+ if self.config.get('no_hard_negative_itm', False):
208
+ file_key = random.choice(self.all_ent_keys)
209
+
210
+ itm_label = 1 if file_key == ori_file_key else 0
211
+ meta = self.data[file_key]['details'][-1]
212
+
213
+ itm_text = self.get_text_from_meta(meta)
214
+ return itm_text, itm_label
215
+
216
+ def get_bboxes_and_text(self, file_key, meta):
217
+ text = self.get_text_from_meta(meta)
218
+ bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)]
219
+
220
+ entity_codes = self.get_entity_codes()
221
+ bboxes, text = self.swap_entities(bboxes, text, entity_codes)
222
+
223
+ itm_text, itm_label = self.get_itm_text(file_key)
224
+ _, itm_text = self.swap_entities([None, None, None], itm_text, entity_codes)
225
+ return {'bboxes': bboxes, 'text': text, 'itm_text': itm_text, 'itm_label': itm_label}
226
+
227
+ def get_image(self, file_key, bboxes):
228
+ image = self.key2img(file_key)
229
+ image = self.jitter_transform(image)
230
+ image = self.hide_region(image, bboxes)
231
+ image = self.final_transform(self.resize_crop(image))
232
+ return image
233
+
234
+ def __getitem__(self, idx):
235
+ file_key = self.idx2name[idx]
236
+
237
+ # Select the last version of label of the sample
238
+ meta = self.data[file_key]['details'][-1]
239
+
240
+ # read bboxes and rationale
241
+ outputs = self.get_bboxes_and_text(file_key, meta)
242
+ text = clip.tokenize(outputs['text'], truncate=True).squeeze()
243
+ itm_text = clip.tokenize(outputs['itm_text'], truncate=True).squeeze()
244
+ itm_label = torch.tensor(outputs['itm_label'])
245
+
246
+ image = self.get_image(file_key, outputs['bboxes'])
247
+
248
+ return {'image': image, 'caption': text, 'raw_text': text, 'file_key': file_key, 'itm_text': itm_text, 'itm_label': itm_label}
249
+
250
+ def __len__(self):
251
+ if self.config.overfit and not (self.split == 'test' and self.mode == 'combined'):
252
+ return 16
253
+ return len(self.data)
254
+
255
+ # %%
256
+ class VarDatasetImageOnly(VarDatasetForAuxEncoders):
257
+ def __init__(self, args, file_path, split="val", mode="combined", do_swap= False):
258
+ super().__init__(args, file_path, split=split, mode=mode, do_swap=do_swap)
259
+
260
+ def __getitem__(self, idx):
261
+ file_key = self.idx2name[idx]
262
+ meta = self.data[file_key]['details'][-1]
263
+ bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)]
264
+ entity_codes = self.get_entity_codes()
265
+ bboxes = [bboxes[entity_code] for entity_code in entity_codes]
266
+ image = self.get_image(file_key, bboxes)
267
+ return {'image': image, 'file_key': file_key}
268
+
269
+ # %%
270
+ class VarDatasetTextOnly(VarDatasetForAuxEncoders):
271
+ def __init__(self, args, file_path, split="val", mode="combined", do_swap= False):
272
+ super().__init__(args, file_path, split=split, mode=mode, do_swap=do_swap)
273
+
274
+ def __getitem__(self, idx):
275
+ file_key = self.idx2name[idx]
276
+ meta = self.data[file_key]['details'][-1]
277
+ # text = self.get_text_from_meta(meta)
278
+ if 'Entity #3' in meta['hazard']:
279
+ n_boxes = 3
280
+ elif 'Entity #2' in meta['hazard']:
281
+ n_boxes = 2
282
+ else:
283
+ n_boxes = 1
284
+
285
+ # for rationale
286
+ text = 'Rationale: ' + str(meta['hazard'])
287
+
288
+ if self.rationale_type == 1 or self.rationale_type == 2:
289
+ for box_idx in range(n_boxes):
290
+ ent_name = f'Entity #{box_idx + 1}'
291
+ ent_desc = f'{ent_name}, {meta[ent_name]}'
292
+ # todo: replace randomly
293
+ text = text.replace(ent_name, ent_desc, 1)
294
+
295
+ entity_codes = self.get_entity_codes()
296
+ for entity_idx, entity_code in enumerate(entity_codes):
297
+ text = text.replace(f"Entity #{entity_idx + 1}", f"Entity #{entity_code + 1}")
298
+ text = clip.tokenize(text, truncate=True).squeeze()
299
+ return {'caption': text,'file_key': file_key}
300
+
301
+ # %%
302
+ import os
303
+ import sys
304
+
305
+ sys.path.append('..')
306
+ import json
307
+ import fire
308
+ import tqdm
309
+
310
+ import clip
311
+ import torch
312
+ import sklearn
313
+ import numpy as np
314
+
315
+ from omegaconf import OmegaConf
316
+ from models.fused_model import Model
317
+ from torch.utils.data import DataLoader
318
+ # from datasets import VarDatasetForAuxEncoders
319
+
320
+ from scipy.stats import rankdata
321
+ from sklearn.metrics import ndcg_score
322
+ from sklearn.metrics import pairwise_distances
323
+
324
+
325
+ # def get_data_loader(config, split="test", mode="combined", do_swap=False):
326
+ # dataset = VarDatasetForAuxEncoders(config, split=split, mode=mode, do_swap=do_swap)
327
+ # return DataLoader(dataset, batch_size=4, shuffle=False)
328
+
329
+ def get_image_data_loader(config, file_path, split="test", mode="combined", do_swap=False):
330
+ dataset = VarDatasetImageOnly(config, file_path, split=split, mode=mode, do_swap=do_swap)
331
+ return DataLoader(dataset, batch_size=4, shuffle=False)
332
+
333
+ def get_text_data_loader(config, file_path, split="test", mode="combined", do_swap=False):
334
+ dataset = VarDatasetTextOnly(config, file_path, split=split, mode=mode, do_swap=do_swap)
335
+ return DataLoader(dataset, batch_size=4, shuffle=False)
336
+
337
+ # def get_data_loader(config, split="test", mode="combined", do_swap=False):
338
+ # dataset = VarDatasetForAuxEncoders(config, split=split, mode=mode, do_swap=do_swap)
339
+ # return DataLoader(dataset, batch_size=4, shuffle=False)
340
+
341
+ def compute_rand_rank(split='test', mode='spec', img_token_dict={}, txt_token_dict={}): # the dicts contain all 2000 test samples
342
+ data = json.load(open( os.path.join(os.environ['ROOT'], f"data/annotations/13_05/anno_random_{split}_{mode}_ids.json")))
343
+
344
+ i2t_ranks = []
345
+ t2i_ranks = []
346
+ i2t_rank_dict = {}
347
+ t2i_rank_dict = {}
348
+
349
+ for file_key in data.keys():
350
+ img_emb = (img_token_dict[file_key]).unsqueeze(0)
351
+ txt_emb = (txt_token_dict[file_key]).unsqueeze(0)
352
+
353
+ txt_embs = torch.stack([txt_token_dict[k] for k in data[file_key]])
354
+ img_embs = torch.stack([img_token_dict[k] for k in data[file_key]])
355
+ assert txt_embs.shape[0] == img_embs.shape[0] == 1000
356
+
357
+ i2t_rank = rankdata(pairwise_distances(img_emb, txt_embs, metric='cosine', n_jobs=8), axis=1)[0]
358
+ t2i_rank = rankdata(pairwise_distances(txt_emb, img_embs, metric='cosine', n_jobs=8), axis=1)[0]
359
+
360
+ i2t_ranks.append(i2t_rank[-1])
361
+ t2i_ranks.append(t2i_rank[-1])
362
+
363
+ i2t_rank_dict[file_key] = i2t_rank
364
+ t2i_rank_dict[file_key] = t2i_rank
365
+
366
+ assert len(i2t_ranks) == len(t2i_ranks) == 1000
367
+ print(f"Random split, mode={mode} i2t rank: ", sum(i2t_ranks) / len(i2t_ranks))
368
+ print(f"Random split, mode={mode} t2i rank: ", sum(t2i_ranks) / len(t2i_ranks))
369
+ # for k in i2t_rank_dict.keys():
370
+ # print(k, i2t_rank_dict[k])
371
+ # print('------------------')
372
+ # break
373
+ return i2t_rank_dict # for computing the NDCG scores
374
+
375
+
376
+ def read_relevance_scores(anno_path="anno_random_test_obvi_ids.json", gpt_path="chatgpt_similarity_score_test_direct_combined.json"):
377
+ gpt_scores = json.load(open(gpt_path))
378
+ data = json.load(open(anno_path))
379
+
380
+ # add_missing_relevance_scores
381
+ for k in tqdm.tqdm(data, total=len(data)):
382
+ cand_keys = data[k]
383
+ for cand_key in cand_keys:
384
+ if cand_key not in gpt_scores[k]:
385
+ gpt_scores[k][cand_key] = 0.0
386
+ if cand_key == k:
387
+ gpt_scores[k][cand_key] = 1.0
388
+
389
+ return gpt_scores
390
+ # %%
391
+
392
+ def compute_ndcg(ranks, scores, k=3):
393
+ """
394
+ ranks = [5, 1, 4, 2, 3]
395
+ scores = [0.1, 0.5, 0.3, 0.95, 1.0]
396
+ """
397
+ rank_score_tuple = list(zip(ranks, scores))
398
+
399
+ top_k = sorted(rank_score_tuple, key=lambda x: x[1], reverse=True)[:k]
400
+
401
+ dcg = sum([score / np.log2(rank + 1) for rank, score in top_k])
402
+
403
+ ideal_dcg = sum([score / np.log2(idx + 2) for idx, (_, score) in enumerate(top_k)])
404
+
405
+ ndcg = dcg / ideal_dcg
406
+ return ndcg
407
+
408
+
409
+ def compute_ndcg_score_per_mode(pred_rank_dict, gpt_rel_scores, mode='spec', split='test', k=200):
410
+ data = json.load(open(os.path.join(os.environ['ROOT'],f"data/annotations/13_05/anno_random_{split}_{mode}_ids.json")))
411
+
412
+ ndcg_scores = []
413
+
414
+ for key in tqdm.tqdm(pred_rank_dict.keys(), total=len(pred_rank_dict.keys())):
415
+ gpt_scores_for_key = [gpt_rel_scores[key][cand_key] for cand_key in data[key]]
416
+ pred_rank_for_key = pred_rank_dict[key]
417
+
418
+ ndcg_score = compute_ndcg(pred_rank_for_key, gpt_scores_for_key, k=k)
419
+ ndcg_scores.append(ndcg_score)
420
+
421
+ avg_ndcg_score = sum(ndcg_scores) / len(ndcg_scores)
422
+ print(f"Random split, mode={mode} ndcg score: ", avg_ndcg_score)
423
+ return avg_ndcg_score
424
+
425
+
426
+ # %%
427
+ def main():
428
+ # %%
429
+ ## Load Model
430
+ config_path= os.path.join(os.environ['ROOT'],"results/config.yaml")
431
+ model_path= os.path.join(os.environ['ROOT'],"results/model_epoch3.pth")
432
+ # %%
433
+ print("Loading config from:", config_path)
434
+ config = OmegaConf.load(config_path)
435
+ #print(OmegaConf.to_yaml(config))
436
+ # %%
437
+
438
+ # load checkpoint
439
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
440
+ print("Loaded model from:", model_path)
441
+
442
+ clip_model, _ = clip.load(config.clip_model, jit=False)
443
+ model = Model(clip_model, config)
444
+ model.load_state_dict(checkpoint['model_state_dict'])
445
+
446
+ model = model.to(config.device)
447
+ model = model.eval()
448
+ model = model.float()
449
+ logit_scale = model.clip_model.logit_scale.exp()
450
+
451
+ image_path = os.path.join(os.environ['ROOT'], "data/eval_test_image.json")
452
+ text_path = os.path.join(os.environ['ROOT'], "data/eval_test_text.json")
453
+
454
+ data_loader_image = get_image_data_loader(config, image_path, split='test', mode='combined' )
455
+ data_loader_text = get_text_data_loader(config, text_path, split='test', mode='combined' )
456
+
457
+ # %%
458
+ key_text_dict = {}
459
+ text_tensor_embedding = None
460
+ with torch.no_grad():
461
+ for i, d in tqdm.tqdm(enumerate(data_loader_text), total=len(data_loader_text)):
462
+ # print("d", d['file_key'])
463
+
464
+ # with torch.amp.autocast(device_type=config.device, dtype=torch.float16):
465
+ text_tensor_out, text_cls_out = model.var_txt_forward(d['caption'].to(config.device))
466
+ #print("text_tensor_out", text_tensor_out[0].shape)
467
+
468
+ if text_tensor_embedding == None:
469
+ text_tensor_embedding = text_cls_out
470
+ else:
471
+ text_tensor_embedding = torch.cat((text_tensor_embedding, text_cls_out), 0)
472
+
473
+ for j,key in enumerate(d['file_key']):
474
+ key_text_dict[key] = int(i*len(d['file_key']) +j)
475
+
476
+ # %%
477
+ key_image_dict = {}
478
+ image_tensor_embedding = None
479
+ with torch.no_grad():
480
+ for i, d in tqdm.tqdm(enumerate(data_loader_image), total=len(data_loader_image)):
481
+ image_tensor_out, img_cls_out = model.var_img_forward(d['image'].to(config.device))
482
+
483
+ if image_tensor_embedding == None:
484
+ image_tensor_embedding = img_cls_out
485
+ else:
486
+ image_tensor_embedding = torch.cat((image_tensor_embedding, img_cls_out), 0)
487
+
488
+ for j,key in enumerate(d['file_key']):
489
+ key_image_dict[key] = int(i*len(d['file_key']) +j)
490
+
491
+ idx2img = {idx: k for idx, k in enumerate(key_image_dict)}
492
+ idx2text = {idx: k for idx, k in enumerate(key_text_dict)}
493
+ # %%
494
+ image_tensor_embedding = image_tensor_embedding.to('cpu')
495
+ text_tensor_embedding = text_tensor_embedding.to('cpu')
496
+
497
+ # %%
498
+ similarity_matrix = pairwise_distances(image_tensor_embedding, text_tensor_embedding, metric='cosine', n_jobs=8)
499
+
500
+ # %%
501
+ results_pair_dict = {}
502
+ ## put into matrix
503
+ for i in range (2000):
504
+ for j in range (2000):
505
+ results_pair_dict[str(idx2img[i])+':'+str(idx2text[j])] = float(similarity_matrix[i][j])
506
+
507
+ # %%
508
+ results_pair_dict1 = {}
509
+ results_pair_dict2 = {}
510
+ len_ = int(len(results_pair_dict)/2)
511
+ for j, key in enumerate(results_pair_dict):
512
+ if j <= len_:
513
+ results_pair_dict1[key] = results_pair_dict[key]
514
+ else:
515
+ results_pair_dict2[key] = results_pair_dict[key]
516
+
517
+ # %%
518
+ # with open(os.path.join(os.environ['ROOT'],'results_pair_dict1.json'), 'w', encoding='utf-8') as f:
519
+ # json.dump(results_pair_dict1, f, ensure_ascii=False, indent=4)
520
+ # with open(os.path.join(os.environ['ROOT'],'results_pair_dict2.json'), 'w', encoding='utf-8') as f:
521
+ # json.dump(results_pair_dict2, f, ensure_ascii=False, indent=4)
522
+ df = pd.DataFrame(results_pair_dict1.items(), columns=['key_pair','score'])
523
+ df.to_csv(os.path.join(os.environ['ROOT'],'results_pair_dict1.csv'))
524
+ df = pd.DataFrame(results_pair_dict2.items(), columns=['key_pair','score'])
525
+ df.to_csv(os.path.join(os.environ['ROOT'],'results_pair_dict2.csv'))
526
+
527
+ # %%
528
+ if __name__ == "__main__":
529
+ main()
530
+
531
+ # %%
data/chatgpt_similarity_score_test_direct.json ADDED
The diff for this file is too large to render. See raw diff
 
data/chatgpt_similarity_score_test_indirect.json ADDED
The diff for this file is too large to render. See raw diff
 
data/eval_test_image.json ADDED
The diff for this file is too large to render. See raw diff
 
data/eval_test_text.json ADDED
The diff for this file is too large to render. See raw diff
 
data/images/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/key_pair.json ADDED
The diff for this file is too large to render. See raw diff
 
data/preview_image.jpeg ADDED
data/random_sample_test_direct_ids.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86c25ef7d12166f6c27fe725004a8857ddfcd4dbc8cfafd14722c189c13efbf5
3
+ size 27110039
data/random_sample_test_indirect_ids.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74851fc1e326e6c838b868062949f2c3daa6a1c3c35f6dae87549caf86acd39d
3
+ size 27109999
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (154 Bytes). View file
 
models/__pycache__/fused_model.cpython-38.pyc ADDED
Binary file (16.1 kB). View file
 
models/fused_model.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from typing import Optional, Tuple, Union
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers.activations import ACT2FN
12
+ from einops import rearrange
13
+ from transformers.models.t5.configuration_t5 import T5Config
14
+ from transformers.modeling_utils import ModuleUtilsMixin
15
+
16
+ from einops import rearrange, reduce
17
+
18
+
19
+ class FeedForward(nn.Module):
20
+
21
+ def __init__(self, config: T5Config):
22
+ super().__init__()
23
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
24
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
25
+ self.dropout = nn.Dropout(config.dropout_rate)
26
+ self.act = ACT2FN["gelu"]
27
+ self.layer_norm = nn.LayerNorm(config.d_model)
28
+
29
+ def forward(self, x):
30
+ x_hidden = self.wo(self.dropout(self.act(self.wi(self.layer_norm(x)))))
31
+ return x + self.dropout(x_hidden)
32
+
33
+
34
+ class Attention(nn.Module):
35
+
36
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
37
+ super().__init__()
38
+ self.is_decoder = config.is_decoder
39
+ self.has_relative_attention_bias = has_relative_attention_bias
40
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
41
+ self.relative_attention_max_distance = config.relative_attention_max_distance
42
+ self.d_model = config.d_model
43
+ self.key_value_proj_dim = config.d_kv
44
+ self.n_heads = config.num_heads
45
+ self.dropout = config.dropout_rate
46
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
47
+
48
+ # Mesh TensorFlow initialization to avoid scaling before softmax
49
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
50
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
51
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
52
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
53
+
54
+ if self.has_relative_attention_bias:
55
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
56
+
57
+ @staticmethod
58
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
59
+ """
60
+ memory_position - query_position -> bucket_idx.
61
+ If bidirectional=False, then positive relative positions are invalid.
62
+ We use smaller buckets for small absolute relative_position
63
+ and larger buckets for larger absolute relative_positions.
64
+ * All relative positions >=max_distance map to the same bucket.
65
+ * All relative positions <=-max_distance map to the same bucket.
66
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
67
+ Args:
68
+ relative_position: an int32 Tensor
69
+ bidirectional: a boolean - whether the attention is bidirectional
70
+ num_buckets: an integer
71
+ max_distance: an integer
72
+ Returns:
73
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
74
+ """
75
+ relative_buckets = 0
76
+ if bidirectional:
77
+ num_buckets //= 2
78
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
79
+ relative_position = torch.abs(relative_position)
80
+ else:
81
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
82
+ # now relative_position is in the range [0, inf)
83
+
84
+ # half of the buckets are for exact increments in positions
85
+ max_exact = num_buckets // 2
86
+ is_small = relative_position < max_exact
87
+
88
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
89
+ relative_position_if_large = max_exact + (torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) *
90
+ (num_buckets - max_exact)).to(torch.long)
91
+ relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
92
+
93
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
94
+ return relative_buckets
95
+
96
+ def compute_bias(self, query_length, key_length, device=None):
97
+ """Compute binned relative position bias"""
98
+ if device is None:
99
+ device = self.relative_attention_bias.weight.device
100
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
101
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
102
+ relative_position = memory_position - context_position # shape (query_length, key_length)
103
+ relative_position_bucket = self._relative_position_bucket(
104
+ relative_position, # shape (query_length, key_length)
105
+ bidirectional=(not self.is_decoder),
106
+ num_buckets=self.relative_attention_num_buckets,
107
+ max_distance=self.relative_attention_max_distance,
108
+ )
109
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
110
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
111
+ return values
112
+
113
+ def forward(self, x, mask=None, x_kv=None, pos_bias=None):
114
+ """
115
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
116
+ """
117
+ # Input is (batch_size, seq_length, dim)
118
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
119
+ batch_size, seq_length = x.shape[:2]
120
+
121
+ real_seq_length = seq_length
122
+ key_length = real_seq_length if x_kv is None else x_kv.shape[1]
123
+
124
+ reshape = lambda states: rearrange(states, 'b s (h d) -> b h s d', h=self.n_heads)
125
+ unshape = lambda states: rearrange(states, 'b h s d -> b s (h d)')
126
+
127
+ q = reshape(self.q(x)) # (batch_size, n_heads, seq_length, dim_per_head)
128
+ k = reshape(self.k(x if x_kv is None else x_kv))
129
+ v = reshape(self.v(x if x_kv is None else x_kv))
130
+
131
+ # compute scores
132
+ scores = torch.matmul(q, k.transpose(3, 2))
133
+
134
+ if pos_bias is None:
135
+ if not self.has_relative_attention_bias:
136
+ pos_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
137
+ else:
138
+ pos_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
139
+
140
+ if mask is not None:
141
+ pos_bias = pos_bias + mask # (batch_size, n_heads, seq_length, key_length)
142
+
143
+ position_bias_masked = pos_bias
144
+ scores += position_bias_masked
145
+ attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (B, H, seq_length, key_length)
146
+ attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) # (B, H, seq_length, key_length)
147
+
148
+ attn_output = unshape(torch.matmul(attn_weights, v)) # (batch_size, seq_length, dim)
149
+ attn_output = self.o(attn_output)
150
+ return (attn_output, pos_bias)
151
+
152
+
153
+ class LayerSelfAttention(nn.Module):
154
+
155
+ def __init__(self, config, has_relative_attention_bias=False):
156
+ super().__init__()
157
+ self.SelfAttention = Attention(config, has_relative_attention_bias=has_relative_attention_bias)
158
+ self.layer_norm = nn.LayerNorm(config.d_model)
159
+ self.dropout = nn.Dropout(config.dropout_rate)
160
+
161
+ def forward(self, x, mask=None, pos_bias=None): # x + drop(attn(ln(x)))
162
+ h = self.layer_norm(x)
163
+ outputs = self.SelfAttention(h, mask=mask, pos_bias=pos_bias)
164
+ x = x + self.dropout(outputs[0])
165
+ return (x, outputs[1]) # outputs[1] is pos_bias
166
+
167
+
168
+ class LayerCrossAttention(nn.Module):
169
+
170
+ def __init__(self, config):
171
+ super().__init__()
172
+ self.EncDecAttention = Attention(config, has_relative_attention_bias=False)
173
+ self.layer_norm = nn.LayerNorm(config.d_model)
174
+ self.dropout = nn.Dropout(config.dropout_rate)
175
+
176
+ def forward(self, x, x_kv, mask=None, pos_bias=None): # x + drop(attn(ln(x), x_kv))
177
+ h = self.layer_norm(x)
178
+
179
+ outputs = self.EncDecAttention(h, mask=mask, x_kv=x_kv, pos_bias=pos_bias)
180
+ x = x + self.dropout(outputs[0])
181
+ return (x, outputs[1]) # outputs[1] is pos_bias
182
+
183
+
184
+ class Block(nn.Module):
185
+
186
+ def __init__(self, config, has_relative_attention_bias=False):
187
+ super().__init__()
188
+ self.is_decoder = config.is_decoder
189
+ self.layer = nn.ModuleList()
190
+ self.layer.append(LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
191
+ if self.is_decoder:
192
+ self.layer.append(LayerCrossAttention(config))
193
+
194
+ self.layer.append(FeedForward(config))
195
+
196
+ def forward(self, x, mask=None, pos_bias=None, context=None, context_mask=None, context_pos_bias=None):
197
+
198
+ self_attention_outputs = self.layer[0](x, mask=mask, pos_bias=pos_bias)
199
+ hidden_states = self_attention_outputs[0]
200
+
201
+ do_cross_attention = self.is_decoder and context is not None
202
+ if do_cross_attention:
203
+
204
+ cross_attention_outputs = self.layer[1](
205
+ hidden_states,
206
+ x_kv=context,
207
+ mask=context_mask,
208
+ pos_bias=context_pos_bias,
209
+ )
210
+ hidden_states = cross_attention_outputs[0]
211
+
212
+ # Apply Feed Forward layer
213
+ hidden_states = self.layer[-1](hidden_states)
214
+
215
+ pos_bias = self_attention_outputs[1]
216
+ context_pos_bias = cross_attention_outputs[1] if do_cross_attention else None
217
+
218
+ return (hidden_states, pos_bias, context_pos_bias)
219
+
220
+
221
+ class Stack(nn.Module):
222
+
223
+ def __init__(self, config, is_decoder=True, has_embedding=False, generate_causal_mask=False):
224
+ super().__init__()
225
+ self.config = config
226
+ if has_embedding:
227
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
228
+
229
+ self.is_decoder = is_decoder
230
+ self.dtype = torch.float32
231
+
232
+ self.generate_causal_mask = generate_causal_mask
233
+
234
+ self.block = nn.ModuleList([Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)])
235
+ self.final_layer_norm = nn.LayerNorm(config.d_model)
236
+ self.dropout = nn.Dropout(config.dropout_rate)
237
+
238
+ def forward(
239
+ self,
240
+ input_ids=None,
241
+ dec_hidden_states=None,
242
+ enc_hidden_states=None,
243
+ dec_attention_mask=None,
244
+ enc_attention_mask=None,
245
+ ):
246
+ input_shape = input_ids.size() if input_ids is not None else dec_hidden_states.shape[:-1]
247
+ batch_size, seq_length = input_shape
248
+
249
+ if input_ids is not None:
250
+ input_ids = input_ids.view(-1, input_shape[-1])
251
+ inputs_embeds = self.embed_tokens(input_ids)
252
+ else:
253
+ inputs_embeds = dec_hidden_states
254
+
255
+ # required mask seq length can be calculated via length of past
256
+ mask_seq_length = seq_length
257
+
258
+ if dec_attention_mask is None:
259
+ dec_attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
260
+ if self.is_decoder and enc_attention_mask is None and enc_hidden_states is not None:
261
+ encoder_seq_length = enc_hidden_states.shape[1]
262
+ enc_attention_mask = torch.ones(batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long)
263
+
264
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
265
+ # ourselves in which case we just need to make it broadcastable to all heads.
266
+ extended_attention_mask = self.get_extended_attention_mask(dec_attention_mask, input_shape)
267
+
268
+ # If a 2D or 3D attention mask is provided for the cross-attention
269
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
270
+ if self.is_decoder and enc_hidden_states is not None:
271
+ encoder_batch_size, encoder_sequence_length, _ = enc_hidden_states.size()
272
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
273
+ if enc_attention_mask is None:
274
+ enc_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
275
+ encoder_extended_attention_mask = self.invert_attention_mask(enc_attention_mask)
276
+ else:
277
+ encoder_extended_attention_mask = None
278
+
279
+ pos_bias = None
280
+ context_pos_bias = None
281
+
282
+ hidden_states = self.dropout(inputs_embeds)
283
+
284
+ for i, layer_module in enumerate(self.block):
285
+
286
+ layer_outputs = layer_module(
287
+ hidden_states,
288
+ mask=extended_attention_mask, # [1, 1, 1, 1 ] [B, L]
289
+ pos_bias=pos_bias,
290
+ context=enc_hidden_states,
291
+ context_mask=encoder_extended_attention_mask,
292
+ context_pos_bias=context_pos_bias,
293
+ )
294
+
295
+ # layer_outputs is a tuple with:
296
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
297
+
298
+ hidden_states, present_key_value_state = layer_outputs[:2] # [B, L, D], None
299
+
300
+ # We share the position biases between the layers - the first layer store them
301
+ pos_bias = layer_outputs[2] # [B, H, L, L]
302
+ if self.is_decoder and enc_hidden_states is not None:
303
+ context_pos_bias = layer_outputs[3]
304
+
305
+ hidden_states = self.final_layer_norm(hidden_states)
306
+ hidden_states = self.dropout(hidden_states)
307
+
308
+ return (hidden_states,)
309
+
310
+ def invert_attention_mask(self, attention_mask):
311
+ """
312
+ Input: 1 for attend, 0 for masked/ignored
313
+ Output: 0 for attend, -1e30 for masked/ignored.
314
+ Then we can add it to the attention logits.
315
+ [B, L] -> [B, 1, 1, L]
316
+ [B, L, L] -> [B, 1, L, L]
317
+ """
318
+ if attention_mask.dim() == 3:
319
+ extended_attention_mask = attention_mask[:, None, :, :]
320
+ if attention_mask.dim() == 2:
321
+ extended_attention_mask = attention_mask[:, None, None, :]
322
+
323
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
324
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
325
+
326
+ return extended_attention_mask
327
+
328
+ def get_extended_attention_mask(self, attention_mask, input_shape, device=None, dtype=None):
329
+ """
330
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
331
+ attention_mask: 1 for attend, 0 for masked/ignored
332
+ Return: The extended attention mask: 0 for attend, -1e30 for masked/ignored
333
+ [B, L] -> [B, 1, 1, L]
334
+ [B, L, L] -> [B, 1, L, L]
335
+ """
336
+ dtype = dtype if dtype else attention_mask.dtype
337
+
338
+ # If input [B, query_length, key_length] -> [B, 1, query_length, key_length]
339
+ if attention_mask.dim() == 3:
340
+ extended_attention_mask = attention_mask[:, None, :, :]
341
+ elif attention_mask.dim() == 2:
342
+ # Provided a padding mask of dimensions [batch_size, seq_length]
343
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
344
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
345
+ if self.config.is_decoder and self.generate_causal_mask:
346
+ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(input_shape, attention_mask, device)
347
+ else:
348
+ extended_attention_mask = attention_mask[:, None, None, :]
349
+ else:
350
+ raise ValueError(f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})")
351
+
352
+ # Input: valid = 1, padding = 0
353
+ # Output: valid = 0, padding = -1e30
354
+ # => then we can add it to the attention logits
355
+ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
356
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
357
+ return extended_attention_mask
358
+
359
+
360
+ class Model(torch.nn.Module):
361
+
362
+ def __init__(self, clip_model, config):
363
+ super().__init__()
364
+ self.clip_model = clip_model
365
+ self.config = config
366
+
367
+ if self.config.has_extra_txt_decoder:
368
+ self.txt_decoder = Stack(config.extra_decoder)
369
+ self.itm_txt_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
370
+
371
+ if self.config.has_extra_img_decoder:
372
+ self.img_decoder = Stack(config.extra_decoder)
373
+ self.itm_img_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
374
+
375
+ if self.config.has_extra_mix_decoder:
376
+ self.mix_decoder = Stack(config.extra_decoder)
377
+ self.mix_itm_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
378
+
379
+ if self.config.has_extra_gen_decoder:
380
+ self.gen_decoder = Stack(config.extra_decoder, has_embedding=True, generate_causal_mask=True)
381
+ self.gen_head = torch.nn.Linear(config.extra_decoder.d_model, config.vocab_size)
382
+
383
+ self.config = config
384
+
385
+ def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224]
386
+ x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid]
387
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
388
+ x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width]
389
+ x = torch.cat(
390
+ [self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
391
+ dim=1) # shape = [*, grid ** 2 + 1, width]
392
+ x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
393
+ x = self.clip_model.visual.ln_pre(x)
394
+
395
+ x = x.permute(1, 0, 2) # NLD -> LND
396
+ x = self.clip_model.visual.transformer(x)
397
+ x = x.permute(1, 0, 2) # LND -> NLD
398
+ x = self.clip_model.visual.ln_post(x) # [NLD]
399
+
400
+ if self.clip_model.visual.proj is not None:
401
+ proj = self.clip_model.visual.proj[None, :, :]
402
+ x = (x @ proj)
403
+
404
+ cls_token = x[:, 0, :]
405
+ return x, cls_token
406
+
407
+ def txt_forward(self, text):
408
+ dtype = self.clip_model.dtype
409
+ x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model]
410
+
411
+ x = x + self.clip_model.positional_embedding.type(dtype)
412
+ x = x.permute(1, 0, 2) # NLD -> LND
413
+ x = self.clip_model.transformer(x)
414
+ x = x.permute(1, 0, 2) # LND -> NLD
415
+ x = self.clip_model.ln_final(x).type(dtype)
416
+
417
+ proj = self.clip_model.text_projection[None, :, :]
418
+ x = (x @ proj)
419
+
420
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
421
+ eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
422
+ return x, eot # [NLD]
423
+
424
+ def var_img_forward(self, image):
425
+ if len(image.shape) == 5:
426
+ img_features1, img_token1 = self.img_forward(image[:, 0, ...])
427
+ img_features2, img_token2 = self.img_forward(image[:, 1, ...])
428
+ img_token = (img_token1 + img_token2) / 2
429
+ img_features = (img_features1 + img_features2) / 2
430
+ else:
431
+ img_features, img_token = self.img_forward(image)
432
+ img_token = img_token / img_token.norm(dim=-1, keepdim=True)
433
+ return img_features, img_token
434
+
435
+ def var_txt_forward(self, text):
436
+ txt_features, txt_token = self.txt_forward(text)
437
+ txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True)
438
+ return txt_features, txt_token
439
+
440
+ def get_device(self):
441
+ return next(self.parameters()).device
442
+
443
+ def get_features(self, image=None, text_ids=None):
444
+ outputs = {}
445
+ if image is not None:
446
+ img_features, img_token = self.var_img_forward(image)
447
+ outputs['img_features'] = img_features
448
+ outputs['img_token'] = img_token
449
+ outputs['img_mask'] = torch.ones_like(img_features[:, :, 0])
450
+ if text_ids is not None:
451
+ txt_features, txt_token = self.var_txt_forward(text_ids)
452
+ outputs['txt_features'] = txt_features
453
+ outputs['txt_token'] = txt_token
454
+ outputs['txt_mask'] = (text_ids != 0).to(txt_features.dtype)
455
+ return outputs
456
+
457
+ def get_prediction(self, img_features, txt_features, img_mask=None, txt_mask=None, decoder="txt_decoder", **kwargs):
458
+ outputs = {}
459
+ if decoder == 'txt_decoder':
460
+ hidden_states = self.txt_decoder(
461
+ dec_hidden_states=txt_features,
462
+ enc_hidden_states=img_features,
463
+ enc_attention_mask=img_mask,
464
+ dec_attention_mask=txt_mask,
465
+ )
466
+ outputs['itm_txt_logits'] = self.itm_txt_head(hidden_states[0][:, 0, :])
467
+ outputs['itm_txt_probs'] = torch.softmax(outputs['itm_txt_logits'], dim=-1)
468
+
469
+ if decoder == 'img_decoder':
470
+ hidden_states = self.img_decoder(
471
+ dec_hidden_states=img_features,
472
+ enc_hidden_states=txt_features,
473
+ enc_attention_mask=txt_mask,
474
+ dec_attention_mask=img_mask,
475
+ )
476
+ outputs['itm_img_logits'] = self.itm_img_head(hidden_states[0][:, 0, :])
477
+ outputs['itm_img_probs'] = torch.softmax(outputs['itm_img_logits'], dim=-1)
478
+ return outputs
479
+
480
+ def forward(self, image, text, itm_text=None, itm_labels=None, gen_inputs=None, gen_labels=None): # , gen_inputs, gen_labels, **kwargs):
481
+ img_features, img_token = self.var_img_forward(image)
482
+ txt_features, txt_token = self.var_txt_forward(text)
483
+
484
+ itm_txt_features, _ = self.var_txt_forward(itm_text)
485
+ itm_txt_mask = (itm_text != 0).to(itm_txt_features.dtype)
486
+
487
+ outputs = dict(
488
+ img_token=img_token,
489
+ txt_token=txt_token,
490
+ img_features=img_features,
491
+ txt_features=txt_features,
492
+ )
493
+ if self.config.has_extra_txt_decoder and itm_text is not None:
494
+ itm_img_features = img_features
495
+ itm_txt_states = self.txt_decoder(
496
+ dec_hidden_states=itm_txt_features,
497
+ enc_hidden_states=itm_img_features,
498
+ enc_attention_mask=None,
499
+ dec_attention_mask=itm_txt_mask,
500
+ )
501
+ outputs['itm_txt_logits'] = self.itm_txt_head(itm_txt_states[0][:, 0])
502
+
503
+ if self.config.has_extra_img_decoder and itm_text is not None:
504
+ itm_img_features = img_features
505
+ itm_img_states = self.img_decoder(
506
+ dec_hidden_states=itm_img_features,
507
+ enc_hidden_states=itm_txt_features,
508
+ enc_attention_mask=itm_txt_mask,
509
+ dec_attention_mask=None,
510
+ )
511
+ outputs['itm_img_logits'] = self.itm_img_head(itm_img_states[0][:, 0])
512
+
513
+ if self.config.has_extra_mix_decoder:
514
+ pass
515
+
516
+ if self.config.has_extra_gen_decoder:
517
+ gen_features = self.gen_decoder(
518
+ input_ids=gen_inputs,
519
+ enc_hidden_states=img_features,
520
+ enc_attention_mask=None,
521
+ dec_attention_mask=None,
522
+ labels=gen_labels,
523
+ )
524
+ outputs['gen_logits'] = self.gen_head(gen_features[0])
525
+
526
+ return outputs
527
+
528
+
529
+ if __name__ == "__main__":
530
+ import sys
531
+ from omegaconf import OmegaConf
532
+ sys.path.append("/home/quang/workspace/traffic_var")
533
+ from config.examples import with_decoder_config as config
534
+ config.has_extra_txt_decoder = True
535
+ print(OmegaConf.to_yaml(config))
536
+
537
+ import clip
538
+
539
+ def get_resolution(model):
540
+ return model.visual.input_resolution if hasattr(model, 'visual') else model.input_resolution
541
+
542
+ model, _ = clip.load(config.clip_model, jit=False, device="cpu")
543
+ config.img_size = get_resolution(model)
544
+ model = Model(model, config)
models/model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from easydict import EasyDict as edict
4
+
5
+
6
+ class Block(nn.Module):
7
+
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ self.config = config
11
+
12
+
13
+ class Model(torch.nn.Module):
14
+
15
+ def __init__(self, clip_model, config):
16
+ super().__init__()
17
+ self.clip_model = clip_model
18
+ # if config.i2t_encoder_layers > 0:
19
+ # self.i2t_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])
20
+
21
+ # if config.t2i_encoder_layers > 0:
22
+ # self.t2i_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])
23
+
24
+ self.config = config
25
+
26
+ def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224]
27
+ x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid]
28
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
29
+ x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width]
30
+ x = torch.cat(
31
+ [self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
32
+ dim=1) # shape = [*, grid ** 2 + 1, width]
33
+ x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
34
+ x = self.clip_model.visual.ln_pre(x)
35
+
36
+ x = x.permute(1, 0, 2) # NLD -> LND
37
+ x = self.clip_model.visual.transformer(x)
38
+ x = x.permute(1, 0, 2) # LND -> NLD
39
+ x = self.clip_model.visual.ln_post(x) # [NLD]
40
+ cls_token = self.clip_model.visual.ln_post(x[:, 0, :])
41
+
42
+ if self.clip_model.visual.proj is not None:
43
+ cls_token = cls_token @ self.clip_model.visual.proj
44
+ return x, cls_token
45
+
46
+ def txt_forward(self, text):
47
+ dtype = self.clip_model.dtype
48
+ x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model]
49
+
50
+ x = x + self.clip_model.positional_embedding.type(dtype)
51
+ x = x.permute(1, 0, 2) # NLD -> LND
52
+ x = self.clip_model.transformer(x)
53
+ x = x.permute(1, 0, 2) # LND -> NLD
54
+ x = self.clip_model.ln_final(x).type(dtype)
55
+
56
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
57
+ eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.clip_model.text_projection
58
+ return x, eot # [NLD]
59
+
60
+ def var_img_forward(self, image):
61
+ if len(image.shape) == 5:
62
+ img_features1, img_token1 = self.img_forward(image[:, 0, ...])
63
+ img_features2, img_token2 = self.img_forward(image[:, 1, ...])
64
+ img_token = (img_token1 + img_token2) / 2
65
+ img_features = (img_features1 + img_features2) / 2
66
+ else:
67
+ img_features, img_token = self.img_forward(image)
68
+ img_token = img_token / img_token.norm(dim=-1, keepdim=True)
69
+ return img_features, img_token
70
+
71
+ def var_txt_forward(self, text):
72
+ txt_features, txt_token = self.txt_forward(text)
73
+ txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True)
74
+ return txt_features, txt_token
75
+
76
+ def forward(self, image, text, past_img_tokens=None, past_txt_tokens=None):
77
+ # TODO: aggregate past img and txt tokens
78
+ img_features, img_token = self.var_img_forward(image)
79
+ txt_features, txt_token = self.var_txt_forward(text)
80
+ logit_scale = self.clip_model.logit_scale.exp()
81
+
82
+ if past_img_tokens is not None:
83
+ past_img_tokens = torch.cat([past_img_tokens, img_token], dim=0)
84
+ past_txt_tokens = torch.cat([past_txt_tokens, txt_token], dim=0)
85
+
86
+ batch_size = past_img_tokens.shape[0]
87
+ ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)
88
+
89
+ logits_for_imgs = logit_scale * past_img_tokens @ past_txt_tokens.t()
90
+ logits_for_txts = logits_for_imgs.t()
91
+ # print(f"past_img_tokens: {past_img_tokens.shape}, past_txt_tokens: {past_txt_tokens.shape}")
92
+
93
+ # CLIP Contrastive Learning Loss Function
94
+ loss_img = torch.nn.CrossEntropyLoss()
95
+ loss_txt = torch.nn.CrossEntropyLoss()
96
+ loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2
97
+ else:
98
+ batch_size = img_token.shape[0]
99
+ ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)
100
+
101
+ logits_for_imgs = logit_scale * img_token @ txt_token.t()
102
+ logits_for_txts = logits_for_imgs.t()
103
+
104
+ # CLIP Contrastive Learning Loss Function
105
+ loss_img = torch.nn.CrossEntropyLoss()
106
+ loss_txt = torch.nn.CrossEntropyLoss()
107
+ loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2
108
+
109
+ return dict(
110
+ img_token=img_token,
111
+ txt_token=txt_token,
112
+ img_features=img_features,
113
+ txt_features=txt_features,
114
+ loss=loss,
115
+ past_img_tokens=past_img_tokens,
116
+ past_txt_tokens=past_txt_tokens,
117
+ )
results/config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: expl5-3a-bs16-imgtxtdec-ema-aux1
2
+ wandb: 90788c79e1500570b08e5acf283e17df7e0c54b2
3
+ root: ${oc.env:DATA_ROOT}
4
+ overfit: false
5
+ batch_size: 16
6
+ num_workers: 4
7
+ img_size: 336
8
+ rationale_type: 0
9
+ val_rationale_type: 0
10
+ hide_true_bbox: 8
11
+ widescreen_processing: 1
12
+ h_flip: false
13
+ ema_decay: 0.9999
14
+ aux_weight: 1.0
15
+ no_hard_negative_itm: false
16
+ clip_model: ViT-L/14@336px
17
+ has_extra_txt_decoder: true
18
+ has_extra_img_decoder: true
19
+ has_extra_mix_decoder: false
20
+ has_extra_gen_decoder: false
21
+ extra_decoder:
22
+ is_decoder: true
23
+ vocab_size: 1000
24
+ d_ff: 768
25
+ d_kv: 64
26
+ d_model: 768
27
+ dropout_rate: 0.1
28
+ num_heads: 12
29
+ num_layers: 2
30
+ relative_attention_max_distance: 128
31
+ relative_attention_num_buckets: 32
32
+ warmup: 1000
33
+ init_from: ''
34
+ lr: 1.0e-05
35
+ n_epochs: 15
36
+ save_every: 0
37
+ early_stop: 5
38
+ val_stat: loss
39
+ device: cuda
40
+ use_multi: false
41
+ local_rank: 0
42
+ run_file: /home/acd13872jh/workspace/traffic_var/results/.runinfo/.sh
results/model_epoch3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a45011f16c569075fcfab80c3f0e6df273a825dbad2ebeca65a438a57a95977f
3
+ size 251753353
results/results_pair_dict.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d3fa8165cf47a2562a3aa534880df625caf3705a59b5f5b221ba40f37a524fb
3
+ size 303815155
results_pair_dict1.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a469b8b91ebec66d55cc3389731e3e9f7e11111f112f13a06ea2e668f573309
3
+ size 151976036
results_pair_dict2.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59e06a5e308a71fd97b3f2f87439fde8fbac2c2834896a3ac5c54868ec9c480f
3
+ size 151839121