tre-1 / src /run_ablation.py
rain1024's picture
Add VLSP 2013 word segmentation, feature ablation study, and Hydra config
27e6434
"""Run feature ablation study for word segmentation.
Runs leave-one-out and additive experiments, collecting results into a summary.
Usage:
python src/run_ablation.py
"""
import csv
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent
RESULTS_FILE = PROJECT_ROOT / "results" / "word_segmentation" / "ablation_study.csv"
# Leave-one-out: disable one group at a time
LEAVE_ONE_OUT = {
"all-type": "model.features.type=false",
"all-morphology": "model.features.morphology=false",
"all-left": "model.features.left=false",
"all-right": "model.features.right=false",
"all-bigram": "model.features.bigram=false",
"all-trigram": "model.features.trigram=false",
}
# Additive: build up from form-only baseline
ADDITIVE = {
"form_only": "model.features.type=false model.features.morphology=false model.features.left=false model.features.right=false model.features.bigram=false model.features.trigram=false",
"form+left": "model.features.type=false model.features.morphology=false model.features.right=false model.features.bigram=false model.features.trigram=false",
"form+left+right": "model.features.type=false model.features.morphology=false model.features.bigram=false model.features.trigram=false",
"form+ctx+bigram": "model.features.type=false model.features.morphology=false model.features.trigram=false",
"form+ctx+bi+type": "model.features.morphology=false model.features.trigram=false",
"form+ctx+bi+type+morph": "model.features.trigram=false",
# full model (all features) is the baseline -- already run
}
def parse_metrics(output: str) -> dict:
"""Parse metrics from training script output."""
metrics = {}
for line in output.split("\n"):
if "Syllable-level Accuracy:" in line:
metrics["syl_accuracy"] = float(line.split(":")[-1].strip())
elif "Word-level Precision:" in line:
metrics["word_precision"] = float(line.split(":")[-1].strip())
elif "Word-level Recall:" in line:
metrics["word_recall"] = float(line.split(":")[-1].strip())
elif "Word-level F1:" in line:
metrics["word_f1"] = float(line.split(":")[-1].strip())
elif "templates)" in line and "Features:" in line:
# Extract template count from "Features: [...] (N templates)"
metrics["num_templates"] = int(line.split("(")[1].split(" ")[0])
return metrics
def run_experiment(name: str, overrides: str) -> dict:
"""Run a single training experiment."""
print(f"\n{'='*60}")
print(f" Experiment: {name}")
print(f" Overrides: {overrides}")
print(f"{'='*60}")
output_override = f"output=models/word_segmentation/ablation/{name}"
cmd = f"python src/train_word_segmentation.py {overrides} {output_override}"
start = time.time()
result = subprocess.run(
cmd, shell=True, capture_output=True, text=True,
cwd=str(PROJECT_ROOT),
)
elapsed = time.time() - start
output = result.stdout + result.stderr
print(output[-500:] if len(output) > 500 else output)
metrics = parse_metrics(output)
metrics["name"] = name
metrics["time_seconds"] = round(elapsed, 1)
if "word_f1" in metrics:
print(f" => Word F1: {metrics['word_f1']:.4f} Syl Acc: {metrics['syl_accuracy']:.4f} ({elapsed:.0f}s)")
else:
print(f" => FAILED ({elapsed:.0f}s)")
metrics["word_f1"] = 0
metrics["syl_accuracy"] = 0
metrics["word_precision"] = 0
metrics["word_recall"] = 0
metrics["num_templates"] = 0
return metrics
def main():
RESULTS_FILE.parent.mkdir(parents=True, exist_ok=True)
all_results = []
# Full model baseline
print("\n" + "#" * 60)
print(" FULL MODEL (baseline)")
print("#" * 60)
full = run_experiment("full", "")
all_results.append(full)
# Leave-one-out
print("\n" + "#" * 60)
print(" LEAVE-ONE-OUT EXPERIMENTS")
print("#" * 60)
for name, overrides in LEAVE_ONE_OUT.items():
result = run_experiment(name, overrides)
all_results.append(result)
# Additive
print("\n" + "#" * 60)
print(" ADDITIVE EXPERIMENTS")
print("#" * 60)
for name, overrides in ADDITIVE.items():
result = run_experiment(name, overrides)
all_results.append(result)
# Write CSV
fieldnames = ["name", "num_templates", "syl_accuracy", "word_precision", "word_recall", "word_f1", "time_seconds"]
with open(RESULTS_FILE, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in all_results:
writer.writerow({k: r.get(k, "") for k in fieldnames})
print(f"\nResults saved to {RESULTS_FILE}")
# Print summary table
print("\n" + "=" * 80)
print(f"{'Experiment':<30} {'Templates':>9} {'Syl Acc':>8} {'Word F1':>8} {'Time':>7}")
print("-" * 80)
full_f1 = all_results[0]["word_f1"] if all_results else 0
for r in all_results:
delta = r["word_f1"] - full_f1
delta_str = f"({delta:+.4f})" if r["name"] != "full" else ""
print(f"{r['name']:<30} {r.get('num_templates', '?'):>9} {r['syl_accuracy']:>8.4f} {r['word_f1']:>8.4f} {delta_str:>10} {r['time_seconds']:>5.0f}s")
print("=" * 80)
if __name__ == "__main__":
main()