File size: 4,235 Bytes
f4dcc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#
# Author: penhe@microsoft.com
# Date: 01/25/2019
#

import os
import csv
import copy
from collections import OrderedDict,defaultdict,Counter
from collections.abc import Sequence
import numpy as np
from ...utils import get_logger
from ...utils import xtqdm as tqdm
from ...data import example_to_feature
from .metrics import *

from ..models import SequenceClassificationModel
logger=get_logger()

__all__ = ['EvalData', 'Task']

class EvalData:
  def __init__(self, name, examples, metrics_fn=None, predict_fn=None, ignore_metric=False, critial_metrics=None):
    def accuracy_fn(logits, labels):
      return OrderedDict(accuracy= metric_accuracy(logits, labels))

    def default_pred_fn(logits, output_dir, name, prefix):
      output=os.path.join(output_dir, 'submit-{}-{}.tsv'.format(name, prefix))
      preds = np.argmax(logits, axis=-1)
      with open(output, 'w', encoding='utf-8') as fs:
        fs.write('index\tpredictions\n')
        for i,p in enumerate(preds):
          fs.write('{}\t{}\n'.format(i, p))
    self.name = name
    self.data = examples
    self.ignore_metric = ignore_metric
    self.critial_metrics = critial_metrics
    self.metrics_fn = metrics_fn if metrics_fn is not None else accuracy_fn
    self.predict_fn = predict_fn if predict_fn is not None else default_pred_fn

  def __repr__(self):
    return f'{self.name}, {type(self.data)}: {len(self.data)}, {self.predict_fn}, {self.metrics_fn}'

class Task():
  _meta={}

  def __init__(self, tokenizer, args, **kwargs):
    self.tokenizer = tokenizer
    self.args = args
  
  def eval_data(self, **kwargs):
    raise NotImplementedError('Eval_data method not implemented yet.')

  def train_data(self, **kwargs):
    raise NotImplementedError('Eval_data method not implemented yet.')

  def test_data(self, **kwargs):
    raise NotImplementedError('Eval_data method not implemented yet.')

  def get_labels(self):
    """Gets the list of labels for this data set."""
    raise NotImplementedError()

  def label2id(self, labelstr):
    label_dict = {l:i for i,l in enumerate(self.get_labels())}
    return label_dict[labelstr] if labelstr in label_dict else -1

  def get_train_fn(self, *args, **kwargs):
    return None

  def get_eval_fn(self, *args, **kwargs):
    return None

  def get_pred_fn(self, *args, **kwargs):
    return None

  def get_loss_fn(self, *args, **kwargs):
    return None

  def get_metrics_fn(self):
    """Calcuate metrics based on prediction results"""
    def metrics_fn(logits, labels):
      return OrderedDict(accuracy= metric_accuracy(logits, labels))
    return metrics_fn

  def get_predict_fn(self):
    """Calcuate metrics based on prediction results"""
    def predict_fn(logits, output_dir, name, prefix):
      output=os.path.join(output_dir, 'submit-{}-{}.tsv'.format(name, prefix))
      preds = np.argmax(logits, axis=-1)
      labels = self.get_labels()
      with open(output, 'w', encoding='utf-8') as fs:
        fs.write('index\tpredictions\n')
        for i,p in enumerate(preds):
          fs.write('{}\t{}\n'.format(i, labels[p]))

    return predict_fn

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    """Reads a tab separated value file."""
    with open(input_file, "r", encoding='utf-8') as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []
      for line in reader:
        lines.append(line)
      return lines

  def get_feature_fn(self, max_seq_len = 512, mask_gen = None, label_type='int', training=False):
    tokenizer = self.tokenizer
    def _example_to_feature(example, rng=None, ext_params=None, **kwargs):
      return example_to_feature(tokenizer, example, max_seq_len = max_seq_len, \
        rng = rng, mask_generator = mask_gen, ext_params = ext_params, label_type=label_type, **kwargs)
    return _example_to_feature

  def get_model_class_fn(self):
    return SequenceClassificationModel.load_model
  
  @classmethod
  def add_arguments(cls, parser):
    """Add task specific arguments
      e.g. parser.add_argument('--data_dir', type=str, help='The path of data directory.')
    """
    parser.add_argument('--task_example_arg', type=str, default=None, help='An example task specific argument')

    return parser