# /// script # requires-python = ">=3.9" # dependencies = [ # "click>=8.0.0", # "psutil>=5.9.0", # "pyyaml>=6.0.0", # ] # /// """ Training script for Vietnamese Dependency Parser (TRE-1). Uses MaltParser (Java-based transition-based parser) trained on VnDT v1.1. Supports multiple parsing algorithms and gold/predicted POS tags. Models are saved to: models/dependency_parsing/{version}/ Usage: uv run src/train_dependency_parsing.py uv run src/train_dependency_parsing.py --pos-type gold uv run src/train_dependency_parsing.py --algorithm stackproj uv run src/train_dependency_parsing.py --version my_experiment """ import platform import shutil import subprocess import time from datetime import datetime from pathlib import Path import click import psutil import yaml PROJECT_ROOT = Path(__file__).parent.parent MALTPARSER_JAR = PROJECT_ROOT / "tools" / "maltparser-1.9.2" / "maltparser-1.9.2.jar" DATASET_DIR = PROJECT_ROOT / "datasets" / "VnDT" ALGORITHMS = [ "nivreeager", "nivrestandard", "stackproj", "stackeager", "stacklazy", "covproj", "covnonproj", ] def get_hardware_info(): """Collect hardware and system information.""" info = { "platform": platform.system(), "platform_release": platform.release(), "architecture": platform.machine(), "python_version": platform.python_version(), "cpu_physical_cores": psutil.cpu_count(logical=False), "cpu_logical_cores": psutil.cpu_count(logical=True), "ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), } try: if platform.system() == "Linux": with open("/proc/cpuinfo", "r") as f: for line in f: if "model name" in line: info["cpu_model"] = line.split(":")[1].strip() break except Exception: info["cpu_model"] = "Unknown" return info def format_duration(seconds): """Format duration in human-readable format.""" if seconds < 60: return f"{seconds:.2f}s" elif seconds < 3600: minutes = int(seconds // 60) secs = seconds % 60 return f"{minutes}m {secs:.2f}s" else: hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = seconds % 60 return f"{hours}h {minutes}m {secs:.2f}s" def get_conll_paths(pos_type): """Get train/dev/test CoNLL file paths for the given POS type.""" prefix = f"VnDTv1.1-{pos_type}-POS-tags" return { "train": DATASET_DIR / f"{prefix}-train.conll", "dev": DATASET_DIR / f"{prefix}-dev.conll", "test": DATASET_DIR / f"{prefix}-test.conll", } def count_sentences(conll_path): """Count sentences in a CoNLL file (blank-line separated).""" count = 0 with open(conll_path) as f: for line in f: if line.strip() == "": count += 1 return count def count_tokens(conll_path): """Count tokens in a CoNLL file (non-blank lines).""" count = 0 with open(conll_path) as f: for line in f: if line.strip(): count += 1 return count def evaluate_conll(gold_path, predicted_path): """Evaluate UAS and LAS by comparing gold and predicted CoNLL files. Compares column 7 (HEAD) for UAS and columns 7+8 (HEAD+DEPREL) for LAS. Skips blank lines (sentence boundaries). """ correct_head = 0 correct_both = 0 total = 0 with open(gold_path) as gf, open(predicted_path) as pf: for gold_line, pred_line in zip(gf, pf): gold_line = gold_line.strip() pred_line = pred_line.strip() if not gold_line: continue gold_cols = gold_line.split("\t") pred_cols = pred_line.split("\t") if len(gold_cols) < 8 or len(pred_cols) < 8: continue total += 1 gold_head = gold_cols[6] pred_head = pred_cols[6] gold_deprel = gold_cols[7] pred_deprel = pred_cols[7] if gold_head == pred_head: correct_head += 1 if gold_deprel == pred_deprel: correct_both += 1 uas = correct_head / total * 100 if total > 0 else 0.0 las = correct_both / total * 100 if total > 0 else 0.0 return {"uas": uas, "las": las, "total_tokens": total} def run_maltparser(args, cwd, java_mem="4g"): """Run MaltParser via Java subprocess.""" cmd = [ "java", f"-Xmx{java_mem}", "-jar", str(MALTPARSER_JAR), ] + args click.echo(f" $ {' '.join(cmd)}") result = subprocess.run( cmd, cwd=str(cwd), capture_output=True, text=True, ) if result.returncode != 0: click.echo(f"STDOUT:\n{result.stdout}") click.echo(f"STDERR:\n{result.stderr}") raise RuntimeError(f"MaltParser failed with exit code {result.returncode}") return result @click.command() @click.option( "--algorithm", "-a", type=click.Choice(ALGORITHMS), default="nivreeager", help="Parsing algorithm", show_default=True, ) @click.option( "--pos-type", type=click.Choice(["predicted", "gold"]), default="predicted", help="POS tag type in CoNLL files", show_default=True, ) @click.option( "--version", "-v", default=None, help="Model version (default: timestamp)", ) @click.option( "--java-mem", default="4g", help="Java heap size", show_default=True, ) def train(algorithm, pos_type, version, java_mem): """Train Vietnamese Dependency Parser using MaltParser on VnDT v1.1.""" total_start_time = time.time() start_datetime = datetime.now() # Validate prerequisites if not MALTPARSER_JAR.exists(): raise click.ClickException( f"MaltParser not found at {MALTPARSER_JAR}\n" "Download: wget http://maltparser.org/dist/maltparser-1.9.2.tar.gz -P tools/ " "&& tar xzf tools/maltparser-1.9.2.tar.gz -C tools/" ) paths = get_conll_paths(pos_type) for name, path in paths.items(): if not path.exists(): raise click.ClickException( f"{name} file not found: {path}\n" "Download: git clone https://github.com/datquocnguyen/VnDT.git datasets/VnDT" ) # Version if version is None: version = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = PROJECT_ROOT / "models" / "dependency_parsing" / version output_dir.mkdir(parents=True, exist_ok=True) # Hardware info hw_info = get_hardware_info() click.echo("=" * 60) click.echo(f"Dependency Parser Training - {version}") click.echo("=" * 60) click.echo(f"Algorithm: {algorithm}") click.echo(f"POS type: {pos_type}") click.echo(f"Java memory: {java_mem}") click.echo(f"Platform: {hw_info['platform']}") click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}") click.echo(f"Output: {output_dir}") click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") click.echo("=" * 60) # Dataset stats train_sents = count_sentences(paths["train"]) dev_sents = count_sentences(paths["dev"]) test_sents = count_sentences(paths["test"]) train_tokens = count_tokens(paths["train"]) test_tokens = count_tokens(paths["test"]) click.echo(f"\nDataset: VnDT v1.1 ({pos_type} POS)") click.echo(f"Train: {train_sents} sentences, {train_tokens} tokens") click.echo(f"Dev: {dev_sents} sentences") click.echo(f"Test: {test_sents} sentences, {test_tokens} tokens") # Copy train file to working directory (MaltParser reads from cwd) train_copy = output_dir / "train.conll" shutil.copy2(paths["train"], train_copy) # Phase 1: Train click.echo(f"\nPhase 1: Training MaltParser ({algorithm})...") train_start = time.time() run_maltparser( ["-c", "model", "-i", "train.conll", "-m", "learn", "-a", algorithm], cwd=output_dir, java_mem=java_mem, ) train_time = time.time() - train_start click.echo(f"Training time: {format_duration(train_time)}") # Clean up train copy train_copy.unlink() # Phase 2: Parse test set click.echo("\nPhase 2: Parsing test set...") test_copy = output_dir / "test.conll" shutil.copy2(paths["test"], test_copy) parse_start = time.time() run_maltparser( ["-c", "model", "-i", "test.conll", "-o", "output.conll", "-m", "parse"], cwd=output_dir, java_mem=java_mem, ) parse_time = time.time() - parse_start click.echo(f"Parse time: {format_duration(parse_time)}") # Clean up test copy test_copy.unlink() # Phase 3: Evaluate click.echo("\nPhase 3: Evaluating...") output_conll = output_dir / "output.conll" if not output_conll.exists(): raise click.ClickException(f"Parser output not found: {output_conll}") metrics = evaluate_conll(paths["test"], output_conll) total_time = time.time() - total_start_time click.echo(f"\nUAS: {metrics['uas']:.2f}%") click.echo(f"LAS: {metrics['las']:.2f}%") click.echo(f"Tokens evaluated: {metrics['total_tokens']}") # Save metadata metadata = { "model": { "name": "Vietnamese Dependency Parser", "version": version, "type": "MaltParser (transition-based)", "algorithm": algorithm, }, "training": { "dataset": "VnDT v1.1", "pos_type": pos_type, "train_sentences": train_sents, "dev_sentences": dev_sents, "test_sentences": test_sents, "train_tokens": train_tokens, "test_tokens": test_tokens, "duration_seconds": round(total_time, 2), "train_duration_seconds": round(train_time, 2), "parse_duration_seconds": round(parse_time, 2), }, "performance": { "uas": round(metrics["uas"], 2), "las": round(metrics["las"], 2), "total_tokens": metrics["total_tokens"], }, "environment": { "platform": hw_info["platform"], "cpu_model": hw_info.get("cpu_model", "Unknown"), "python_version": hw_info["python_version"], "java_memory": java_mem, }, "files": { "model": "model.mco", "output": "output.conll", }, "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "author": "undertheseanlp", } metadata_path = output_dir / "metadata.yaml" with open(metadata_path, "w") as f: yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False) click.echo("\n" + "=" * 60) click.echo("Training Summary") click.echo("=" * 60) click.echo(f"Algorithm: {algorithm}") click.echo(f"POS type: {pos_type}") click.echo(f"Version: {version}") click.echo(f"UAS: {metrics['uas']:.2f}%") click.echo(f"LAS: {metrics['las']:.2f}%") click.echo(f"Total time: {format_duration(total_time)}") click.echo(f"Model: {output_dir / 'model.mco'}") click.echo(f"Metadata: {metadata_path}") click.echo("=" * 60) if __name__ == "__main__": train()