|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Build training graphs from KWDLC with JUMANDIC. |
|
|
|
|
|
Pipeline: |
|
|
1) Read gold morphemes from KNP files |
|
|
2) Parse text with MeCab (JUMANDIC) to get candidate morphemes |
|
|
3) Match candidates to gold and assign annotations ('+', '-', '?') |
|
|
4) Save graph data as .pt |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
from collections import defaultdict |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
from tqdm import tqdm |
|
|
|
|
|
from mecari.analyzers.mecab import MeCabAnalyzer |
|
|
from mecari.data.data_module import DataModule |
|
|
from mecari.featurizers.lexical import LexicalNGramFeaturizer as LexicalFeaturizer |
|
|
from mecari.featurizers.lexical import Morpheme |
|
|
from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates |
|
|
|
|
|
|
|
|
def add_lexical_features(morphemes: List[Dict], text: str, feature_dim: int = 100000) -> List[Dict]: |
|
|
"""Add lexical (index, value) pairs to morphemes. Not used when saving JSON. |
|
|
|
|
|
Kept for backward-compatibility and test equivalence. |
|
|
""" |
|
|
featurizer = LexicalFeaturizer(dim=feature_dim, add_bias=True) |
|
|
for m in morphemes: |
|
|
surf = m.get("surface", "") |
|
|
morph_obj = Morpheme( |
|
|
surf=surf, |
|
|
lemma=m.get("base_form", surf), |
|
|
pos=m.get("pos", "*"), |
|
|
pos1=m.get("pos_detail1", "*"), |
|
|
ctype="*", |
|
|
cform="*", |
|
|
reading=m.get("reading", "*"), |
|
|
) |
|
|
st = m.get("start_pos", 0) |
|
|
ed = m.get("end_pos", st + len(surf)) |
|
|
prev_char = text[st - 1] if st > 0 and st <= len(text) else None |
|
|
next_char = text[ed] if ed < len(text) else None |
|
|
feats = featurizer.unigram_feats(morph_obj, prev_char, next_char) |
|
|
m["lexical_features"] = feats |
|
|
return morphemes |
|
|
|
|
|
|
|
|
def hiragana_to_katakana(text: str) -> str: |
|
|
"""Convert hiragana to katakana.""" |
|
|
return "".join([chr(ord(c) + 96) if "ぁ" <= c <= "ん" else c for c in text]) |
|
|
|
|
|
|
|
|
def _load_gold_with_kyoto(knp_path: Path) -> List[Dict]: |
|
|
"""Load sentences and morphemes from a KNP file using kyoto-reader (required).""" |
|
|
try: |
|
|
from kyoto_reader import KyotoReader |
|
|
except Exception as e: |
|
|
raise RuntimeError("kyoto-reader is required for gold loading. Install it (pip install kyoto-reader).") from e |
|
|
|
|
|
try: |
|
|
try: |
|
|
reader = KyotoReader(str(knp_path), n_jobs=0) |
|
|
except TypeError: |
|
|
reader = KyotoReader(str(knp_path)) |
|
|
sents: List[Dict] = [] |
|
|
for doc in reader.process_all_documents(n_jobs=0): |
|
|
if doc is None: |
|
|
continue |
|
|
for sent in doc.sentences: |
|
|
text = sent.surf |
|
|
morphemes: List[Dict] = [] |
|
|
pos = 0 |
|
|
for mrph in sent.mrph_list(): |
|
|
surf = getattr(mrph, "midasi", "") or "" |
|
|
read = getattr(mrph, "yomi", surf) or surf |
|
|
lemma = getattr(mrph, "genkei", surf) or surf |
|
|
pos_main = getattr(mrph, "hinsi", "*") or "*" |
|
|
pos1 = getattr(mrph, "bunrui", "*") or "*" |
|
|
st = pos |
|
|
ed = st + len(surf) |
|
|
pos = ed |
|
|
morphemes.append( |
|
|
{ |
|
|
"surface": surf, |
|
|
"reading": read, |
|
|
"base_form": lemma, |
|
|
"pos": pos_main, |
|
|
"pos_detail1": pos1, |
|
|
"pos_detail2": "*", |
|
|
"pos_detail3": "*", |
|
|
"start_pos": st, |
|
|
"end_pos": ed, |
|
|
} |
|
|
) |
|
|
sents.append({"text": text, "morphemes": morphemes}) |
|
|
return sents |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to parse KNP with kyoto-reader: {knp_path}") from e |
|
|
|
|
|
|
|
|
def match_morphemes_with_gold(candidates: List[Dict], gold_morphemes: List[Dict], text: str) -> List[Dict]: |
|
|
"""Match candidate morphemes to gold and assign annotations ('?', '+', '-'). |
|
|
|
|
|
Policy: |
|
|
- Initialize every candidate as '?' |
|
|
- Mark '+' for candidates that strictly match gold (surface, POS, base, reading) |
|
|
- Mark '-' for candidates that overlap any '+' span |
|
|
""" |
|
|
|
|
|
gold_details = [] |
|
|
cur = 0 |
|
|
for g in gold_morphemes: |
|
|
surf = g.get("surface", "") |
|
|
st, ed = cur, cur + len(surf) |
|
|
cur = ed |
|
|
gold_details.append( |
|
|
{ |
|
|
"start_pos": st, |
|
|
"end_pos": ed, |
|
|
"surface": surf, |
|
|
"pos": g.get("pos", "*"), |
|
|
"pos_detail1": g.get("pos_detail1", "*"), |
|
|
"pos_detail2": g.get("pos_detail2", "*"), |
|
|
"base_form": g.get("base_form", ""), |
|
|
"reading": hiragana_to_katakana(g.get("reading", "")), |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
annotated: List[Dict] = [] |
|
|
for cand in candidates: |
|
|
a = {**cand} |
|
|
a["annotation"] = "?" |
|
|
if "inflection_type" not in a: |
|
|
a["inflection_type"] = "*" |
|
|
if "inflection_form" not in a: |
|
|
a["inflection_form"] = "*" |
|
|
annotated.append(a) |
|
|
|
|
|
|
|
|
span_to_cands: dict[tuple[int, int], list[Dict]] = {} |
|
|
for a in annotated: |
|
|
cs = a.get("start_pos", 0) |
|
|
ce = a.get("end_pos", cs + len(a.get("surface", ""))) |
|
|
span_to_cands.setdefault((cs, ce), []).append(a) |
|
|
|
|
|
matched_spans: List[tuple[int, int]] = [] |
|
|
for g in gold_details: |
|
|
span = (g["start_pos"], g["end_pos"]) |
|
|
cands = span_to_cands.get(span, []) |
|
|
if not cands: |
|
|
continue |
|
|
strict = [] |
|
|
fallback = [] |
|
|
for a in cands: |
|
|
if a.get("surface", "") != g["surface"]: |
|
|
continue |
|
|
if a.get("pos", "*") != g["pos"]: |
|
|
continue |
|
|
if a.get("pos_detail1", "*") != g.get("pos_detail1", "*"): |
|
|
continue |
|
|
if a.get("base_form", "") != g["base_form"]: |
|
|
continue |
|
|
if hiragana_to_katakana(a.get("reading", "")) == g["reading"]: |
|
|
strict.append(a) |
|
|
else: |
|
|
fallback.append(a) |
|
|
chosen_list = strict if strict else fallback |
|
|
if chosen_list: |
|
|
for a in chosen_list: |
|
|
a["annotation"] = "+" |
|
|
matched_spans.append(span) |
|
|
for a in cands: |
|
|
if (a not in chosen_list) and a.get("annotation") != "+": |
|
|
a["annotation"] = "-" |
|
|
|
|
|
|
|
|
plus_spans = [] |
|
|
for a in annotated: |
|
|
if a.get("annotation") == "+": |
|
|
cs = a.get("start_pos", 0) |
|
|
ce = a.get("end_pos", cs + len(a.get("surface", ""))) |
|
|
plus_spans.append((cs, ce)) |
|
|
|
|
|
def _strict_overlap(st1: int, ed1: int, st2: int, ed2: int) -> bool: |
|
|
|
|
|
return max(st1, st2) < min(ed1, ed2) |
|
|
|
|
|
for a in annotated: |
|
|
if a.get("annotation") == "+": |
|
|
continue |
|
|
cs = a.get("start_pos", 0) |
|
|
ce = a.get("end_pos", cs + len(a.get("surface", ""))) |
|
|
for ms, me in plus_spans: |
|
|
if _strict_overlap(cs, ce, ms, me): |
|
|
a["annotation"] = "-" |
|
|
break |
|
|
return annotated |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Create training data from KWDLC (JUMANDIC)") |
|
|
parser.add_argument("--input-dir", type=str, default="KWDLC/knp", help="Directory containing KNP files") |
|
|
parser.add_argument("--config", type=str, default="configs/gat.yaml", help="Path to config file") |
|
|
parser.add_argument("--limit", type=int, help="Max number of files to process") |
|
|
parser.add_argument("--test-only", action="store_true", help="Process only test split IDs") |
|
|
parser.add_argument("--jumandic-path", type=str, default="/var/lib/mecab/dic/juman-utf8", help="Path to JUMANDIC") |
|
|
args = parser.parse_args() |
|
|
|
|
|
config = {} |
|
|
if args.config and Path(args.config).exists(): |
|
|
with open(args.config, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
if "extends" in config: |
|
|
parent_config_path = Path(args.config).parent / config["extends"] |
|
|
if parent_config_path.exists(): |
|
|
with open(parent_config_path, "r") as f: |
|
|
parent_config = yaml.safe_load(f) |
|
|
|
|
|
def deep_merge(base, override): |
|
|
for key, value in override.items(): |
|
|
if key in base and isinstance(base[key], dict) and isinstance(value, dict): |
|
|
deep_merge(base[key], value) |
|
|
else: |
|
|
base[key] = value |
|
|
return base |
|
|
|
|
|
config = deep_merge(parent_config, config) |
|
|
|
|
|
features_config = config.get("features", {}) |
|
|
feature_dim = features_config.get("lexical_feature_dim", 100000) |
|
|
training_config = config.get("training", {}) |
|
|
|
|
|
if training_config.get("annotations_dir"): |
|
|
output_dir = Path(training_config.get("annotations_dir")) |
|
|
else: |
|
|
output_dir = Path("annotations_kwdlc_juman") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
print(f"Lexical features: using {feature_dim} dims") |
|
|
print(f"Output directory: {output_dir}") |
|
|
|
|
|
analyzer = MeCabAnalyzer( |
|
|
jumandic_path=args.jumandic_path, |
|
|
) |
|
|
|
|
|
knp_files = [] |
|
|
|
|
|
if args.test_only: |
|
|
test_id_file = Path("KWDLC/id/split_for_pas/test.id") |
|
|
if test_id_file.exists(): |
|
|
with open(test_id_file, "r") as f: |
|
|
test_ids = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
knp_base_dir = Path(args.input_dir) |
|
|
for file_id in test_ids: |
|
|
dir_name = file_id[:13] |
|
|
file_name = f"{file_id}.knp" |
|
|
knp_path = knp_base_dir / dir_name / file_name |
|
|
if knp_path.exists(): |
|
|
knp_files.append(knp_path) |
|
|
else: |
|
|
knp_dir = Path(args.input_dir) |
|
|
knp_files = sorted(knp_dir.glob("**/*.knp")) |
|
|
|
|
|
if args.limit: |
|
|
knp_files = knp_files[: args.limit] |
|
|
|
|
|
print(f"Files to process: {len(knp_files)}") |
|
|
print(f"JUMANDIC: {args.jumandic_path}") |
|
|
print(f"Output to: {output_dir}") |
|
|
|
|
|
total_stats = defaultdict(int) |
|
|
annotation_idx = 0 |
|
|
|
|
|
dm = DataModule( |
|
|
annotations_dir=str(output_dir), |
|
|
lexical_feature_dim=int(feature_dim), |
|
|
use_bidirectional_edges=bool(config.get("edge_features", {}).get("use_bidirectional_edges", True)), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for knp_path in tqdm(knp_files, desc="processing"): |
|
|
try: |
|
|
sentences = _load_gold_with_kyoto(knp_path) |
|
|
if not sentences: |
|
|
continue |
|
|
|
|
|
doc_id = knp_path.stem |
|
|
for s in sentences: |
|
|
s["source_id"] = doc_id |
|
|
|
|
|
for sent_idx, sentence in enumerate(sentences): |
|
|
text = sentence["text"] |
|
|
gold_morphemes = sentence["morphemes"] |
|
|
source_id = sentence.get("source_id", doc_id) |
|
|
|
|
|
candidates = analyzer.get_morpheme_candidates(text) |
|
|
candidates = normalize_mecab_candidates(candidates) |
|
|
candidates = dedup_morphemes(candidates) |
|
|
if not candidates: |
|
|
continue |
|
|
|
|
|
annotated_morphemes = match_morphemes_with_gold(candidates, gold_morphemes, text) |
|
|
|
|
|
edges = build_adjacent_edges(annotated_morphemes) |
|
|
|
|
|
for m in annotated_morphemes: |
|
|
if "lexical_features" in m: |
|
|
m.pop("lexical_features", None) |
|
|
|
|
|
morphemes_with_feats = dm.compute_lexical_features(annotated_morphemes, text) |
|
|
graph = dm.create_graph_from_morphemes_data( |
|
|
morphemes=morphemes_with_feats, |
|
|
edges=edges, |
|
|
text=text, |
|
|
for_training=True, |
|
|
) |
|
|
if graph is None: |
|
|
continue |
|
|
|
|
|
graph_file = output_dir / f"graph_{annotation_idx:04d}.pt" |
|
|
payload = { |
|
|
"graph": graph, |
|
|
"source_id": source_id, |
|
|
"text": text, |
|
|
} |
|
|
torch.save(payload, graph_file) |
|
|
|
|
|
total_stats["sentences"] += 1 |
|
|
total_stats["morphemes"] += len(annotated_morphemes) |
|
|
total_stats["positive"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "+") |
|
|
total_stats["negative"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "-") |
|
|
|
|
|
annotation_idx += 1 |
|
|
|
|
|
total_stats["files"] += 1 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error ({knp_path}): {e}") |
|
|
total_stats["errors"] += 1 |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("Processing complete") |
|
|
print("=" * 50) |
|
|
print(f"Files: {total_stats['files']}") |
|
|
print(f"Sentences: {total_stats['sentences']}") |
|
|
print(f"Morphemes: {total_stats['morphemes']}") |
|
|
print(f"Positive (+): {total_stats['positive']}") |
|
|
print(f"Negative (-): {total_stats['negative']}") |
|
|
|
|
|
if total_stats["errors"] > 0: |
|
|
print(f"Errors: {total_stats['errors']}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|