Adive01 commited on
Commit
8beebbb
Β·
verified Β·
1 Parent(s): fecd2b2

Upload mlplo/data_cleaning.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/data_cleaning.py +277 -0
mlplo/data_cleaning.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import logging
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from datasets import Dataset, DatasetDict, load_dataset
9
+
10
+ from .common import (
11
+ CACHE_DIR,
12
+ DEFAULT_DATASET_NAME,
13
+ DEFAULT_INPUT_MAX_LENGTH,
14
+ DEFAULT_MODEL_NAME,
15
+ DEFAULT_SUMMARY_COLUMN,
16
+ DEFAULT_TARGET_MAX_LENGTH,
17
+ DEFAULT_TEXT_COLUMN,
18
+ IS_WINDOWS,
19
+ PROCESSED_DIR,
20
+ build_preprocess_function,
21
+ count_words,
22
+ ensure_project_dirs,
23
+ load_tokenizer,
24
+ maybe_limit_split,
25
+ normalize_text,
26
+ write_json,
27
+ )
28
+
29
+ LOGGER = logging.getLogger(__name__)
30
+
31
+
32
+ def parse_args() -> argparse.Namespace:
33
+ parser = argparse.ArgumentParser(
34
+ description="Clean, filter, deduplicate, and tokenize XSum for BART."
35
+ )
36
+ parser.add_argument("--dataset-name", default=DEFAULT_DATASET_NAME)
37
+ parser.add_argument("--dataset-config", default=None)
38
+ parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
39
+ parser.add_argument("--text-column", default=DEFAULT_TEXT_COLUMN)
40
+ parser.add_argument("--summary-column", default=DEFAULT_SUMMARY_COLUMN)
41
+ parser.add_argument("--cache-dir", default=str(CACHE_DIR))
42
+ parser.add_argument("--output-dir", default=str(PROCESSED_DIR / "xsum_bart_base"))
43
+ parser.add_argument("--max-input-length", type=int, default=DEFAULT_INPUT_MAX_LENGTH)
44
+ parser.add_argument("--max-target-length", type=int, default=DEFAULT_TARGET_MAX_LENGTH)
45
+ parser.add_argument("--min-document-words", type=int, default=50)
46
+ parser.add_argument("--max-document-words", type=int, default=1024)
47
+ parser.add_argument("--min-summary-words", type=int, default=5)
48
+ parser.add_argument("--train-samples", type=int, default=None)
49
+ parser.add_argument("--validation-samples", type=int, default=None)
50
+ parser.add_argument("--test-samples", type=int, default=None)
51
+ parser.add_argument(
52
+ "--num-proc",
53
+ type=int,
54
+ default=1,
55
+ help="Worker processes for dataset.map(). Forced to 1 on Windows.",
56
+ )
57
+ parser.add_argument(
58
+ "--debug",
59
+ action="store_true",
60
+ help="Use tiny split sizes (256/64/64) for a fast smoke-test.",
61
+ )
62
+ return parser.parse_args()
63
+
64
+
65
+ def clean_batch(
66
+ batch: dict[str, list[str]], text_column: str, summary_column: str
67
+ ) -> dict[str, list[str]]:
68
+ return {
69
+ text_column: [normalize_text(text) for text in batch[text_column]],
70
+ summary_column: [normalize_text(text) for text in batch[summary_column]],
71
+ }
72
+
73
+
74
+ def is_valid_example(
75
+ example: dict[str, str],
76
+ text_column: str,
77
+ summary_column: str,
78
+ min_document_words: int,
79
+ max_document_words: int,
80
+ min_summary_words: int,
81
+ ) -> bool:
82
+ document_length = count_words(example.get(text_column, ""))
83
+ summary_length = count_words(example.get(summary_column, ""))
84
+ return (
85
+ min_document_words <= document_length <= max_document_words
86
+ and summary_length >= min_summary_words
87
+ and bool(example.get(text_column, "").strip())
88
+ and bool(example.get(summary_column, "").strip())
89
+ )
90
+
91
+
92
+ def deduplicate_split(split: Dataset, text_column: str) -> tuple[Dataset, int]:
93
+ """Remove exact-duplicate documents using a hash set (O(n) time)."""
94
+ seen: set[str] = set()
95
+ keep: list[int] = []
96
+ for index, example in enumerate(split):
97
+ doc = example[text_column]
98
+ if doc in seen:
99
+ continue
100
+ seen.add(doc)
101
+ keep.append(index)
102
+ removed = len(split) - len(keep)
103
+ return split.select(keep), removed
104
+
105
+
106
+ def _safe_output_dir(output_dir: Path) -> None:
107
+ """Raise FileExistsError if the directory is non-empty, with PermissionError guard."""
108
+ if not output_dir.exists():
109
+ return
110
+ try:
111
+ non_empty = any(output_dir.iterdir())
112
+ except PermissionError as exc:
113
+ raise PermissionError(
114
+ f"Cannot read output directory '{output_dir}'. "
115
+ "It may be locked by another process (e.g. OneDrive sync)."
116
+ ) from exc
117
+ if non_empty:
118
+ raise FileExistsError(
119
+ f"Output directory '{output_dir}' is not empty. "
120
+ "Choose a new path or clear it first."
121
+ )
122
+
123
+
124
+ def _resolve_num_proc(requested: int) -> int:
125
+ """Force num_proc=1 on Windows; warn if the user asked for more."""
126
+ if IS_WINDOWS and requested > 1:
127
+ LOGGER.warning(
128
+ "Multiprocessing with num_proc=%d is unreliable on Windows "
129
+ "(datasets uses fork). Falling back to num_proc=1.",
130
+ requested,
131
+ )
132
+ return 1
133
+ return requested
134
+
135
+
136
+ def main() -> None:
137
+ logging.basicConfig(
138
+ level=logging.INFO,
139
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
140
+ )
141
+ args = parse_args()
142
+ ensure_project_dirs()
143
+
144
+ # ── Validate length arguments ──────────────────────────────────────────────
145
+ if args.max_input_length <= args.max_target_length:
146
+ raise ValueError(
147
+ f"--max-input-length ({args.max_input_length}) must be greater than "
148
+ f"--max-target-length ({args.max_target_length})."
149
+ )
150
+
151
+ # ── Debug mode: use None-safe check so --train-samples 0 is respected ─────
152
+ if args.debug:
153
+ if args.train_samples is None:
154
+ args.train_samples = 256
155
+ if args.validation_samples is None:
156
+ args.validation_samples = 64
157
+ if args.test_samples is None:
158
+ args.test_samples = 64
159
+
160
+ output_dir = Path(args.output_dir)
161
+ _safe_output_dir(output_dir)
162
+
163
+ num_proc = _resolve_num_proc(args.num_proc)
164
+
165
+ # ── Load dataset ───────────────────────────────────────────────────────────
166
+ LOGGER.info("Loading dataset '%s'…", args.dataset_name)
167
+ try:
168
+ dataset = load_dataset(
169
+ args.dataset_name,
170
+ args.dataset_config,
171
+ cache_dir=args.cache_dir,
172
+ )
173
+ except Exception as exc:
174
+ raise RuntimeError(
175
+ f"Failed to load dataset '{args.dataset_name}'. "
176
+ "Check your internet connection and dataset name."
177
+ ) from exc
178
+
179
+ # ── Validate expected splits exist ────────────────────────────────────────
180
+ required_splits = {"train", "validation", "test"}
181
+ missing = required_splits - set(dataset.keys())
182
+ if missing:
183
+ LOGGER.warning(
184
+ "Dataset '%s' is missing splits: %s. Skipping those splits.",
185
+ args.dataset_name,
186
+ missing,
187
+ )
188
+
189
+ subset_limits = {
190
+ "train": args.train_samples,
191
+ "validation": args.validation_samples,
192
+ "test": args.test_samples,
193
+ }
194
+ dataset = DatasetDict(
195
+ {
196
+ split_name: maybe_limit_split(split, subset_limits.get(split_name))
197
+ for split_name, split in dataset.items()
198
+ }
199
+ )
200
+
201
+ # ── Normalize ──────────────────────────────────────────────────────────────
202
+ LOGGER.info("Normalizing text…")
203
+ dataset = dataset.map(
204
+ clean_batch,
205
+ batched=True,
206
+ fn_kwargs={
207
+ "text_column": args.text_column,
208
+ "summary_column": args.summary_column,
209
+ },
210
+ num_proc=num_proc,
211
+ desc="Whitespace cleanup",
212
+ )
213
+
214
+ # ── Filter ────────────────────────────────────────────────────────────────
215
+ LOGGER.info("Filtering unusable rows…")
216
+ dataset = dataset.filter(
217
+ is_valid_example,
218
+ fn_kwargs={
219
+ "text_column": args.text_column,
220
+ "summary_column": args.summary_column,
221
+ "min_document_words": args.min_document_words,
222
+ "max_document_words": args.max_document_words,
223
+ "min_summary_words": args.min_summary_words,
224
+ },
225
+ num_proc=num_proc,
226
+ desc="Length filtering",
227
+ )
228
+
229
+ # ── Deduplicate ───────────────────────────────────────────────────────────
230
+ dedupe_report: dict[str, int] = {}
231
+ deduped_splits: dict[str, Dataset] = {}
232
+ LOGGER.info("Deduplicating rows…")
233
+ for split_name, split in dataset.items():
234
+ deduped_split, removed = deduplicate_split(split, args.text_column)
235
+ deduped_splits[split_name] = deduped_split
236
+ dedupe_report[split_name] = removed
237
+ dataset = DatasetDict(deduped_splits)
238
+
239
+ # ── Tokenize ──────────────────────────────────────────────────────────────
240
+ tokenizer = load_tokenizer(args.model_name)
241
+ preprocess_fn = build_preprocess_function(
242
+ tokenizer=tokenizer,
243
+ text_column=args.text_column,
244
+ summary_column=args.summary_column,
245
+ max_input_length=args.max_input_length,
246
+ max_target_length=args.max_target_length,
247
+ )
248
+ LOGGER.info("Tokenizing rows…")
249
+ tokenized_dataset = dataset.map(
250
+ preprocess_fn,
251
+ batched=True,
252
+ num_proc=num_proc,
253
+ desc="Tokenization",
254
+ )
255
+
256
+ # ── Save ──────────────────────────────────────────────────────────────────
257
+ LOGGER.info("Saving tokenized dataset to %s", output_dir)
258
+ tokenized_dataset.save_to_disk(str(output_dir))
259
+
260
+ manifest = {
261
+ "dataset_name": args.dataset_name,
262
+ "dataset_config": args.dataset_config,
263
+ "model_name": args.model_name,
264
+ "text_column": args.text_column,
265
+ "summary_column": args.summary_column,
266
+ "max_input_length": args.max_input_length,
267
+ "max_target_length": args.max_target_length,
268
+ "subset_limits": subset_limits,
269
+ "splits": {name: len(split) for name, split in tokenized_dataset.items()},
270
+ "duplicates_removed": dedupe_report,
271
+ }
272
+ write_json(output_dir / "manifest.json", manifest)
273
+ LOGGER.info("Finished preprocessing. Split sizes: %s", manifest["splits"])
274
+
275
+
276
+ if __name__ == "__main__":
277
+ main()