File size: 13,647 Bytes
34c8a90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Build training graphs from KWDLC with JUMANDIC.

Pipeline:
1) Read gold morphemes from KNP files
2) Parse text with MeCab (JUMANDIC) to get candidate morphemes
3) Match candidates to gold and assign annotations ('+', '-', '?')
4) Save graph data as .pt
"""

import argparse
from collections import defaultdict
from pathlib import Path
from typing import Dict, List

import torch
import yaml
from tqdm import tqdm

from mecari.analyzers.mecab import MeCabAnalyzer
from mecari.data.data_module import DataModule
from mecari.featurizers.lexical import LexicalNGramFeaturizer as LexicalFeaturizer
from mecari.featurizers.lexical import Morpheme
from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates


def add_lexical_features(morphemes: List[Dict], text: str, feature_dim: int = 100000) -> List[Dict]:
    """Add lexical (index, value) pairs to morphemes. Not used when saving JSON.

    Kept for backward-compatibility and test equivalence.
    """
    featurizer = LexicalFeaturizer(dim=feature_dim, add_bias=True)
    for m in morphemes:
        surf = m.get("surface", "")
        morph_obj = Morpheme(
            surf=surf,
            lemma=m.get("base_form", surf),
            pos=m.get("pos", "*"),
            pos1=m.get("pos_detail1", "*"),
            ctype="*",
            cform="*",
            reading=m.get("reading", "*"),
        )
        st = m.get("start_pos", 0)
        ed = m.get("end_pos", st + len(surf))
        prev_char = text[st - 1] if st > 0 and st <= len(text) else None
        next_char = text[ed] if ed < len(text) else None
        feats = featurizer.unigram_feats(morph_obj, prev_char, next_char)
        m["lexical_features"] = feats
    return morphemes


def hiragana_to_katakana(text: str) -> str:
    """Convert hiragana to katakana."""
    return "".join([chr(ord(c) + 96) if "ぁ" <= c <= "ん" else c for c in text])


def _load_gold_with_kyoto(knp_path: Path) -> List[Dict]:
    """Load sentences and morphemes from a KNP file using kyoto-reader (required)."""
    try:
        from kyoto_reader import KyotoReader  # type: ignore
    except Exception as e:  # pragma: no cover
        raise RuntimeError("kyoto-reader is required for gold loading. Install it (pip install kyoto-reader).") from e

    try:
        try:
            reader = KyotoReader(str(knp_path), n_jobs=0)
        except TypeError:
            reader = KyotoReader(str(knp_path))
        sents: List[Dict] = []
        for doc in reader.process_all_documents(n_jobs=0):
            if doc is None:
                continue
            for sent in doc.sentences:
                text = sent.surf
                morphemes: List[Dict] = []
                pos = 0
                for mrph in sent.mrph_list():
                    surf = getattr(mrph, "midasi", "") or ""
                    read = getattr(mrph, "yomi", surf) or surf
                    lemma = getattr(mrph, "genkei", surf) or surf
                    pos_main = getattr(mrph, "hinsi", "*") or "*"
                    pos1 = getattr(mrph, "bunrui", "*") or "*"
                    st = pos
                    ed = st + len(surf)
                    pos = ed
                    morphemes.append(
                        {
                            "surface": surf,
                            "reading": read,
                            "base_form": lemma,
                            "pos": pos_main,
                            "pos_detail1": pos1,
                            "pos_detail2": "*",
                            "pos_detail3": "*",
                            "start_pos": st,
                            "end_pos": ed,
                        }
                    )
                sents.append({"text": text, "morphemes": morphemes})
        return sents
    except Exception as e:
        raise RuntimeError(f"Failed to parse KNP with kyoto-reader: {knp_path}") from e


def match_morphemes_with_gold(candidates: List[Dict], gold_morphemes: List[Dict], text: str) -> List[Dict]:
    """Match candidate morphemes to gold and assign annotations ('?', '+', '-').

    Policy:
      - Initialize every candidate as '?'
      - Mark '+' for candidates that strictly match gold (surface, POS, base, reading)
      - Mark '-' for candidates that overlap any '+' span
    """
    # Reconstruct gold spans in character offsets
    gold_details = []
    cur = 0
    for g in gold_morphemes:
        surf = g.get("surface", "")
        st, ed = cur, cur + len(surf)
        cur = ed
        gold_details.append(
            {
                "start_pos": st,
                "end_pos": ed,
                "surface": surf,
                "pos": g.get("pos", "*"),
                "pos_detail1": g.get("pos_detail1", "*"),
                "pos_detail2": g.get("pos_detail2", "*"),
                "base_form": g.get("base_form", ""),
                "reading": hiragana_to_katakana(g.get("reading", "")),
            }
        )

    # Initialize all candidates with '?'
    annotated: List[Dict] = []
    for cand in candidates:
        a = {**cand}
        a["annotation"] = "?"
        if "inflection_type" not in a:
            a["inflection_type"] = "*"
        if "inflection_form" not in a:
            a["inflection_form"] = "*"
        annotated.append(a)

    # Match by strict equality first; allow reading mismatch as fallback
    span_to_cands: dict[tuple[int, int], list[Dict]] = {}
    for a in annotated:
        cs = a.get("start_pos", 0)
        ce = a.get("end_pos", cs + len(a.get("surface", "")))
        span_to_cands.setdefault((cs, ce), []).append(a)

    matched_spans: List[tuple[int, int]] = []
    for g in gold_details:
        span = (g["start_pos"], g["end_pos"])
        cands = span_to_cands.get(span, [])
        if not cands:
            continue
        strict = []
        fallback = []
        for a in cands:
            if a.get("surface", "") != g["surface"]:
                continue
            if a.get("pos", "*") != g["pos"]:
                continue
            if a.get("pos_detail1", "*") != g.get("pos_detail1", "*"):
                continue
            if a.get("base_form", "") != g["base_form"]:
                continue
            if hiragana_to_katakana(a.get("reading", "")) == g["reading"]:
                strict.append(a)
            else:
                fallback.append(a)
        chosen_list = strict if strict else fallback
        if chosen_list:
            for a in chosen_list:
                a["annotation"] = "+"
            matched_spans.append(span)
            for a in cands:
                if (a not in chosen_list) and a.get("annotation") != "+":
                    a["annotation"] = "-"

    # Demote any morpheme that overlaps (by at least 1 char) with any '+' span.
    plus_spans = []
    for a in annotated:
        if a.get("annotation") == "+":
            cs = a.get("start_pos", 0)
            ce = a.get("end_pos", cs + len(a.get("surface", "")))
            plus_spans.append((cs, ce))

    def _strict_overlap(st1: int, ed1: int, st2: int, ed2: int) -> bool:
        # overlap only if intersection length > 0 (touching is not overlap)
        return max(st1, st2) < min(ed1, ed2)

    for a in annotated:
        if a.get("annotation") == "+":
            continue
        cs = a.get("start_pos", 0)
        ce = a.get("end_pos", cs + len(a.get("surface", "")))
        for ms, me in plus_spans:
            if _strict_overlap(cs, ce, ms, me):
                a["annotation"] = "-"
                break
    return annotated


def main():
    parser = argparse.ArgumentParser(description="Create training data from KWDLC (JUMANDIC)")
    parser.add_argument("--input-dir", type=str, default="KWDLC/knp", help="Directory containing KNP files")
    parser.add_argument("--config", type=str, default="configs/gat.yaml", help="Path to config file")
    parser.add_argument("--limit", type=int, help="Max number of files to process")
    parser.add_argument("--test-only", action="store_true", help="Process only test split IDs")
    parser.add_argument("--jumandic-path", type=str, default="/var/lib/mecab/dic/juman-utf8", help="Path to JUMANDIC")
    args = parser.parse_args()

    config = {}
    if args.config and Path(args.config).exists():
        with open(args.config, "r") as f:
            config = yaml.safe_load(f)

        if "extends" in config:
            parent_config_path = Path(args.config).parent / config["extends"]
            if parent_config_path.exists():
                with open(parent_config_path, "r") as f:
                    parent_config = yaml.safe_load(f)

                def deep_merge(base, override):
                    for key, value in override.items():
                        if key in base and isinstance(base[key], dict) and isinstance(value, dict):
                            deep_merge(base[key], value)
                        else:
                            base[key] = value
                    return base

                config = deep_merge(parent_config, config)

    features_config = config.get("features", {})
    feature_dim = features_config.get("lexical_feature_dim", 100000)
    training_config = config.get("training", {})

    if training_config.get("annotations_dir"):
        output_dir = Path(training_config.get("annotations_dir"))
    else:
        output_dir = Path("annotations_kwdlc_juman")
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Lexical features: using {feature_dim} dims")
    print(f"Output directory: {output_dir}")

    analyzer = MeCabAnalyzer(
        jumandic_path=args.jumandic_path,
    )

    knp_files = []

    if args.test_only:
        test_id_file = Path("KWDLC/id/split_for_pas/test.id")
        if test_id_file.exists():
            with open(test_id_file, "r") as f:
                test_ids = [line.strip() for line in f if line.strip()]

            knp_base_dir = Path(args.input_dir)
            for file_id in test_ids:
                dir_name = file_id[:13]
                file_name = f"{file_id}.knp"
                knp_path = knp_base_dir / dir_name / file_name
                if knp_path.exists():
                    knp_files.append(knp_path)
    else:
        knp_dir = Path(args.input_dir)
        knp_files = sorted(knp_dir.glob("**/*.knp"))

    if args.limit:
        knp_files = knp_files[: args.limit]

    print(f"Files to process: {len(knp_files)}")
    print(f"JUMANDIC: {args.jumandic_path}")
    print(f"Output to: {output_dir}")

    total_stats = defaultdict(int)
    annotation_idx = 0

    dm = DataModule(
        annotations_dir=str(output_dir),
        lexical_feature_dim=int(feature_dim),
        use_bidirectional_edges=bool(config.get("edge_features", {}).get("use_bidirectional_edges", True)),
    )

    # Save .pt files directly under the output_dir

    for knp_path in tqdm(knp_files, desc="processing"):
        try:
            sentences = _load_gold_with_kyoto(knp_path)
            if not sentences:
                continue

            doc_id = knp_path.stem
            for s in sentences:
                s["source_id"] = doc_id

            for sent_idx, sentence in enumerate(sentences):
                text = sentence["text"]
                gold_morphemes = sentence["morphemes"]
                source_id = sentence.get("source_id", doc_id)

                candidates = analyzer.get_morpheme_candidates(text)
                candidates = normalize_mecab_candidates(candidates)
                candidates = dedup_morphemes(candidates)
                if not candidates:
                    continue

                annotated_morphemes = match_morphemes_with_gold(candidates, gold_morphemes, text)

                edges = build_adjacent_edges(annotated_morphemes)

                for m in annotated_morphemes:
                    if "lexical_features" in m:
                        m.pop("lexical_features", None)

                morphemes_with_feats = dm.compute_lexical_features(annotated_morphemes, text)
                graph = dm.create_graph_from_morphemes_data(
                    morphemes=morphemes_with_feats,
                    edges=edges,
                    text=text,
                    for_training=True,
                )
                if graph is None:
                    continue

                graph_file = output_dir / f"graph_{annotation_idx:04d}.pt"
                payload = {
                    "graph": graph,
                    "source_id": source_id,
                    "text": text,
                }
                torch.save(payload, graph_file)

                total_stats["sentences"] += 1
                total_stats["morphemes"] += len(annotated_morphemes)
                total_stats["positive"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "+")
                total_stats["negative"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "-")

                annotation_idx += 1

            total_stats["files"] += 1

        except Exception as e:
            print(f"Error ({knp_path}): {e}")
            total_stats["errors"] += 1

    print("\n" + "=" * 50)
    print("Processing complete")
    print("=" * 50)
    print(f"Files: {total_stats['files']}")
    print(f"Sentences: {total_stats['sentences']}")
    print(f"Morphemes: {total_stats['morphemes']}")
    print(f"Positive (+): {total_stats['positive']}")
    print(f"Negative (-): {total_stats['negative']}")
    #
    if total_stats["errors"] > 0:
        print(f"Errors: {total_stats['errors']}")


if __name__ == "__main__":
    main()