|
|
|
|
|
|
|
|
import os |
|
|
from parser.utils import Embedding |
|
|
from parser.utils.alg import eisner |
|
|
from parser.utils.common import bos, pad, unk |
|
|
from parser.utils.corpus import CoNLL, Corpus |
|
|
from parser.utils.field import BertField, CharField, Field |
|
|
from parser.utils.fn import ispunct |
|
|
from parser.utils.metric import Metric |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoTokenizer, BertTokenizer |
|
|
|
|
|
|
|
|
class CMD(object): |
|
|
|
|
|
def __call__(self, args): |
|
|
self.args = args |
|
|
if not os.path.exists(args.file): |
|
|
os.mkdir(args.file) |
|
|
if not os.path.exists(args.fields) or args.preprocess: |
|
|
print("Preprocess the data") |
|
|
self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) |
|
|
if args.feat == 'char': |
|
|
self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos, |
|
|
fix_len=args.fix_len, tokenize=list) |
|
|
elif args.feat == 'bert': |
|
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model) |
|
|
|
|
|
self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]', |
|
|
tokenize=tokenizer.encode) |
|
|
else: |
|
|
self.FEAT = Field('tags', bos=bos) |
|
|
self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int) |
|
|
self.REL = Field('rels', bos=bos) |
|
|
if args.feat in ('char', 'bert'): |
|
|
self.fields = CoNLL(FORM=(self.WORD, self.FEAT), |
|
|
HEAD=self.HEAD, DEPREL=self.REL) |
|
|
else: |
|
|
self.fields = CoNLL(FORM=self.WORD, CPOS=self.FEAT, |
|
|
HEAD=self.HEAD, DEPREL=self.REL) |
|
|
|
|
|
train = Corpus.load(args.ftrain, self.fields) |
|
|
|
|
|
|
|
|
|
|
|
embed = None |
|
|
self.WORD.build(train, args.min_freq, embed) |
|
|
self.FEAT.build(train) |
|
|
self.REL.build(train) |
|
|
torch.save(self.fields, args.fields) |
|
|
else: |
|
|
self.fields = torch.load(args.fields) |
|
|
if args.feat in ('char', 'bert'): |
|
|
self.WORD, self.FEAT = self.fields.FORM |
|
|
else: |
|
|
self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS |
|
|
self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL |
|
|
self.puncts = torch.tensor([i for s, i in self.WORD.vocab.stoi.items() |
|
|
if ispunct(s)]).to(args.device) |
|
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
print(f"{self.WORD}\n{self.FEAT}\n{self.HEAD}\n{self.REL}") |
|
|
args.update({ |
|
|
'n_words': self.WORD.vocab.n_init, |
|
|
'n_feats': len(self.FEAT.vocab), |
|
|
'n_rels': len(self.REL.vocab), |
|
|
'pad_index': self.WORD.pad_index, |
|
|
'unk_index': self.WORD.unk_index, |
|
|
'bos_index': self.WORD.bos_index |
|
|
}) |
|
|
|
|
|
def train(self, loader): |
|
|
self.model.train() |
|
|
|
|
|
for words, feats, arcs, rels in loader: |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
mask = words.ne(self.args.pad_index) |
|
|
|
|
|
mask[:, 0] = 0 |
|
|
arc_scores, rel_scores = self.model(words, feats) |
|
|
loss = self.get_loss(arc_scores, rel_scores, arcs, rels, mask) |
|
|
loss.backward() |
|
|
nn.utils.clip_grad_norm_(self.model.parameters(), |
|
|
self.args.clip) |
|
|
self.optimizer.step() |
|
|
self.scheduler.step() |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(self, loader): |
|
|
self.model.eval() |
|
|
|
|
|
loss, metric = 0, Metric() |
|
|
|
|
|
for words, feats, arcs, rels in loader: |
|
|
mask = words.ne(self.args.pad_index) |
|
|
|
|
|
mask[:, 0] = 0 |
|
|
arc_scores, rel_scores = self.model(words, feats) |
|
|
loss += self.get_loss(arc_scores, rel_scores, arcs, rels, mask) |
|
|
arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask) |
|
|
|
|
|
if not self.args.punct: |
|
|
mask &= words.unsqueeze(-1).ne(self.puncts).all(-1) |
|
|
metric(arc_preds, rel_preds, arcs, rels, mask) |
|
|
loss /= len(loader) |
|
|
|
|
|
return loss, metric |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, loader): |
|
|
self.model.eval() |
|
|
|
|
|
all_arcs, all_rels = [], [] |
|
|
for words, feats in loader: |
|
|
print("words ->", words, " ", "features -> ",feats ) |
|
|
mask = words.ne(self.args.pad_index) |
|
|
|
|
|
mask[:, 0] = 0 |
|
|
lens = mask.sum(1).tolist() |
|
|
arc_scores, rel_scores = self.model(words, feats) |
|
|
arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask) |
|
|
all_arcs.extend(arc_preds[mask].split(lens)) |
|
|
all_rels.extend(rel_preds[mask].split(lens)) |
|
|
all_arcs = [seq.tolist() for seq in all_arcs] |
|
|
all_rels = [self.REL.vocab.id2token(seq.tolist()) for seq in all_rels] |
|
|
|
|
|
return all_arcs, all_rels |
|
|
|
|
|
def get_loss(self, arc_scores, rel_scores, arcs, rels, mask): |
|
|
arc_scores, arcs = arc_scores[mask], arcs[mask] |
|
|
rel_scores, rels = rel_scores[mask], rels[mask] |
|
|
rel_scores = rel_scores[torch.arange(len(arcs)), arcs] |
|
|
arc_loss = self.criterion(arc_scores, arcs) |
|
|
rel_loss = self.criterion(rel_scores, rels) |
|
|
loss = arc_loss + rel_loss |
|
|
|
|
|
return loss |
|
|
|
|
|
def decode(self, arc_scores, rel_scores, mask): |
|
|
if self.args.tree: |
|
|
arc_preds = eisner(arc_scores, mask) |
|
|
else: |
|
|
arc_preds = arc_scores.argmax(-1) |
|
|
rel_preds = rel_scores.argmax(-1) |
|
|
rel_preds = rel_preds.gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
return arc_preds, rel_preds |
|
|
|