File size: 3,793 Bytes
dbc6ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import csv
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, List


csv.field_size_limit(sys.maxsize)


COLS = [
    "sent_id",
    "token_id",
    "word",
    "lemma",
    "upos",
    "xpos",
    "feats",
    "head",
    "deprel",
    "deps",
    "misc",
    "predicate_sense",
    "semantic_role",
]


def _iter_rows(path: Path) -> Iterable[List[str]]:
    with path.open("r", encoding="utf-8", newline="") as f:
        sample = f.read(4096)
        f.seek(0)
        delimiter = "\t" if "\t" in sample else ","
        reader = csv.reader(f, delimiter=delimiter)
        for row in reader:
            if not row:
                continue
            if len(row) == 1 and not row[0].strip():
                continue
            if row[0].startswith("#"):
                continue
            yield row


def read_conllu_srl(path: Path) -> List[Dict[str, str]]:
    records: List[Dict[str, str]] = []
    for row in _iter_rows(path):
        if len(row) < len(COLS):
            continue
        row = row[: len(COLS)]

        sent_id = row[0].strip()
        token_id = row[1].strip()

        if not sent_id.isdigit():
            continue
        if not token_id.isdigit():
            continue

        records.append({col: value for col, value in zip(COLS, row)})

    return records


def _join_tokens(tokens_with_misc: List[Dict[str, str]]) -> str:
    chunks: List[str] = []
    for token in tokens_with_misc:
        chunks.append(token["word"])
        if "SpaceAfter=No" not in token.get("misc", "_"):
            chunks.append(" ")
    return "".join(chunks).strip()


def flatten_to_corpus(records: List[Dict[str, str]]) -> List[Dict[str, str]]:
    grouped: Dict[str, List[Dict[str, str]]] = defaultdict(list)
    for rec in records:
        grouped[rec["sent_id"]].append(rec)

    corpus: List[Dict[str, str]] = []
    for sent_id in sorted(grouped.keys(), key=lambda x: int(x)):
        toks = sorted(grouped[sent_id], key=lambda r: int(r["token_id"]))
        text = _join_tokens(toks)

        predicate = ""
        roles: Dict[str, List[str]] = defaultdict(list)

        for tok in toks:
            sense = tok.get("predicate_sense", "_")
            role = tok.get("semantic_role", "_")
            word = tok.get("word", "")

            if not predicate and sense and sense != "_":
                predicate = sense

            if role and role != "_":
                roles[role].append(word)

        doc: Dict[str, str] = {"id": sent_id, "text": text, "predicate": predicate}
        doc.update({role_name: " ".join(words) for role_name, words in roles.items()})
        corpus.append(doc)

    return corpus


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Parse CoNLL-U SRL data to flattened JSON corpus"
    )
    parser.add_argument(
        "--input-dir",
        type=Path,
        default=Path(__file__).resolve().parents[1] / "data",
        help="Directory containing CoNLL-U style TSV/CSV files",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=Path(__file__).resolve().parents[1] / "data" / "corpus.json",
        help="Output JSON corpus path",
    )
    args = parser.parse_args()

    data_files = [f for f in args.input_dir.glob("*.csv")]
    corpus: List[Dict[str, str]] = []
    for data_file in sorted(data_files):
        records = read_conllu_srl(data_file)
        corpus.extend(flatten_to_corpus(records))

    args.output.parent.mkdir(parents=True, exist_ok=True)
    with args.output.open("w", encoding="utf-8") as f:
        json.dump(corpus, f, ensure_ascii=False, indent=2)

    print(f"Wrote {len(corpus)} documents to {args.output}")


if __name__ == "__main__":
    main()