bamboo-1 / src /vndt_corpus.py
rain1024's picture
Upload src/vndt_corpus.py with huggingface_hub
0bf906a verified
"""
VnDT Corpus loader for dependency parsing.
This module provides a corpus class for the VnDT dataset (Vietnamese Dependency Treebank)
by Dat Quoc Nguyen (https://github.com/datquocnguyen/VnDT).
VnDT v1.1:
- Train: 8,977 sentences
- Dev: 200 sentences
- Test: 1,020 sentences
- Format: CoNLL format with gold POS tags
"""
import urllib.request
from pathlib import Path
from typing import Optional
class VnDTCorpus:
"""
Corpus class for VnDT dataset.
VnDT is a Vietnamese dependency treebank containing 10,200+ sentences.
This is the standard benchmark for Vietnamese dependency parsing,
used by VnCoreNLP, PhoBERT, and other Vietnamese NLP models.
Attributes:
train: Path to the training data file (CoNLL format)
dev: Path to the development/validation data file (CoNLL format)
test: Path to the test data file (CoNLL format)
Example:
>>> from src.vndt_corpus import VnDTCorpus
>>> corpus = VnDTCorpus()
>>> print(corpus.train) # Path to train.conll
"""
name = "VnDT"
# VnDT GitHub repository URLs
BASE_URL = "https://raw.githubusercontent.com/datquocnguyen/VnDT/master"
FILE_NAMES = {
"train": "VnDTv1.1-gold-POS-tags-train.conll",
"dev": "VnDTv1.1-gold-POS-tags-dev.conll",
"test": "VnDTv1.1-gold-POS-tags-test.conll",
}
def __init__(self, data_dir: Optional[str] = None, force_download: bool = False):
"""
Initialize the VnDT corpus.
Args:
data_dir: Directory to store the CoNLL files.
Defaults to ./data/vndt
force_download: If True, re-download even if files exist.
"""
if data_dir is None:
data_dir = Path(__file__).parent.parent / "data" / "vndt"
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
# Use simplified names locally
self._train = self.data_dir / "train.conllu"
self._dev = self.data_dir / "dev.conllu"
self._test = self.data_dir / "test.conllu"
if force_download or not self._files_exist():
self._download()
def _files_exist(self) -> bool:
"""Check if all required files exist."""
return self._train.exists() and self._dev.exists() and self._test.exists()
def _download(self):
"""Download VnDT files from GitHub and convert to CoNLL-U format."""
print(f"Downloading VnDT from GitHub...")
for split, filename in self.FILE_NAMES.items():
url = f"{self.BASE_URL}/{filename}"
temp_path = self.data_dir / filename
output_path = getattr(self, f"_{split}")
print(f" Downloading {filename}...")
try:
urllib.request.urlretrieve(url, temp_path)
# Convert to CoNLL-U format
self._convert_to_conllu(temp_path, output_path)
# Remove temp file
temp_path.unlink()
except Exception as e:
raise RuntimeError(
f"Failed to download VnDT. "
f"Please download manually from: "
f"https://github.com/datquocnguyen/VnDT\n"
f"Error: {e}"
)
print(f"Dataset saved to {self.data_dir}")
self._print_statistics()
def _convert_to_conllu(self, input_path: Path, output_path: Path):
"""
Convert VnDT CoNLL format to CoNLL-U format.
VnDT format:
ID FORM _ _ POS _ HEAD DEPREL _ _
CoNLL-U format:
ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC
"""
with open(input_path, "r", encoding="utf-8") as f_in, \
open(output_path, "w", encoding="utf-8") as f_out:
sent_id = 0
tokens = []
for line in f_in:
line = line.strip()
if not line:
# End of sentence
if tokens:
sent_id += 1
f_out.write(f"# sent_id = {sent_id}\n")
for token in tokens:
f_out.write(token + "\n")
f_out.write("\n")
tokens = []
else:
parts = line.split("\t")
if len(parts) >= 8:
# VnDT: ID FORM _ _ POS _ HEAD DEPREL _ _
# CoNLL-U: ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC
token_id = parts[0]
form = parts[1]
pos = parts[4] if len(parts) > 4 else "_"
head = parts[6] if len(parts) > 6 else "0"
deprel = parts[7] if len(parts) > 7 else "_"
# Map VnDT dependency relations to UD-style
deprel = self._map_deprel(deprel)
conllu_line = f"{token_id}\t{form}\t_\t{pos}\t{pos}\t_\t{head}\t{deprel}\t_\t_"
tokens.append(conllu_line)
# Handle last sentence if no trailing newline
if tokens:
sent_id += 1
f_out.write(f"# sent_id = {sent_id}\n")
for token in tokens:
f_out.write(token + "\n")
f_out.write("\n")
def _map_deprel(self, deprel: str) -> str:
"""
Map VnDT dependency relations to UD-style relations.
VnDT uses its own tagset, this maps them to something closer to UD.
"""
# VnDT to UD-like mapping
mapping = {
"sub": "nsubj",
"dob": "obj",
"iob": "iobj",
"pob": "obl",
"vmod": "xcomp",
"nmod": "nmod",
"amod": "amod",
"adv": "advmod",
"coord": "cc",
"conj": "conj",
"det": "det",
"punct": "punct",
"root": "root",
"tmp": "obl:tmod",
"loc": "obl:loc",
"prp": "advcl",
"dep": "dep",
"prd": "xcomp",
"lgs": "obl:agent",
"mnr": "advmod",
"topic": "dislocated",
}
return mapping.get(deprel, deprel)
def _print_statistics(self):
"""Print dataset statistics."""
for name, path in [("Train", self._train), ("Dev", self._dev), ("Test", self._test)]:
n_sents, n_tokens = self._count_sentences_tokens(path)
print(f" {name}: {n_sents} sentences, {n_tokens} tokens")
def _count_sentences_tokens(self, path: Path) -> tuple:
"""Count sentences and tokens in a CoNLL-U file."""
n_sents = 0
n_tokens = 0
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
n_sents += 1
elif not line.startswith("#"):
parts = line.split("\t")
if "-" not in parts[0] and "." not in parts[0]:
n_tokens += 1
return n_sents, n_tokens
@property
def train(self) -> str:
"""Path to training data file."""
return str(self._train)
@property
def dev(self) -> str:
"""Path to development/validation data file."""
return str(self._dev)
@property
def test(self) -> str:
"""Path to test data file."""
return str(self._test)
def get_statistics(self) -> dict:
"""Get dataset statistics."""
stats = {}
for split_name, path in [
("train", self._train),
("dev", self._dev),
("test", self._test)
]:
n_sents, n_tokens = self._count_sentences_tokens(path)
stats[f"{split_name}_sentences"] = n_sents
stats[f"{split_name}_tokens"] = n_tokens
# Collect all POS tags and relations
all_pos = set()
all_deprels = set()
for path in [self._train, self._dev, self._test]:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
parts = line.split("\t")
if len(parts) >= 8 and "-" not in parts[0] and "." not in parts[0]:
all_pos.add(parts[3])
all_deprels.add(parts[7])
stats["num_pos_tags"] = len(all_pos)
stats["num_deprels"] = len(all_deprels)
stats["pos_tags"] = sorted(all_pos)
stats["deprels"] = sorted(all_deprels)
return stats