tre-1 / src /train_dependency_parsing.py
rain1024's picture
update
6f3ebfa
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "click>=8.0.0",
# "psutil>=5.9.0",
# "pyyaml>=6.0.0",
# ]
# ///
"""
Training script for Vietnamese Dependency Parser (TRE-1).
Uses MaltParser (Java-based transition-based parser) trained on VnDT v1.1.
Supports multiple parsing algorithms and gold/predicted POS tags.
Models are saved to: models/dependency_parsing/{version}/
Usage:
uv run src/train_dependency_parsing.py
uv run src/train_dependency_parsing.py --pos-type gold
uv run src/train_dependency_parsing.py --algorithm stackproj
uv run src/train_dependency_parsing.py --version my_experiment
"""
import platform
import shutil
import subprocess
import time
from datetime import datetime
from pathlib import Path
import click
import psutil
import yaml
PROJECT_ROOT = Path(__file__).parent.parent
MALTPARSER_JAR = PROJECT_ROOT / "tools" / "maltparser-1.9.2" / "maltparser-1.9.2.jar"
DATASET_DIR = PROJECT_ROOT / "datasets" / "VnDT"
ALGORITHMS = [
"nivreeager",
"nivrestandard",
"stackproj",
"stackeager",
"stacklazy",
"covproj",
"covnonproj",
]
def get_hardware_info():
"""Collect hardware and system information."""
info = {
"platform": platform.system(),
"platform_release": platform.release(),
"architecture": platform.machine(),
"python_version": platform.python_version(),
"cpu_physical_cores": psutil.cpu_count(logical=False),
"cpu_logical_cores": psutil.cpu_count(logical=True),
"ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
}
try:
if platform.system() == "Linux":
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
info["cpu_model"] = line.split(":")[1].strip()
break
except Exception:
info["cpu_model"] = "Unknown"
return info
def format_duration(seconds):
"""Format duration in human-readable format."""
if seconds < 60:
return f"{seconds:.2f}s"
elif seconds < 3600:
minutes = int(seconds // 60)
secs = seconds % 60
return f"{minutes}m {secs:.2f}s"
else:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
return f"{hours}h {minutes}m {secs:.2f}s"
def get_conll_paths(pos_type):
"""Get train/dev/test CoNLL file paths for the given POS type."""
prefix = f"VnDTv1.1-{pos_type}-POS-tags"
return {
"train": DATASET_DIR / f"{prefix}-train.conll",
"dev": DATASET_DIR / f"{prefix}-dev.conll",
"test": DATASET_DIR / f"{prefix}-test.conll",
}
def count_sentences(conll_path):
"""Count sentences in a CoNLL file (blank-line separated)."""
count = 0
with open(conll_path) as f:
for line in f:
if line.strip() == "":
count += 1
return count
def count_tokens(conll_path):
"""Count tokens in a CoNLL file (non-blank lines)."""
count = 0
with open(conll_path) as f:
for line in f:
if line.strip():
count += 1
return count
def evaluate_conll(gold_path, predicted_path):
"""Evaluate UAS and LAS by comparing gold and predicted CoNLL files.
Compares column 7 (HEAD) for UAS and columns 7+8 (HEAD+DEPREL) for LAS.
Skips blank lines (sentence boundaries).
"""
correct_head = 0
correct_both = 0
total = 0
with open(gold_path) as gf, open(predicted_path) as pf:
for gold_line, pred_line in zip(gf, pf):
gold_line = gold_line.strip()
pred_line = pred_line.strip()
if not gold_line:
continue
gold_cols = gold_line.split("\t")
pred_cols = pred_line.split("\t")
if len(gold_cols) < 8 or len(pred_cols) < 8:
continue
total += 1
gold_head = gold_cols[6]
pred_head = pred_cols[6]
gold_deprel = gold_cols[7]
pred_deprel = pred_cols[7]
if gold_head == pred_head:
correct_head += 1
if gold_deprel == pred_deprel:
correct_both += 1
uas = correct_head / total * 100 if total > 0 else 0.0
las = correct_both / total * 100 if total > 0 else 0.0
return {"uas": uas, "las": las, "total_tokens": total}
def run_maltparser(args, cwd, java_mem="4g"):
"""Run MaltParser via Java subprocess."""
cmd = [
"java",
f"-Xmx{java_mem}",
"-jar", str(MALTPARSER_JAR),
] + args
click.echo(f" $ {' '.join(cmd)}")
result = subprocess.run(
cmd,
cwd=str(cwd),
capture_output=True,
text=True,
)
if result.returncode != 0:
click.echo(f"STDOUT:\n{result.stdout}")
click.echo(f"STDERR:\n{result.stderr}")
raise RuntimeError(f"MaltParser failed with exit code {result.returncode}")
return result
@click.command()
@click.option(
"--algorithm", "-a",
type=click.Choice(ALGORITHMS),
default="nivreeager",
help="Parsing algorithm",
show_default=True,
)
@click.option(
"--pos-type",
type=click.Choice(["predicted", "gold"]),
default="predicted",
help="POS tag type in CoNLL files",
show_default=True,
)
@click.option(
"--version", "-v",
default=None,
help="Model version (default: timestamp)",
)
@click.option(
"--java-mem",
default="4g",
help="Java heap size",
show_default=True,
)
def train(algorithm, pos_type, version, java_mem):
"""Train Vietnamese Dependency Parser using MaltParser on VnDT v1.1."""
total_start_time = time.time()
start_datetime = datetime.now()
# Validate prerequisites
if not MALTPARSER_JAR.exists():
raise click.ClickException(
f"MaltParser not found at {MALTPARSER_JAR}\n"
"Download: wget http://maltparser.org/dist/maltparser-1.9.2.tar.gz -P tools/ "
"&& tar xzf tools/maltparser-1.9.2.tar.gz -C tools/"
)
paths = get_conll_paths(pos_type)
for name, path in paths.items():
if not path.exists():
raise click.ClickException(
f"{name} file not found: {path}\n"
"Download: git clone https://github.com/datquocnguyen/VnDT.git datasets/VnDT"
)
# Version
if version is None:
version = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = PROJECT_ROOT / "models" / "dependency_parsing" / version
output_dir.mkdir(parents=True, exist_ok=True)
# Hardware info
hw_info = get_hardware_info()
click.echo("=" * 60)
click.echo(f"Dependency Parser Training - {version}")
click.echo("=" * 60)
click.echo(f"Algorithm: {algorithm}")
click.echo(f"POS type: {pos_type}")
click.echo(f"Java memory: {java_mem}")
click.echo(f"Platform: {hw_info['platform']}")
click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}")
click.echo(f"Output: {output_dir}")
click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
click.echo("=" * 60)
# Dataset stats
train_sents = count_sentences(paths["train"])
dev_sents = count_sentences(paths["dev"])
test_sents = count_sentences(paths["test"])
train_tokens = count_tokens(paths["train"])
test_tokens = count_tokens(paths["test"])
click.echo(f"\nDataset: VnDT v1.1 ({pos_type} POS)")
click.echo(f"Train: {train_sents} sentences, {train_tokens} tokens")
click.echo(f"Dev: {dev_sents} sentences")
click.echo(f"Test: {test_sents} sentences, {test_tokens} tokens")
# Copy train file to working directory (MaltParser reads from cwd)
train_copy = output_dir / "train.conll"
shutil.copy2(paths["train"], train_copy)
# Phase 1: Train
click.echo(f"\nPhase 1: Training MaltParser ({algorithm})...")
train_start = time.time()
run_maltparser(
["-c", "model", "-i", "train.conll", "-m", "learn", "-a", algorithm],
cwd=output_dir,
java_mem=java_mem,
)
train_time = time.time() - train_start
click.echo(f"Training time: {format_duration(train_time)}")
# Clean up train copy
train_copy.unlink()
# Phase 2: Parse test set
click.echo("\nPhase 2: Parsing test set...")
test_copy = output_dir / "test.conll"
shutil.copy2(paths["test"], test_copy)
parse_start = time.time()
run_maltparser(
["-c", "model", "-i", "test.conll", "-o", "output.conll", "-m", "parse"],
cwd=output_dir,
java_mem=java_mem,
)
parse_time = time.time() - parse_start
click.echo(f"Parse time: {format_duration(parse_time)}")
# Clean up test copy
test_copy.unlink()
# Phase 3: Evaluate
click.echo("\nPhase 3: Evaluating...")
output_conll = output_dir / "output.conll"
if not output_conll.exists():
raise click.ClickException(f"Parser output not found: {output_conll}")
metrics = evaluate_conll(paths["test"], output_conll)
total_time = time.time() - total_start_time
click.echo(f"\nUAS: {metrics['uas']:.2f}%")
click.echo(f"LAS: {metrics['las']:.2f}%")
click.echo(f"Tokens evaluated: {metrics['total_tokens']}")
# Save metadata
metadata = {
"model": {
"name": "Vietnamese Dependency Parser",
"version": version,
"type": "MaltParser (transition-based)",
"algorithm": algorithm,
},
"training": {
"dataset": "VnDT v1.1",
"pos_type": pos_type,
"train_sentences": train_sents,
"dev_sentences": dev_sents,
"test_sentences": test_sents,
"train_tokens": train_tokens,
"test_tokens": test_tokens,
"duration_seconds": round(total_time, 2),
"train_duration_seconds": round(train_time, 2),
"parse_duration_seconds": round(parse_time, 2),
},
"performance": {
"uas": round(metrics["uas"], 2),
"las": round(metrics["las"], 2),
"total_tokens": metrics["total_tokens"],
},
"environment": {
"platform": hw_info["platform"],
"cpu_model": hw_info.get("cpu_model", "Unknown"),
"python_version": hw_info["python_version"],
"java_memory": java_mem,
},
"files": {
"model": "model.mco",
"output": "output.conll",
},
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"author": "undertheseanlp",
}
metadata_path = output_dir / "metadata.yaml"
with open(metadata_path, "w") as f:
yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
click.echo("\n" + "=" * 60)
click.echo("Training Summary")
click.echo("=" * 60)
click.echo(f"Algorithm: {algorithm}")
click.echo(f"POS type: {pos_type}")
click.echo(f"Version: {version}")
click.echo(f"UAS: {metrics['uas']:.2f}%")
click.echo(f"LAS: {metrics['las']:.2f}%")
click.echo(f"Total time: {format_duration(total_time)}")
click.echo(f"Model: {output_dir / 'model.mco'}")
click.echo(f"Metadata: {metadata_path}")
click.echo("=" * 60)
if __name__ == "__main__":
train()