Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # coding=utf-8 | |
| from data.parser.to_mrp.abstract_parser import AbstractParser | |
| class NodeCentricParser(AbstractParser): | |
| def parse(self, prediction): | |
| output = {} | |
| output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()] | |
| output["nodes"] = self.create_nodes(prediction) | |
| output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True) | |
| output["edges"] = self.create_edges(prediction, output["nodes"]) | |
| return output | |
| def create_edge(self, source, target, prediction, edges, nodes): | |
| edge = {"source": source, "target": target, "label": None} | |
| edges.append(edge) | |
| def create_edges(self, prediction, nodes): | |
| N = len(nodes) | |
| edge_prediction = prediction["edge presence"][:N, :N] | |
| targets = [i for i, node in enumerate(nodes) if node["label"] in ["Source", "Target"]] | |
| sources = [i for i, node in enumerate(nodes) if node["label"] not in ["Source", "Target"]] | |
| edges = [] | |
| for target in targets: | |
| for source in sources: | |
| if edge_prediction[source, target] >= 0.5: | |
| self.create_edge(source, target, prediction, edges, nodes) | |
| return edges | |