Spaces:
Runtime error
Runtime error
Upload 34 files
Browse files- 4_20230227-0026/config.json +26 -0
- 4_20230227-0026/experiments.log +83 -0
- 4_20230227-0026/main_config.cfg +27 -0
- 4_20230227-0026/merges.txt +0 -0
- 4_20230227-0026/pytorch_model.bin +3 -0
- 4_20230227-0026/training_args.bin +3 -0
- 4_20230227-0026/vocab.json +0 -0
- app.py +236 -0
- data_loader.py +233 -0
- flagged/log.csv +2 -0
- flagged/output/tmpginvysx3.json +1 -0
- main.py +545 -0
- main_config.cfg +58 -0
- modeling.py +403 -0
- requirements.txt +2 -0
- run_classifier_dataset_utils.py +669 -0
- scripts/run.sh +2 -0
- scripts/run_bagging.sh +8 -0
- utils/Config.py +128 -0
- utils/Logger.py +84 -0
- utils/ResultTable.py +160 -0
- utils/Statistics.py +29 -0
- utils/Tool.py +17 -0
- utils/__init__.py +4 -0
- utils/__pycache__/Config.cpython-36.pyc +0 -0
- utils/__pycache__/Config.cpython-38.pyc +0 -0
- utils/__pycache__/Logger.cpython-36.pyc +0 -0
- utils/__pycache__/Logger.cpython-38.pyc +0 -0
- utils/__pycache__/ResultTable.cpython-36.pyc +0 -0
- utils/__pycache__/ResultTable.cpython-38.pyc +0 -0
- utils/__pycache__/Tool.cpython-36.pyc +0 -0
- utils/__pycache__/Tool.cpython-38.pyc +0 -0
- utils/__pycache__/__init__.cpython-36.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
4_20230227-0026/config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "roberta-base",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"RobertaForMaskedLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"eos_token_id": 2,
|
| 9 |
+
"gradient_checkpointing": false,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"hidden_size": 768,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 3072,
|
| 15 |
+
"layer_norm_eps": 1e-05,
|
| 16 |
+
"max_position_embeddings": 514,
|
| 17 |
+
"model_type": "roberta",
|
| 18 |
+
"num_attention_heads": 12,
|
| 19 |
+
"num_hidden_layers": 12,
|
| 20 |
+
"pad_token_id": 1,
|
| 21 |
+
"position_embedding_type": "absolute",
|
| 22 |
+
"transformers_version": "4.2.2",
|
| 23 |
+
"type_vocab_size": 4,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"vocab_size": 50265
|
| 26 |
+
}
|
4_20230227-0026/experiments.log
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[56]2023-02-27 00:26:50,668: device: cuda n_gpu: 1
|
| 2 |
+
[56]2023-02-27 00:29:19,004: ***** Running training *****
|
| 3 |
+
[56]2023-02-27 00:29:19,004: Batch size = 16
|
| 4 |
+
[56]2023-02-27 00:29:19,004: Num steps = 30030
|
| 5 |
+
[56]2023-02-27 01:18:38,101: [epoch 1] ,lr: 1.5e-05 ,tr_loss: 4808.392509443685
|
| 6 |
+
[56]2023-02-27 01:19:00,946: ***** Running evaluation *****
|
| 7 |
+
[56]2023-02-27 01:21:16,307: acc = 0.8590286538114976
|
| 8 |
+
[56]2023-02-27 01:21:16,308: f1 = 0.4659498207885305
|
| 9 |
+
[56]2023-02-27 01:21:16,308: precision = 0.727224294086308
|
| 10 |
+
[56]2023-02-27 01:21:16,308: recall = 0.3427925665494726
|
| 11 |
+
[56]2023-02-27 02:08:39,574: [epoch 2] ,lr: 3e-05 ,tr_loss: 2660.9157069001812
|
| 12 |
+
[56]2023-02-27 02:09:03,298: ***** Running evaluation *****
|
| 13 |
+
[56]2023-02-27 02:11:18,653: acc = 0.8952513966480447
|
| 14 |
+
[56]2023-02-27 02:11:18,653: f1 = 0.6620148277365896
|
| 15 |
+
[56]2023-02-27 02:11:18,653: precision = 0.7859855022437003
|
| 16 |
+
[56]2023-02-27 02:11:18,653: recall = 0.5718232044198895
|
| 17 |
+
[56]2023-02-27 02:59:03,528: [epoch 3] ,lr: 0.0 ,tr_loss: 1665.7832884637173
|
| 18 |
+
[56]2023-02-27 02:59:27,660: ***** Running evaluation *****
|
| 19 |
+
[56]2023-02-27 03:01:41,367: acc = 0.9039466570553253
|
| 20 |
+
[56]2023-02-27 03:01:41,367: f1 = 0.7135178715399086
|
| 21 |
+
[56]2023-02-27 03:01:41,367: precision = 0.7673410404624278
|
| 22 |
+
[56]2023-02-27 03:01:41,367: recall = 0.666750376695128
|
| 23 |
+
[56]2023-02-27 03:01:43,253: -----Best Result-----
|
| 24 |
+
[56]2023-02-27 03:01:43,253: acc = 0.9039466570553253
|
| 25 |
+
[56]2023-02-27 03:01:43,253: f1 = 0.7135178715399086
|
| 26 |
+
[56]2023-02-27 03:01:43,253: precision = 0.7673410404624278
|
| 27 |
+
[56]2023-02-27 03:01:43,253: recall = 0.666750376695128
|
| 28 |
+
[56]2023-02-27 03:02:04,028: ***** Running evaluation *****
|
| 29 |
+
[56]2023-02-27 03:04:17,297: acc = 0.9039466570553253
|
| 30 |
+
[56]2023-02-27 03:04:17,297: f1 = 0.7135178715399086
|
| 31 |
+
[56]2023-02-27 03:04:17,297: precision = 0.7673410404624278
|
| 32 |
+
[56]2023-02-27 03:04:17,297: recall = 0.666750376695128
|
| 33 |
+
[56]2023-02-27 03:04:17,298: Saved to saves/roberta-base/4_20230227-0026
|
| 34 |
+
[56]2023-02-28 02:41:26,304: device: cpu n_gpu: 0
|
| 35 |
+
[56]2023-02-28 02:44:12,493: device: cpu n_gpu: 0
|
| 36 |
+
[56]2023-02-28 02:50:35,429: device: cpu n_gpu: 0
|
| 37 |
+
[56]2023-02-28 02:50:56,790: ***** Running evaluation *****
|
| 38 |
+
[56]2023-02-28 02:53:34,368: device: cpu n_gpu: 0
|
| 39 |
+
[56]2023-02-28 03:00:12,499: ***** Running evaluation *****
|
| 40 |
+
[56]2023-02-28 03:22:26,949: device: cpu n_gpu: 0
|
| 41 |
+
[56]2023-02-28 03:47:13,134: device: cpu n_gpu: 0
|
| 42 |
+
[56]2023-02-28 03:48:38,484: device: cpu n_gpu: 0
|
| 43 |
+
[56]2023-02-28 03:49:14,589: device: cpu n_gpu: 0
|
| 44 |
+
[56]2023-02-28 04:02:49,943: device: cpu n_gpu: 0
|
| 45 |
+
[56]2023-02-28 04:32:46,799: device: cpu n_gpu: 0
|
| 46 |
+
[56]2023-02-28 04:34:48,647: device: cpu n_gpu: 0
|
| 47 |
+
[56]2023-02-28 04:37:23,090: device: cpu n_gpu: 0
|
| 48 |
+
[56]2023-02-28 04:40:06,170: device: cpu n_gpu: 0
|
| 49 |
+
[56]2023-02-28 04:43:40,692: device: cpu n_gpu: 0
|
| 50 |
+
[56]2023-02-28 04:57:12,843: device: cpu n_gpu: 0
|
| 51 |
+
[56]2023-02-28 05:03:30,653: device: cpu n_gpu: 0
|
| 52 |
+
[56]2023-02-28 05:14:12,276: device: cpu n_gpu: 0
|
| 53 |
+
[56]2023-02-28 05:15:25,604: device: cpu n_gpu: 0
|
| 54 |
+
[56]2023-02-28 05:21:26,346: device: cpu n_gpu: 0
|
| 55 |
+
[56]2023-02-28 05:29:09,123: device: cpu n_gpu: 0
|
| 56 |
+
[56]2023-02-28 05:30:59,283: device: cpu n_gpu: 0
|
| 57 |
+
[56]2023-02-28 05:34:33,463: device: cpu n_gpu: 0
|
| 58 |
+
[56]2023-02-28 05:37:25,436: device: cpu n_gpu: 0
|
| 59 |
+
[56]2023-02-28 05:40:54,252: device: cpu n_gpu: 0
|
| 60 |
+
[56]2023-02-28 05:51:44,583: device: cpu n_gpu: 0
|
| 61 |
+
[56]2023-02-28 05:54:21,953: device: cpu n_gpu: 0
|
| 62 |
+
[56]2023-02-28 06:04:58,550: device: cpu n_gpu: 0
|
| 63 |
+
[56]2023-02-28 06:12:55,019: device: cpu n_gpu: 0
|
| 64 |
+
[56]2023-02-28 06:17:31,790: device: cpu n_gpu: 0
|
| 65 |
+
[56]2023-02-28 06:21:49,848: device: cpu n_gpu: 0
|
| 66 |
+
[56]2023-02-28 06:23:45,894: device: cpu n_gpu: 0
|
| 67 |
+
[56]2023-02-28 06:30:27,960: device: cpu n_gpu: 0
|
| 68 |
+
[56]2023-02-28 06:34:11,145: device: cpu n_gpu: 0
|
| 69 |
+
[56]2023-02-28 06:36:56,962: device: cpu n_gpu: 0
|
| 70 |
+
[56]2023-02-28 06:38:45,488: device: cpu n_gpu: 0
|
| 71 |
+
[56]2023-02-28 06:39:18,822: device: cpu n_gpu: 0
|
| 72 |
+
[56]2023-02-28 06:39:44,789: device: cpu n_gpu: 0
|
| 73 |
+
[56]2023-02-28 06:44:02,812: device: cpu n_gpu: 0
|
| 74 |
+
[56]2023-02-28 06:45:15,008: device: cpu n_gpu: 0
|
| 75 |
+
[56]2023-02-28 06:48:26,234: device: cpu n_gpu: 0
|
| 76 |
+
[56]2023-02-28 06:54:51,113: device: cpu n_gpu: 0
|
| 77 |
+
[56]2023-02-28 07:08:07,149: device: cpu n_gpu: 0
|
| 78 |
+
[56]2023-02-28 07:10:04,991: device: cpu n_gpu: 0
|
| 79 |
+
[56]2023-02-28 07:11:36,409: device: cpu n_gpu: 0
|
| 80 |
+
[56]2023-02-28 07:12:03,892: device: cpu n_gpu: 0
|
| 81 |
+
[56]2023-02-28 07:13:15,960: device: cpu n_gpu: 0
|
| 82 |
+
[56]2023-02-28 07:21:36,476: device: cpu n_gpu: 0
|
| 83 |
+
[56]2023-02-28 07:23:15,164: device: cpu n_gpu: 0
|
4_20230227-0026/main_config.cfg
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[args]
|
| 2 |
+
bert_model=roberta-base
|
| 3 |
+
data_dir=data/VUA20
|
| 4 |
+
task_name=vua
|
| 5 |
+
model_type=MELBERT
|
| 6 |
+
classifier_hidden=768
|
| 7 |
+
lr_schedule=warmup_linear
|
| 8 |
+
warmup_epoch=2
|
| 9 |
+
drop_ratio=0.2
|
| 10 |
+
kfold=10
|
| 11 |
+
num_bagging=0
|
| 12 |
+
bagging_index=0
|
| 13 |
+
use_pos=True
|
| 14 |
+
use_local_context=True
|
| 15 |
+
max_seq_length=150
|
| 16 |
+
do_train=True
|
| 17 |
+
do_test=True
|
| 18 |
+
do_eval=True
|
| 19 |
+
do_lower_case=False
|
| 20 |
+
class_weight=3
|
| 21 |
+
train_batch_size=16
|
| 22 |
+
eval_batch_size=8
|
| 23 |
+
learning_rate=3e-05
|
| 24 |
+
num_train_epoch=3
|
| 25 |
+
no_cuda=False
|
| 26 |
+
seed=42
|
| 27 |
+
|
4_20230227-0026/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4_20230227-0026/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3027110c50c22427bd23fad97158fbf5c366c8f3dabb92865abcb66df4334adf
|
| 3 |
+
size 508135877
|
4_20230227-0026/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:044cdea230767bb31596587242f594f00c14006b56bf8277ef59e3086cb00d4a
|
| 3 |
+
size 1339
|
4_20230227-0026/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import pickle
|
| 4 |
+
import random
|
| 5 |
+
import copy
|
| 6 |
+
import numpy as np
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import re
|
| 9 |
+
import string
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from tqdm import tqdm, trange
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
|
| 17 |
+
|
| 18 |
+
from utils import Config, Logger, make_log_dir
|
| 19 |
+
from modeling import (
|
| 20 |
+
AutoModelForSequenceClassification,
|
| 21 |
+
AutoModelForTokenClassification,
|
| 22 |
+
AutoModelForSequenceClassification_SPV,
|
| 23 |
+
AutoModelForSequenceClassification_MIP,
|
| 24 |
+
AutoModelForSequenceClassification_SPV_MIP,
|
| 25 |
+
)
|
| 26 |
+
from run_classifier_dataset_utils import processors, output_modes, compute_metrics
|
| 27 |
+
from data_loader import load_train_data, load_train_data_kf, load_test_data, load_sentence_data
|
| 28 |
+
|
| 29 |
+
from frame_semantic_transformer import FrameSemanticTransformer
|
| 30 |
+
|
| 31 |
+
frame_transformer = FrameSemanticTransformer()
|
| 32 |
+
frame_transformer.setup()
|
| 33 |
+
|
| 34 |
+
CONFIG_NAME = "config.json"
|
| 35 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 36 |
+
ARGS_NAME = "training_args.bin"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
# read configs
|
| 41 |
+
config = Config(main_conf_path="./")
|
| 42 |
+
|
| 43 |
+
# apply system arguments if exist
|
| 44 |
+
argv = sys.argv[1:]
|
| 45 |
+
if len(argv) > 0:
|
| 46 |
+
cmd_arg = OrderedDict()
|
| 47 |
+
argvs = " ".join(sys.argv[1:]).split(" ")
|
| 48 |
+
for i in range(0, len(argvs), 2):
|
| 49 |
+
arg_name, arg_value = argvs[i], argvs[i + 1]
|
| 50 |
+
arg_name = arg_name.strip("-")
|
| 51 |
+
cmd_arg[arg_name] = arg_value
|
| 52 |
+
config.update_params(cmd_arg)
|
| 53 |
+
|
| 54 |
+
args = config
|
| 55 |
+
print(args.__dict__)
|
| 56 |
+
|
| 57 |
+
# logger
|
| 58 |
+
if "saves" in args.bert_model:
|
| 59 |
+
log_dir = args.bert_model
|
| 60 |
+
logger = Logger(log_dir)
|
| 61 |
+
config = Config(main_conf_path=log_dir)
|
| 62 |
+
old_args = copy.deepcopy(args)
|
| 63 |
+
args.__dict__.update(config.__dict__)
|
| 64 |
+
|
| 65 |
+
args.bert_model = old_args.bert_model
|
| 66 |
+
args.do_train = old_args.do_train
|
| 67 |
+
args.data_dir = old_args.data_dir
|
| 68 |
+
args.task_name = old_args.task_name
|
| 69 |
+
|
| 70 |
+
# apply system arguments if exist
|
| 71 |
+
argv = sys.argv[1:]
|
| 72 |
+
if len(argv) > 0:
|
| 73 |
+
cmd_arg = OrderedDict()
|
| 74 |
+
argvs = " ".join(sys.argv[1:]).split(" ")
|
| 75 |
+
for i in range(0, len(argvs), 2):
|
| 76 |
+
arg_name, arg_value = argvs[i], argvs[i + 1]
|
| 77 |
+
arg_name = arg_name.strip("-")
|
| 78 |
+
cmd_arg[arg_name] = arg_value
|
| 79 |
+
config.update_params(cmd_arg)
|
| 80 |
+
else:
|
| 81 |
+
if not os.path.exists("saves"):
|
| 82 |
+
os.mkdir("saves")
|
| 83 |
+
log_dir = make_log_dir(os.path.join("saves", args.bert_model))
|
| 84 |
+
logger = Logger(log_dir)
|
| 85 |
+
config.save(log_dir)
|
| 86 |
+
args.log_dir = log_dir
|
| 87 |
+
|
| 88 |
+
# set CUDA devices
|
| 89 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
| 90 |
+
args.n_gpu = torch.cuda.device_count()
|
| 91 |
+
args.device = device
|
| 92 |
+
|
| 93 |
+
logger.info("device: {} n_gpu: {}".format(device, args.n_gpu))
|
| 94 |
+
|
| 95 |
+
# set seed
|
| 96 |
+
random.seed(args.seed)
|
| 97 |
+
np.random.seed(args.seed)
|
| 98 |
+
torch.manual_seed(args.seed)
|
| 99 |
+
if args.n_gpu > 0:
|
| 100 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 101 |
+
|
| 102 |
+
# get dataset and processor
|
| 103 |
+
args.num_labels = 2
|
| 104 |
+
|
| 105 |
+
# build tokenizer and model
|
| 106 |
+
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
| 107 |
+
model = load_pretrained_model(args)
|
| 108 |
+
|
| 109 |
+
# Load trained model
|
| 110 |
+
if "saves" in args.bert_model:
|
| 111 |
+
model = load_trained_model(args, model, tokenizer)
|
| 112 |
+
|
| 113 |
+
#tokenizer.tokenize('the debate has sharpened.')
|
| 114 |
+
|
| 115 |
+
def run_one_sentence(sentence):
|
| 116 |
+
print('sentence:', sentence)
|
| 117 |
+
sentence = re.sub(r'([.,!?()-]+)', r' \1 ', sentence)
|
| 118 |
+
sentence = ' '.join(sentence.split())
|
| 119 |
+
print('sentence:', sentence)
|
| 120 |
+
|
| 121 |
+
result = frame_transformer.detect_frames(sentence)
|
| 122 |
+
print(result)
|
| 123 |
+
|
| 124 |
+
model.eval()
|
| 125 |
+
s_batch = load_sentence_data(args, sentence, ['0','1'], tokenizer, 'classification')
|
| 126 |
+
|
| 127 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 128 |
+
input_ids, input_mask, segment_ids, label_ids, idx, input_ids_2, input_mask_2, segment_ids_2 = s_batch
|
| 129 |
+
else:
|
| 130 |
+
input_ids, input_mask, segment_ids, label_ids, idx = s_batch
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
# compute loss values
|
| 134 |
+
if args.model_type in ["BERT_BASE", "BERT_SEQ", "MELBERT_SPV"]:
|
| 135 |
+
logits = model(
|
| 136 |
+
input_ids,
|
| 137 |
+
target_mask=(segment_ids == 1),
|
| 138 |
+
token_type_ids=segment_ids,
|
| 139 |
+
attention_mask=input_mask,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
elif args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 143 |
+
logits = model(
|
| 144 |
+
input_ids,
|
| 145 |
+
input_ids_2,
|
| 146 |
+
target_mask=(segment_ids == 1),
|
| 147 |
+
target_mask_2=segment_ids_2,
|
| 148 |
+
attention_mask_2=input_mask_2,
|
| 149 |
+
token_type_ids=segment_ids,
|
| 150 |
+
attention_mask=input_mask,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
pred = logits.detach().cpu().numpy()
|
| 154 |
+
pred = np.argmax(pred, axis=1)
|
| 155 |
+
pred_list = [None for _ in range(len(sentence.split()))]
|
| 156 |
+
for i,n in enumerate(idx):
|
| 157 |
+
pred_list[n] = 'M' if pred[i] == 1 else None
|
| 158 |
+
print(len(pred_list), pred_list)
|
| 159 |
+
label_list = [(w, p) for w,p in zip(sentence.split(), pred_list)]
|
| 160 |
+
print(label_list)
|
| 161 |
+
return label_list, result
|
| 162 |
+
#import pdb; pdb.set_trace()
|
| 163 |
+
|
| 164 |
+
demo = gr.Interface(
|
| 165 |
+
run_one_sentence,
|
| 166 |
+
gr.Textbox(placeholder="Enter sentence here..."),
|
| 167 |
+
['highlight', 'json'],
|
| 168 |
+
examples=[
|
| 169 |
+
['while new departments are born and others extended .'],
|
| 170 |
+
['The sounds are the same as those of daylight , yet somehow the night magnifies and sharpens the creak of a yielding block , the sigh of air over a shroud , the stretching of a sail , the hiss of water sliding sleek against the hull , the curl of a quarter-wave falling away , and the thump as a wave strikes the cutwater to be sheared into two bright slices of whiteness .'],
|
| 171 |
+
['and finally, the debate has sharpened.']
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
demo.launch(debug=True)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def load_pretrained_model(args):
|
| 179 |
+
# Pretrained Model
|
| 180 |
+
bert = AutoModel.from_pretrained(args.bert_model)
|
| 181 |
+
#for name, param in bert.named_parameters():
|
| 182 |
+
# print(name, param.requires_grad)
|
| 183 |
+
config = bert.config
|
| 184 |
+
config.type_vocab_size = 4
|
| 185 |
+
if "albert" in args.bert_model:
|
| 186 |
+
bert.embeddings.token_type_embeddings = nn.Embedding(
|
| 187 |
+
config.type_vocab_size, config.embedding_size
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
bert.embeddings.token_type_embeddings = nn.Embedding(
|
| 191 |
+
config.type_vocab_size, config.hidden_size
|
| 192 |
+
)
|
| 193 |
+
bert._init_weights(bert.embeddings.token_type_embeddings)
|
| 194 |
+
|
| 195 |
+
# Additional Layers
|
| 196 |
+
if args.model_type in ["BERT_BASE"]:
|
| 197 |
+
model = AutoModelForSequenceClassification(
|
| 198 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 199 |
+
)
|
| 200 |
+
if args.model_type == "BERT_SEQ":
|
| 201 |
+
model = AutoModelForTokenClassification(
|
| 202 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 203 |
+
)
|
| 204 |
+
if args.model_type == "MELBERT_SPV":
|
| 205 |
+
model = AutoModelForSequenceClassification_SPV(
|
| 206 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 207 |
+
)
|
| 208 |
+
if args.model_type == "MELBERT_MIP":
|
| 209 |
+
model = AutoModelForSequenceClassification_MIP(
|
| 210 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 211 |
+
)
|
| 212 |
+
if args.model_type == "MELBERT":
|
| 213 |
+
model = AutoModelForSequenceClassification_SPV_MIP(
|
| 214 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
model.to(args.device)
|
| 218 |
+
if args.n_gpu > 1 and not args.no_cuda:
|
| 219 |
+
model = torch.nn.DataParallel(model)
|
| 220 |
+
return model
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def load_trained_model(args, model, tokenizer):
|
| 224 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 225 |
+
output_model_file = os.path.join(args.log_dir, WEIGHTS_NAME)
|
| 226 |
+
|
| 227 |
+
if hasattr(model, "module"):
|
| 228 |
+
model.module.load_state_dict(torch.load(output_model_file, map_location=args.device))
|
| 229 |
+
else:
|
| 230 |
+
model.load_state_dict(torch.load(output_model_file, map_location=args.device))
|
| 231 |
+
|
| 232 |
+
return model
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
main()
|
data_loader.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import string
|
| 6 |
+
from sklearn.model_selection import StratifiedKFold
|
| 7 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
| 8 |
+
from run_classifier_dataset_utils import (
|
| 9 |
+
convert_examples_to_two_features,
|
| 10 |
+
convert_examples_to_features,
|
| 11 |
+
convert_two_examples_to_features,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_train_data(args, logger, processor, task_name, label_list, tokenizer, output_mode, k=None):
|
| 16 |
+
# Prepare data loader
|
| 17 |
+
if task_name == "vua":
|
| 18 |
+
train_examples = processor.get_train_examples(args.data_dir)
|
| 19 |
+
elif task_name == "trofi":
|
| 20 |
+
train_examples = processor.get_train_examples(args.data_dir, k)
|
| 21 |
+
else:
|
| 22 |
+
raise ("task_name should be 'vua' or 'trofi'!")
|
| 23 |
+
import pdb; pdb.set_trace()
|
| 24 |
+
print(args.model_type, args.max_data_num)
|
| 25 |
+
# make features file
|
| 26 |
+
if args.model_type == "BERT_BASE":
|
| 27 |
+
train_features = convert_two_examples_to_features(
|
| 28 |
+
train_examples, label_list, args.max_seq_length, tokenizer, output_mode
|
| 29 |
+
)
|
| 30 |
+
if args.model_type in ["BERT_SEQ", "MELBERT_SPV"]:
|
| 31 |
+
train_features = convert_examples_to_features(
|
| 32 |
+
train_examples, label_list, args.max_seq_length, tokenizer, output_mode, args
|
| 33 |
+
)
|
| 34 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 35 |
+
train_features = convert_examples_to_two_features(
|
| 36 |
+
train_examples, label_list, args.max_seq_length, tokenizer, output_mode, args
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# make features into tensor
|
| 40 |
+
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
| 41 |
+
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
| 42 |
+
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
| 43 |
+
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
| 44 |
+
|
| 45 |
+
# add additional features for MELBERT_MIP and MELBERT
|
| 46 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 47 |
+
all_input_ids_2 = torch.tensor([f.input_ids_2 for f in train_features], dtype=torch.long)
|
| 48 |
+
all_input_mask_2 = torch.tensor([f.input_mask_2 for f in train_features], dtype=torch.long)
|
| 49 |
+
all_segment_ids_2 = torch.tensor(
|
| 50 |
+
[f.segment_ids_2 for f in train_features], dtype=torch.long
|
| 51 |
+
)
|
| 52 |
+
train_data = TensorDataset(
|
| 53 |
+
all_input_ids,
|
| 54 |
+
all_input_mask,
|
| 55 |
+
all_segment_ids,
|
| 56 |
+
all_label_ids,
|
| 57 |
+
all_input_ids_2,
|
| 58 |
+
all_input_mask_2,
|
| 59 |
+
all_segment_ids_2,
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
| 63 |
+
train_sampler = RandomSampler(train_data)
|
| 64 |
+
train_dataloader = DataLoader(
|
| 65 |
+
train_data, sampler=train_sampler, batch_size=args.train_batch_size
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return train_dataloader
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_train_data_kf(
|
| 72 |
+
args, logger, processor, task_name, label_list, tokenizer, output_mode, k=None
|
| 73 |
+
):
|
| 74 |
+
# Prepare data loader
|
| 75 |
+
if task_name == "vua":
|
| 76 |
+
train_examples = processor.get_train_examples(args.data_dir)
|
| 77 |
+
elif task_name == "trofi":
|
| 78 |
+
train_examples = processor.get_train_examples(args.data_dir, k)
|
| 79 |
+
else:
|
| 80 |
+
raise ("task_name should be 'vua' or 'trofi'!")
|
| 81 |
+
|
| 82 |
+
# make features file
|
| 83 |
+
if args.model_type == "BERT_BASE":
|
| 84 |
+
train_features = convert_two_examples_to_features(
|
| 85 |
+
train_examples, label_list, args.max_seq_length, tokenizer, output_mode
|
| 86 |
+
)
|
| 87 |
+
if args.model_type in ["BERT_SEQ", "MELBERT_SPV"]:
|
| 88 |
+
train_features = convert_examples_to_features(
|
| 89 |
+
train_examples, label_list, args.max_seq_length, tokenizer, output_mode, args
|
| 90 |
+
)
|
| 91 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 92 |
+
train_features = convert_examples_to_two_features(
|
| 93 |
+
train_examples, label_list, args.max_seq_length, tokenizer, output_mode, args
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# make features into tensor
|
| 97 |
+
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
| 98 |
+
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
| 99 |
+
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
| 100 |
+
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
| 101 |
+
|
| 102 |
+
# add additional features for MELBERT_MIP and MELBERT
|
| 103 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 104 |
+
all_input_ids_2 = torch.tensor([f.input_ids_2 for f in train_features], dtype=torch.long)
|
| 105 |
+
all_input_mask_2 = torch.tensor([f.input_mask_2 for f in train_features], dtype=torch.long)
|
| 106 |
+
all_segment_ids_2 = torch.tensor(
|
| 107 |
+
[f.segment_ids_2 for f in train_features], dtype=torch.long
|
| 108 |
+
)
|
| 109 |
+
train_data = TensorDataset(
|
| 110 |
+
all_input_ids,
|
| 111 |
+
all_input_mask,
|
| 112 |
+
all_segment_ids,
|
| 113 |
+
all_label_ids,
|
| 114 |
+
all_input_ids_2,
|
| 115 |
+
all_input_mask_2,
|
| 116 |
+
all_segment_ids_2,
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
| 120 |
+
gkf = StratifiedKFold(n_splits=args.num_bagging).split(X=all_input_ids, y=all_label_ids.numpy())
|
| 121 |
+
return train_data, gkf
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def load_test_data(args, logger, processor, task_name, label_list, tokenizer, output_mode, k=None):
|
| 125 |
+
if task_name == "vua":
|
| 126 |
+
eval_examples = processor.get_test_examples(args.data_dir)
|
| 127 |
+
elif task_name == "trofi":
|
| 128 |
+
eval_examples = processor.get_test_examples(args.data_dir, k)
|
| 129 |
+
else:
|
| 130 |
+
raise ("task_name should be 'vua' or 'trofi'!")
|
| 131 |
+
import pdb; pdb.set_trace()
|
| 132 |
+
eval_examples = eval_examples[14185:14216]
|
| 133 |
+
if args.model_type == "BERT_BASE":
|
| 134 |
+
eval_features = convert_two_examples_to_features(
|
| 135 |
+
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode
|
| 136 |
+
)
|
| 137 |
+
if args.model_type in ["BERT_SEQ", "MELBERT_SPV"]:
|
| 138 |
+
eval_features = convert_examples_to_features(
|
| 139 |
+
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, args
|
| 140 |
+
)
|
| 141 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 142 |
+
eval_features = convert_examples_to_two_features(
|
| 143 |
+
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, args
|
| 144 |
+
)
|
| 145 |
+
import pdb; pdb.set_trace()
|
| 146 |
+
logger.info("***** Running evaluation *****")
|
| 147 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 148 |
+
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
| 149 |
+
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
| 150 |
+
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
| 151 |
+
all_guids = [f.guid for f in eval_features]
|
| 152 |
+
all_idx = torch.tensor([i for i in range(len(eval_features))], dtype=torch.long)
|
| 153 |
+
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
| 154 |
+
all_input_ids_2 = torch.tensor([f.input_ids_2 for f in eval_features], dtype=torch.long)
|
| 155 |
+
all_input_mask_2 = torch.tensor([f.input_mask_2 for f in eval_features], dtype=torch.long)
|
| 156 |
+
all_segment_ids_2 = torch.tensor([f.segment_ids_2 for f in eval_features], dtype=torch.long)
|
| 157 |
+
eval_data = TensorDataset(
|
| 158 |
+
all_input_ids,
|
| 159 |
+
all_input_mask,
|
| 160 |
+
all_segment_ids,
|
| 161 |
+
all_label_ids,
|
| 162 |
+
all_idx,
|
| 163 |
+
all_input_ids_2,
|
| 164 |
+
all_input_mask_2,
|
| 165 |
+
all_segment_ids_2,
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
| 169 |
+
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
| 170 |
+
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
| 171 |
+
all_guids = [f.guid for f in eval_features]
|
| 172 |
+
all_idx = torch.tensor([i for i in range(len(eval_features))], dtype=torch.long)
|
| 173 |
+
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
| 174 |
+
eval_data = TensorDataset(
|
| 175 |
+
all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_idx
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Run prediction for full data
|
| 179 |
+
eval_sampler = SequentialSampler(eval_data)
|
| 180 |
+
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
| 181 |
+
|
| 182 |
+
return all_guids, eval_dataloader
|
| 183 |
+
|
| 184 |
+
from run_classifier_dataset_utils import InputExample
|
| 185 |
+
def load_sentence_data(args, sentence, label_list, tokenizer, output_mode, ):
|
| 186 |
+
#tokens = tokenizer.tokenize(sentence)
|
| 187 |
+
#print('tokens:', tokens)
|
| 188 |
+
examples = []
|
| 189 |
+
example_idxs = []
|
| 190 |
+
for index, token in enumerate(sentence.split()):
|
| 191 |
+
if token not in string.punctuation:
|
| 192 |
+
examples.append(
|
| 193 |
+
InputExample(
|
| 194 |
+
guid='', text_a=sentence, text_b=str(index), label='0', POS='', FGPOS=''
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
print('[', index, token, ']', end=', ')
|
| 198 |
+
example_idxs.append(index)
|
| 199 |
+
eval_features = convert_examples_to_two_features(
|
| 200 |
+
examples, label_list, args.max_seq_length, tokenizer, output_mode, args
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 204 |
+
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
| 205 |
+
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
| 206 |
+
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
| 207 |
+
all_guids = [f.guid for f in eval_features]
|
| 208 |
+
all_idx = torch.tensor(example_idxs, dtype=torch.long)
|
| 209 |
+
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
| 210 |
+
all_input_ids_2 = torch.tensor([f.input_ids_2 for f in eval_features], dtype=torch.long)
|
| 211 |
+
all_input_mask_2 = torch.tensor([f.input_mask_2 for f in eval_features], dtype=torch.long)
|
| 212 |
+
all_segment_ids_2 = torch.tensor([f.segment_ids_2 for f in eval_features], dtype=torch.long)
|
| 213 |
+
eval_data = (
|
| 214 |
+
all_input_ids,
|
| 215 |
+
all_input_mask,
|
| 216 |
+
all_segment_ids,
|
| 217 |
+
all_label_ids,
|
| 218 |
+
all_idx,
|
| 219 |
+
all_input_ids_2,
|
| 220 |
+
all_input_mask_2,
|
| 221 |
+
all_segment_ids_2,
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
| 225 |
+
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
| 226 |
+
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
| 227 |
+
all_guids = [f.guid for f in eval_features]
|
| 228 |
+
all_idx = torch.tensor(example_idxs, dtype=torch.long)
|
| 229 |
+
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
| 230 |
+
eval_data = (
|
| 231 |
+
all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_idx
|
| 232 |
+
)
|
| 233 |
+
return eval_data
|
flagged/log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sentence,output,flag,username,timestamp
|
| 2 |
+
"The sounds are the same as those of daylight , yet somehow the night magnifies and sharpens the creak of a yielding block , the sigh of air over a shroud , the stretching of a sail , the hiss of water sliding sleek against the hull , the curl of a quarter-wave falling away , and the thump as a wave strikes the cutwater to be sheared into two bright slices of whiteness .",/Users/yiningmao/Desktop/CS224N/MelBERT-main/flagged/output/tmpginvysx3.json,,,2023-02-28 07:12:44.582261
|
flagged/output/tmpginvysx3.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[["The", null], ["sounds", null], ["are", null], ["the", null], ["same", null], ["as", null], ["those", ""], ["of", null], ["daylight", null], [",", null], ["yet", null], ["somehow", null], ["the", null], ["night", null], ["magnifies", ""], ["and", null], ["sharpens", ""], ["the", null], ["creak", null], ["of", null], ["a", null], ["yielding", null], ["block", null], [",", null], ["the", null], ["sigh", ""], ["of", null], ["air", null], ["over", null], ["a", null], ["shroud", null], [",", null], ["the", null], ["stretching", null], ["of", null], ["a", null], ["sail", null], [",", null], ["the", null], ["hiss", null], ["of", null], ["water", null], ["sliding", ""], ["sleek", ""], ["against", null], ["the", null], ["hull", null], [",", null], ["the", null], ["curl", ""], ["of", null], ["a", null], ["quarter", null], ["-", null], ["wave", null], ["falling", null], ["away", null], [",", null], ["and", null], ["the", null], ["thump", null], ["as", null], ["a", null], ["wave", null], ["strikes", ""], ["the", null], ["cutwater", null], ["to", null], ["be", null], ["sheared", ""], ["into", ""], ["two", null], ["bright", ""], ["slices", ""], ["of", null], ["whiteness", null], [".", null]]
|
main.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import pickle
|
| 4 |
+
import random
|
| 5 |
+
import copy
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm, trange
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
| 14 |
+
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
|
| 15 |
+
|
| 16 |
+
from utils import Config, Logger, make_log_dir
|
| 17 |
+
from modeling import (
|
| 18 |
+
AutoModelForSequenceClassification,
|
| 19 |
+
AutoModelForTokenClassification,
|
| 20 |
+
AutoModelForSequenceClassification_SPV,
|
| 21 |
+
AutoModelForSequenceClassification_MIP,
|
| 22 |
+
AutoModelForSequenceClassification_SPV_MIP,
|
| 23 |
+
)
|
| 24 |
+
from run_classifier_dataset_utils import processors, output_modes, compute_metrics
|
| 25 |
+
from data_loader import load_train_data, load_train_data_kf, load_test_data
|
| 26 |
+
|
| 27 |
+
CONFIG_NAME = "config.json"
|
| 28 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 29 |
+
ARGS_NAME = "training_args.bin"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
# read configs
|
| 34 |
+
config = Config(main_conf_path="./")
|
| 35 |
+
|
| 36 |
+
# apply system arguments if exist
|
| 37 |
+
argv = sys.argv[1:]
|
| 38 |
+
if len(argv) > 0:
|
| 39 |
+
cmd_arg = OrderedDict()
|
| 40 |
+
argvs = " ".join(sys.argv[1:]).split(" ")
|
| 41 |
+
for i in range(0, len(argvs), 2):
|
| 42 |
+
arg_name, arg_value = argvs[i], argvs[i + 1]
|
| 43 |
+
arg_name = arg_name.strip("-")
|
| 44 |
+
cmd_arg[arg_name] = arg_value
|
| 45 |
+
config.update_params(cmd_arg)
|
| 46 |
+
|
| 47 |
+
args = config
|
| 48 |
+
print(args.__dict__)
|
| 49 |
+
|
| 50 |
+
# logger
|
| 51 |
+
if "saves" in args.bert_model:
|
| 52 |
+
log_dir = args.bert_model
|
| 53 |
+
logger = Logger(log_dir)
|
| 54 |
+
config = Config(main_conf_path=log_dir)
|
| 55 |
+
old_args = copy.deepcopy(args)
|
| 56 |
+
args.__dict__.update(config.__dict__)
|
| 57 |
+
|
| 58 |
+
args.bert_model = old_args.bert_model
|
| 59 |
+
args.do_train = old_args.do_train
|
| 60 |
+
args.data_dir = old_args.data_dir
|
| 61 |
+
args.task_name = old_args.task_name
|
| 62 |
+
|
| 63 |
+
# apply system arguments if exist
|
| 64 |
+
argv = sys.argv[1:]
|
| 65 |
+
if len(argv) > 0:
|
| 66 |
+
cmd_arg = OrderedDict()
|
| 67 |
+
argvs = " ".join(sys.argv[1:]).split(" ")
|
| 68 |
+
for i in range(0, len(argvs), 2):
|
| 69 |
+
arg_name, arg_value = argvs[i], argvs[i + 1]
|
| 70 |
+
arg_name = arg_name.strip("-")
|
| 71 |
+
cmd_arg[arg_name] = arg_value
|
| 72 |
+
config.update_params(cmd_arg)
|
| 73 |
+
else:
|
| 74 |
+
if not os.path.exists("saves"):
|
| 75 |
+
os.mkdir("saves")
|
| 76 |
+
log_dir = make_log_dir(os.path.join("saves", args.bert_model))
|
| 77 |
+
logger = Logger(log_dir)
|
| 78 |
+
config.save(log_dir)
|
| 79 |
+
args.log_dir = log_dir
|
| 80 |
+
|
| 81 |
+
# set CUDA devices
|
| 82 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
| 83 |
+
args.n_gpu = torch.cuda.device_count()
|
| 84 |
+
args.device = device
|
| 85 |
+
|
| 86 |
+
logger.info("device: {} n_gpu: {}".format(device, args.n_gpu))
|
| 87 |
+
|
| 88 |
+
# set seed
|
| 89 |
+
random.seed(args.seed)
|
| 90 |
+
np.random.seed(args.seed)
|
| 91 |
+
torch.manual_seed(args.seed)
|
| 92 |
+
if args.n_gpu > 0:
|
| 93 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 94 |
+
|
| 95 |
+
# get dataset and processor
|
| 96 |
+
task_name = args.task_name.lower()
|
| 97 |
+
processor = processors[task_name]()
|
| 98 |
+
output_mode = output_modes[task_name]
|
| 99 |
+
label_list = processor.get_labels()
|
| 100 |
+
args.num_labels = len(label_list)
|
| 101 |
+
|
| 102 |
+
# build tokenizer and model
|
| 103 |
+
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
| 104 |
+
model = load_pretrained_model(args)
|
| 105 |
+
########### Training ###########
|
| 106 |
+
|
| 107 |
+
# VUA18 / VUA20 for bagging
|
| 108 |
+
if args.do_train and args.task_name == "vua" and args.num_bagging:
|
| 109 |
+
train_data, gkf = load_train_data_kf(args, logger, processor, task_name, label_list, tokenizer, output_mode)
|
| 110 |
+
|
| 111 |
+
for fold, (train_idx, valid_idx) in enumerate(tqdm(gkf, desc="bagging...")):
|
| 112 |
+
if fold != args.bagging_index:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
print(f"bagging_index = {args.bagging_index}")
|
| 116 |
+
|
| 117 |
+
# Load data
|
| 118 |
+
temp_train_data = TensorDataset(*train_data[train_idx])
|
| 119 |
+
train_sampler = RandomSampler(temp_train_data)
|
| 120 |
+
train_dataloader = DataLoader(temp_train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
| 121 |
+
|
| 122 |
+
# Reset Model
|
| 123 |
+
model = load_pretrained_model(args)
|
| 124 |
+
model, best_result = run_train(args, logger, model, train_dataloader, processor, task_name, label_list, tokenizer, output_mode)
|
| 125 |
+
|
| 126 |
+
# Test
|
| 127 |
+
all_guids, eval_dataloader = load_test_data(args, logger, processor, task_name, label_list, tokenizer, output_mode)
|
| 128 |
+
preds = run_eval(args, logger, model, eval_dataloader, all_guids, task_name, return_preds=True)
|
| 129 |
+
with open(os.path.join(args.data_dir, f"seed{args.seed}_preds_{fold}.p"), "wb") as f:
|
| 130 |
+
pickle.dump(preds, f)
|
| 131 |
+
|
| 132 |
+
# If train data is VUA20, the model needs to be tested on VUAverb as well.
|
| 133 |
+
# You can just adjust the names of data_dir in conditions below for your own data directories.
|
| 134 |
+
if "VUA20" in args.data_dir:
|
| 135 |
+
# Verb
|
| 136 |
+
args.data_dir = "data/VUAverb"
|
| 137 |
+
all_guids, eval_dataloader = load_test_data(args, logger, processor, task_name, label_list, tokenizer, output_mode)
|
| 138 |
+
preds = run_eval(args, logger, model, eval_dataloader, all_guids, task_name, return_preds=True)
|
| 139 |
+
with open(os.path.join(args.data_dir, f"seed{args.seed}_preds_{fold}.p"), "wb") as f:
|
| 140 |
+
pickle.dump(preds, f)
|
| 141 |
+
|
| 142 |
+
logger.info(f"Saved to {logger.log_dir}")
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# VUA18 / VUA20
|
| 147 |
+
if args.do_train and args.task_name == "vua":
|
| 148 |
+
train_dataloader = load_train_data(
|
| 149 |
+
args, logger, processor, task_name, label_list, tokenizer, output_mode
|
| 150 |
+
)
|
| 151 |
+
model, best_result = run_train(
|
| 152 |
+
args,
|
| 153 |
+
logger,
|
| 154 |
+
model,
|
| 155 |
+
train_dataloader,
|
| 156 |
+
processor,
|
| 157 |
+
task_name,
|
| 158 |
+
label_list,
|
| 159 |
+
tokenizer,
|
| 160 |
+
output_mode,
|
| 161 |
+
)
|
| 162 |
+
# TroFi / MOH-X (K-fold)
|
| 163 |
+
elif args.do_train and args.task_name == "trofi":
|
| 164 |
+
k_result = []
|
| 165 |
+
for k in tqdm(range(args.kfold), desc="K-fold"):
|
| 166 |
+
model = load_pretrained_model(args)
|
| 167 |
+
train_dataloader = load_train_data(
|
| 168 |
+
args, logger, processor, task_name, label_list, tokenizer, output_mode, k
|
| 169 |
+
)
|
| 170 |
+
model, best_result = run_train(
|
| 171 |
+
args,
|
| 172 |
+
logger,
|
| 173 |
+
model,
|
| 174 |
+
train_dataloader,
|
| 175 |
+
processor,
|
| 176 |
+
task_name,
|
| 177 |
+
label_list,
|
| 178 |
+
tokenizer,
|
| 179 |
+
output_mode,
|
| 180 |
+
k,
|
| 181 |
+
)
|
| 182 |
+
k_result.append(best_result)
|
| 183 |
+
|
| 184 |
+
# Calculate average result
|
| 185 |
+
avg_result = copy.deepcopy(k_result[0])
|
| 186 |
+
for result in k_result[1:]:
|
| 187 |
+
for k, v in result.items():
|
| 188 |
+
avg_result[k] += v
|
| 189 |
+
for k, v in avg_result.items():
|
| 190 |
+
avg_result[k] /= len(k_result)
|
| 191 |
+
|
| 192 |
+
logger.info(f"-----Averge Result-----")
|
| 193 |
+
for key in sorted(avg_result.keys()):
|
| 194 |
+
logger.info(f" {key} = {str(avg_result[key])}")
|
| 195 |
+
|
| 196 |
+
# Load trained model
|
| 197 |
+
if "saves" in args.bert_model:
|
| 198 |
+
model = load_trained_model(args, model, tokenizer)
|
| 199 |
+
|
| 200 |
+
########### Inference ###########
|
| 201 |
+
# VUA18 / VUA20
|
| 202 |
+
if (args.do_eval or args.do_test) and task_name == "vua":
|
| 203 |
+
# if test data is genre or POS tag data
|
| 204 |
+
if ("genre" in args.data_dir) or ("pos" in args.data_dir):
|
| 205 |
+
if "genre" in args.data_dir:
|
| 206 |
+
targets = ["acad", "conv", "fict", "news"]
|
| 207 |
+
elif "pos" in args.data_dir:
|
| 208 |
+
targets = ["adj", "adv", "noun", "verb"]
|
| 209 |
+
orig_data_dir = args.data_dir
|
| 210 |
+
for idx, target in tqdm(enumerate(targets)):
|
| 211 |
+
logger.info(f"====================== Evaluating {target} =====================")
|
| 212 |
+
args.data_dir = os.path.join(orig_data_dir, target)
|
| 213 |
+
all_guids, eval_dataloader = load_test_data(
|
| 214 |
+
args, logger, processor, task_name, label_list, tokenizer, output_mode
|
| 215 |
+
)
|
| 216 |
+
run_eval(args, logger, model, eval_dataloader, all_guids, task_name)
|
| 217 |
+
else:
|
| 218 |
+
all_guids, eval_dataloader = load_test_data(
|
| 219 |
+
args, logger, processor, task_name, label_list, tokenizer, output_mode
|
| 220 |
+
)
|
| 221 |
+
run_eval(args, logger, model, eval_dataloader, all_guids, task_name)
|
| 222 |
+
|
| 223 |
+
# TroFi / MOH-X (K-fold)
|
| 224 |
+
elif (args.do_eval or args.do_test) and args.task_name == "trofi":
|
| 225 |
+
logger.info(f"***** Evaluating with {args.data_dir}")
|
| 226 |
+
k_result = []
|
| 227 |
+
for k in tqdm(range(10), desc="K-fold"):
|
| 228 |
+
all_guids, eval_dataloader = load_test_data(
|
| 229 |
+
args, logger, processor, task_name, label_list, tokenizer, output_mode, k
|
| 230 |
+
)
|
| 231 |
+
result = run_eval(args, logger, model, eval_dataloader, all_guids, task_name)
|
| 232 |
+
k_result.append(result)
|
| 233 |
+
|
| 234 |
+
# Calculate average result
|
| 235 |
+
avg_result = copy.deepcopy(k_result[0])
|
| 236 |
+
for result in k_result[1:]:
|
| 237 |
+
for k, v in result.items():
|
| 238 |
+
avg_result[k] += v
|
| 239 |
+
for k, v in avg_result.items():
|
| 240 |
+
avg_result[k] /= len(k_result)
|
| 241 |
+
|
| 242 |
+
logger.info(f"-----Averge Result-----")
|
| 243 |
+
for key in sorted(avg_result.keys()):
|
| 244 |
+
logger.info(f" {key} = {str(avg_result[key])}")
|
| 245 |
+
logger.info(f"Saved to {logger.log_dir}")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def run_train(
|
| 249 |
+
args,
|
| 250 |
+
logger,
|
| 251 |
+
model,
|
| 252 |
+
train_dataloader,
|
| 253 |
+
processor,
|
| 254 |
+
task_name,
|
| 255 |
+
label_list,
|
| 256 |
+
tokenizer,
|
| 257 |
+
output_mode,
|
| 258 |
+
k=None,
|
| 259 |
+
):
|
| 260 |
+
|
| 261 |
+
tr_loss = 0
|
| 262 |
+
num_train_optimization_steps = len(train_dataloader) * args.num_train_epoch
|
| 263 |
+
|
| 264 |
+
# Prepare optimizer, scheduler
|
| 265 |
+
param_optimizer = list(model.named_parameters())
|
| 266 |
+
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
| 267 |
+
optimizer_grouped_parameters = [
|
| 268 |
+
{
|
| 269 |
+
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
| 270 |
+
"weight_decay": 0.01,
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
|
| 274 |
+
"weight_decay": 0.0,
|
| 275 |
+
},
|
| 276 |
+
]
|
| 277 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
| 278 |
+
if args.lr_schedule != False or args.lr_schedule.lower() != "none":
|
| 279 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 280 |
+
optimizer,
|
| 281 |
+
num_warmup_steps=int(args.warmup_epoch * len(train_dataloader)),
|
| 282 |
+
num_training_steps=num_train_optimization_steps,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
logger.info("***** Running training *****")
|
| 286 |
+
logger.info(f" Batch size = {args.train_batch_size}")
|
| 287 |
+
logger.info(f" Num steps = { num_train_optimization_steps}")
|
| 288 |
+
|
| 289 |
+
# Run training
|
| 290 |
+
model.train()
|
| 291 |
+
max_val_f1 = -1
|
| 292 |
+
max_result = {}
|
| 293 |
+
for epoch in trange(int(args.num_train_epoch), desc="Epoch"):
|
| 294 |
+
tr_loss = 0
|
| 295 |
+
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
| 296 |
+
# move batch data to gpu
|
| 297 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 298 |
+
|
| 299 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 300 |
+
(
|
| 301 |
+
input_ids,
|
| 302 |
+
input_mask,
|
| 303 |
+
segment_ids,
|
| 304 |
+
label_ids,
|
| 305 |
+
input_ids_2,
|
| 306 |
+
input_mask_2,
|
| 307 |
+
segment_ids_2,
|
| 308 |
+
) = batch
|
| 309 |
+
else:
|
| 310 |
+
input_ids, input_mask, segment_ids, label_ids = batch
|
| 311 |
+
|
| 312 |
+
# compute loss values
|
| 313 |
+
if args.model_type in ["BERT_SEQ", "BERT_BASE", "MELBERT_SPV"]:
|
| 314 |
+
logits = model(
|
| 315 |
+
input_ids,
|
| 316 |
+
target_mask=(segment_ids == 1),
|
| 317 |
+
token_type_ids=segment_ids,
|
| 318 |
+
attention_mask=input_mask,
|
| 319 |
+
)
|
| 320 |
+
loss_fct = nn.NLLLoss(weight=torch.Tensor([1, args.class_weight]).to(args.device))
|
| 321 |
+
loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
|
| 322 |
+
elif args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 323 |
+
logits = model(
|
| 324 |
+
input_ids,
|
| 325 |
+
input_ids_2,
|
| 326 |
+
target_mask=(segment_ids == 1),
|
| 327 |
+
target_mask_2=segment_ids_2,
|
| 328 |
+
attention_mask_2=input_mask_2,
|
| 329 |
+
token_type_ids=segment_ids,
|
| 330 |
+
attention_mask=input_mask,
|
| 331 |
+
)
|
| 332 |
+
loss_fct = nn.NLLLoss(weight=torch.Tensor([1, args.class_weight]).to(args.device))
|
| 333 |
+
loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
|
| 334 |
+
|
| 335 |
+
# average loss if on multi-gpu.
|
| 336 |
+
if args.n_gpu > 1:
|
| 337 |
+
loss = loss.mean()
|
| 338 |
+
|
| 339 |
+
loss.backward()
|
| 340 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 341 |
+
optimizer.step()
|
| 342 |
+
|
| 343 |
+
if args.lr_schedule != False or args.lr_schedule.lower() != "none":
|
| 344 |
+
scheduler.step()
|
| 345 |
+
|
| 346 |
+
optimizer.zero_grad()
|
| 347 |
+
|
| 348 |
+
tr_loss += loss.item()
|
| 349 |
+
|
| 350 |
+
cur_lr = optimizer.param_groups[0]["lr"]
|
| 351 |
+
logger.info(f"[epoch {epoch+1}] ,lr: {cur_lr} ,tr_loss: {tr_loss}")
|
| 352 |
+
|
| 353 |
+
# evaluate
|
| 354 |
+
if args.do_eval:
|
| 355 |
+
all_guids, eval_dataloader = load_test_data(
|
| 356 |
+
args, logger, processor, task_name, label_list, tokenizer, output_mode, k
|
| 357 |
+
)
|
| 358 |
+
result = run_eval(args, logger, model, eval_dataloader, all_guids, task_name)
|
| 359 |
+
|
| 360 |
+
# update
|
| 361 |
+
if result["f1"] > max_val_f1:
|
| 362 |
+
max_val_f1 = result["f1"]
|
| 363 |
+
max_result = result
|
| 364 |
+
if args.task_name == "trofi":
|
| 365 |
+
save_model(args, model, tokenizer)
|
| 366 |
+
if args.task_name == "vua":
|
| 367 |
+
save_model(args, model, tokenizer)
|
| 368 |
+
|
| 369 |
+
logger.info(f"-----Best Result-----")
|
| 370 |
+
for key in sorted(max_result.keys()):
|
| 371 |
+
logger.info(f" {key} = {str(max_result[key])}")
|
| 372 |
+
|
| 373 |
+
return model, max_result
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def run_eval(args, logger, model, eval_dataloader, all_guids, task_name, return_preds=False):
|
| 377 |
+
model.eval()
|
| 378 |
+
|
| 379 |
+
eval_loss = 0
|
| 380 |
+
nb_eval_steps = 0
|
| 381 |
+
preds = []
|
| 382 |
+
pred_guids = []
|
| 383 |
+
out_label_ids = None
|
| 384 |
+
|
| 385 |
+
for eval_batch in tqdm(eval_dataloader, desc="Evaluating"):
|
| 386 |
+
eval_batch = tuple(t.to(args.device) for t in eval_batch)
|
| 387 |
+
|
| 388 |
+
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 389 |
+
(
|
| 390 |
+
input_ids,
|
| 391 |
+
input_mask,
|
| 392 |
+
segment_ids,
|
| 393 |
+
label_ids,
|
| 394 |
+
idx,
|
| 395 |
+
input_ids_2,
|
| 396 |
+
input_mask_2,
|
| 397 |
+
segment_ids_2,
|
| 398 |
+
) = eval_batch
|
| 399 |
+
else:
|
| 400 |
+
input_ids, input_mask, segment_ids, label_ids, idx = eval_batch
|
| 401 |
+
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
# compute loss values
|
| 404 |
+
if args.model_type in ["BERT_BASE", "BERT_SEQ", "MELBERT_SPV"]:
|
| 405 |
+
logits = model(
|
| 406 |
+
input_ids,
|
| 407 |
+
target_mask=(segment_ids == 1),
|
| 408 |
+
token_type_ids=segment_ids,
|
| 409 |
+
attention_mask=input_mask,
|
| 410 |
+
)
|
| 411 |
+
loss_fct = nn.NLLLoss()
|
| 412 |
+
tmp_eval_loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
|
| 413 |
+
eval_loss += tmp_eval_loss.mean().item()
|
| 414 |
+
nb_eval_steps += 1
|
| 415 |
+
|
| 416 |
+
if len(preds) == 0:
|
| 417 |
+
preds.append(logits.detach().cpu().numpy())
|
| 418 |
+
pred_guids.append([all_guids[i] for i in idx])
|
| 419 |
+
out_label_ids = label_ids.detach().cpu().numpy()
|
| 420 |
+
else:
|
| 421 |
+
preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
|
| 422 |
+
pred_guids[0].extend([all_guids[i] for i in idx])
|
| 423 |
+
out_label_ids = np.append(
|
| 424 |
+
out_label_ids, label_ids.detach().cpu().numpy(), axis=0
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
elif args.model_type in ["MELBERT_MIP", "MELBERT"]:
|
| 428 |
+
logits = model(
|
| 429 |
+
input_ids,
|
| 430 |
+
input_ids_2,
|
| 431 |
+
target_mask=(segment_ids == 1),
|
| 432 |
+
target_mask_2=segment_ids_2,
|
| 433 |
+
attention_mask_2=input_mask_2,
|
| 434 |
+
token_type_ids=segment_ids,
|
| 435 |
+
attention_mask=input_mask,
|
| 436 |
+
)
|
| 437 |
+
loss_fct = nn.NLLLoss()
|
| 438 |
+
tmp_eval_loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
|
| 439 |
+
eval_loss += tmp_eval_loss.mean().item()
|
| 440 |
+
nb_eval_steps += 1
|
| 441 |
+
|
| 442 |
+
if len(preds) == 0:
|
| 443 |
+
preds.append(logits.detach().cpu().numpy())
|
| 444 |
+
pred_guids.append([all_guids[i] for i in idx])
|
| 445 |
+
out_label_ids = label_ids.detach().cpu().numpy()
|
| 446 |
+
else:
|
| 447 |
+
preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
|
| 448 |
+
pred_guids[0].extend([all_guids[i] for i in idx])
|
| 449 |
+
out_label_ids = np.append(
|
| 450 |
+
out_label_ids, label_ids.detach().cpu().numpy(), axis=0
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
eval_loss = eval_loss / nb_eval_steps
|
| 454 |
+
preds = preds[0]
|
| 455 |
+
preds = np.argmax(preds, axis=1)
|
| 456 |
+
|
| 457 |
+
# compute metrics
|
| 458 |
+
result = compute_metrics(preds, out_label_ids)
|
| 459 |
+
|
| 460 |
+
for key in sorted(result.keys()):
|
| 461 |
+
logger.info(f" {key} = {str(result[key])}")
|
| 462 |
+
|
| 463 |
+
if return_preds:
|
| 464 |
+
return preds
|
| 465 |
+
return result
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def load_pretrained_model(args):
|
| 469 |
+
# Pretrained Model
|
| 470 |
+
bert = AutoModel.from_pretrained(args.bert_model)
|
| 471 |
+
for name, param in bert.named_parameters():
|
| 472 |
+
print(name, param.requires_grad)
|
| 473 |
+
|
| 474 |
+
config = bert.config
|
| 475 |
+
config.type_vocab_size = 4
|
| 476 |
+
if "albert" in args.bert_model:
|
| 477 |
+
bert.embeddings.token_type_embeddings = nn.Embedding(
|
| 478 |
+
config.type_vocab_size, config.embedding_size
|
| 479 |
+
)
|
| 480 |
+
else:
|
| 481 |
+
bert.embeddings.token_type_embeddings = nn.Embedding(
|
| 482 |
+
config.type_vocab_size, config.hidden_size
|
| 483 |
+
)
|
| 484 |
+
bert._init_weights(bert.embeddings.token_type_embeddings)
|
| 485 |
+
|
| 486 |
+
# Additional Layers
|
| 487 |
+
if args.model_type in ["BERT_BASE"]:
|
| 488 |
+
model = AutoModelForSequenceClassification(
|
| 489 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 490 |
+
)
|
| 491 |
+
if args.model_type == "BERT_SEQ":
|
| 492 |
+
model = AutoModelForTokenClassification(
|
| 493 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 494 |
+
)
|
| 495 |
+
if args.model_type == "MELBERT_SPV":
|
| 496 |
+
model = AutoModelForSequenceClassification_SPV(
|
| 497 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 498 |
+
)
|
| 499 |
+
if args.model_type == "MELBERT_MIP":
|
| 500 |
+
model = AutoModelForSequenceClassification_MIP(
|
| 501 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 502 |
+
)
|
| 503 |
+
if args.model_type == "MELBERT":
|
| 504 |
+
model = AutoModelForSequenceClassification_SPV_MIP(
|
| 505 |
+
args=args, Model=bert, config=config, num_labels=args.num_labels
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
model.to(args.device)
|
| 509 |
+
if args.n_gpu > 1 and not args.no_cuda:
|
| 510 |
+
model = torch.nn.DataParallel(model)
|
| 511 |
+
return model
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def save_model(args, model, tokenizer):
|
| 515 |
+
model_to_save = (
|
| 516 |
+
model.module if hasattr(model, "module") else model
|
| 517 |
+
) # Only save the model it-self
|
| 518 |
+
|
| 519 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 520 |
+
output_model_file = os.path.join(args.log_dir, WEIGHTS_NAME)
|
| 521 |
+
output_config_file = os.path.join(args.log_dir, CONFIG_NAME)
|
| 522 |
+
|
| 523 |
+
torch.save(model_to_save.state_dict(), output_model_file)
|
| 524 |
+
model_to_save.config.to_json_file(output_config_file)
|
| 525 |
+
tokenizer.save_vocabulary(args.log_dir)
|
| 526 |
+
|
| 527 |
+
# Good practice: save your training arguments together with the trained model
|
| 528 |
+
output_args_file = os.path.join(args.log_dir, ARGS_NAME)
|
| 529 |
+
torch.save(args, output_args_file)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def load_trained_model(args, model, tokenizer):
|
| 533 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 534 |
+
output_model_file = os.path.join(args.log_dir, WEIGHTS_NAME)
|
| 535 |
+
|
| 536 |
+
if hasattr(model, "module"):
|
| 537 |
+
model.module.load_state_dict(torch.load(output_model_file, map_location=args.device))
|
| 538 |
+
else:
|
| 539 |
+
model.load_state_dict(torch.load(output_model_file, map_location=args.device))
|
| 540 |
+
|
| 541 |
+
return model
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
if __name__ == "__main__":
|
| 545 |
+
main()
|
main_config.cfg
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[args]
|
| 2 |
+
# Bert pre-trained model selected in the list [bert-base-cased, roberta-base, albert-base-v1 / albert-large-v1] (default = roberta-base)
|
| 3 |
+
bert_model = roberta-base
|
| 4 |
+
|
| 5 |
+
# The input data dir. Should contain the .tsv files (VUA18 / VUAverb / MOH-X/CLS / TroFi/CLS / VUA20)
|
| 6 |
+
data_dir = data/VUA20
|
| 7 |
+
# The name of the task to train (vua(1-fold) / trofi(10-fold))
|
| 8 |
+
task_name = vua
|
| 9 |
+
# The name of model type (default = MELBERT) (BERT_BASE / BERT_SEQ / MELBERT_SPV / MELBERT_MIP / MELBERT)
|
| 10 |
+
model_type = MELBERT
|
| 11 |
+
# The hidden dimension for classifier (default = 768)
|
| 12 |
+
classifier_hidden = 768
|
| 13 |
+
# Learning rate scheduler (default = warmup_linear) (none / warmup_linear)
|
| 14 |
+
lr_schedule = warmup_linear
|
| 15 |
+
# Training epochs to perform linear learning rate warmup for. (default = 2)
|
| 16 |
+
warmup_epoch = 2
|
| 17 |
+
# Dropout ratio (default = 0.2)
|
| 18 |
+
drop_ratio = 0.2
|
| 19 |
+
# K-fold (default = 10)
|
| 20 |
+
kfold = 10
|
| 21 |
+
# Number of bagging (default = 0) (0 not for using bagging technique)
|
| 22 |
+
num_bagging = 0
|
| 23 |
+
# The index of bagging only for the case using bagging technique (default = 0)
|
| 24 |
+
bagging_index = 0
|
| 25 |
+
|
| 26 |
+
# Use additional linguistic features
|
| 27 |
+
# POS tag (default = True)
|
| 28 |
+
use_pos = True
|
| 29 |
+
# Local context (default = True)
|
| 30 |
+
use_local_context= True
|
| 31 |
+
|
| 32 |
+
# The maximum total input sequence length after WordPiece tokenization. (default = 200)
|
| 33 |
+
max_seq_length = 150
|
| 34 |
+
# Whether to run training (default = False)
|
| 35 |
+
do_train = False
|
| 36 |
+
# Whether to run eval on the test set (default = False)
|
| 37 |
+
do_test = True
|
| 38 |
+
# Whether to run eval on the dev set. (default = False)
|
| 39 |
+
do_eval = True
|
| 40 |
+
# Set this flag if you are using an uncased model. (default = False)
|
| 41 |
+
do_lower_case = False
|
| 42 |
+
# Weight of metaphor. (default = 3.0)
|
| 43 |
+
class_weight = 3
|
| 44 |
+
# Total batch size for training. (default = 32)
|
| 45 |
+
train_batch_size = 16
|
| 46 |
+
# Total batch size for eval. (default = 8)
|
| 47 |
+
eval_batch_size = 8
|
| 48 |
+
# The initial learning rate for Adam (default = 3e-5)
|
| 49 |
+
learning_rate = 3e-5
|
| 50 |
+
# Total number of training epochs to perform. (default = 3.0)
|
| 51 |
+
num_train_epoch = 3
|
| 52 |
+
|
| 53 |
+
# Whether not to use CUDA when available (default = False)
|
| 54 |
+
no_cuda = False
|
| 55 |
+
# random seed for initialization (default = 42)
|
| 56 |
+
seed = 42
|
| 57 |
+
|
| 58 |
+
max_data_num = None
|
modeling.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from utils import Config
|
| 6 |
+
from transformers import AutoTokenizer, AutoModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AutoModelForSequenceClassification(nn.Module):
|
| 10 |
+
"""Base model for sequence classification"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, args, Model, config, num_labels=2):
|
| 13 |
+
"""Initialize the model"""
|
| 14 |
+
super(AutoModelForSequenceClassification, self).__init__()
|
| 15 |
+
self.num_labels = num_labels
|
| 16 |
+
self.encoder = Model
|
| 17 |
+
self.config = config
|
| 18 |
+
self.dropout = nn.Dropout(args.drop_ratio)
|
| 19 |
+
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
| 20 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 21 |
+
|
| 22 |
+
self._init_weights(self.classifier)
|
| 23 |
+
|
| 24 |
+
def _init_weights(self, module):
|
| 25 |
+
"""Initialize the weights"""
|
| 26 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 27 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 28 |
+
elif isinstance(module, nn.LayerNorm):
|
| 29 |
+
module.bias.data.zero_()
|
| 30 |
+
module.weight.data.fill_(1.0)
|
| 31 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 32 |
+
module.bias.data.zero_()
|
| 33 |
+
|
| 34 |
+
def forward(
|
| 35 |
+
self,
|
| 36 |
+
input_ids,
|
| 37 |
+
target_mask=None,
|
| 38 |
+
token_type_ids=None,
|
| 39 |
+
attention_mask=None,
|
| 40 |
+
labels=None,
|
| 41 |
+
head_mask=None,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Inputs:
|
| 45 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
|
| 46 |
+
`target_mask`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target wor. 1 for target word and 0 otherwise.
|
| 47 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices
|
| 48 |
+
selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
| 49 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1].
|
| 50 |
+
It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch.
|
| 51 |
+
It's the mask that we typically use for attention when a batch has varying length sentences.
|
| 52 |
+
`labels`: optional labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
| 53 |
+
with indices selected in [0, ..., num_labels].
|
| 54 |
+
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
| 55 |
+
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
| 56 |
+
"""
|
| 57 |
+
outputs = self.encoder(
|
| 58 |
+
input_ids,
|
| 59 |
+
token_type_ids=token_type_ids,
|
| 60 |
+
attention_mask=attention_mask,
|
| 61 |
+
head_mask=head_mask,
|
| 62 |
+
)
|
| 63 |
+
pooled_output = outputs[1]
|
| 64 |
+
pooled_output = self.dropout(pooled_output)
|
| 65 |
+
logits = self.classifier(pooled_output)
|
| 66 |
+
logits = self.logsoftmax(logits)
|
| 67 |
+
|
| 68 |
+
if labels is not None:
|
| 69 |
+
loss_fct = nn.NLLLoss()
|
| 70 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 71 |
+
return loss
|
| 72 |
+
return logits
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class AutoModelForTokenClassification(nn.Module):
|
| 76 |
+
"""Base model for token classification"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, args, Model, config, num_labels=2):
|
| 79 |
+
"""Initialize the model"""
|
| 80 |
+
super(AutoModelForTokenClassification, self).__init__()
|
| 81 |
+
self.num_labels = num_labels
|
| 82 |
+
self.bert = Model
|
| 83 |
+
self.config = config
|
| 84 |
+
self.dropout = nn.Dropout(args.drop_ratio)
|
| 85 |
+
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
| 86 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 87 |
+
|
| 88 |
+
self._init_weights(self.classifier)
|
| 89 |
+
|
| 90 |
+
def _init_weights(self, module):
|
| 91 |
+
"""Initialize the weights"""
|
| 92 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 93 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 94 |
+
elif isinstance(module, nn.LayerNorm):
|
| 95 |
+
module.bias.data.zero_()
|
| 96 |
+
module.weight.data.fill_(1.0)
|
| 97 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 98 |
+
module.bias.data.zero_()
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
input_ids,
|
| 103 |
+
target_mask,
|
| 104 |
+
token_type_ids=None,
|
| 105 |
+
attention_mask=None,
|
| 106 |
+
labels=None,
|
| 107 |
+
head_mask=None,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Inputs:
|
| 111 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
|
| 112 |
+
`target_mask`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target wor. 1 for target word and 0 otherwise.
|
| 113 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices
|
| 114 |
+
selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
| 115 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1].
|
| 116 |
+
It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch.
|
| 117 |
+
It's the mask that we typically use for attention when a batch has varying length sentences.
|
| 118 |
+
`labels`: optional labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
| 119 |
+
with indices selected in [0, ..., num_labels].
|
| 120 |
+
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
| 121 |
+
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
| 122 |
+
"""
|
| 123 |
+
outputs = self.bert(
|
| 124 |
+
input_ids,
|
| 125 |
+
token_type_ids=token_type_ids,
|
| 126 |
+
attention_mask=attention_mask,
|
| 127 |
+
head_mask=head_mask,
|
| 128 |
+
)
|
| 129 |
+
sequence_output = outputs[0] # [batch, max_len, hidden]
|
| 130 |
+
target_output = sequence_output * target_mask.unsqueeze(2)
|
| 131 |
+
target_output = self.dropout(target_output)
|
| 132 |
+
target_output = target_output.sum(1) / target_mask.sum() # [batch, hideen]
|
| 133 |
+
|
| 134 |
+
logits = self.classifier(target_output)
|
| 135 |
+
logits = self.logsoftmax(logits)
|
| 136 |
+
|
| 137 |
+
if labels is not None:
|
| 138 |
+
loss_fct = nn.NLLLoss()
|
| 139 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 140 |
+
return loss
|
| 141 |
+
return logits
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class AutoModelForSequenceClassification_SPV(nn.Module):
|
| 145 |
+
"""MelBERT with only SPV"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, args, Model, config, num_labels=2):
|
| 148 |
+
"""Initialize the model"""
|
| 149 |
+
super(AutoModelForSequenceClassification_SPV, self).__init__()
|
| 150 |
+
self.num_labels = num_labels
|
| 151 |
+
self.encoder = Model
|
| 152 |
+
self.config = config
|
| 153 |
+
self.dropout = nn.Dropout(args.drop_ratio)
|
| 154 |
+
self.classifier = nn.Linear(config.hidden_size * 2, num_labels)
|
| 155 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 156 |
+
|
| 157 |
+
self._init_weights(self.classifier)
|
| 158 |
+
|
| 159 |
+
def _init_weights(self, module):
|
| 160 |
+
"""Initialize the weights"""
|
| 161 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 162 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 163 |
+
elif isinstance(module, nn.LayerNorm):
|
| 164 |
+
module.bias.data.zero_()
|
| 165 |
+
module.weight.data.fill_(1.0)
|
| 166 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 167 |
+
module.bias.data.zero_()
|
| 168 |
+
|
| 169 |
+
def forward(
|
| 170 |
+
self,
|
| 171 |
+
input_ids,
|
| 172 |
+
target_mask,
|
| 173 |
+
token_type_ids=None,
|
| 174 |
+
attention_mask=None,
|
| 175 |
+
labels=None,
|
| 176 |
+
head_mask=None,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Inputs:
|
| 180 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
|
| 181 |
+
`target_mask`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target wor. 1 for target word and 0 otherwise.
|
| 182 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices
|
| 183 |
+
selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
| 184 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1].
|
| 185 |
+
`labels`: optional labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
| 186 |
+
with indices selected in [0, ..., num_labels].
|
| 187 |
+
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
| 188 |
+
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
| 189 |
+
"""
|
| 190 |
+
outputs = self.encoder(
|
| 191 |
+
input_ids,
|
| 192 |
+
token_type_ids=token_type_ids,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
head_mask=head_mask,
|
| 195 |
+
)
|
| 196 |
+
sequence_output = outputs[0] # [batch, max_len, hidden]
|
| 197 |
+
pooled_output = outputs[1] # [batch, hidden]
|
| 198 |
+
|
| 199 |
+
# Get target ouput with target mask
|
| 200 |
+
target_output = sequence_output * target_mask.unsqueeze(2) # [batch, hidden]
|
| 201 |
+
|
| 202 |
+
# dropout
|
| 203 |
+
target_output = self.dropout(target_output)
|
| 204 |
+
pooled_output = self.dropout(pooled_output)
|
| 205 |
+
|
| 206 |
+
# Get mean value of target output if the target output consistst of more than one token
|
| 207 |
+
target_output = target_output.mean(1)
|
| 208 |
+
|
| 209 |
+
logits = self.classifier(torch.cat([target_output, pooled_output], dim=1))
|
| 210 |
+
logits = self.logsoftmax(logits)
|
| 211 |
+
|
| 212 |
+
if labels is not None:
|
| 213 |
+
loss_fct = nn.NLLLoss()
|
| 214 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 215 |
+
return loss
|
| 216 |
+
return logits
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class AutoModelForSequenceClassification_MIP(nn.Module):
|
| 220 |
+
"""MelBERT with only MIP"""
|
| 221 |
+
|
| 222 |
+
def __init__(self, args, Model, config, num_labels=2):
|
| 223 |
+
"""Initialize the model"""
|
| 224 |
+
super(AutoModelForSequenceClassification_MIP, self).__init__()
|
| 225 |
+
self.num_labels = num_labels
|
| 226 |
+
self.encoder = Model
|
| 227 |
+
self.config = config
|
| 228 |
+
self.dropout = nn.Dropout(args.drop_ratio)
|
| 229 |
+
self.args = args
|
| 230 |
+
self.classifier = nn.Linear(config.hidden_size * 2, num_labels)
|
| 231 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 232 |
+
|
| 233 |
+
self._init_weights(self.classifier)
|
| 234 |
+
|
| 235 |
+
def _init_weights(self, module):
|
| 236 |
+
"""Initialize the weights"""
|
| 237 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 238 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 239 |
+
elif isinstance(module, nn.LayerNorm):
|
| 240 |
+
module.bias.data.zero_()
|
| 241 |
+
module.weight.data.fill_(1.0)
|
| 242 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 243 |
+
module.bias.data.zero_()
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
input_ids,
|
| 248 |
+
input_ids_2,
|
| 249 |
+
target_mask,
|
| 250 |
+
target_mask_2,
|
| 251 |
+
attention_mask_2,
|
| 252 |
+
token_type_ids=None,
|
| 253 |
+
attention_mask=None,
|
| 254 |
+
labels=None,
|
| 255 |
+
head_mask=None,
|
| 256 |
+
):
|
| 257 |
+
"""
|
| 258 |
+
Inputs:
|
| 259 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the first input token indices in the vocabulary
|
| 260 |
+
`input_ids_2`: a torch.LongTensor of shape [batch_size, sequence_length] with the second input token indicies
|
| 261 |
+
`target_mask`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target word in the first input. 1 for target word and 0 otherwise.
|
| 262 |
+
`target_mask_2`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target word in the second input. 1 for target word and 0 otherwise.
|
| 263 |
+
`attention_mask_2`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1] for the second input.
|
| 264 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices
|
| 265 |
+
selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
| 266 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1] for the first input.
|
| 267 |
+
`labels`: optional labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
| 268 |
+
with indices selected in [0, ..., num_labels].
|
| 269 |
+
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
| 270 |
+
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
| 271 |
+
"""
|
| 272 |
+
# First encoder for full sentence
|
| 273 |
+
outputs = self.encoder(
|
| 274 |
+
input_ids,
|
| 275 |
+
token_type_ids=token_type_ids,
|
| 276 |
+
attention_mask=attention_mask,
|
| 277 |
+
head_mask=head_mask,
|
| 278 |
+
)
|
| 279 |
+
sequence_output = outputs[0] # [batch, max_len, hidden]
|
| 280 |
+
|
| 281 |
+
# Get target ouput with target mask
|
| 282 |
+
target_output = sequence_output * target_mask.unsqueeze(2)
|
| 283 |
+
target_output = self.dropout(target_output)
|
| 284 |
+
target_output = target_output.sum(1) / target_mask.sum() # [batch, hidden]
|
| 285 |
+
|
| 286 |
+
# Second encoder for only the target word
|
| 287 |
+
outputs_2 = self.encoder(input_ids_2, attention_mask=attention_mask_2, head_mask=head_mask)
|
| 288 |
+
sequence_output_2 = outputs_2[0] # [batch, max_len, hidden]
|
| 289 |
+
|
| 290 |
+
# Get target ouput with target mask
|
| 291 |
+
target_output_2 = sequence_output_2 * target_mask_2.unsqueeze(2)
|
| 292 |
+
target_output_2 = self.dropout(target_output_2)
|
| 293 |
+
target_output_2 = target_output_2.sum(1) / target_mask_2.sum()
|
| 294 |
+
|
| 295 |
+
logits = self.classifier(torch.cat([target_output_2, target_output], dim=1))
|
| 296 |
+
logits = self.logsoftmax(logits)
|
| 297 |
+
|
| 298 |
+
if labels is not None:
|
| 299 |
+
loss_fct = nn.NLLLoss()
|
| 300 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 301 |
+
return loss
|
| 302 |
+
return logits
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class AutoModelForSequenceClassification_SPV_MIP(nn.Module):
|
| 306 |
+
"""MelBERT"""
|
| 307 |
+
|
| 308 |
+
def __init__(self, args, Model, config, num_labels=2):
|
| 309 |
+
"""Initialize the model"""
|
| 310 |
+
super(AutoModelForSequenceClassification_SPV_MIP, self).__init__()
|
| 311 |
+
self.num_labels = num_labels
|
| 312 |
+
self.encoder = Model
|
| 313 |
+
self.config = config
|
| 314 |
+
self.dropout = nn.Dropout(args.drop_ratio)
|
| 315 |
+
self.args = args
|
| 316 |
+
|
| 317 |
+
self.SPV_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
|
| 318 |
+
self.MIP_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
|
| 319 |
+
self.classifier = nn.Linear(args.classifier_hidden * 2, num_labels)
|
| 320 |
+
self._init_weights(self.SPV_linear)
|
| 321 |
+
self._init_weights(self.MIP_linear)
|
| 322 |
+
|
| 323 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 324 |
+
self._init_weights(self.classifier)
|
| 325 |
+
|
| 326 |
+
def _init_weights(self, module):
|
| 327 |
+
"""Initialize the weights"""
|
| 328 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 329 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 330 |
+
elif isinstance(module, nn.LayerNorm):
|
| 331 |
+
module.bias.data.zero_()
|
| 332 |
+
module.weight.data.fill_(1.0)
|
| 333 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 334 |
+
module.bias.data.zero_()
|
| 335 |
+
|
| 336 |
+
def forward(
|
| 337 |
+
self,
|
| 338 |
+
input_ids,
|
| 339 |
+
input_ids_2,
|
| 340 |
+
target_mask,
|
| 341 |
+
target_mask_2,
|
| 342 |
+
attention_mask_2,
|
| 343 |
+
token_type_ids=None,
|
| 344 |
+
attention_mask=None,
|
| 345 |
+
labels=None,
|
| 346 |
+
head_mask=None,
|
| 347 |
+
):
|
| 348 |
+
"""
|
| 349 |
+
Inputs:
|
| 350 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the first input token indices in the vocabulary
|
| 351 |
+
`input_ids_2`: a torch.LongTensor of shape [batch_size, sequence_length] with the second input token indicies
|
| 352 |
+
`target_mask`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target word in the first input. 1 for target word and 0 otherwise.
|
| 353 |
+
`target_mask_2`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target word in the second input. 1 for target word and 0 otherwise.
|
| 354 |
+
`attention_mask_2`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1] for the second input.
|
| 355 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices
|
| 356 |
+
selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
| 357 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1] for the first input.
|
| 358 |
+
`labels`: optional labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
| 359 |
+
with indices selected in [0, ..., num_labels].
|
| 360 |
+
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
| 361 |
+
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
# First encoder for full sentence
|
| 365 |
+
outputs = self.encoder(
|
| 366 |
+
input_ids,
|
| 367 |
+
token_type_ids=token_type_ids,
|
| 368 |
+
attention_mask=attention_mask,
|
| 369 |
+
head_mask=head_mask,
|
| 370 |
+
)
|
| 371 |
+
sequence_output = outputs[0] # [batch, max_len, hidden]
|
| 372 |
+
pooled_output = outputs[1] # [batch, hidden]
|
| 373 |
+
|
| 374 |
+
# Get target ouput with target mask
|
| 375 |
+
target_output = sequence_output * target_mask.unsqueeze(2)
|
| 376 |
+
|
| 377 |
+
# dropout
|
| 378 |
+
target_output = self.dropout(target_output)
|
| 379 |
+
pooled_output = self.dropout(pooled_output)
|
| 380 |
+
|
| 381 |
+
target_output = target_output.mean(1) # [batch, hidden]
|
| 382 |
+
|
| 383 |
+
# Second encoder for only the target word
|
| 384 |
+
outputs_2 = self.encoder(input_ids_2, attention_mask=attention_mask_2, head_mask=head_mask)
|
| 385 |
+
sequence_output_2 = outputs_2[0] # [batch, max_len, hidden]
|
| 386 |
+
|
| 387 |
+
# Get target ouput with target mask
|
| 388 |
+
target_output_2 = sequence_output_2 * target_mask_2.unsqueeze(2)
|
| 389 |
+
target_output_2 = self.dropout(target_output_2)
|
| 390 |
+
target_output_2 = target_output_2.mean(1)
|
| 391 |
+
|
| 392 |
+
# Get hidden vectors each from SPV and MIP linear layers
|
| 393 |
+
SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1))
|
| 394 |
+
MIP_hidden = self.MIP_linear(torch.cat([target_output_2, target_output], dim=1))
|
| 395 |
+
|
| 396 |
+
logits = self.classifier(self.dropout(torch.cat([SPV_hidden, MIP_hidden], dim=1)))
|
| 397 |
+
logits = self.logsoftmax(logits)
|
| 398 |
+
|
| 399 |
+
if labels is not None:
|
| 400 |
+
loss_fct = nn.NLLLoss()
|
| 401 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 402 |
+
return loss
|
| 403 |
+
return logits
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
boto3==1.16.63 nltk==3.5 numpy==1.20.0 requests==2.25.1 scikit-learn==0.24.1 scipy==1.6.0
|
| 2 |
+
torch==1.6.0 torchvision==0.7.0 tqdm==4.56.0 transformers==4.2.2
|
run_classifier_dataset_utils.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
|
| 17 |
+
|
| 18 |
+
from __future__ import absolute_import, division, print_function
|
| 19 |
+
|
| 20 |
+
import csv
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import torch
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from scipy.stats import pearsonr, spearmanr, truncnorm
|
| 28 |
+
from sklearn.metrics import (
|
| 29 |
+
matthews_corrcoef,
|
| 30 |
+
f1_score,
|
| 31 |
+
precision_score,
|
| 32 |
+
recall_score,
|
| 33 |
+
mean_squared_error,
|
| 34 |
+
)
|
| 35 |
+
import random
|
| 36 |
+
import nltk
|
| 37 |
+
from nltk.corpus import wordnet
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class InputExample(object):
|
| 43 |
+
"""A single training/test example for simple sequence classification."""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
guid,
|
| 48 |
+
text_a,
|
| 49 |
+
text_b=None,
|
| 50 |
+
label=None,
|
| 51 |
+
POS=None,
|
| 52 |
+
FGPOS=None,
|
| 53 |
+
text_a_2=None,
|
| 54 |
+
text_b_2=None,
|
| 55 |
+
):
|
| 56 |
+
"""Constructs a InputExample.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
guid: Unique id for the example.
|
| 60 |
+
text_a: string. The untokenized text of the first sequence. For single
|
| 61 |
+
sequence tasks, only this sequence must be specified.
|
| 62 |
+
text_b: (Optional) string. The untokenized text of the second sequence.
|
| 63 |
+
Only must be specified for sequence pair tasks.
|
| 64 |
+
label: (Optional) string. The label of the example. This should be
|
| 65 |
+
specified for train and dev examples, but not for test examples.
|
| 66 |
+
"""
|
| 67 |
+
self.guid = guid
|
| 68 |
+
self.text_a = text_a
|
| 69 |
+
self.text_b = text_b
|
| 70 |
+
self.label = label
|
| 71 |
+
self.POS = POS
|
| 72 |
+
self.FGPOS = FGPOS
|
| 73 |
+
self.text_a_2 = text_a_2
|
| 74 |
+
self.text_b_2 = text_b_2
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class InputFeatures(object):
|
| 78 |
+
"""A single set of features of data."""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
input_ids,
|
| 83 |
+
input_mask,
|
| 84 |
+
segment_ids,
|
| 85 |
+
label_id,
|
| 86 |
+
guid=None,
|
| 87 |
+
input_ids_2=None,
|
| 88 |
+
input_mask_2=None,
|
| 89 |
+
segment_ids_2=None,
|
| 90 |
+
):
|
| 91 |
+
self.input_ids = input_ids
|
| 92 |
+
self.input_mask = input_mask
|
| 93 |
+
self.segment_ids = segment_ids
|
| 94 |
+
self.label_id = label_id
|
| 95 |
+
self.guid = guid
|
| 96 |
+
self.input_ids_2 = input_ids_2
|
| 97 |
+
self.input_mask_2 = input_mask_2
|
| 98 |
+
self.segment_ids_2 = segment_ids_2
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class DataProcessor(object):
|
| 102 |
+
"""Base class for data converters for sequence classification data sets."""
|
| 103 |
+
|
| 104 |
+
def get_train_examples(self, data_dir):
|
| 105 |
+
"""Gets a collection of `InputExample`s for the train set."""
|
| 106 |
+
raise NotImplementedError()
|
| 107 |
+
|
| 108 |
+
def get_dev_examples(self, data_dir):
|
| 109 |
+
"""Gets a collection of `InputExample`s for the dev set."""
|
| 110 |
+
raise NotImplementedError()
|
| 111 |
+
|
| 112 |
+
def get_labels(self):
|
| 113 |
+
"""Gets the list of labels for this data set."""
|
| 114 |
+
raise NotImplementedError()
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def _read_tsv(cls, input_file, quotechar=None):
|
| 118 |
+
"""Reads a tab separated value file."""
|
| 119 |
+
with open(input_file, "r", encoding="utf-8") as f:
|
| 120 |
+
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
| 121 |
+
lines = []
|
| 122 |
+
for line in reader:
|
| 123 |
+
if sys.version_info[0] == 2:
|
| 124 |
+
line = list(unicode(cell, "utf-8") for cell in line)
|
| 125 |
+
lines.append(line)
|
| 126 |
+
return lines
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class TrofiProcessor(DataProcessor):
|
| 130 |
+
"""Processor for the TroFi and MOH-X data set."""
|
| 131 |
+
|
| 132 |
+
def get_train_examples(self, data_dir, k=None):
|
| 133 |
+
"""See base class."""
|
| 134 |
+
if k is not None:
|
| 135 |
+
return self._create_examples(
|
| 136 |
+
self._read_tsv(os.path.join(data_dir, "train" + str(k) + ".tsv")), "train"
|
| 137 |
+
)
|
| 138 |
+
else:
|
| 139 |
+
return self._create_examples(
|
| 140 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def get_test_examples(self, data_dir, k=None):
|
| 144 |
+
"""See base class."""
|
| 145 |
+
if k is not None:
|
| 146 |
+
return self._create_examples(
|
| 147 |
+
self._read_tsv(os.path.join(data_dir, "test" + str(k) + ".tsv")), "test"
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 151 |
+
|
| 152 |
+
def get_dev_examples(self, data_dir, k=None):
|
| 153 |
+
"""See base class."""
|
| 154 |
+
if k is not None:
|
| 155 |
+
return self._create_examples(
|
| 156 |
+
self._read_tsv(os.path.join(data_dir, "dev" + str(k) + ".tsv")), "dev"
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 160 |
+
|
| 161 |
+
def get_labels(self):
|
| 162 |
+
"""See base class."""
|
| 163 |
+
return ["0", "1"]
|
| 164 |
+
|
| 165 |
+
def _create_examples(self, lines, set_type):
|
| 166 |
+
"""Creates examples for the training and dev sets."""
|
| 167 |
+
examples = []
|
| 168 |
+
for (i, line) in enumerate(lines):
|
| 169 |
+
if i == 0:
|
| 170 |
+
continue
|
| 171 |
+
guid = "%s-%s" % (set_type, line[0])
|
| 172 |
+
text_a = line[2]
|
| 173 |
+
label = line[1]
|
| 174 |
+
POS = line[3]
|
| 175 |
+
FGPOS = line[4]
|
| 176 |
+
index = line[-1]
|
| 177 |
+
examples.append(
|
| 178 |
+
InputExample(
|
| 179 |
+
guid=guid, text_a=text_a, text_b=index, label=label, POS=POS, FGPOS=FGPOS
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
return examples
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class VUAProcessor(DataProcessor):
|
| 186 |
+
"""Processor for the VUA data set."""
|
| 187 |
+
|
| 188 |
+
def get_train_examples(self, data_dir):
|
| 189 |
+
"""See base class."""
|
| 190 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 191 |
+
|
| 192 |
+
def get_test_examples(self, data_dir):
|
| 193 |
+
"""See base class."""
|
| 194 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 195 |
+
|
| 196 |
+
def get_dev_examples(self, data_dir):
|
| 197 |
+
"""See base class."""
|
| 198 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 199 |
+
|
| 200 |
+
def get_labels(self):
|
| 201 |
+
"""See base class."""
|
| 202 |
+
return ["0", "1"]
|
| 203 |
+
|
| 204 |
+
def _create_examples(self, lines, set_type):
|
| 205 |
+
"""Creates examples for the training and dev sets."""
|
| 206 |
+
examples = []
|
| 207 |
+
for (i, line) in enumerate(lines):
|
| 208 |
+
if i == 0:
|
| 209 |
+
continue
|
| 210 |
+
guid = "%s-%s" % (set_type, line[0])
|
| 211 |
+
text_a = line[2]
|
| 212 |
+
label = line[1]
|
| 213 |
+
POS = line[3]
|
| 214 |
+
FGPOS = line[4]
|
| 215 |
+
if len(line) == 8:
|
| 216 |
+
index = line[5]
|
| 217 |
+
text_a_2 = line[6]
|
| 218 |
+
index_2 = line[7]
|
| 219 |
+
examples.append(
|
| 220 |
+
InputExample(
|
| 221 |
+
guid=guid,
|
| 222 |
+
text_a=text_a,
|
| 223 |
+
text_b=index,
|
| 224 |
+
label=label,
|
| 225 |
+
POS=POS,
|
| 226 |
+
FGPOS=FGPOS,
|
| 227 |
+
text_a_2=text_a_2,
|
| 228 |
+
text_b_2=index_2,
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
index = line[-1]
|
| 233 |
+
examples.append(
|
| 234 |
+
InputExample(
|
| 235 |
+
guid=guid, text_a=text_a, text_b=index, label=label, POS=POS, FGPOS=FGPOS
|
| 236 |
+
)
|
| 237 |
+
)
|
| 238 |
+
return examples
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def convert_examples_to_features(
|
| 242 |
+
examples, label_list, max_seq_length, tokenizer, output_mode, args
|
| 243 |
+
):
|
| 244 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
| 245 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
| 246 |
+
|
| 247 |
+
features = []
|
| 248 |
+
for (ex_index, example) in tqdm(enumerate(examples)):
|
| 249 |
+
if ex_index % 10000 == 0:
|
| 250 |
+
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
| 251 |
+
|
| 252 |
+
tokens_a = tokenizer.tokenize(example.text_a) # tokenize the sentence
|
| 253 |
+
tokens_b = None
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
text_b = int(example.text_b) # index of target word
|
| 257 |
+
tokens_b = text_b
|
| 258 |
+
|
| 259 |
+
# truncate the sentence to max_seq_len
|
| 260 |
+
if len(tokens_a) > max_seq_length - 2:
|
| 261 |
+
tokens_a = tokens_a[: (max_seq_length - 2)]
|
| 262 |
+
|
| 263 |
+
# Find the target word index
|
| 264 |
+
for i, w in enumerate(example.text_a.split()):
|
| 265 |
+
# If w is a target word, tokenize the word and save to text_b
|
| 266 |
+
if i == text_b:
|
| 267 |
+
# consider the index due to models that use a byte-level BPE as a tokenizer (e.g., GPT2, RoBERTa)
|
| 268 |
+
text_b = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
|
| 269 |
+
break
|
| 270 |
+
w_tok = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
|
| 271 |
+
|
| 272 |
+
# Count number of tokens before the target word to get the target word index
|
| 273 |
+
if w_tok:
|
| 274 |
+
tokens_b += len(w_tok) - 1
|
| 275 |
+
|
| 276 |
+
except TypeError:
|
| 277 |
+
if example.text_b:
|
| 278 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
| 279 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
| 280 |
+
# length is less than the specified length.
|
| 281 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
| 282 |
+
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
| 283 |
+
else:
|
| 284 |
+
# Account for [CLS] and [SEP] with "- 2"
|
| 285 |
+
if len(tokens_a) > max_seq_length - 2:
|
| 286 |
+
tokens_a = tokens_a[: (max_seq_length - 2)]
|
| 287 |
+
|
| 288 |
+
tokens = [tokenizer.cls_token] + tokens_a + [tokenizer.sep_token]
|
| 289 |
+
segment_ids = [0] * len(tokens)
|
| 290 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 291 |
+
|
| 292 |
+
# set the target word as 1 in segment ids
|
| 293 |
+
try:
|
| 294 |
+
tokens_b += 1 # add 1 to the target word index considering [CLS]
|
| 295 |
+
for i in range(len(text_b)):
|
| 296 |
+
segment_ids[tokens_b + i] = 1
|
| 297 |
+
except TypeError:
|
| 298 |
+
pass
|
| 299 |
+
|
| 300 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
| 301 |
+
# tokens are attended to.
|
| 302 |
+
input_mask = [1] * len(input_ids)
|
| 303 |
+
|
| 304 |
+
# Zero-pad up to the sequence length.
|
| 305 |
+
padding = [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)] * (
|
| 306 |
+
max_seq_length - len(input_ids)
|
| 307 |
+
)
|
| 308 |
+
input_ids += padding
|
| 309 |
+
input_mask += [0] * len(padding)
|
| 310 |
+
segment_ids += [0] * len(padding)
|
| 311 |
+
|
| 312 |
+
assert len(input_ids) == max_seq_length
|
| 313 |
+
assert len(input_mask) == max_seq_length
|
| 314 |
+
assert len(segment_ids) == max_seq_length
|
| 315 |
+
|
| 316 |
+
if output_mode == "classification":
|
| 317 |
+
label_id = label_map[example.label]
|
| 318 |
+
else:
|
| 319 |
+
raise KeyError(output_mode)
|
| 320 |
+
|
| 321 |
+
if ex_index < 5:
|
| 322 |
+
logger.info("*** Example ***")
|
| 323 |
+
logger.info("guid: %s" % (example.guid))
|
| 324 |
+
logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
|
| 325 |
+
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
| 326 |
+
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
| 327 |
+
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
| 328 |
+
logger.info("label: %s (id = %s)" % (example.label, str(label_id)))
|
| 329 |
+
|
| 330 |
+
features.append(
|
| 331 |
+
InputFeatures(
|
| 332 |
+
input_ids=input_ids,
|
| 333 |
+
input_mask=input_mask,
|
| 334 |
+
segment_ids=segment_ids,
|
| 335 |
+
label_id=label_id,
|
| 336 |
+
guid=example.guid + " " + str(example.text_b),
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
return features
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def convert_two_examples_to_features(
|
| 343 |
+
examples, label_list, max_seq_length, tokenizer, output_mode, win_size=-1
|
| 344 |
+
):
|
| 345 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
| 346 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
| 347 |
+
|
| 348 |
+
features = []
|
| 349 |
+
for (ex_index, example) in enumerate(examples):
|
| 350 |
+
if ex_index % 10000 == 0:
|
| 351 |
+
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
| 352 |
+
|
| 353 |
+
tokens_a = tokenizer.tokenize(example.text_a) # tokenize the sentence
|
| 354 |
+
tokens_b = None
|
| 355 |
+
text_b = None
|
| 356 |
+
|
| 357 |
+
try:
|
| 358 |
+
text_b = int(example.text_b) # index of target word
|
| 359 |
+
tokens_b = text_b
|
| 360 |
+
|
| 361 |
+
# truncate the sentence to max_seq_len
|
| 362 |
+
if len(tokens_a) > max_seq_length - 2:
|
| 363 |
+
tokens_a = tokens_a[: (max_seq_length - 2)]
|
| 364 |
+
|
| 365 |
+
# Find the target word index
|
| 366 |
+
for i, w in enumerate(example.text_a.split()):
|
| 367 |
+
# If w is a target word, tokenize the word and save to text_b
|
| 368 |
+
if i == text_b:
|
| 369 |
+
# consider the index due to models that use a byte-level BPE as a tokenizer (e.g., GPT2, RoBERTa)
|
| 370 |
+
text_b = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
|
| 371 |
+
break
|
| 372 |
+
w_tok = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
|
| 373 |
+
|
| 374 |
+
# Count number of tokens before the target word to get the target word index
|
| 375 |
+
if w_tok:
|
| 376 |
+
tokens_b += len(w_tok) - 1
|
| 377 |
+
|
| 378 |
+
except TypeError:
|
| 379 |
+
if example.text_b:
|
| 380 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
| 381 |
+
|
| 382 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
| 383 |
+
# length is less than the specified length.
|
| 384 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
| 385 |
+
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
| 386 |
+
else:
|
| 387 |
+
# Account for [CLS] and [SEP] with "- 2"
|
| 388 |
+
if len(tokens_a) > max_seq_length - 2:
|
| 389 |
+
tokens_a = tokens_a[: (max_seq_length - 2)]
|
| 390 |
+
|
| 391 |
+
tokens = [tokenizer.cls_token] + tokens_a + [tokenizer.sep_token]
|
| 392 |
+
segment_ids = [0] * len(tokens)
|
| 393 |
+
#import pdb; pdb.set_trace()
|
| 394 |
+
# set the target word as 1 in segment ids
|
| 395 |
+
try:
|
| 396 |
+
tokens_b += 1 # add 1 to the target word index considering [CLS]
|
| 397 |
+
for i in range(len(text_b)):
|
| 398 |
+
segment_ids[tokens_b + i] = 1
|
| 399 |
+
|
| 400 |
+
# concatentate the second sentence ( ["[CLS]"] + tokens_a + ["[SEP]"] -> ["[CLS]"] + tokens_a + ["[SEP]"] + text_b + ["[SEP]"])
|
| 401 |
+
tokens = tokens + text_b + [tokenizer.sep_token]
|
| 402 |
+
segment_ids = segment_ids + [0] * len(text_b)
|
| 403 |
+
except TypeError:
|
| 404 |
+
pass
|
| 405 |
+
|
| 406 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
| 407 |
+
# tokens are attended to.
|
| 408 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 409 |
+
input_mask = [1] * len(input_ids)
|
| 410 |
+
|
| 411 |
+
# Zero-pad up to the sequence length.
|
| 412 |
+
padding = [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)] * (
|
| 413 |
+
max_seq_length - len(input_ids)
|
| 414 |
+
)
|
| 415 |
+
input_ids += padding
|
| 416 |
+
input_mask += [0] * len(padding)
|
| 417 |
+
segment_ids += [0] * len(padding)
|
| 418 |
+
|
| 419 |
+
assert len(input_ids) == max_seq_length
|
| 420 |
+
assert len(input_mask) == max_seq_length
|
| 421 |
+
assert len(segment_ids) == max_seq_length
|
| 422 |
+
|
| 423 |
+
if output_mode == "classification":
|
| 424 |
+
label_id = label_map[example.label]
|
| 425 |
+
else:
|
| 426 |
+
raise KeyError(output_mode)
|
| 427 |
+
|
| 428 |
+
if ex_index < 5:
|
| 429 |
+
logger.info("*** Example ***")
|
| 430 |
+
logger.info("guid: %s" % (example.guid))
|
| 431 |
+
logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
|
| 432 |
+
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
| 433 |
+
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
| 434 |
+
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
| 435 |
+
logger.info("label: %s (id = %s)" % (example.label, str(label_id)))
|
| 436 |
+
|
| 437 |
+
features.append(
|
| 438 |
+
InputFeatures(
|
| 439 |
+
input_ids=input_ids,
|
| 440 |
+
input_mask=input_mask,
|
| 441 |
+
segment_ids=segment_ids,
|
| 442 |
+
label_id=label_id,
|
| 443 |
+
guid=example.guid + " " + example.text_b,
|
| 444 |
+
)
|
| 445 |
+
)
|
| 446 |
+
return features
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def convert_examples_to_two_features(
|
| 450 |
+
examples, label_list, max_seq_length, tokenizer, output_mode, args
|
| 451 |
+
):
|
| 452 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
| 453 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
| 454 |
+
#import pdb; pdb.set_trace()
|
| 455 |
+
# examples = examples[:args.max_data_num] if args.max_data_num is not None else examples
|
| 456 |
+
|
| 457 |
+
features = []
|
| 458 |
+
for (ex_index, example) in tqdm(enumerate(examples)):
|
| 459 |
+
if ex_index % 10000 == 0:
|
| 460 |
+
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
| 461 |
+
|
| 462 |
+
tokens_a = tokenizer.tokenize(example.text_a) # tokenize the sentence
|
| 463 |
+
tokens_b = None
|
| 464 |
+
text_b = None
|
| 465 |
+
#import pdb; pdb.set_trace()
|
| 466 |
+
try:
|
| 467 |
+
#import pdb; pdb.set_trace()
|
| 468 |
+
text_b = int(example.text_b) # index of target word
|
| 469 |
+
tokens_b = text_b
|
| 470 |
+
|
| 471 |
+
# truncate the sentence to max_seq_len
|
| 472 |
+
if len(tokens_a) > max_seq_length - 6:
|
| 473 |
+
tokens_a = tokens_a[: (max_seq_length - 6)]
|
| 474 |
+
|
| 475 |
+
# Find the target word index
|
| 476 |
+
for i, w in enumerate(example.text_a.split()):
|
| 477 |
+
# If w is a target word, tokenize the word and save to text_b
|
| 478 |
+
if i == text_b:
|
| 479 |
+
# consider the index due to models that use a byte-level BPE as a tokenizer (e.g., GPT2, RoBERTa)
|
| 480 |
+
text_b = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
|
| 481 |
+
break
|
| 482 |
+
|
| 483 |
+
w_tok = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
|
| 484 |
+
|
| 485 |
+
# Count number of tokens before the target word to get the target word index
|
| 486 |
+
if w_tok:
|
| 487 |
+
tokens_b += len(w_tok) - 1
|
| 488 |
+
|
| 489 |
+
if tokens_b + len(text_b) > max_seq_length - 6:
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
except TypeError:
|
| 493 |
+
#import pdb; pdb.set_trace()
|
| 494 |
+
print('Y|', example.text_b, tokens_b)
|
| 495 |
+
if example.text_b:
|
| 496 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
| 497 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
| 498 |
+
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
| 499 |
+
else:
|
| 500 |
+
# Account for [CLS] and [SEP] with "- 2"
|
| 501 |
+
if len(tokens_a) > max_seq_length - 2:
|
| 502 |
+
tokens_a = tokens_a[: (max_seq_length - 2)]
|
| 503 |
+
|
| 504 |
+
tokens = [tokenizer.cls_token] + tokens_a + [tokenizer.sep_token]
|
| 505 |
+
print('after|', text_b, tokens_b, tokens)
|
| 506 |
+
#print('N|', tokens_b)
|
| 507 |
+
# POS tag tokens
|
| 508 |
+
if args.use_pos:
|
| 509 |
+
POS_token = tokenizer.tokenize(example.POS)
|
| 510 |
+
tokens += POS_token + [tokenizer.sep_token]
|
| 511 |
+
|
| 512 |
+
# Local context
|
| 513 |
+
if args.use_local_context:
|
| 514 |
+
local_start = 1
|
| 515 |
+
local_end = local_start + len(tokens_a)
|
| 516 |
+
comma1 = tokenizer.tokenize(",")[0]
|
| 517 |
+
comma2 = tokenizer.tokenize(" ,")[0]
|
| 518 |
+
for i, w in enumerate(tokens):
|
| 519 |
+
if i < tokens_b + 1 and (w in [comma1, comma2]):
|
| 520 |
+
local_start = i
|
| 521 |
+
if i > tokens_b + 1 and (w in [comma1, comma2]):
|
| 522 |
+
local_end = i
|
| 523 |
+
break
|
| 524 |
+
segment_ids = [
|
| 525 |
+
2 if i >= local_start and i <= local_end else 0 for i in range(len(tokens))
|
| 526 |
+
]
|
| 527 |
+
else:
|
| 528 |
+
segment_ids = [0] * len(tokens)
|
| 529 |
+
|
| 530 |
+
# POS tag encoding
|
| 531 |
+
after_token_a = False
|
| 532 |
+
for i, t in enumerate(tokens):
|
| 533 |
+
if t == tokenizer.sep_token:
|
| 534 |
+
after_token_a = True
|
| 535 |
+
if after_token_a and t != tokenizer.sep_token:
|
| 536 |
+
segment_ids[i] = 3
|
| 537 |
+
|
| 538 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 539 |
+
|
| 540 |
+
try:
|
| 541 |
+
tokens_b += 1 # add 1 to the target word index considering [CLS]
|
| 542 |
+
for i in range(len(text_b)):
|
| 543 |
+
segment_ids[tokens_b + i] = 1
|
| 544 |
+
except TypeError:
|
| 545 |
+
pass
|
| 546 |
+
|
| 547 |
+
input_mask = [1] * len(input_ids)
|
| 548 |
+
padding = [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)] * (
|
| 549 |
+
max_seq_length - len(input_ids)
|
| 550 |
+
)
|
| 551 |
+
input_ids += padding
|
| 552 |
+
input_mask += [0] * len(padding)
|
| 553 |
+
segment_ids += [0] * len(padding)
|
| 554 |
+
|
| 555 |
+
assert len(input_ids) == max_seq_length
|
| 556 |
+
assert len(input_mask) == max_seq_length
|
| 557 |
+
assert len(segment_ids) == max_seq_length
|
| 558 |
+
|
| 559 |
+
if output_mode == "classification":
|
| 560 |
+
label_id = label_map[example.label]
|
| 561 |
+
else:
|
| 562 |
+
raise KeyError(output_mode)
|
| 563 |
+
|
| 564 |
+
# Second features (Target word)
|
| 565 |
+
tokens = [tokenizer.cls_token] + text_b + [tokenizer.sep_token]
|
| 566 |
+
segment_ids_2 = [0] * len(tokens)
|
| 567 |
+
try:
|
| 568 |
+
tokens_b = 1 # add 1 to the target word index considering [CLS]
|
| 569 |
+
for i in range(len(text_b)):
|
| 570 |
+
segment_ids_2[tokens_b + i] = 1
|
| 571 |
+
except TypeError:
|
| 572 |
+
pass
|
| 573 |
+
|
| 574 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
|
| 575 |
+
input_ids_2 = tokenizer.convert_tokens_to_ids(tokens)
|
| 576 |
+
input_mask_2 = [1] * len(input_ids_2)
|
| 577 |
+
|
| 578 |
+
padding = [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)] * (
|
| 579 |
+
max_seq_length - len(input_ids_2)
|
| 580 |
+
)
|
| 581 |
+
input_ids_2 += padding
|
| 582 |
+
input_mask_2 += [0] * len(padding)
|
| 583 |
+
segment_ids_2 += [0] * len(padding)
|
| 584 |
+
|
| 585 |
+
assert len(input_ids_2) == max_seq_length
|
| 586 |
+
assert len(input_mask_2) == max_seq_length
|
| 587 |
+
assert len(segment_ids_2) == max_seq_length
|
| 588 |
+
|
| 589 |
+
features.append(
|
| 590 |
+
InputFeatures(
|
| 591 |
+
input_ids=input_ids,
|
| 592 |
+
input_mask=input_mask,
|
| 593 |
+
segment_ids=segment_ids,
|
| 594 |
+
label_id=label_id,
|
| 595 |
+
guid=example.guid + " " + str(example.text_b),
|
| 596 |
+
input_ids_2=input_ids_2,
|
| 597 |
+
input_mask_2=input_mask_2,
|
| 598 |
+
segment_ids_2=segment_ids_2,
|
| 599 |
+
)
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
return features
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
| 606 |
+
"""Truncates a sequence pair in place to the maximum length."""
|
| 607 |
+
|
| 608 |
+
# This is a simple heuristic which will always truncate the longer sequence
|
| 609 |
+
# one token at a time. This makes more sense than truncating an equal percent
|
| 610 |
+
# of tokens from each, since if one sequence is very short then each token
|
| 611 |
+
# that's truncated likely contains more information than a longer sequence.
|
| 612 |
+
while True:
|
| 613 |
+
total_length = len(tokens_a) + len(tokens_b)
|
| 614 |
+
if total_length <= max_length:
|
| 615 |
+
break
|
| 616 |
+
if len(tokens_a) > len(tokens_b):
|
| 617 |
+
tokens_a.pop()
|
| 618 |
+
else:
|
| 619 |
+
tokens_b.pop()
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def simple_accuracy(preds, labels):
|
| 623 |
+
return (preds == labels).mean()
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def seq_accuracy(preds, labels):
|
| 627 |
+
acc = []
|
| 628 |
+
for idx, pred in enumerate(preds):
|
| 629 |
+
acc.append((pred == labels[idx]).mean())
|
| 630 |
+
return acc.mean()
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def acc_and_f1(preds, labels):
|
| 634 |
+
acc = simple_accuracy(preds, labels)
|
| 635 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
| 636 |
+
return {
|
| 637 |
+
"acc": acc,
|
| 638 |
+
"f1": f1,
|
| 639 |
+
"acc_and_f1": (acc + f1) / 2,
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def all_metrics(preds, labels):
|
| 644 |
+
acc = simple_accuracy(preds, labels)
|
| 645 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
| 646 |
+
pre = precision_score(y_true=labels, y_pred=preds)
|
| 647 |
+
rec = recall_score(y_true=labels, y_pred=preds)
|
| 648 |
+
return {
|
| 649 |
+
"acc": acc,
|
| 650 |
+
"precision": pre,
|
| 651 |
+
"recall": rec,
|
| 652 |
+
"f1": f1,
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def compute_metrics(preds, labels):
|
| 657 |
+
assert len(preds) == len(labels)
|
| 658 |
+
return all_metrics(preds, labels)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
processors = {
|
| 662 |
+
"vua": VUAProcessor,
|
| 663 |
+
"trofi": TrofiProcessor,
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
output_modes = {
|
| 667 |
+
"vua": "classification",
|
| 668 |
+
"trofi": "classification",
|
| 669 |
+
}
|
scripts/run.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
python main.py --data_dir data/VUA20 --task_name vua --model_type MELBERT --train_batch_size 32 --learning_rate 3e-5 --warmup_epoch 2
|
scripts/run_bagging.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
INDEXES=$(seq 0 9)
|
| 4 |
+
for i in $INDEXES
|
| 5 |
+
do
|
| 6 |
+
echo "Running bagging for index $i"
|
| 7 |
+
python main.py --data_dir data/VUA20 --task_name vua --model_type MELBERT --train_batch_size 32 --learning_rate 3e-5 --warmup_epoch 2 --num_bagging 10 --bagging_index $i
|
| 8 |
+
done
|
utils/Config.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from configparser import ConfigParser
|
| 4 |
+
|
| 5 |
+
class Config:
|
| 6 |
+
def __init__(self, main_conf_path):
|
| 7 |
+
self.main_conf_path = main_conf_path
|
| 8 |
+
self.main_config = self.read_config(os.path.join(main_conf_path, 'main_config.cfg'))
|
| 9 |
+
|
| 10 |
+
def read_config(self, conf_path):
|
| 11 |
+
conf_dict = OrderedDict()
|
| 12 |
+
|
| 13 |
+
config = ConfigParser()
|
| 14 |
+
config.read(conf_path)
|
| 15 |
+
for section in config.sections():
|
| 16 |
+
section_config = OrderedDict(config[section].items())
|
| 17 |
+
conf_dict[section] = self.type_ensurance(section_config)
|
| 18 |
+
self.__dict__.update((k, v) for k, v in conf_dict[section].items())
|
| 19 |
+
|
| 20 |
+
return conf_dict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def ensure_value_type(self, v):
|
| 24 |
+
BOOLEAN = {'false': False, 'False': False,
|
| 25 |
+
'true': True, 'True': True}
|
| 26 |
+
if isinstance(v, str):
|
| 27 |
+
try:
|
| 28 |
+
value = eval(v)
|
| 29 |
+
if not isinstance(value, (str, int, float, list, tuple)):
|
| 30 |
+
value = v
|
| 31 |
+
except:
|
| 32 |
+
if v in BOOLEAN:
|
| 33 |
+
v = BOOLEAN[v]
|
| 34 |
+
value = v
|
| 35 |
+
else:
|
| 36 |
+
value = v
|
| 37 |
+
return value
|
| 38 |
+
|
| 39 |
+
def type_ensurance(self, config):
|
| 40 |
+
BOOLEAN = {'false': False, 'False': False,
|
| 41 |
+
'true': True, 'True': True}
|
| 42 |
+
|
| 43 |
+
for k, v in config.items():
|
| 44 |
+
try:
|
| 45 |
+
value = eval(v)
|
| 46 |
+
if not isinstance(value, (str, int, float, list, tuple)):
|
| 47 |
+
value = v
|
| 48 |
+
except:
|
| 49 |
+
if v in BOOLEAN:
|
| 50 |
+
v = BOOLEAN[v]
|
| 51 |
+
value = v
|
| 52 |
+
config[k] = value
|
| 53 |
+
return config
|
| 54 |
+
|
| 55 |
+
def get_param(self, section, param):
|
| 56 |
+
if section in self.main_config:
|
| 57 |
+
section = self.main_config[section]
|
| 58 |
+
else:
|
| 59 |
+
raise NameError("There are not the parameter named '%s'" % section)
|
| 60 |
+
|
| 61 |
+
if param in section:
|
| 62 |
+
value = section[param]
|
| 63 |
+
else:
|
| 64 |
+
raise NameError("There are not the parameter named '%s'" % param)
|
| 65 |
+
|
| 66 |
+
return value
|
| 67 |
+
|
| 68 |
+
def update_params(self, params):
|
| 69 |
+
# for now, assume 'params' is dictionary
|
| 70 |
+
|
| 71 |
+
for k, v in params.items():
|
| 72 |
+
updated=False
|
| 73 |
+
for section in self.main_config:
|
| 74 |
+
if k in self.main_config[section]:
|
| 75 |
+
self.main_config[section][k] = self.ensure_value_type(v)
|
| 76 |
+
self.__dict__[k] = self.main_config[section][k]
|
| 77 |
+
updated = True
|
| 78 |
+
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
if not updated:
|
| 82 |
+
# raise ValueError
|
| 83 |
+
print('Parameter not updated. \'%s\' not exists.' % k)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def save(self, base_dir):
|
| 87 |
+
def helper(section_k, section_v):
|
| 88 |
+
sec_str = '[%s]\n' % section_k
|
| 89 |
+
for k, v in section_v.items():
|
| 90 |
+
sec_str += '%s=%s\n' % (str(k), str(v))
|
| 91 |
+
sec_str += '\n'
|
| 92 |
+
return sec_str
|
| 93 |
+
|
| 94 |
+
# save main config
|
| 95 |
+
main_conf_str =''
|
| 96 |
+
for section in self.main_config:
|
| 97 |
+
main_conf_str += helper(section, self.main_config[section])
|
| 98 |
+
with open(os.path.join(base_dir, 'main_config.cfg'), 'wt') as f:
|
| 99 |
+
f.write(main_conf_str)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
print('main config saved in %s' % base_dir)
|
| 103 |
+
|
| 104 |
+
def __getitem__(self, item):
|
| 105 |
+
if not isinstance(item, str):
|
| 106 |
+
raise TypeError("index must be a str")
|
| 107 |
+
|
| 108 |
+
if item in self.main_config:
|
| 109 |
+
section = self.main_config[item]
|
| 110 |
+
else:
|
| 111 |
+
raise NameError("There are not the parameter named '%s'" % item)
|
| 112 |
+
return section
|
| 113 |
+
|
| 114 |
+
def __str__(self):
|
| 115 |
+
config_str = '\n'
|
| 116 |
+
|
| 117 |
+
config_str += '>>>>> Main Config\n'
|
| 118 |
+
for section in self.main_config:
|
| 119 |
+
config_str += '[%s]\n' % section
|
| 120 |
+
config_str += '\n'.join(['{}: {}'.format(k, self.main_config[section][k]) for k in self.main_config[section]])
|
| 121 |
+
config_str += '\n\n'
|
| 122 |
+
|
| 123 |
+
return config_str
|
| 124 |
+
|
| 125 |
+
if __name__ == '__main__':
|
| 126 |
+
param = Config('../main_config.cfg')
|
| 127 |
+
|
| 128 |
+
print(param)
|
utils/Logger.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from time import strftime
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def make_log_dir(log_dir):
|
| 7 |
+
"""
|
| 8 |
+
Generate directory path to log
|
| 9 |
+
|
| 10 |
+
:param log_dir:
|
| 11 |
+
|
| 12 |
+
:return:
|
| 13 |
+
"""
|
| 14 |
+
if not os.path.exists(log_dir):
|
| 15 |
+
os.mkdir(log_dir)
|
| 16 |
+
|
| 17 |
+
log_dirs = os.listdir(log_dir)
|
| 18 |
+
if len(log_dirs) == 0:
|
| 19 |
+
idx = 0
|
| 20 |
+
else:
|
| 21 |
+
idx_list = sorted([int(d.split("_")[0]) for d in log_dirs])
|
| 22 |
+
idx = idx_list[-1] + 1
|
| 23 |
+
|
| 24 |
+
cur_log_dir = "%d_%s" % (idx, strftime("%Y%m%d-%H%M"))
|
| 25 |
+
full_log_dir = os.path.join(log_dir, cur_log_dir)
|
| 26 |
+
if not os.path.exists(full_log_dir):
|
| 27 |
+
os.mkdir(full_log_dir)
|
| 28 |
+
|
| 29 |
+
return full_log_dir
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Logger:
|
| 33 |
+
def __init__(self, log_dir):
|
| 34 |
+
log_file_format = "[%(lineno)d]%(asctime)s: %(message)s"
|
| 35 |
+
log_console_format = "%(message)s"
|
| 36 |
+
|
| 37 |
+
# Main logger
|
| 38 |
+
self.log_dir = log_dir
|
| 39 |
+
|
| 40 |
+
self.logger = logging.getLogger(log_dir)
|
| 41 |
+
self.logger.setLevel(logging.INFO)
|
| 42 |
+
self.logger.propagate = False
|
| 43 |
+
|
| 44 |
+
console_handler = logging.StreamHandler()
|
| 45 |
+
console_handler.setLevel(logging.INFO)
|
| 46 |
+
console_handler.setFormatter(logging.Formatter(log_console_format))
|
| 47 |
+
|
| 48 |
+
file_handler = logging.FileHandler(os.path.join(log_dir, "experiments.log"))
|
| 49 |
+
file_handler.setLevel(logging.DEBUG)
|
| 50 |
+
file_handler.setFormatter(logging.Formatter(log_file_format))
|
| 51 |
+
|
| 52 |
+
self.logger.addHandler(console_handler)
|
| 53 |
+
self.logger.addHandler(file_handler)
|
| 54 |
+
|
| 55 |
+
def info(self, msg):
|
| 56 |
+
self.logger.info(msg)
|
| 57 |
+
|
| 58 |
+
def close(self):
|
| 59 |
+
for handle in self.logger.handlers[:]:
|
| 60 |
+
self.logger.removeHandler(handle)
|
| 61 |
+
logging.shutdown()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def setup_logger(log_dir):
|
| 65 |
+
log_file_format = "[%(lineno)d]%(asctime)s: %(message)s"
|
| 66 |
+
log_console_format = "%(message)s"
|
| 67 |
+
|
| 68 |
+
# Main logger
|
| 69 |
+
logger = logging.getLogger()
|
| 70 |
+
logger.setLevel(logging.INFO)
|
| 71 |
+
logger.propagate = False
|
| 72 |
+
|
| 73 |
+
console_handler = logging.StreamHandler()
|
| 74 |
+
console_handler.setLevel(logging.INFO)
|
| 75 |
+
console_handler.setFormatter(logging.Formatter(log_console_format))
|
| 76 |
+
|
| 77 |
+
file_handler = logging.FileHandler(os.path.join(log_dir, "experiments.log"))
|
| 78 |
+
file_handler.setLevel(logging.DEBUG)
|
| 79 |
+
file_handler.setFormatter(logging.Formatter(log_file_format))
|
| 80 |
+
|
| 81 |
+
logger.addHandler(console_handler)
|
| 82 |
+
logger.addHandler(file_handler)
|
| 83 |
+
|
| 84 |
+
return logger
|
utils/ResultTable.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
class ResultTable:
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
Class to save and show result neatly.
|
| 8 |
+
First column is always 'NAME' column.
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, table_name='table', header=None, splitter='||', int_formatter='%3d', float_formatter='%.4f'):
|
| 12 |
+
"""
|
| 13 |
+
Initialize table setting.
|
| 14 |
+
|
| 15 |
+
:param list header: list of string, table headers.
|
| 16 |
+
:param str splitter:
|
| 17 |
+
:param str int_formatter:
|
| 18 |
+
:param str float_formatter:
|
| 19 |
+
"""
|
| 20 |
+
self.table_name = table_name
|
| 21 |
+
self.header = header
|
| 22 |
+
if self.header is not None:
|
| 23 |
+
self.set_headers(self.header)
|
| 24 |
+
self.num_rows = 0
|
| 25 |
+
self.splitter = splitter
|
| 26 |
+
self.int_formatter = int_formatter
|
| 27 |
+
self.float_formatter = float_formatter
|
| 28 |
+
|
| 29 |
+
def set_headers(self, header):
|
| 30 |
+
"""
|
| 31 |
+
Set table headers as given and clear all data.
|
| 32 |
+
|
| 33 |
+
:param list header: list of header strings
|
| 34 |
+
:return: None
|
| 35 |
+
"""
|
| 36 |
+
self.header = header
|
| 37 |
+
if 'NAME' not in header:
|
| 38 |
+
self.header = ['NAME'] + self.header
|
| 39 |
+
self.data = OrderedDict([(h, []) for h in self.header])
|
| 40 |
+
self.max_len = OrderedDict([(h, len(h)) for h in self.header])
|
| 41 |
+
# {h: len(h) for h in self.header}
|
| 42 |
+
|
| 43 |
+
def add_row(self, row_name, row_dict):
|
| 44 |
+
"""
|
| 45 |
+
Add new row into the table.
|
| 46 |
+
|
| 47 |
+
:param str row_name: name of the row, which will be the first column
|
| 48 |
+
:param dict row_dict: dictionary containing column name as a key and column value as value.
|
| 49 |
+
:return: None
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# If header is not defined, fetch from input dict
|
| 53 |
+
if self.header is None:
|
| 54 |
+
self.set_headers(list(row_dict.keys()))
|
| 55 |
+
|
| 56 |
+
# If input dict has new column, make one
|
| 57 |
+
for key in row_dict:
|
| 58 |
+
if key not in self.data:
|
| 59 |
+
self.set_headers(self.header + [key])
|
| 60 |
+
|
| 61 |
+
for h in self.header:
|
| 62 |
+
if h == 'NAME':
|
| 63 |
+
self.data['NAME'].append(row_name)
|
| 64 |
+
self.max_len[h] = max(self.max_len['NAME'], len(row_name))
|
| 65 |
+
else:
|
| 66 |
+
# If input dict doesn't have values for table header, make empty value.
|
| 67 |
+
if h not in row_dict:
|
| 68 |
+
row_dict[h] = '-'
|
| 69 |
+
|
| 70 |
+
# convert input dict to string
|
| 71 |
+
d = row_dict[h]
|
| 72 |
+
|
| 73 |
+
if isinstance(d, (int, np.integer)):
|
| 74 |
+
d_str = self.int_formatter % d
|
| 75 |
+
elif isinstance(d, (float, np.float)):
|
| 76 |
+
d_str = self.float_formatter % d
|
| 77 |
+
elif isinstance(d, str):
|
| 78 |
+
d_str = d
|
| 79 |
+
elif isinstance(d, list):
|
| 80 |
+
d_str = str(d)
|
| 81 |
+
else:
|
| 82 |
+
raise NotImplementedError('data type currently not supported. %s' % str(type(d)))
|
| 83 |
+
|
| 84 |
+
self.data[h].append(d_str)
|
| 85 |
+
self.max_len[h] = max(self.max_len[h], len(d_str))
|
| 86 |
+
self.num_rows += 1
|
| 87 |
+
|
| 88 |
+
def row_to_line(self, row_values):
|
| 89 |
+
"""
|
| 90 |
+
Convert a row into string form
|
| 91 |
+
|
| 92 |
+
:param list row_values: list of row values as string
|
| 93 |
+
:return: string form of a row
|
| 94 |
+
"""
|
| 95 |
+
value_str = []
|
| 96 |
+
for i, header in enumerate(self.header):
|
| 97 |
+
max_length = self.max_len[header]
|
| 98 |
+
length = len(row_values[i])
|
| 99 |
+
diff = max_length - length
|
| 100 |
+
|
| 101 |
+
# Center align
|
| 102 |
+
# left_space = diff // 2
|
| 103 |
+
# right_space = diff - left_space
|
| 104 |
+
# s = ' ' * left_space + row_values[i] + ' ' * right_space
|
| 105 |
+
|
| 106 |
+
# Left align
|
| 107 |
+
s = row_values[i] + ' ' * diff
|
| 108 |
+
value_str.append(s)
|
| 109 |
+
|
| 110 |
+
# for i, max_length in enumerate(self.max_len.values()):
|
| 111 |
+
# length = len(row_values[i])
|
| 112 |
+
# diff = max_length - length
|
| 113 |
+
#
|
| 114 |
+
# # Center align
|
| 115 |
+
# # left_space = diff // 2
|
| 116 |
+
# # right_space = diff - left_space
|
| 117 |
+
# # s = ' ' * left_space + row_values[i] + ' ' * right_space
|
| 118 |
+
#
|
| 119 |
+
# # Left align
|
| 120 |
+
# s = row_values[i] + ' ' * diff
|
| 121 |
+
# value_str.append(s)
|
| 122 |
+
|
| 123 |
+
return self.splitter + ' ' + (' %s ' % self.splitter).join(value_str) + ' ' + self.splitter
|
| 124 |
+
|
| 125 |
+
def to_string(self):
|
| 126 |
+
"""
|
| 127 |
+
Convert a table into string form
|
| 128 |
+
|
| 129 |
+
:return: string form of the table
|
| 130 |
+
"""
|
| 131 |
+
size_per_col = {h: self.max_len[h] + 2 + len(self.splitter) for h in self.header}
|
| 132 |
+
line_len = sum([size_per_col[c] for c in size_per_col]) + len(self.splitter)
|
| 133 |
+
table_str = '\n'
|
| 134 |
+
|
| 135 |
+
# TABLE NAME
|
| 136 |
+
table_str += self.table_name + '\n'
|
| 137 |
+
|
| 138 |
+
# HEADER
|
| 139 |
+
line = self.row_to_line(self.header)
|
| 140 |
+
table_str += '=' * line_len + '\n'
|
| 141 |
+
table_str += line + '\n'
|
| 142 |
+
table_str += self.splitter + '-' * (line_len - len(self.splitter) * 2) + self.splitter + '\n'
|
| 143 |
+
|
| 144 |
+
# DATA
|
| 145 |
+
for row_values in zip(*self.data.values()):
|
| 146 |
+
line = self.row_to_line(row_values)
|
| 147 |
+
table_str += line + '\n'
|
| 148 |
+
table_str += '=' * line_len + '\n'
|
| 149 |
+
return table_str
|
| 150 |
+
|
| 151 |
+
def show(self):
|
| 152 |
+
print(self.to_string())
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def shape(self):
|
| 156 |
+
return (self.num_rows, self.num_cols)
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def num_cols(self):
|
| 160 |
+
return len(self.header)
|
utils/Statistics.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
class Statistics:
|
| 4 |
+
def __init__(self, name='AVG'):
|
| 5 |
+
self.name = name
|
| 6 |
+
self.history = []
|
| 7 |
+
self.sum = 0
|
| 8 |
+
self.cnt = 0
|
| 9 |
+
|
| 10 |
+
def update(self, val):
|
| 11 |
+
self.history.append(val)
|
| 12 |
+
self.sum += val
|
| 13 |
+
self.cnt += 1
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def mean_std(self):
|
| 17 |
+
# mean = self.sum / self.cnt
|
| 18 |
+
mean = np.mean(self.history)
|
| 19 |
+
std = np.std(self.history)
|
| 20 |
+
return mean, std
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def mean(self):
|
| 24 |
+
# return self.sum / self.cnt
|
| 25 |
+
return np.mean(self.history)
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def std(self):
|
| 29 |
+
return np.std(self.history)
|
utils/Tool.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
def set_random_seed(seed):
|
| 8 |
+
np.random.seed(seed)
|
| 9 |
+
torch.random.manual_seed(seed)
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
torch.cuda.manual_seed_all(seed)
|
| 12 |
+
torch.backends.cudnn.deterministic = True
|
| 13 |
+
torch.backends.cudnn.benchmark = False
|
| 14 |
+
|
| 15 |
+
def getlocaltime():
|
| 16 |
+
date = time.strftime('%y-%m-%d', time.localtime())
|
| 17 |
+
current_time = time.strftime('%H:%M:%S', time.localtime())
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .Config import Config
|
| 2 |
+
from .ResultTable import ResultTable
|
| 3 |
+
from .Logger import Logger, make_log_dir
|
| 4 |
+
from .Tool import set_random_seed
|
utils/__pycache__/Config.cpython-36.pyc
ADDED
|
Binary file (3.8 kB). View file
|
|
|
utils/__pycache__/Config.cpython-38.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
utils/__pycache__/Logger.cpython-36.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
utils/__pycache__/Logger.cpython-38.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
utils/__pycache__/ResultTable.cpython-36.pyc
ADDED
|
Binary file (4.42 kB). View file
|
|
|
utils/__pycache__/ResultTable.cpython-38.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
utils/__pycache__/Tool.cpython-36.pyc
ADDED
|
Binary file (708 Bytes). View file
|
|
|
utils/__pycache__/Tool.cpython-38.pyc
ADDED
|
Binary file (724 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (304 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (312 Bytes). View file
|
|
|