| #!/usr/bin/env python3 | |
| # coding=utf-8 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model.head.abstract_head import AbstractHead | |
| from data.parser.to_mrp.sequential_parser import SequentialParser | |
| from utility.cross_entropy import cross_entropy | |
| class SequentialHead(AbstractHead): | |
| def __init__(self, dataset, args, initialize): | |
| config = { | |
| "label": True, | |
| "edge presence": False, | |
| "edge label": False, | |
| "anchor": True, | |
| "source_anchor": True, | |
| "target_anchor": True | |
| } | |
| super(SequentialHead, self).__init__(dataset, args, config, initialize) | |
| self.parser = SequentialParser(dataset) | |