| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Training script for Vietnamese Dependency Parser (TRE-1). |
| |
| Uses MaltParser (Java-based transition-based parser) trained on VnDT v1.1. |
| Supports multiple parsing algorithms and gold/predicted POS tags. |
| |
| Models are saved to: models/dependency_parsing/{version}/ |
| |
| Usage: |
| uv run src/train_dependency_parsing.py |
| uv run src/train_dependency_parsing.py --pos-type gold |
| uv run src/train_dependency_parsing.py --algorithm stackproj |
| uv run src/train_dependency_parsing.py --version my_experiment |
| """ |
|
|
| import platform |
| import shutil |
| import subprocess |
| import time |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import click |
| import psutil |
| import yaml |
|
|
|
|
| PROJECT_ROOT = Path(__file__).parent.parent |
| MALTPARSER_JAR = PROJECT_ROOT / "tools" / "maltparser-1.9.2" / "maltparser-1.9.2.jar" |
| DATASET_DIR = PROJECT_ROOT / "datasets" / "VnDT" |
|
|
| ALGORITHMS = [ |
| "nivreeager", |
| "nivrestandard", |
| "stackproj", |
| "stackeager", |
| "stacklazy", |
| "covproj", |
| "covnonproj", |
| ] |
|
|
|
|
| def get_hardware_info(): |
| """Collect hardware and system information.""" |
| info = { |
| "platform": platform.system(), |
| "platform_release": platform.release(), |
| "architecture": platform.machine(), |
| "python_version": platform.python_version(), |
| "cpu_physical_cores": psutil.cpu_count(logical=False), |
| "cpu_logical_cores": psutil.cpu_count(logical=True), |
| "ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
| } |
| try: |
| if platform.system() == "Linux": |
| with open("/proc/cpuinfo", "r") as f: |
| for line in f: |
| if "model name" in line: |
| info["cpu_model"] = line.split(":")[1].strip() |
| break |
| except Exception: |
| info["cpu_model"] = "Unknown" |
| return info |
|
|
|
|
| def format_duration(seconds): |
| """Format duration in human-readable format.""" |
| if seconds < 60: |
| return f"{seconds:.2f}s" |
| elif seconds < 3600: |
| minutes = int(seconds // 60) |
| secs = seconds % 60 |
| return f"{minutes}m {secs:.2f}s" |
| else: |
| hours = int(seconds // 3600) |
| minutes = int((seconds % 3600) // 60) |
| secs = seconds % 60 |
| return f"{hours}h {minutes}m {secs:.2f}s" |
|
|
|
|
| def get_conll_paths(pos_type): |
| """Get train/dev/test CoNLL file paths for the given POS type.""" |
| prefix = f"VnDTv1.1-{pos_type}-POS-tags" |
| return { |
| "train": DATASET_DIR / f"{prefix}-train.conll", |
| "dev": DATASET_DIR / f"{prefix}-dev.conll", |
| "test": DATASET_DIR / f"{prefix}-test.conll", |
| } |
|
|
|
|
| def count_sentences(conll_path): |
| """Count sentences in a CoNLL file (blank-line separated).""" |
| count = 0 |
| with open(conll_path) as f: |
| for line in f: |
| if line.strip() == "": |
| count += 1 |
| return count |
|
|
|
|
| def count_tokens(conll_path): |
| """Count tokens in a CoNLL file (non-blank lines).""" |
| count = 0 |
| with open(conll_path) as f: |
| for line in f: |
| if line.strip(): |
| count += 1 |
| return count |
|
|
|
|
| def evaluate_conll(gold_path, predicted_path): |
| """Evaluate UAS and LAS by comparing gold and predicted CoNLL files. |
| |
| Compares column 7 (HEAD) for UAS and columns 7+8 (HEAD+DEPREL) for LAS. |
| Skips blank lines (sentence boundaries). |
| """ |
| correct_head = 0 |
| correct_both = 0 |
| total = 0 |
|
|
| with open(gold_path) as gf, open(predicted_path) as pf: |
| for gold_line, pred_line in zip(gf, pf): |
| gold_line = gold_line.strip() |
| pred_line = pred_line.strip() |
|
|
| if not gold_line: |
| continue |
|
|
| gold_cols = gold_line.split("\t") |
| pred_cols = pred_line.split("\t") |
|
|
| if len(gold_cols) < 8 or len(pred_cols) < 8: |
| continue |
|
|
| total += 1 |
| gold_head = gold_cols[6] |
| pred_head = pred_cols[6] |
| gold_deprel = gold_cols[7] |
| pred_deprel = pred_cols[7] |
|
|
| if gold_head == pred_head: |
| correct_head += 1 |
| if gold_deprel == pred_deprel: |
| correct_both += 1 |
|
|
| uas = correct_head / total * 100 if total > 0 else 0.0 |
| las = correct_both / total * 100 if total > 0 else 0.0 |
| return {"uas": uas, "las": las, "total_tokens": total} |
|
|
|
|
| def run_maltparser(args, cwd, java_mem="4g"): |
| """Run MaltParser via Java subprocess.""" |
| cmd = [ |
| "java", |
| f"-Xmx{java_mem}", |
| "-jar", str(MALTPARSER_JAR), |
| ] + args |
|
|
| click.echo(f" $ {' '.join(cmd)}") |
| result = subprocess.run( |
| cmd, |
| cwd=str(cwd), |
| capture_output=True, |
| text=True, |
| ) |
|
|
| if result.returncode != 0: |
| click.echo(f"STDOUT:\n{result.stdout}") |
| click.echo(f"STDERR:\n{result.stderr}") |
| raise RuntimeError(f"MaltParser failed with exit code {result.returncode}") |
|
|
| return result |
|
|
|
|
| @click.command() |
| @click.option( |
| "--algorithm", "-a", |
| type=click.Choice(ALGORITHMS), |
| default="nivreeager", |
| help="Parsing algorithm", |
| show_default=True, |
| ) |
| @click.option( |
| "--pos-type", |
| type=click.Choice(["predicted", "gold"]), |
| default="predicted", |
| help="POS tag type in CoNLL files", |
| show_default=True, |
| ) |
| @click.option( |
| "--version", "-v", |
| default=None, |
| help="Model version (default: timestamp)", |
| ) |
| @click.option( |
| "--java-mem", |
| default="4g", |
| help="Java heap size", |
| show_default=True, |
| ) |
| def train(algorithm, pos_type, version, java_mem): |
| """Train Vietnamese Dependency Parser using MaltParser on VnDT v1.1.""" |
| total_start_time = time.time() |
| start_datetime = datetime.now() |
|
|
| |
| if not MALTPARSER_JAR.exists(): |
| raise click.ClickException( |
| f"MaltParser not found at {MALTPARSER_JAR}\n" |
| "Download: wget http://maltparser.org/dist/maltparser-1.9.2.tar.gz -P tools/ " |
| "&& tar xzf tools/maltparser-1.9.2.tar.gz -C tools/" |
| ) |
|
|
| paths = get_conll_paths(pos_type) |
| for name, path in paths.items(): |
| if not path.exists(): |
| raise click.ClickException( |
| f"{name} file not found: {path}\n" |
| "Download: git clone https://github.com/datquocnguyen/VnDT.git datasets/VnDT" |
| ) |
|
|
| |
| if version is None: |
| version = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| output_dir = PROJECT_ROOT / "models" / "dependency_parsing" / version |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| hw_info = get_hardware_info() |
|
|
| click.echo("=" * 60) |
| click.echo(f"Dependency Parser Training - {version}") |
| click.echo("=" * 60) |
| click.echo(f"Algorithm: {algorithm}") |
| click.echo(f"POS type: {pos_type}") |
| click.echo(f"Java memory: {java_mem}") |
| click.echo(f"Platform: {hw_info['platform']}") |
| click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}") |
| click.echo(f"Output: {output_dir}") |
| click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") |
| click.echo("=" * 60) |
|
|
| |
| train_sents = count_sentences(paths["train"]) |
| dev_sents = count_sentences(paths["dev"]) |
| test_sents = count_sentences(paths["test"]) |
| train_tokens = count_tokens(paths["train"]) |
| test_tokens = count_tokens(paths["test"]) |
|
|
| click.echo(f"\nDataset: VnDT v1.1 ({pos_type} POS)") |
| click.echo(f"Train: {train_sents} sentences, {train_tokens} tokens") |
| click.echo(f"Dev: {dev_sents} sentences") |
| click.echo(f"Test: {test_sents} sentences, {test_tokens} tokens") |
|
|
| |
| train_copy = output_dir / "train.conll" |
| shutil.copy2(paths["train"], train_copy) |
|
|
| |
| click.echo(f"\nPhase 1: Training MaltParser ({algorithm})...") |
| train_start = time.time() |
| run_maltparser( |
| ["-c", "model", "-i", "train.conll", "-m", "learn", "-a", algorithm], |
| cwd=output_dir, |
| java_mem=java_mem, |
| ) |
| train_time = time.time() - train_start |
| click.echo(f"Training time: {format_duration(train_time)}") |
|
|
| |
| train_copy.unlink() |
|
|
| |
| click.echo("\nPhase 2: Parsing test set...") |
| test_copy = output_dir / "test.conll" |
| shutil.copy2(paths["test"], test_copy) |
|
|
| parse_start = time.time() |
| run_maltparser( |
| ["-c", "model", "-i", "test.conll", "-o", "output.conll", "-m", "parse"], |
| cwd=output_dir, |
| java_mem=java_mem, |
| ) |
| parse_time = time.time() - parse_start |
| click.echo(f"Parse time: {format_duration(parse_time)}") |
|
|
| |
| test_copy.unlink() |
|
|
| |
| click.echo("\nPhase 3: Evaluating...") |
| output_conll = output_dir / "output.conll" |
| if not output_conll.exists(): |
| raise click.ClickException(f"Parser output not found: {output_conll}") |
|
|
| metrics = evaluate_conll(paths["test"], output_conll) |
| total_time = time.time() - total_start_time |
|
|
| click.echo(f"\nUAS: {metrics['uas']:.2f}%") |
| click.echo(f"LAS: {metrics['las']:.2f}%") |
| click.echo(f"Tokens evaluated: {metrics['total_tokens']}") |
|
|
| |
| metadata = { |
| "model": { |
| "name": "Vietnamese Dependency Parser", |
| "version": version, |
| "type": "MaltParser (transition-based)", |
| "algorithm": algorithm, |
| }, |
| "training": { |
| "dataset": "VnDT v1.1", |
| "pos_type": pos_type, |
| "train_sentences": train_sents, |
| "dev_sentences": dev_sents, |
| "test_sentences": test_sents, |
| "train_tokens": train_tokens, |
| "test_tokens": test_tokens, |
| "duration_seconds": round(total_time, 2), |
| "train_duration_seconds": round(train_time, 2), |
| "parse_duration_seconds": round(parse_time, 2), |
| }, |
| "performance": { |
| "uas": round(metrics["uas"], 2), |
| "las": round(metrics["las"], 2), |
| "total_tokens": metrics["total_tokens"], |
| }, |
| "environment": { |
| "platform": hw_info["platform"], |
| "cpu_model": hw_info.get("cpu_model", "Unknown"), |
| "python_version": hw_info["python_version"], |
| "java_memory": java_mem, |
| }, |
| "files": { |
| "model": "model.mco", |
| "output": "output.conll", |
| }, |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "author": "undertheseanlp", |
| } |
|
|
| metadata_path = output_dir / "metadata.yaml" |
| with open(metadata_path, "w") as f: |
| yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False) |
|
|
| click.echo("\n" + "=" * 60) |
| click.echo("Training Summary") |
| click.echo("=" * 60) |
| click.echo(f"Algorithm: {algorithm}") |
| click.echo(f"POS type: {pos_type}") |
| click.echo(f"Version: {version}") |
| click.echo(f"UAS: {metrics['uas']:.2f}%") |
| click.echo(f"LAS: {metrics['las']:.2f}%") |
| click.echo(f"Total time: {format_duration(total_time)}") |
| click.echo(f"Model: {output_dir / 'model.mco'}") |
| click.echo(f"Metadata: {metadata_path}") |
| click.echo("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|