| |
| import json |
| import os |
|
|
| import jieba |
| |
| import pandas as pd |
| |
| import torch |
| |
| from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer |
| from nltk.translate.bleu_score import sentence_bleu |
| from sacrebleu.metrics import CHRF |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| |
| from transformers import BertTokenizer, EncoderDecoderModel |
|
|
| chrf = CHRF(word_order=2) |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "5" |
|
|
| |
| |
| MAX_TGT_LEN = 512 |
| MAX_SRC_LEN = 512 |
|
|
| |
| output_dir = "./mytest" |
| os.makedirs(output_dir, exist_ok=True) |
| early_stopping = True |
| num_beams = 2 |
| length_penalty = 1.0 |
| batch_size = 16 |
| metric_res_filepath = os.path.join(output_dir, "metric_res.json") |
| decoding_res_filepath = os.path.join(output_dir, "decoding_res.json") |
| trained_model_dir = "/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_28.step_400000.layer_12-12/checkpoint-64000" |
|
|
| dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset" |
| data_file = f"{dataset_dir}/testset_10k.jsonl" |
|
|
|
|
| def no_blank(sen): |
| return "".join(sen.split()) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base" |
|
|
| tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer" |
|
|
| src_tokenizer, tgt_tokenizer = prepare_tokenizer( |
| src_tokenizer_dir=encoder_ckpt_dir, |
| tgt_tokenizer_dir=tgt_tokenizer_dir, |
| ) |
| dataset_df = prepare_dataset_df(data_file=data_file) |
| my_dataset = MyDataset(df=dataset_df, |
| src_tokenizer=src_tokenizer, |
| tgt_tokenizer=tgt_tokenizer, |
| max_src_length=512, |
| max_target_length=512) |
| print(len(my_dataset)) |
| from torch.utils.data import Subset |
| num_test = 5000 |
| my_dataset = Subset(my_dataset, range(0, num_test)) |
| my_dataloader = DataLoader( |
| my_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| ) |
|
|
| |
| model = EncoderDecoderModel.from_pretrained(trained_model_dir) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
|
|
| print(model) |
|
|
| |
| pred_res_list = [] |
| gt_list = [] |
|
|
| for batch in tqdm(my_dataloader): |
| |
| with torch.no_grad(): |
| encoder_outputs = model.encoder( |
| input_ids=batch["input_ids"].to(device), |
| bbox=batch["bbox"].to(device), |
| attention_mask=batch["attention_mask"].to(device), |
| ) |
| outputs = model.generate( |
| input_ids=batch["input_ids"].to(device), |
| attention_mask=batch["attention_mask"].to(device), |
| encoder_outputs=encoder_outputs, |
| max_length=MAX_TGT_LEN, |
| early_stopping=early_stopping, |
| num_beams=num_beams, |
| length_penalty=length_penalty, |
| use_cache=True, |
| decoder_start_token_id=0) |
|
|
| |
| pred_str = tgt_tokenizer.batch_decode(outputs, |
| skip_special_tokens=True) |
| labels = batch["labels"] |
| labels[labels == -100] = tgt_tokenizer.pad_token_id |
| label_str = tgt_tokenizer.batch_decode(labels, |
| skip_special_tokens=True) |
|
|
| pred_res_list += pred_str |
| gt_list += label_str |
|
|
| gt_list = [no_blank(sen) for sen in gt_list] |
| pred_res_list = [no_blank(sen) for sen in pred_res_list] |
|
|
| |
| img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist() |
| text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist() |
| bleu_list = [] |
| chrf_list = [] |
|
|
| pred_res_seg_list = [" ".join(jieba.cut(item)) for item in pred_res_list] |
| gt_seg_list = [" ".join(jieba.cut(item)) for item in gt_list] |
| print(len(text_src_list), len(pred_res_seg_list), len(gt_seg_list)) |
| |
| assert len(img_name_list) == len(pred_res_seg_list) == len(gt_seg_list) |
|
|
| with open(decoding_res_filepath, "w") as decoding_res_file: |
| for img_name, text_src, pred_res_seg, gt_seg in zip( |
| img_name_list, text_src_list, pred_res_seg_list, gt_seg_list): |
|
|
| instance_bleu = sentence_bleu([gt_seg.split()], |
| pred_res_seg.split()) |
| bleu_list.append(instance_bleu) |
|
|
| instance_chrf = chrf.sentence_score( |
| hypothesis=pred_res_seg, |
| references=[gt_seg], |
| ).score |
| chrf_list.append(instance_chrf) |
|
|
| res_dict = { |
| "img_name": img_name, |
| "text_src": text_src, |
| "instance_bleu": instance_bleu, |
| "instance_chrf": instance_chrf, |
| "trans_res_seg": pred_res_seg, |
| "gt_seg": gt_seg, |
| } |
|
|
| record = f"{json.dumps(res_dict, ensure_ascii=False)}\n" |
| decoding_res_file.write(record) |
|
|
| trans_avg_bleu = sum(bleu_list) / len(bleu_list) |
| trans_avg_chrf = sum(chrf_list) / len(chrf_list) |
| with open(metric_res_filepath, "w") as metric_res_file: |
| eval_res_dict = { |
| "trans_avg_bleu": trans_avg_bleu, |
| "trans_avg_chrf": trans_avg_chrf, |
| } |
| json.dump(eval_res_dict, metric_res_file, indent=4, ensure_ascii=False) |
|
|