Chen42 commited on
Commit
41404b8
·
verified ·
1 Parent(s): fb68b12

Upload folder using huggingface_hub

Browse files
__pycache__/model_and_train.cpython-312.pyc ADDED
Binary file (9.84 kB). View file
 
__pycache__/utils.cpython-313.pyc ADDED
Binary file (2.1 kB). View file
 
filt_result_by_bleu.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import pandas as pd
3
+
4
+ decoding_res = '/home/zychen/hwproject/my_modeling_phase_1/mytest_3600_test5k/decoding_res.json'
5
+ dataset2 = load_dataset("json", data_files=decoding_res)["train"]
6
+ print(f"Number of examples: {len(dataset2)}")
7
+ decoding_df = dataset2.to_pandas()
8
+
9
+ decoding_df
make_comet_hyp_and_ref.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import pandas as pd
3
+
4
+ text_src_jsonl = '/home/zychen/hwproject/my_modeling_phase_1/mytest/text_src.jsonl'
5
+ dataset = load_dataset("json", data_files=text_src_jsonl)["train"]
6
+ print(f"Number of examples: {len(dataset)}")
7
+ text_src_df = dataset.to_pandas()
8
+
9
+ decoding_res = '/home/zychen/hwproject/my_modeling_phase_1/mytest_3600_test5k/decoding_res.json'
10
+ dataset2 = load_dataset("json", data_files=decoding_res)["train"]
11
+ print(f"Number of examples: {len(dataset2)}")
12
+ decoding_df = dataset2.to_pandas()
13
+
14
+ df_merged = pd.concat([text_src_df, decoding_df], axis=1)
15
+ print(df_merged.columns.tolist(), df_merged.iloc[4500])
16
+
17
+
18
+ def clean(sentence):
19
+ return ''.join(sentence.split())
20
+
21
+
22
+ df = df_merged
23
+ with open('text_src.txt', 'w', encoding='utf-8') as f:
24
+ for text in df['text_src']:
25
+ # cleaned_text = clean(text)
26
+ f.write(text + '\n')
27
+
28
+ # 将trans_res_seg列的内容写入hyp.txt
29
+ with open('hyp.txt', 'w', encoding='utf-8') as f:
30
+ for text in df['trans_res_seg']:
31
+ cleaned_text = clean(text)
32
+ f.write(cleaned_text + '\n')
33
+
34
+ # 将gt_seg列的内容写入ref.txt
35
+ with open('ref.txt', 'w', encoding='utf-8') as f:
36
+ for text in df['gt_seg']:
37
+ cleaned_text = clean(text)
38
+ f.write(cleaned_text + '\n')
make_jsonl.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import csv
3
+ import os
4
+ from tqdm import tqdm # 导入tqdm用于显示进度条
5
+ import re
6
+ from utils import *
7
+ import traceback
8
+
9
+
10
+ def process_json_files(csv_path, output_dir):
11
+ # 创建输出目录
12
+ os.makedirs(output_dir, exist_ok=True)
13
+ json_file = open(os.path.join(output_dir, 'output1.jsonl'),
14
+ 'w',
15
+ encoding='utf-8')
16
+ try:
17
+ # 读取CSV文件
18
+ with open(csv_path, 'r', encoding='utf-8') as csv_file:
19
+ csv_reader = csv.reader(csv_file)
20
+ next(csv_reader) # 跳过标题行
21
+
22
+ # 使用tqdm包装csv_reader以显示进度条
23
+ for row in tqdm(csv_reader,
24
+ desc="Processing JSON files",
25
+ unit="file"):
26
+ json_path = row[0] # 获取JSON文件路径
27
+ # print('row', row)
28
+ try:
29
+ # 读取JSON文件
30
+ with open(json_path, 'r', encoding='utf-8') as f:
31
+ json_data = json.load(f)
32
+ img_path = row[1]
33
+ shape = cv2.imread(img_path).shape
34
+ #element -> tuple: (word_text, word_bbox, normed_word_bbox)
35
+ # resize_box(text_['src_word_bboxes'][i],shape)
36
+ doc_triplet = []
37
+ doc_tgt_sen_trans = []
38
+ doc_words_boxes_list = []
39
+ # 处理JSON数据
40
+ for key, value in json_data.items():
41
+ if value.get("attribute") == 'text_block':
42
+ for text_ in value.get('text', []):
43
+ combined_list = [(
44
+ text_['src_words'][i],
45
+ text_['src_word_bboxes'][i],
46
+ ) for i in range(len(text_['src_words']))]
47
+ doc_words_boxes_list.extend(combined_list)
48
+ # print(f'combined_list:{combined_list}')
49
+ doc_tgt_sen_trans.append(
50
+ text_['tgt_text.zh-CN'])
51
+ processed_list = [
52
+ (src_w, src_w_boxes, resize_box(src_w_boxes, shape))
53
+ for (src_w, src_w_boxes) in doc_words_boxes_list
54
+ ]
55
+ # print(f'processed:{processed_list}')
56
+ sorted_tuple_list = tblr_reading_order_detector(
57
+ processed_list)
58
+
59
+ text_src_list = [atuple[0] for atuple in sorted_tuple_list]
60
+ layout_src_list = [
61
+ atuple[2] for atuple in sorted_tuple_list
62
+ ]
63
+ text_src = ' '.join(text_src_list)
64
+ tgt_sen_trans = ''.join(doc_tgt_sen_trans)
65
+ # print('text_src', text_src)
66
+ data_dict = {
67
+ "img_path": img_path,
68
+ "text_src": text_src,
69
+ "layout_src": layout_src_list,
70
+ "tgt_sen_trans": tgt_sen_trans
71
+ }
72
+ # print(data_dict)
73
+ json_line = json.dumps(data_dict, ensure_ascii=False)
74
+ json_file.write(json_line + '\n')
75
+
76
+ except FileNotFoundError:
77
+ print(f"File not found: {json_path}")
78
+ except json.JSONDecodeError:
79
+ print(f"Error decoding JSON in file: {json_path}")
80
+ except KeyError as e:
81
+ print(f"Missing key {e} in file: {json_path}")
82
+ except Exception as e:
83
+ print(f"Unexpected error processing {json_path}: {str(e)}")
84
+ traceback.print_exc()
85
+
86
+ except FileNotFoundError:
87
+ print(f"CSV file not found: {csv_path}")
88
+ except Exception as e:
89
+ print(f"Error reading CSV file: {str(e)}")
90
+
91
+ print("Processing completed!")
92
+
93
+
94
+ # csv_path = '/home/zychen/hwproject/my_modeling_phase_1/dataset/output_part2.csv' # 替换为你的CSV文件路径
95
+ csv_path = '/home/zychen/hwproject/my_modeling_phase_1/dataset/output.csv' # 替换为你的CSV文件路径
96
+ output_dir = '/home/zychen/hwproject/my_modeling_phase_1/dataset' # 输出目录名
97
+
98
+ process_json_files(csv_path, output_dir)
make_text_src_list.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer
4
+ from torch.utils.data import DataLoader, Dataset
5
+
6
+ dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
7
+ data_file = f"{dataset_dir}/testset_10k.jsonl"
8
+ if __name__ == "__main__":
9
+
10
+ encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"
11
+
12
+ tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
13
+
14
+ src_tokenizer, tgt_tokenizer = prepare_tokenizer(
15
+ src_tokenizer_dir=encoder_ckpt_dir,
16
+ tgt_tokenizer_dir=tgt_tokenizer_dir,
17
+ )
18
+ dataset_df = prepare_dataset_df(data_file=data_file)
19
+ my_dataset = MyDataset(df=dataset_df,
20
+ src_tokenizer=src_tokenizer,
21
+ tgt_tokenizer=tgt_tokenizer,
22
+ max_src_length=512,
23
+ max_target_length=512)
24
+ print(len(my_dataset))
25
+ from torch.utils.data import Subset
26
+ num_test = 5000 #total 10k
27
+ my_dataset = Subset(my_dataset, range(0, num_test))
28
+ # my_dataloader = DataLoader(
29
+ # my_dataset,
30
+ # batch_size=batch_size,
31
+ # shuffle=False,
32
+ # )
33
+ img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist()
34
+ text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist()
35
+ with open('./mytest/text_src.jsonl', "w") as decoding_res_file:
36
+ for img_name, text_src in zip(img_name_list, text_src_list):
37
+ res_dict = {
38
+ "img_name": img_name,
39
+ "text_src": text_src,
40
+ }
41
+
42
+ record = f"{json.dumps(res_dict, ensure_ascii=False)}\n"
43
+ decoding_res_file.write(record)
model_and_train.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basic imports
2
+ import os
3
+
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "4"
5
+
6
+ # other external imports
7
+ import pandas as pd
8
+ # torch imports
9
+ import torch
10
+ from datasets import load_dataset
11
+ from torch.utils.data import DataLoader, Dataset
12
+ # transformers imports
13
+ from transformers import (BertConfig, BertTokenizer, EncoderDecoderConfig,
14
+ EncoderDecoderModel, LayoutLMv3Tokenizer, LiltConfig,
15
+ LiltModel, Seq2SeqTrainer, Seq2SeqTrainingArguments,
16
+ default_data_collator)
17
+
18
+ # internal imports
19
+
20
+
21
+
22
+ # prepare tokenizer.
23
+ def prepare_tokenizer(src_tokenizer_dir, tgt_tokenizer_dir):
24
+ src_tokenizer = LayoutLMv3Tokenizer.from_pretrained(src_tokenizer_dir)
25
+ tgt_tokenizer = BertTokenizer.from_pretrained(tgt_tokenizer_dir)
26
+
27
+ return src_tokenizer, tgt_tokenizer
28
+
29
+
30
+ # read data points.
31
+ def prepare_dataset_df(data_file):
32
+
33
+ def filter_fn(exam):
34
+ bboxes = exam["layout_src"]
35
+ for box in bboxes:
36
+ x0, y0, x1, y1 = box
37
+ if (x0 > x1) or (y0 > y1):
38
+ print("(x0 > x1) or (y0 > y1)")
39
+ return False
40
+ for cor in box:
41
+ if cor < 0 or cor > 1000:
42
+ # print("cor < 0 or cor > 1000")
43
+ # print(exam['img_path'],box)
44
+ return False
45
+ return True
46
+
47
+ dataset = load_dataset("json", data_files=data_file)["train"]
48
+ print()
49
+ print(f"Number of examples: {len(dataset)}")
50
+ print()
51
+
52
+ dataset = dataset.filter(filter_fn, num_proc=48)
53
+
54
+ dataset_df = dataset.to_pandas()
55
+ # dataset_df = pd.read_json(data_file, lines=True, orient="records")
56
+
57
+ # filter the nan data points.
58
+ dataset_df = dataset_df[~dataset_df["tgt_sen_trans"].isna()]
59
+ dataset_df = dataset_df[~dataset_df["text_src"].isna()]
60
+ dataset_df = dataset_df[~dataset_df["layout_src"].isna()]
61
+ # remove entries where "text_src" length is less than 3
62
+ dataset_df = dataset_df[dataset_df["text_src"].str.len() >= 3]
63
+ # reconstruct the idx to avoid index_error.
64
+ dataset_df = dataset_df.reset_index(drop=True)
65
+
66
+ print(f"Number of examples after filtered: {len(dataset_df)}")
67
+ return dataset_df
68
+
69
+
70
+ class MyDataset(Dataset):
71
+
72
+ def __init__(
73
+ self,
74
+ df,
75
+ src_tokenizer,
76
+ tgt_tokenizer,
77
+ max_src_length,
78
+ max_target_length,
79
+ ):
80
+ self.df = df
81
+ self.src_tokenizer = src_tokenizer
82
+ self.tgt_tokenizer = tgt_tokenizer
83
+ self.max_src_length = max_src_length
84
+ self.max_target_length = max_target_length
85
+
86
+ def __len__(self):
87
+ return len(self.df)
88
+
89
+ def __getitem__(self, idx):
90
+ # get text_src + layout_src + tgt_trans.
91
+ text_src = self.df['text_src'][idx]
92
+ layout_src = self.df['layout_src'][idx]
93
+ tgt_trans = self.df['tgt_sen_trans'][idx]
94
+
95
+ # read in annotations at word-level (words, word boxes)
96
+ words_ = text_src.split(" ")
97
+ word_boxes_ = layout_src
98
+ # print('words', words_, len(words_), len(word_boxes_))
99
+ assert len(words_) == len(word_boxes_)
100
+ words = []
101
+ word_boxes = []
102
+ for word, word_box in zip(words_, word_boxes_):
103
+ if (word_box[0] >= word_box[2]) or (word_box[1] >= word_box[3]):
104
+ continue
105
+
106
+ words.append(word)
107
+ word_boxes.append(word_box)
108
+
109
+ assert len(words) == len(word_boxes)
110
+
111
+ encoding = self.src_tokenizer(
112
+ words,
113
+ boxes=word_boxes,
114
+ padding="max_length",
115
+ truncation=True,
116
+ max_length=self.max_src_length,
117
+ )
118
+
119
+ # construct labels.
120
+ labels = self.tgt_tokenizer(
121
+ tgt_trans,
122
+ padding="max_length",
123
+ truncation=True,
124
+ max_length=self.max_target_length)["input_ids"]
125
+ # important: make sure that PAD tokens are ignored by the loss function
126
+ labels = [
127
+ label if label != self.tgt_tokenizer.pad_token_id else -100
128
+ for label in labels
129
+ ]
130
+
131
+ encoding["labels"] = labels
132
+
133
+ assert len(encoding['input_ids']) == self.max_src_length
134
+ assert len(encoding['attention_mask']) == self.max_src_length
135
+ assert len(encoding['bbox']) == self.max_src_length
136
+ assert len(encoding['labels']) == self.max_target_length
137
+
138
+ # finally, convert everything to PyTorch tensors
139
+ for k, v in encoding.items():
140
+ encoding[k] = torch.as_tensor(encoding[k])
141
+
142
+ return encoding
143
+
144
+
145
+ def prepare_model(src_tokenizer,
146
+ tgt_tokenizer,
147
+ max_src_len,
148
+ max_tgt_len,
149
+ num_encoder_hidden_layers,
150
+ num_decoder_hidden_layers,
151
+ encoder_ckpt_dir,
152
+ model_ckpt_dir=None):
153
+ config_encoder = LiltConfig.from_pretrained(
154
+ encoder_ckpt_dir,
155
+ max_position_embeddings=max_src_len + 2,
156
+ num_hidden_layers=num_encoder_hidden_layers)
157
+ config_decoder = BertConfig(vocab_size=tgt_tokenizer.vocab_size,
158
+ max_position_embeddings=max_tgt_len,
159
+ num_hidden_layers=num_decoder_hidden_layers)
160
+
161
+ model_config = EncoderDecoderConfig.from_encoder_decoder_configs(
162
+ encoder_config=config_encoder,
163
+ decoder_config=config_decoder,
164
+ )
165
+ model = EncoderDecoderModel(config=model_config, )
166
+
167
+ model.config.decoder_start_token_id = tgt_tokenizer.cls_token_id
168
+ model.config.pad_token_id = tgt_tokenizer.pad_token_id
169
+ model.config.vocab_size = tgt_tokenizer.vocab_size
170
+ model.config.eos_token_id = tgt_tokenizer.pad_token_id
171
+
172
+ from safetensors.torch import load_file
173
+ if model_ckpt_dir:
174
+ bin_path = f"{model_ckpt_dir}/pytorch_model.bin"
175
+ safetensors_path = f"{model_ckpt_dir}/model.safetensors"
176
+ if os.path.exists(bin_path):
177
+ state_dict = torch.load(bin_path)
178
+ elif os.path.exists(safetensors_path):
179
+ state_dict = load_file(safetensors_path)
180
+ else:
181
+ raise FileNotFoundError(
182
+ "Neither pytorch_model.bin nor model.safetensors found in the specified directory."
183
+ )
184
+ model.load_state_dict(state_dict, strict=False)
185
+ model.save_pretrained(
186
+ f"continued_{model_ckpt_dir}") #save at continued training
187
+ else:
188
+ # Loading the pre-trained params and then save the model, including its configuration.
189
+ tmp_encoder = LiltModel.from_pretrained(
190
+ pretrained_model_name_or_path=encoder_ckpt_dir,
191
+ config=config_encoder,
192
+ )
193
+ # tmp_encoder = LiltModel(config=config_encoder)
194
+ model.encoder = tmp_encoder
195
+ # model.save_pretrained("undertrained_default_safe_true")
196
+ model.save_pretrained("undertrained_safe_serialization_False", safe_serialization=False)
197
+ # model.load_state_dict(torch.load(f"undertrained/pytorch_model.bin"))
198
+
199
+ bin_path = "undertrained_safe_serialization_False/pytorch_model.bin"
200
+ safetensors_path = "undertrained_default_safe_true/model.safetensors"
201
+ if os.path.exists(bin_path):
202
+ state_dict = torch.load(bin_path)
203
+ elif os.path.exists(safetensors_path):
204
+ state_dict = load_file(safetensors_path)
205
+ else:
206
+ raise FileNotFoundError(
207
+ "Neither pytorch_model.bin nor model.safetensors found in the specified directory."
208
+ )
209
+ model.load_state_dict(state_dict, strict=False)
210
+
211
+ print(model.config)
212
+ print(model)
213
+
214
+ return model
215
+
216
+
217
+ if __name__ == "__main__":
218
+
219
+ # hyper-parameters.
220
+ ## for model.
221
+ MAX_TGT_LEN = 512
222
+ MAX_SRC_LEN = 512
223
+ num_encoder_hidden_layers = 12
224
+ num_decoder_hidden_layers = 12
225
+
226
+ ## for training.
227
+ num_instances = 500000 #total 620082 ./dataset/merged.jsonl Number of examples after filtered: 547084
228
+ learning_rate = 1e-4
229
+ batch_size = 28
230
+ num_train_steps = 400000 #400000
231
+ output_dir = f"./train.lr_{learning_rate}.bsz_{batch_size}.step_{num_train_steps}.layer_{num_encoder_hidden_layers}-{num_decoder_hidden_layers}"
232
+ save_total_limit = 100
233
+ save_steps = num_train_steps // save_total_limit
234
+
235
+ dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
236
+ data_file = f"{dataset_dir}/merged.jsonl"
237
+
238
+ # model_ckpt_dir = '/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_8.step_400000.layer_12-12/checkpoint-32000'
239
+ model_ckpt_dir = '/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_16.step_500000.layer_12-12_36k+20k/checkpoint-20000'
240
+ encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"
241
+
242
+ tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
243
+
244
+ src_tokenizer, tgt_tokenizer = prepare_tokenizer(
245
+ src_tokenizer_dir=encoder_ckpt_dir,
246
+ tgt_tokenizer_dir=tgt_tokenizer_dir,
247
+ )
248
+ dataset_df = prepare_dataset_df(data_file=data_file)[:num_instances]
249
+ print(f"\nnum_instances: {len(dataset_df)}\n")
250
+ print(dataset_df)
251
+ my_dataset = MyDataset(
252
+ df=dataset_df,
253
+ src_tokenizer=src_tokenizer,
254
+ tgt_tokenizer=tgt_tokenizer,
255
+ max_src_length=MAX_SRC_LEN,
256
+ max_target_length=MAX_TGT_LEN,
257
+ )
258
+ model = prepare_model(src_tokenizer=src_tokenizer,
259
+ tgt_tokenizer=tgt_tokenizer,
260
+ max_src_len=MAX_SRC_LEN,
261
+ max_tgt_len=MAX_TGT_LEN,
262
+ num_encoder_hidden_layers=num_encoder_hidden_layers,
263
+ num_decoder_hidden_layers=num_decoder_hidden_layers,
264
+ encoder_ckpt_dir=encoder_ckpt_dir,
265
+ model_ckpt_dir=model_ckpt_dir)
266
+
267
+ training_args = Seq2SeqTrainingArguments(
268
+ predict_with_generate=False,
269
+ evaluation_strategy="no",
270
+ per_device_train_batch_size=batch_size,
271
+ fp16=True,
272
+ output_dir=output_dir,
273
+ logging_steps=1,
274
+ # save_strategy="epoch",
275
+ learning_rate=learning_rate,
276
+ max_steps=num_train_steps,
277
+ warmup_ratio=0.05,
278
+ save_total_limit=save_total_limit,
279
+ save_steps=save_steps,
280
+ save_safetensors=False,
281
+ )
282
+ # print(training_args)
283
+ # instantiate trainer
284
+ trainer = Seq2SeqTrainer(
285
+ model=model,
286
+ args=training_args,
287
+ compute_metrics=None,
288
+ train_dataset=my_dataset,
289
+ eval_dataset=None,
290
+ data_collator=default_data_collator,
291
+ )
292
+
293
+ trainer.train()
old_model_and_train.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basic imports
2
+ import os
3
+
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
+
6
+ # transformers imports
7
+ from transformers import LiltConfig, BertConfig, EncoderDecoderConfig, EncoderDecoderModel, BertTokenizer, LayoutLMv3Tokenizer, LiltModel
8
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
9
+ from transformers import default_data_collator
10
+ from datasets import load_dataset
11
+
12
+ # torch imports
13
+ import torch
14
+ from torch.utils.data import Dataset, DataLoader
15
+
16
+ # internal imports
17
+
18
+ # other external imports
19
+ import pandas as pd
20
+
21
+
22
+ # prepare tokenizer.
23
+ def prepare_tokenizer(src_tokenizer_dir, tgt_tokenizer_dir):
24
+ src_tokenizer = LayoutLMv3Tokenizer.from_pretrained(src_tokenizer_dir)
25
+ tgt_tokenizer = BertTokenizer.from_pretrained(tgt_tokenizer_dir)
26
+
27
+ return src_tokenizer, tgt_tokenizer
28
+
29
+
30
+ # read data points.
31
+ def prepare_dataset_df(data_file):
32
+
33
+ def filter_fn(exam):
34
+ bboxes = exam['block_list']
35
+ for box in bboxes:
36
+ x0, y0, x1, y1 = box["block_bbox"]
37
+ if (x0 > x1) or (y0 > y1):
38
+ print(box["block_bbox"])
39
+ return False
40
+ for cor in box["block_bbox"]:
41
+ # if cor < 0 or cor > 1000:
42
+ if cor <0:
43
+ return False
44
+ return True
45
+
46
+ dataset = load_dataset("json", data_files=data_file)["train"]
47
+ print()
48
+ print(f"Number of examples: {len(dataset)}")
49
+ print()
50
+ # print(dataset[0]['block_list'])
51
+ dataset = dataset.filter(filter_fn, num_proc=48)
52
+
53
+ dataset_df = dataset.to_pandas()
54
+ # dataset_df = pd.read_json(data_file, lines=True, orient="records")
55
+
56
+ # filter the nan data points.
57
+ # dataset_df = dataset_df[~dataset_df["tgt_sen_trans"].isna()]
58
+ # dataset_df = dataset_df[~dataset_df["text_src"].isna()]
59
+ # dataset_df = dataset_df[~dataset_df["layout_src"].isna()]
60
+
61
+ # reconstruct the idx to avoid index_error.
62
+ dataset_df = dataset_df.reset_index(drop=True)
63
+
64
+ print(f"Number of examples after filtered: {len(dataset_df)}")
65
+ print(dataset_df)
66
+ return dataset_df
67
+
68
+
69
+ class MyDataset(Dataset):
70
+
71
+ def __init__(
72
+ self,
73
+ df,
74
+ src_tokenizer,
75
+ tgt_tokenizer,
76
+ max_src_length,
77
+ max_target_length,
78
+ ):
79
+ self.df = df
80
+ self.src_tokenizer = src_tokenizer
81
+ self.tgt_tokenizer = tgt_tokenizer
82
+ self.max_src_length = max_src_length
83
+ self.max_target_length = max_target_length
84
+
85
+ def __len__(self):
86
+ return len(self.df)
87
+
88
+ def __getitem__(self, idx):
89
+ # get text_src + layout_src + tgt_trans.
90
+ text_src = self.df['text_src'][idx]
91
+ layout_src = self.df['layout_src'][idx]
92
+ tgt_trans = self.df['tgt_sen_trans'][idx]
93
+
94
+ # read in annotations at word-level (words, word boxes)
95
+ words_ = text_src.split(" ")
96
+ word_boxes_ = layout_src
97
+ assert len(words_) == len(word_boxes_)
98
+ words = []
99
+ word_boxes = []
100
+ for word, word_box in zip(words_, word_boxes_):
101
+ if (word_box[0] >= word_box[2]) or (word_box[1] >= word_box[3]):
102
+ continue
103
+
104
+ words.append(word)
105
+ word_boxes.append(word_box)
106
+
107
+ assert len(words) == len(word_boxes)
108
+
109
+ encoding = self.src_tokenizer(
110
+ words,
111
+ boxes=word_boxes,
112
+ padding="max_length",
113
+ truncation=True,
114
+ max_length=self.max_src_length,
115
+ )
116
+
117
+ # construct labels.
118
+ labels = self.tgt_tokenizer(
119
+ tgt_trans,
120
+ padding="max_length",
121
+ truncation=True,
122
+ max_length=self.max_target_length)["input_ids"]
123
+ # important: make sure that PAD tokens are ignored by the loss function
124
+ labels = [
125
+ label if label != self.tgt_tokenizer.pad_token_id else -100
126
+ for label in labels
127
+ ]
128
+
129
+ encoding["labels"] = labels
130
+
131
+ assert len(encoding['input_ids']) == self.max_src_length
132
+ assert len(encoding['attention_mask']) == self.max_src_length
133
+ assert len(encoding['bbox']) == self.max_src_length
134
+ assert len(encoding['labels']) == self.max_target_length
135
+
136
+ # finally, convert everything to PyTorch tensors
137
+ for k, v in encoding.items():
138
+ encoding[k] = torch.as_tensor(encoding[k])
139
+
140
+ return encoding
141
+
142
+
143
+ def prepare_model(src_tokenizer,
144
+ tgt_tokenizer,
145
+ max_src_len,
146
+ max_tgt_len,
147
+ num_encoder_hidden_layers,
148
+ num_decoder_hidden_layers,
149
+ encoder_ckpt_dir,
150
+ model_ckpt_dir=None):
151
+ config_encoder = LiltConfig.from_pretrained(
152
+ encoder_ckpt_dir,
153
+ max_position_embeddings=max_src_len + 2,
154
+ num_hidden_layers=num_encoder_hidden_layers)
155
+ config_decoder = BertConfig(vocab_size=tgt_tokenizer.vocab_size,
156
+ max_position_embeddings=max_tgt_len,
157
+ num_hidden_layers=num_decoder_hidden_layers)
158
+
159
+ model_config = EncoderDecoderConfig.from_encoder_decoder_configs(
160
+ encoder_config=config_encoder,
161
+ decoder_config=config_decoder,
162
+ )
163
+ model = EncoderDecoderModel(config=model_config, )
164
+
165
+ model.config.decoder_start_token_id = tgt_tokenizer.cls_token_id
166
+ model.config.pad_token_id = tgt_tokenizer.pad_token_id
167
+ model.config.vocab_size = tgt_tokenizer.vocab_size
168
+ model.config.eos_token_id = tgt_tokenizer.pad_token_id
169
+
170
+ if model_ckpt_dir:
171
+ model.load_state_dict(
172
+ torch.load(f"{model_ckpt_dir}/pytorch_model.bin"))
173
+ else:
174
+ # Loading the pre-trained params and then save the model, including its configuration.
175
+ tmp_encoder = LiltModel.from_pretrained(
176
+ pretrained_model_name_or_path=encoder_ckpt_dir,
177
+ config=config_encoder,
178
+ )
179
+ # tmp_encoder = LiltModel(config=config_encoder)
180
+ model.encoder = tmp_encoder
181
+ model.save_pretrained("undertrained")
182
+ model.load_state_dict(torch.load(f"undertrained/pytorch_model.bin"))
183
+
184
+ print(model.config)
185
+ print(model)
186
+
187
+ return model
188
+
189
+
190
+ if __name__ == "__main__":
191
+
192
+ # hyper-parameters.
193
+ ## for model.
194
+ MAX_TGT_LEN = 512
195
+ MAX_SRC_LEN = 512
196
+ num_encoder_hidden_layers = 12
197
+ num_decoder_hidden_layers = 12
198
+
199
+ ## for training.
200
+ # wc 12420 ./dataset/scene_imgs/jsons/en_json/en_scene.jsonl
201
+ # wc 12230 ./dataset/scene_imgs/jsons/zh_json/zh_scene.jsonl
202
+ num_instances = 500000
203
+ learning_rate = 1e-4
204
+ batch_size = 16
205
+ num_train_steps = 400000
206
+ output_dir = f"./train.lr_{learning_rate}.bsz_{batch_size}.step_{num_train_steps}.layer_{num_encoder_hidden_layers}-{num_decoder_hidden_layers}"
207
+ save_total_limit = 100
208
+ save_steps = num_train_steps // save_total_limit
209
+
210
+ # dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset/scene_imgs/jsons/en_json/en_scene.jsonl"
211
+ data_file = "/home/zychen/hwproject/my_modeling_phase_1/dataset/scene_imgs/jsons/en_json/en_scene.jsonl"
212
+
213
+ model_ckpt_dir = None
214
+
215
+ encoder_ckpt_dir = "./Tokenizer_PretrainedWeights/lilt-roberta-en-base"
216
+
217
+ tgt_tokenizer_dir = "./Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
218
+
219
+ src_tokenizer, tgt_tokenizer = prepare_tokenizer(
220
+ src_tokenizer_dir=encoder_ckpt_dir,
221
+ tgt_tokenizer_dir=tgt_tokenizer_dir,
222
+ )
223
+ dataset_df = prepare_dataset_df(data_file=data_file)[:num_instances]
224
+ print(f"\nnum_instances: {len(dataset_df)}\n")
225
+ my_dataset = MyDataset(
226
+ df=dataset_df,
227
+ src_tokenizer=src_tokenizer,
228
+ tgt_tokenizer=tgt_tokenizer,
229
+ max_src_length=MAX_SRC_LEN,
230
+ max_target_length=MAX_TGT_LEN,
231
+ )
232
+ model = prepare_model(src_tokenizer=src_tokenizer,
233
+ tgt_tokenizer=tgt_tokenizer,
234
+ max_src_len=MAX_SRC_LEN,
235
+ max_tgt_len=MAX_TGT_LEN,
236
+ num_encoder_hidden_layers=num_encoder_hidden_layers,
237
+ num_decoder_hidden_layers=num_decoder_hidden_layers,
238
+ encoder_ckpt_dir=encoder_ckpt_dir,
239
+ model_ckpt_dir=model_ckpt_dir)
240
+
241
+ training_args = Seq2SeqTrainingArguments(
242
+ predict_with_generate=False,
243
+ evaluation_strategy="no",
244
+ per_device_train_batch_size=batch_size,
245
+ fp16=True,
246
+ output_dir=output_dir,
247
+ logging_steps=1,
248
+ # save_strategy="epoch",
249
+ learning_rate=learning_rate,
250
+ max_steps=num_train_steps,
251
+ warmup_ratio=0.05,
252
+ save_total_limit=save_total_limit,
253
+ save_steps=save_steps,
254
+ )
255
+
256
+ # instantiate trainer
257
+ trainer = Seq2SeqTrainer(
258
+ model=model,
259
+ args=training_args,
260
+ compute_metrics=None,
261
+ train_dataset=my_dataset,
262
+ eval_dataset=None,
263
+ data_collator=default_data_collator,
264
+ )
265
+
266
+ trainer.train()
sample_generate.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basic imports
2
+ import os
3
+
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
+
6
+ # other external imports
7
+ import pandas as pd
8
+ # torch imports
9
+ import torch
10
+ from datasets import load_dataset
11
+ from torch.utils.data import DataLoader, Dataset
12
+ # transformers imports
13
+ from transformers import (BertConfig, BertTokenizer, EncoderDecoderConfig,
14
+ EncoderDecoderModel, LayoutLMv3Tokenizer, LiltConfig,
15
+ LiltModel, Seq2SeqTrainer, Seq2SeqTrainingArguments,
16
+ default_data_collator)
17
+
18
+ # internal imports
19
+
20
+
21
+
22
+ def prepare_tokenizer(src_tokenizer_dir, tgt_tokenizer_dir):
23
+ src_tokenizer = LayoutLMv3Tokenizer.from_pretrained(src_tokenizer_dir)
24
+ tgt_tokenizer = BertTokenizer.from_pretrained(tgt_tokenizer_dir)
25
+
26
+ return src_tokenizer, tgt_tokenizer
27
+
28
+
29
+ if __name__ == "__main__":
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ device = 'cpu'
32
+ print(device)
33
+ checkpoints_dir = '/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_8.step_400000.layer_12-12_36000'
34
+ model = EncoderDecoderModel.from_pretrained(
35
+ f"{checkpoints_dir}/checkpoint-36000").to(device)
36
+ encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"
37
+ tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
38
+
39
+ src_tokenizer, tgt_tokenizer = prepare_tokenizer(
40
+ src_tokenizer_dir=encoder_ckpt_dir,
41
+ tgt_tokenizer_dir=tgt_tokenizer_dir,
42
+ )
43
+ model.eval()
44
+
45
+ from model_and_train import (MyDataset, prepare_dataset_df,
46
+ prepare_tokenizer)
47
+
48
+ dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
49
+ data_file = f"{dataset_dir}/merged.jsonl"
50
+ dataset_df = prepare_dataset_df(data_file=data_file)[:1000]
51
+ print(f"\nnum_instances: {len(dataset_df)}\n")
52
+ print(dataset_df)
53
+ my_dataset = MyDataset(
54
+ df=dataset_df,
55
+ src_tokenizer=src_tokenizer,
56
+ tgt_tokenizer=tgt_tokenizer,
57
+ max_src_length=512,
58
+ max_target_length=512,
59
+ )
60
+ sample = my_dataset[0]
61
+ from transformers import GenerationConfig
62
+ generation_config = GenerationConfig(
63
+ max_length=512,
64
+ early_stopping=True,
65
+ num_beams=1,
66
+ use_cache=True,
67
+ length_penalty=1.0,
68
+ )
69
+
70
+ with torch.no_grad():
71
+ generation_config = None
72
+ outputs = model.generate(
73
+ input_ids=sample['input_ids'].unsqueeze(
74
+ 0), # 添加 unsqueeze 以增加 batch 维度
75
+ attention_mask=sample['attention_mask'].unsqueeze(0),
76
+ do_sample=False,
77
+ generation_config=generation_config,
78
+ bos_token_id=0)
79
+ decoded_preds = tgt_tokenizer.batch_decode(outputs,
80
+ skip_special_tokens=True)
81
+ print(decoded_preds)
82
+ print(sample['labels'])
test_bleu.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basic imports
2
+ import os
3
+
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
+
6
+ # other external imports
7
+ import pandas as pd
8
+ import sacrebleu
9
+ # torch imports
10
+ import torch
11
+ from datasets import load_dataset
12
+ from torch.utils.data import DataLoader, Dataset
13
+ # transformers imports
14
+ from tqdm import tqdm
15
+ from transformers import (BertConfig, BertTokenizer, EncoderDecoderConfig,
16
+ EncoderDecoderModel, LayoutLMv3Tokenizer, LiltConfig,
17
+ LiltModel, Seq2SeqTrainer, Seq2SeqTrainingArguments,
18
+ default_data_collator)
19
+
20
+ # internal imports
21
+
22
+
23
+
24
+ def prepare_tokenizer(src_tokenizer_dir, tgt_tokenizer_dir):
25
+ src_tokenizer = LayoutLMv3Tokenizer.from_pretrained(src_tokenizer_dir)
26
+ tgt_tokenizer = BertTokenizer.from_pretrained(tgt_tokenizer_dir)
27
+
28
+ return src_tokenizer, tgt_tokenizer
29
+
30
+
31
+ def prepare_dataset_df(data_file):
32
+ dataset_df = pd.read_json(data_file, lines=True)
33
+ return dataset_df
34
+
35
+
36
+ if __name__ == "__main__":
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ device = 'cpu'
39
+ print(device)
40
+ checkpoints_dir = '/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_28.step_400000.layer_12-12'
41
+ model = EncoderDecoderModel.from_pretrained(
42
+ f"{checkpoints_dir}/checkpoint-64000").to(device)
43
+ encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"
44
+ tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
45
+
46
+ src_tokenizer, tgt_tokenizer = prepare_tokenizer(
47
+ src_tokenizer_dir=encoder_ckpt_dir,
48
+ tgt_tokenizer_dir=tgt_tokenizer_dir,
49
+ )
50
+ model.eval()
51
+
52
+ dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
53
+ data_file = f"{dataset_dir}/merged.jsonl"
54
+ dataset_df = prepare_dataset_df(data_file=data_file)[:5000]
55
+ print(f"\nnum_instances: {len(dataset_df)}\n")
56
+ from model_and_train import (MyDataset, prepare_dataset_df,
57
+ prepare_tokenizer)
58
+
59
+ my_dataset = MyDataset(
60
+ df=dataset_df,
61
+ src_tokenizer=src_tokenizer,
62
+ tgt_tokenizer=tgt_tokenizer,
63
+ max_src_length=512,
64
+ max_target_length=512,
65
+ )
66
+
67
+ dataloader = DataLoader(my_dataset, batch_size=4, shuffle=False)
68
+
69
+ references = []
70
+ predictions = []
71
+
72
+ with torch.no_grad():
73
+ for batch in tqdm(dataloader):
74
+ input_ids = batch['input_ids'].to(device)
75
+ attention_mask = batch['attention_mask'].to(device)
76
+ labels = batch['labels'].tolist()
77
+ outputs = model.generate(input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ do_sample=True,
80
+ max_length=512,
81
+ num_beams=1,
82
+ use_cache=True,
83
+ length_penalty=1.0,
84
+ bos_token_id=0)
85
+
86
+ decoded_preds = tgt_tokenizer.batch_decode(
87
+ outputs, skip_special_tokens=True)
88
+ decoded_labels = tgt_tokenizer.batch_decode(
89
+ labels, skip_special_tokens=True)
90
+
91
+ predictions.extend(decoded_preds)
92
+ references.extend([label.split(' ') for label in decoded_labels])
93
+
94
+ predictions_str = ''.join(predictions)
95
+ references_str = ''.join([''.join(ref) for ref in references])
96
+
97
+ print(predictions_str, references_str)
98
+
99
+ bleu_score = sacrebleu.corpus_bleu(predictions, [references])
100
+ print(f"BLEU score: {bleu_score.score}")
test_bleu_chrf.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basic imports
2
+ import json
3
+ import os
4
+
5
+ import jieba
6
+ # other external imports
7
+ import pandas as pd
8
+ # torch imports
9
+ import torch
10
+ # internal imports
11
+ from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer
12
+ from nltk.translate.bleu_score import sentence_bleu
13
+ from sacrebleu.metrics import CHRF
14
+ from torch.utils.data import DataLoader
15
+ from tqdm import tqdm
16
+ # transformers imports
17
+ from transformers import BertTokenizer, EncoderDecoderModel
18
+
19
+ chrf = CHRF(word_order=2) # word_order=2 to be chrf++.
20
+
21
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
22
+
23
+ # hyper-parameters.
24
+ ## for model.
25
+ MAX_TGT_LEN = 512
26
+ MAX_SRC_LEN = 512
27
+
28
+ ## for decoding.
29
+ output_dir = "./mytest"
30
+ os.makedirs(output_dir, exist_ok=True)
31
+ early_stopping = True
32
+ num_beams = 2
33
+ length_penalty = 1.0
34
+ batch_size = 16
35
+ metric_res_filepath = os.path.join(output_dir, "metric_res.json")
36
+ decoding_res_filepath = os.path.join(output_dir, "decoding_res.json")
37
+ trained_model_dir = "/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_28.step_400000.layer_12-12/checkpoint-64000"
38
+
39
+ dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
40
+ data_file = f"{dataset_dir}/testset_10k.jsonl"
41
+
42
+
43
+ def no_blank(sen):
44
+ return "".join(sen.split())
45
+
46
+
47
+ if __name__ == "__main__":
48
+
49
+ encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"
50
+
51
+ tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
52
+
53
+ src_tokenizer, tgt_tokenizer = prepare_tokenizer(
54
+ src_tokenizer_dir=encoder_ckpt_dir,
55
+ tgt_tokenizer_dir=tgt_tokenizer_dir,
56
+ )
57
+ dataset_df = prepare_dataset_df(data_file=data_file)
58
+ my_dataset = MyDataset(df=dataset_df,
59
+ src_tokenizer=src_tokenizer,
60
+ tgt_tokenizer=tgt_tokenizer,
61
+ max_src_length=512,
62
+ max_target_length=512)
63
+ print(len(my_dataset))
64
+ from torch.utils.data import Subset
65
+ num_test = 5000 #total 10k
66
+ my_dataset = Subset(my_dataset, range(0, num_test))
67
+ my_dataloader = DataLoader(
68
+ my_dataset,
69
+ batch_size=batch_size,
70
+ shuffle=False,
71
+ )
72
+
73
+ # loading model and config from pretrained folder
74
+ model = EncoderDecoderModel.from_pretrained(trained_model_dir)
75
+ # device='cpu'
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+ model.to(device)
78
+ model.eval()
79
+
80
+ print(model)
81
+
82
+ # decoding testset
83
+ pred_res_list = []
84
+ gt_list = []
85
+
86
+ for batch in tqdm(my_dataloader):
87
+ # predict use generate
88
+ with torch.no_grad():
89
+ encoder_outputs = model.encoder(
90
+ input_ids=batch["input_ids"].to(device),
91
+ bbox=batch["bbox"].to(device),
92
+ attention_mask=batch["attention_mask"].to(device),
93
+ )
94
+ outputs = model.generate(
95
+ input_ids=batch["input_ids"].to(device),
96
+ attention_mask=batch["attention_mask"].to(device),
97
+ encoder_outputs=encoder_outputs,
98
+ max_length=MAX_TGT_LEN,
99
+ early_stopping=early_stopping,
100
+ num_beams=num_beams,
101
+ length_penalty=length_penalty,
102
+ use_cache=True,
103
+ decoder_start_token_id=0)
104
+
105
+ # decode
106
+ pred_str = tgt_tokenizer.batch_decode(outputs,
107
+ skip_special_tokens=True)
108
+ labels = batch["labels"]
109
+ labels[labels == -100] = tgt_tokenizer.pad_token_id
110
+ label_str = tgt_tokenizer.batch_decode(labels,
111
+ skip_special_tokens=True)
112
+
113
+ pred_res_list += pred_str
114
+ gt_list += label_str
115
+
116
+ gt_list = [no_blank(sen) for sen in gt_list]
117
+ pred_res_list = [no_blank(sen) for sen in pred_res_list]
118
+
119
+ # write the decoding res and compute metric.
120
+ img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist()
121
+ text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist()
122
+ bleu_list = []
123
+ chrf_list = []
124
+
125
+ pred_res_seg_list = [" ".join(jieba.cut(item)) for item in pred_res_list]
126
+ gt_seg_list = [" ".join(jieba.cut(item)) for item in gt_list]
127
+ print(len(text_src_list), len(pred_res_seg_list), len(gt_seg_list))
128
+ # print(img_name_list, pred_res_list, gt_seg_list)
129
+ assert len(img_name_list) == len(pred_res_seg_list) == len(gt_seg_list)
130
+
131
+ with open(decoding_res_filepath, "w") as decoding_res_file:
132
+ for img_name, text_src, pred_res_seg, gt_seg in zip(
133
+ img_name_list, text_src_list, pred_res_seg_list, gt_seg_list):
134
+
135
+ instance_bleu = sentence_bleu([gt_seg.split()],
136
+ pred_res_seg.split())
137
+ bleu_list.append(instance_bleu)
138
+
139
+ instance_chrf = chrf.sentence_score(
140
+ hypothesis=pred_res_seg,
141
+ references=[gt_seg],
142
+ ).score
143
+ chrf_list.append(instance_chrf)
144
+
145
+ res_dict = {
146
+ "img_name": img_name,
147
+ "text_src": text_src,
148
+ "instance_bleu": instance_bleu,
149
+ "instance_chrf": instance_chrf,
150
+ "trans_res_seg": pred_res_seg,
151
+ "gt_seg": gt_seg,
152
+ }
153
+
154
+ record = f"{json.dumps(res_dict, ensure_ascii=False)}\n"
155
+ decoding_res_file.write(record)
156
+
157
+ trans_avg_bleu = sum(bleu_list) / len(bleu_list)
158
+ trans_avg_chrf = sum(chrf_list) / len(chrf_list)
159
+ with open(metric_res_filepath, "w") as metric_res_file:
160
+ eval_res_dict = {
161
+ "trans_avg_bleu": trans_avg_bleu,
162
+ "trans_avg_chrf": trans_avg_chrf,
163
+ }
164
+ json.dump(eval_res_dict, metric_res_file, indent=4, ensure_ascii=False)
utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+ import json
5
+ from functools import cmp_to_key
6
+
7
+ def resize_box(box, ori_img_shape, nomarlized_img_shape=(1000, 1000, 3)):
8
+ """
9
+ box: [x0, y0, x1, y1],
10
+ ori_img_shape: 形如(1560, 1103, 3)
11
+ """
12
+
13
+ height_ratio = nomarlized_img_shape[0] / ori_img_shape[0]
14
+ width_ratio = nomarlized_img_shape[1] / ori_img_shape[1]
15
+
16
+ x0, y0, x1, y1 = box
17
+ norm_x0, norm_x1 = round(x0 * width_ratio), round(x1 * width_ratio)
18
+ norm_y0, norm_y1 = round(y0 * height_ratio), round(y1 * height_ratio)
19
+
20
+ return [norm_x0, norm_y0, norm_x1, norm_y1]
21
+
22
+ def tblr_reading_order_detector(tuple_list):
23
+ """rule: top-to-bottom, left-to-right
24
+
25
+ tuple: (word_text, word_bbox, normed_word_bbox)
26
+
27
+ return: sorted_tuple_list
28
+ """
29
+
30
+ def sort_cmp_fn(word_box1, word_box2):
31
+ """
32
+ sorted function的排序的2个元素的比较准则。
33
+ 1. 比较box1和box2的y坐标,如果二者的高重合度达到了二者的50%,则位于同一行,否则位于不同行。
34
+ 2. 如果位于同一行,那么比较二者的x0,如果box1_x0 < box2_x0,则返回-1,表示box_1<box_2,否则返回0(表示相等)或者1(box1>box2)。
35
+ 3. 如果不位于同一行,那么比较二者的y0,如果box1_y0 < box2_y0,则返回-1,否则返回0或者1.
36
+ """
37
+
38
+ x0, y0, x1, y1 = word_box1[1]
39
+ x0_, y0_, x1_, y1_ = word_box2[1]
40
+
41
+
42
+ if y0 < y0_:
43
+ return -1
44
+ elif y0 > y0_:
45
+ return 1
46
+ elif y0 == y0_:
47
+ if x0 <= x0_:
48
+ return -1
49
+ elif x0 > x0_:
50
+ return 1
51
+
52
+ sorted_tuple_list = sorted(tuple_list, key=cmp_to_key(sort_cmp_fn))
53
+ # print(sorted_word_box_list)
54
+
55
+ return sorted_tuple_list
56
+