""" Prepare balanced FastText training data from TurkuNLP/register_oscar dataset. Downloads English shards, extracts labeled documents, and creates a balanced training set by oversampling minority classes and undersampling majority classes to the median class size. Requirements: pip install huggingface_hub Usage: # Download shards first: for i in $(seq 0 9); do hf download TurkuNLP/register_oscar \ $(printf "en/en_%05d.jsonl.gz" $i) \ --repo-type dataset --local-dir ./data done # Then run: python prepare_data.py --data-dir ./data/en --output-dir ./prepared """ import json import gzip import re import random import glob import argparse from collections import Counter, defaultdict from pathlib import Path REGISTER_LABELS = { "IN": "Informational", "NA": "Narrative", "OP": "Opinion", "IP": "Persuasion", "HI": "HowTo", "ID": "Discussion", "SP": "Spoken", "LY": "Lyrical", } def clean_text(text: str, max_words: int = 500) -> str: """Collapse whitespace and truncate to max_words.""" text = re.sub(r"\s+", " ", text).strip() words = text.split()[:max_words] return " ".join(words) def main(): parser = argparse.ArgumentParser(description="Prepare balanced FastText training data") parser.add_argument("--data-dir", default="./data/en", help="Directory with .jsonl.gz shards") parser.add_argument("--output-dir", default="./prepared", help="Output directory for train/test files") parser.add_argument("--max-words", type=int, default=500, help="Max words per document") parser.add_argument("--min-text-len", type=int, default=50, help="Min character length to keep") parser.add_argument("--test-ratio", type=float, default=0.1, help="Fraction held out for test") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() random.seed(args.seed) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Collect all labeled docs grouped by primary label by_label = defaultdict(list) total = 0 skipped_nolabel = 0 skipped_short = 0 shard_files = sorted(glob.glob(f"{args.data_dir}/*.jsonl.gz")) if not shard_files: raise FileNotFoundError(f"No .jsonl.gz files found in {args.data_dir}") print(f"Found {len(shard_files)} shard(s)") for shard_file in shard_files: print(f" Processing {Path(shard_file).name}...") with gzip.open(shard_file, "rt") as f: for line in f: d = json.loads(line) labels = d.get("labels", []) text = d.get("text", "") if not labels: skipped_nolabel += 1 continue if len(text) < args.min_text_len: skipped_short += 1 continue cleaned = clean_text(text, args.max_words) if not cleaned: continue label_str = " ".join(f"__label__{l}" for l in labels) ft_line = f"{label_str} {cleaned}\n" primary = labels[0] by_label[primary].append(ft_line) total += 1 print(f"\nTotal labeled docs: {total}") print(f"Skipped (no label): {skipped_nolabel}") print(f"Skipped (too short): {skipped_short}") # Raw distribution print("\nRaw distribution:") for label in sorted(by_label.keys()): name = REGISTER_LABELS.get(label, label) print(f" {label} ({name}): {len(by_label[label])}") # Balance: oversample minority to median, undersample majority to median sizes = {k: len(v) for k, v in by_label.items()} sorted_sizes = sorted(sizes.values()) median_size = sorted_sizes[len(sorted_sizes) // 2] target = median_size print(f"\nBalancing target (median): {target}") train_lines = [] test_lines = [] for label, lines in by_label.items(): random.shuffle(lines) n_test = max(len(lines) // 10, 50) test_pool = lines[:n_test] train_pool = lines[n_test:] test_lines.extend(test_pool) n_train = len(train_pool) if n_train >= target: sampled = random.sample(train_pool, target) train_lines.extend(sampled) print(f" {label}: {n_train} -> {target} (undersampled)") else: train_lines.extend(train_pool) n_needed = target - n_train oversampled = random.choices(train_pool, k=n_needed) train_lines.extend(oversampled) print(f" {label}: {n_train} -> {target} (oversampled +{n_needed})") random.shuffle(train_lines) random.shuffle(test_lines) train_path = output_dir / "train.txt" test_path = output_dir / "test.txt" with open(train_path, "w") as f: f.writelines(train_lines) with open(test_path, "w") as f: f.writelines(test_lines) print(f"\nTrain: {len(train_lines)} -> {train_path}") print(f"Test: {len(test_lines)} -> {test_path}") # Verify balance c = Counter() for line in train_lines: for tok in line.split(): if tok.startswith("__label__"): c[tok] += 1 print("\nFinal train label distribution:") for l, cnt in c.most_common(): name = REGISTER_LABELS.get(l.replace("__label__", ""), l) print(f" {l} ({name}): {cnt}") if __name__ == "__main__": main()