SummaryGenerator / mlplo /data_cleaning.py
Adive01's picture
Upload mlplo/data_cleaning.py with huggingface_hub
8beebbb verified
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from datasets import Dataset, DatasetDict, load_dataset
from .common import (
CACHE_DIR,
DEFAULT_DATASET_NAME,
DEFAULT_INPUT_MAX_LENGTH,
DEFAULT_MODEL_NAME,
DEFAULT_SUMMARY_COLUMN,
DEFAULT_TARGET_MAX_LENGTH,
DEFAULT_TEXT_COLUMN,
IS_WINDOWS,
PROCESSED_DIR,
build_preprocess_function,
count_words,
ensure_project_dirs,
load_tokenizer,
maybe_limit_split,
normalize_text,
write_json,
)
LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Clean, filter, deduplicate, and tokenize XSum for BART."
)
parser.add_argument("--dataset-name", default=DEFAULT_DATASET_NAME)
parser.add_argument("--dataset-config", default=None)
parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
parser.add_argument("--text-column", default=DEFAULT_TEXT_COLUMN)
parser.add_argument("--summary-column", default=DEFAULT_SUMMARY_COLUMN)
parser.add_argument("--cache-dir", default=str(CACHE_DIR))
parser.add_argument("--output-dir", default=str(PROCESSED_DIR / "xsum_bart_base"))
parser.add_argument("--max-input-length", type=int, default=DEFAULT_INPUT_MAX_LENGTH)
parser.add_argument("--max-target-length", type=int, default=DEFAULT_TARGET_MAX_LENGTH)
parser.add_argument("--min-document-words", type=int, default=50)
parser.add_argument("--max-document-words", type=int, default=1024)
parser.add_argument("--min-summary-words", type=int, default=5)
parser.add_argument("--train-samples", type=int, default=None)
parser.add_argument("--validation-samples", type=int, default=None)
parser.add_argument("--test-samples", type=int, default=None)
parser.add_argument(
"--num-proc",
type=int,
default=1,
help="Worker processes for dataset.map(). Forced to 1 on Windows.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Use tiny split sizes (256/64/64) for a fast smoke-test.",
)
return parser.parse_args()
def clean_batch(
batch: dict[str, list[str]], text_column: str, summary_column: str
) -> dict[str, list[str]]:
return {
text_column: [normalize_text(text) for text in batch[text_column]],
summary_column: [normalize_text(text) for text in batch[summary_column]],
}
def is_valid_example(
example: dict[str, str],
text_column: str,
summary_column: str,
min_document_words: int,
max_document_words: int,
min_summary_words: int,
) -> bool:
document_length = count_words(example.get(text_column, ""))
summary_length = count_words(example.get(summary_column, ""))
return (
min_document_words <= document_length <= max_document_words
and summary_length >= min_summary_words
and bool(example.get(text_column, "").strip())
and bool(example.get(summary_column, "").strip())
)
def deduplicate_split(split: Dataset, text_column: str) -> tuple[Dataset, int]:
"""Remove exact-duplicate documents using a hash set (O(n) time)."""
seen: set[str] = set()
keep: list[int] = []
for index, example in enumerate(split):
doc = example[text_column]
if doc in seen:
continue
seen.add(doc)
keep.append(index)
removed = len(split) - len(keep)
return split.select(keep), removed
def _safe_output_dir(output_dir: Path) -> None:
"""Raise FileExistsError if the directory is non-empty, with PermissionError guard."""
if not output_dir.exists():
return
try:
non_empty = any(output_dir.iterdir())
except PermissionError as exc:
raise PermissionError(
f"Cannot read output directory '{output_dir}'. "
"It may be locked by another process (e.g. OneDrive sync)."
) from exc
if non_empty:
raise FileExistsError(
f"Output directory '{output_dir}' is not empty. "
"Choose a new path or clear it first."
)
def _resolve_num_proc(requested: int) -> int:
"""Force num_proc=1 on Windows; warn if the user asked for more."""
if IS_WINDOWS and requested > 1:
LOGGER.warning(
"Multiprocessing with num_proc=%d is unreliable on Windows "
"(datasets uses fork). Falling back to num_proc=1.",
requested,
)
return 1
return requested
def main() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
args = parse_args()
ensure_project_dirs()
# ── Validate length arguments ──────────────────────────────────────────────
if args.max_input_length <= args.max_target_length:
raise ValueError(
f"--max-input-length ({args.max_input_length}) must be greater than "
f"--max-target-length ({args.max_target_length})."
)
# ── Debug mode: use None-safe check so --train-samples 0 is respected ─────
if args.debug:
if args.train_samples is None:
args.train_samples = 256
if args.validation_samples is None:
args.validation_samples = 64
if args.test_samples is None:
args.test_samples = 64
output_dir = Path(args.output_dir)
_safe_output_dir(output_dir)
num_proc = _resolve_num_proc(args.num_proc)
# ── Load dataset ───────────────────────────────────────────────────────────
LOGGER.info("Loading dataset '%s'…", args.dataset_name)
try:
dataset = load_dataset(
args.dataset_name,
args.dataset_config,
cache_dir=args.cache_dir,
)
except Exception as exc:
raise RuntimeError(
f"Failed to load dataset '{args.dataset_name}'. "
"Check your internet connection and dataset name."
) from exc
# ── Validate expected splits exist ────────────────────────────────────────
required_splits = {"train", "validation", "test"}
missing = required_splits - set(dataset.keys())
if missing:
LOGGER.warning(
"Dataset '%s' is missing splits: %s. Skipping those splits.",
args.dataset_name,
missing,
)
subset_limits = {
"train": args.train_samples,
"validation": args.validation_samples,
"test": args.test_samples,
}
dataset = DatasetDict(
{
split_name: maybe_limit_split(split, subset_limits.get(split_name))
for split_name, split in dataset.items()
}
)
# ── Normalize ──────────────────────────────────────────────────────────────
LOGGER.info("Normalizing text…")
dataset = dataset.map(
clean_batch,
batched=True,
fn_kwargs={
"text_column": args.text_column,
"summary_column": args.summary_column,
},
num_proc=num_proc,
desc="Whitespace cleanup",
)
# ── Filter ────────────────────────────────────────────────────────────────
LOGGER.info("Filtering unusable rows…")
dataset = dataset.filter(
is_valid_example,
fn_kwargs={
"text_column": args.text_column,
"summary_column": args.summary_column,
"min_document_words": args.min_document_words,
"max_document_words": args.max_document_words,
"min_summary_words": args.min_summary_words,
},
num_proc=num_proc,
desc="Length filtering",
)
# ── Deduplicate ───────────────────────────────────────────────────────────
dedupe_report: dict[str, int] = {}
deduped_splits: dict[str, Dataset] = {}
LOGGER.info("Deduplicating rows…")
for split_name, split in dataset.items():
deduped_split, removed = deduplicate_split(split, args.text_column)
deduped_splits[split_name] = deduped_split
dedupe_report[split_name] = removed
dataset = DatasetDict(deduped_splits)
# ── Tokenize ──────────────────────────────────────────────────────────────
tokenizer = load_tokenizer(args.model_name)
preprocess_fn = build_preprocess_function(
tokenizer=tokenizer,
text_column=args.text_column,
summary_column=args.summary_column,
max_input_length=args.max_input_length,
max_target_length=args.max_target_length,
)
LOGGER.info("Tokenizing rows…")
tokenized_dataset = dataset.map(
preprocess_fn,
batched=True,
num_proc=num_proc,
desc="Tokenization",
)
# ── Save ──────────────────────────────────────────────────────────────────
LOGGER.info("Saving tokenized dataset to %s", output_dir)
tokenized_dataset.save_to_disk(str(output_dir))
manifest = {
"dataset_name": args.dataset_name,
"dataset_config": args.dataset_config,
"model_name": args.model_name,
"text_column": args.text_column,
"summary_column": args.summary_column,
"max_input_length": args.max_input_length,
"max_target_length": args.max_target_length,
"subset_limits": subset_limits,
"splits": {name: len(split) for name, split in tokenized_dataset.items()},
"duplicates_removed": dedupe_report,
}
write_json(output_dir / "manifest.json", manifest)
LOGGER.info("Finished preprocessing. Split sizes: %s", manifest["splits"])
if __name__ == "__main__":
main()