| """ |
| VnDT Corpus loader for dependency parsing. |
| |
| This module provides a corpus class for the VnDT dataset (Vietnamese Dependency Treebank) |
| by Dat Quoc Nguyen (https://github.com/datquocnguyen/VnDT). |
| |
| VnDT v1.1: |
| - Train: 8,977 sentences |
| - Dev: 200 sentences |
| - Test: 1,020 sentences |
| - Format: CoNLL format with gold POS tags |
| """ |
|
|
| import urllib.request |
| from pathlib import Path |
| from typing import Optional |
|
|
|
|
| class VnDTCorpus: |
| """ |
| Corpus class for VnDT dataset. |
| |
| VnDT is a Vietnamese dependency treebank containing 10,200+ sentences. |
| This is the standard benchmark for Vietnamese dependency parsing, |
| used by VnCoreNLP, PhoBERT, and other Vietnamese NLP models. |
| |
| Attributes: |
| train: Path to the training data file (CoNLL format) |
| dev: Path to the development/validation data file (CoNLL format) |
| test: Path to the test data file (CoNLL format) |
| |
| Example: |
| >>> from src.vndt_corpus import VnDTCorpus |
| >>> corpus = VnDTCorpus() |
| >>> print(corpus.train) # Path to train.conll |
| """ |
|
|
| name = "VnDT" |
|
|
| |
| BASE_URL = "https://raw.githubusercontent.com/datquocnguyen/VnDT/master" |
|
|
| FILE_NAMES = { |
| "train": "VnDTv1.1-gold-POS-tags-train.conll", |
| "dev": "VnDTv1.1-gold-POS-tags-dev.conll", |
| "test": "VnDTv1.1-gold-POS-tags-test.conll", |
| } |
|
|
| def __init__(self, data_dir: Optional[str] = None, force_download: bool = False): |
| """ |
| Initialize the VnDT corpus. |
| |
| Args: |
| data_dir: Directory to store the CoNLL files. |
| Defaults to ./data/vndt |
| force_download: If True, re-download even if files exist. |
| """ |
| if data_dir is None: |
| data_dir = Path(__file__).parent.parent / "data" / "vndt" |
| self.data_dir = Path(data_dir) |
| self.data_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| self._train = self.data_dir / "train.conllu" |
| self._dev = self.data_dir / "dev.conllu" |
| self._test = self.data_dir / "test.conllu" |
|
|
| if force_download or not self._files_exist(): |
| self._download() |
|
|
| def _files_exist(self) -> bool: |
| """Check if all required files exist.""" |
| return self._train.exists() and self._dev.exists() and self._test.exists() |
|
|
| def _download(self): |
| """Download VnDT files from GitHub and convert to CoNLL-U format.""" |
| print(f"Downloading VnDT from GitHub...") |
|
|
| for split, filename in self.FILE_NAMES.items(): |
| url = f"{self.BASE_URL}/{filename}" |
| temp_path = self.data_dir / filename |
| output_path = getattr(self, f"_{split}") |
|
|
| print(f" Downloading {filename}...") |
| try: |
| urllib.request.urlretrieve(url, temp_path) |
| |
| self._convert_to_conllu(temp_path, output_path) |
| |
| temp_path.unlink() |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed to download VnDT. " |
| f"Please download manually from: " |
| f"https://github.com/datquocnguyen/VnDT\n" |
| f"Error: {e}" |
| ) |
|
|
| print(f"Dataset saved to {self.data_dir}") |
| self._print_statistics() |
|
|
| def _convert_to_conllu(self, input_path: Path, output_path: Path): |
| """ |
| Convert VnDT CoNLL format to CoNLL-U format. |
| |
| VnDT format: |
| ID FORM _ _ POS _ HEAD DEPREL _ _ |
| |
| CoNLL-U format: |
| ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC |
| """ |
| with open(input_path, "r", encoding="utf-8") as f_in, \ |
| open(output_path, "w", encoding="utf-8") as f_out: |
|
|
| sent_id = 0 |
| tokens = [] |
|
|
| for line in f_in: |
| line = line.strip() |
|
|
| if not line: |
| |
| if tokens: |
| sent_id += 1 |
| f_out.write(f"# sent_id = {sent_id}\n") |
| for token in tokens: |
| f_out.write(token + "\n") |
| f_out.write("\n") |
| tokens = [] |
| else: |
| parts = line.split("\t") |
| if len(parts) >= 8: |
| |
| |
| token_id = parts[0] |
| form = parts[1] |
| pos = parts[4] if len(parts) > 4 else "_" |
| head = parts[6] if len(parts) > 6 else "0" |
| deprel = parts[7] if len(parts) > 7 else "_" |
|
|
| |
| deprel = self._map_deprel(deprel) |
|
|
| conllu_line = f"{token_id}\t{form}\t_\t{pos}\t{pos}\t_\t{head}\t{deprel}\t_\t_" |
| tokens.append(conllu_line) |
|
|
| |
| if tokens: |
| sent_id += 1 |
| f_out.write(f"# sent_id = {sent_id}\n") |
| for token in tokens: |
| f_out.write(token + "\n") |
| f_out.write("\n") |
|
|
| def _map_deprel(self, deprel: str) -> str: |
| """ |
| Map VnDT dependency relations to UD-style relations. |
| |
| VnDT uses its own tagset, this maps them to something closer to UD. |
| """ |
| |
| mapping = { |
| "sub": "nsubj", |
| "dob": "obj", |
| "iob": "iobj", |
| "pob": "obl", |
| "vmod": "xcomp", |
| "nmod": "nmod", |
| "amod": "amod", |
| "adv": "advmod", |
| "coord": "cc", |
| "conj": "conj", |
| "det": "det", |
| "punct": "punct", |
| "root": "root", |
| "tmp": "obl:tmod", |
| "loc": "obl:loc", |
| "prp": "advcl", |
| "dep": "dep", |
| "prd": "xcomp", |
| "lgs": "obl:agent", |
| "mnr": "advmod", |
| "topic": "dislocated", |
| } |
| return mapping.get(deprel, deprel) |
|
|
| def _print_statistics(self): |
| """Print dataset statistics.""" |
| for name, path in [("Train", self._train), ("Dev", self._dev), ("Test", self._test)]: |
| n_sents, n_tokens = self._count_sentences_tokens(path) |
| print(f" {name}: {n_sents} sentences, {n_tokens} tokens") |
|
|
| def _count_sentences_tokens(self, path: Path) -> tuple: |
| """Count sentences and tokens in a CoNLL-U file.""" |
| n_sents = 0 |
| n_tokens = 0 |
|
|
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| n_sents += 1 |
| elif not line.startswith("#"): |
| parts = line.split("\t") |
| if "-" not in parts[0] and "." not in parts[0]: |
| n_tokens += 1 |
|
|
| return n_sents, n_tokens |
|
|
| @property |
| def train(self) -> str: |
| """Path to training data file.""" |
| return str(self._train) |
|
|
| @property |
| def dev(self) -> str: |
| """Path to development/validation data file.""" |
| return str(self._dev) |
|
|
| @property |
| def test(self) -> str: |
| """Path to test data file.""" |
| return str(self._test) |
|
|
| def get_statistics(self) -> dict: |
| """Get dataset statistics.""" |
| stats = {} |
|
|
| for split_name, path in [ |
| ("train", self._train), |
| ("dev", self._dev), |
| ("test", self._test) |
| ]: |
| n_sents, n_tokens = self._count_sentences_tokens(path) |
| stats[f"{split_name}_sentences"] = n_sents |
| stats[f"{split_name}_tokens"] = n_tokens |
|
|
| |
| all_pos = set() |
| all_deprels = set() |
|
|
| for path in [self._train, self._dev, self._test]: |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith("#"): |
| parts = line.split("\t") |
| if len(parts) >= 8 and "-" not in parts[0] and "." not in parts[0]: |
| all_pos.add(parts[3]) |
| all_deprels.add(parts[7]) |
|
|
| stats["num_pos_tags"] = len(all_pos) |
| stats["num_deprels"] = len(all_deprels) |
| stats["pos_tags"] = sorted(all_pos) |
| stats["deprels"] = sorted(all_deprels) |
|
|
| return stats |
|
|