File size: 5,467 Bytes
27e6434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Run feature ablation study for word segmentation.

Runs leave-one-out and additive experiments, collecting results into a summary.

Usage:
    python src/run_ablation.py
"""

import csv
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path

PROJECT_ROOT = Path(__file__).parent.parent
RESULTS_FILE = PROJECT_ROOT / "results" / "word_segmentation" / "ablation_study.csv"

# Leave-one-out: disable one group at a time
LEAVE_ONE_OUT = {
    "all-type":       "model.features.type=false",
    "all-morphology": "model.features.morphology=false",
    "all-left":       "model.features.left=false",
    "all-right":      "model.features.right=false",
    "all-bigram":     "model.features.bigram=false",
    "all-trigram":    "model.features.trigram=false",
}

# Additive: build up from form-only baseline
ADDITIVE = {
    "form_only":           "model.features.type=false model.features.morphology=false model.features.left=false model.features.right=false model.features.bigram=false model.features.trigram=false",
    "form+left":           "model.features.type=false model.features.morphology=false model.features.right=false model.features.bigram=false model.features.trigram=false",
    "form+left+right":     "model.features.type=false model.features.morphology=false model.features.bigram=false model.features.trigram=false",
    "form+ctx+bigram":     "model.features.type=false model.features.morphology=false model.features.trigram=false",
    "form+ctx+bi+type":    "model.features.morphology=false model.features.trigram=false",
    "form+ctx+bi+type+morph": "model.features.trigram=false",
    # full model (all features) is the baseline -- already run
}


def parse_metrics(output: str) -> dict:
    """Parse metrics from training script output."""
    metrics = {}
    for line in output.split("\n"):
        if "Syllable-level Accuracy:" in line:
            metrics["syl_accuracy"] = float(line.split(":")[-1].strip())
        elif "Word-level Precision:" in line:
            metrics["word_precision"] = float(line.split(":")[-1].strip())
        elif "Word-level Recall:" in line:
            metrics["word_recall"] = float(line.split(":")[-1].strip())
        elif "Word-level F1:" in line:
            metrics["word_f1"] = float(line.split(":")[-1].strip())
        elif "templates)" in line and "Features:" in line:
            # Extract template count from "Features: [...] (N templates)"
            metrics["num_templates"] = int(line.split("(")[1].split(" ")[0])
    return metrics


def run_experiment(name: str, overrides: str) -> dict:
    """Run a single training experiment."""
    print(f"\n{'='*60}")
    print(f"  Experiment: {name}")
    print(f"  Overrides: {overrides}")
    print(f"{'='*60}")

    output_override = f"output=models/word_segmentation/ablation/{name}"
    cmd = f"python src/train_word_segmentation.py {overrides} {output_override}"

    start = time.time()
    result = subprocess.run(
        cmd, shell=True, capture_output=True, text=True,
        cwd=str(PROJECT_ROOT),
    )
    elapsed = time.time() - start

    output = result.stdout + result.stderr
    print(output[-500:] if len(output) > 500 else output)

    metrics = parse_metrics(output)
    metrics["name"] = name
    metrics["time_seconds"] = round(elapsed, 1)

    if "word_f1" in metrics:
        print(f"  => Word F1: {metrics['word_f1']:.4f}  Syl Acc: {metrics['syl_accuracy']:.4f}  ({elapsed:.0f}s)")
    else:
        print(f"  => FAILED ({elapsed:.0f}s)")
        metrics["word_f1"] = 0
        metrics["syl_accuracy"] = 0
        metrics["word_precision"] = 0
        metrics["word_recall"] = 0
        metrics["num_templates"] = 0

    return metrics


def main():
    RESULTS_FILE.parent.mkdir(parents=True, exist_ok=True)

    all_results = []

    # Full model baseline
    print("\n" + "#" * 60)
    print("  FULL MODEL (baseline)")
    print("#" * 60)
    full = run_experiment("full", "")
    all_results.append(full)

    # Leave-one-out
    print("\n" + "#" * 60)
    print("  LEAVE-ONE-OUT EXPERIMENTS")
    print("#" * 60)
    for name, overrides in LEAVE_ONE_OUT.items():
        result = run_experiment(name, overrides)
        all_results.append(result)

    # Additive
    print("\n" + "#" * 60)
    print("  ADDITIVE EXPERIMENTS")
    print("#" * 60)
    for name, overrides in ADDITIVE.items():
        result = run_experiment(name, overrides)
        all_results.append(result)

    # Write CSV
    fieldnames = ["name", "num_templates", "syl_accuracy", "word_precision", "word_recall", "word_f1", "time_seconds"]
    with open(RESULTS_FILE, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in all_results:
            writer.writerow({k: r.get(k, "") for k in fieldnames})

    print(f"\nResults saved to {RESULTS_FILE}")

    # Print summary table
    print("\n" + "=" * 80)
    print(f"{'Experiment':<30} {'Templates':>9} {'Syl Acc':>8} {'Word F1':>8} {'Time':>7}")
    print("-" * 80)

    full_f1 = all_results[0]["word_f1"] if all_results else 0

    for r in all_results:
        delta = r["word_f1"] - full_f1
        delta_str = f"({delta:+.4f})" if r["name"] != "full" else ""
        print(f"{r['name']:<30} {r.get('num_templates', '?'):>9} {r['syl_accuracy']:>8.4f} {r['word_f1']:>8.4f} {delta_str:>10} {r['time_seconds']:>5.0f}s")

    print("=" * 80)


if __name__ == "__main__":
    main()