| """ |
| Prepare balanced FastText training data from TurkuNLP/register_oscar dataset. |
| |
| Downloads English shards, extracts labeled documents, and creates a balanced |
| training set by oversampling minority classes and undersampling majority classes |
| to the median class size. |
| |
| Requirements: |
| pip install huggingface_hub |
| |
| Usage: |
| # Download shards first: |
| for i in $(seq 0 9); do |
| hf download TurkuNLP/register_oscar \ |
| $(printf "en/en_%05d.jsonl.gz" $i) \ |
| --repo-type dataset --local-dir ./data |
| done |
| |
| # Then run: |
| python prepare_data.py --data-dir ./data/en --output-dir ./prepared |
| """ |
|
|
| import json |
| import gzip |
| import re |
| import random |
| import glob |
| import argparse |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
|
|
| REGISTER_LABELS = { |
| "IN": "Informational", |
| "NA": "Narrative", |
| "OP": "Opinion", |
| "IP": "Persuasion", |
| "HI": "HowTo", |
| "ID": "Discussion", |
| "SP": "Spoken", |
| "LY": "Lyrical", |
| } |
|
|
|
|
| def clean_text(text: str, max_words: int = 500) -> str: |
| """Collapse whitespace and truncate to max_words.""" |
| text = re.sub(r"\s+", " ", text).strip() |
| words = text.split()[:max_words] |
| return " ".join(words) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Prepare balanced FastText training data") |
| parser.add_argument("--data-dir", default="./data/en", help="Directory with .jsonl.gz shards") |
| parser.add_argument("--output-dir", default="./prepared", help="Output directory for train/test files") |
| parser.add_argument("--max-words", type=int, default=500, help="Max words per document") |
| parser.add_argument("--min-text-len", type=int, default=50, help="Min character length to keep") |
| parser.add_argument("--test-ratio", type=float, default=0.1, help="Fraction held out for test") |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| random.seed(args.seed) |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| by_label = defaultdict(list) |
| total = 0 |
| skipped_nolabel = 0 |
| skipped_short = 0 |
|
|
| shard_files = sorted(glob.glob(f"{args.data_dir}/*.jsonl.gz")) |
| if not shard_files: |
| raise FileNotFoundError(f"No .jsonl.gz files found in {args.data_dir}") |
|
|
| print(f"Found {len(shard_files)} shard(s)") |
|
|
| for shard_file in shard_files: |
| print(f" Processing {Path(shard_file).name}...") |
| with gzip.open(shard_file, "rt") as f: |
| for line in f: |
| d = json.loads(line) |
| labels = d.get("labels", []) |
| text = d.get("text", "") |
|
|
| if not labels: |
| skipped_nolabel += 1 |
| continue |
| if len(text) < args.min_text_len: |
| skipped_short += 1 |
| continue |
|
|
| cleaned = clean_text(text, args.max_words) |
| if not cleaned: |
| continue |
|
|
| label_str = " ".join(f"__label__{l}" for l in labels) |
| ft_line = f"{label_str} {cleaned}\n" |
|
|
| primary = labels[0] |
| by_label[primary].append(ft_line) |
| total += 1 |
|
|
| print(f"\nTotal labeled docs: {total}") |
| print(f"Skipped (no label): {skipped_nolabel}") |
| print(f"Skipped (too short): {skipped_short}") |
|
|
| |
| print("\nRaw distribution:") |
| for label in sorted(by_label.keys()): |
| name = REGISTER_LABELS.get(label, label) |
| print(f" {label} ({name}): {len(by_label[label])}") |
|
|
| |
| sizes = {k: len(v) for k, v in by_label.items()} |
| sorted_sizes = sorted(sizes.values()) |
| median_size = sorted_sizes[len(sorted_sizes) // 2] |
| target = median_size |
|
|
| print(f"\nBalancing target (median): {target}") |
|
|
| train_lines = [] |
| test_lines = [] |
|
|
| for label, lines in by_label.items(): |
| random.shuffle(lines) |
|
|
| n_test = max(len(lines) // 10, 50) |
| test_pool = lines[:n_test] |
| train_pool = lines[n_test:] |
|
|
| test_lines.extend(test_pool) |
| n_train = len(train_pool) |
|
|
| if n_train >= target: |
| sampled = random.sample(train_pool, target) |
| train_lines.extend(sampled) |
| print(f" {label}: {n_train} -> {target} (undersampled)") |
| else: |
| train_lines.extend(train_pool) |
| n_needed = target - n_train |
| oversampled = random.choices(train_pool, k=n_needed) |
| train_lines.extend(oversampled) |
| print(f" {label}: {n_train} -> {target} (oversampled +{n_needed})") |
|
|
| random.shuffle(train_lines) |
| random.shuffle(test_lines) |
|
|
| train_path = output_dir / "train.txt" |
| test_path = output_dir / "test.txt" |
|
|
| with open(train_path, "w") as f: |
| f.writelines(train_lines) |
| with open(test_path, "w") as f: |
| f.writelines(test_lines) |
|
|
| print(f"\nTrain: {len(train_lines)} -> {train_path}") |
| print(f"Test: {len(test_lines)} -> {test_path}") |
|
|
| |
| c = Counter() |
| for line in train_lines: |
| for tok in line.split(): |
| if tok.startswith("__label__"): |
| c[tok] += 1 |
| print("\nFinal train label distribution:") |
| for l, cnt in c.most_common(): |
| name = REGISTER_LABELS.get(l.replace("__label__", ""), l) |
| print(f" {l} ({name}): {cnt}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|