File size: 5,505 Bytes
3dea709 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """
Prepare balanced FastText training data from TurkuNLP/register_oscar dataset.
Downloads English shards, extracts labeled documents, and creates a balanced
training set by oversampling minority classes and undersampling majority classes
to the median class size.
Requirements:
pip install huggingface_hub
Usage:
# Download shards first:
for i in $(seq 0 9); do
hf download TurkuNLP/register_oscar \
$(printf "en/en_%05d.jsonl.gz" $i) \
--repo-type dataset --local-dir ./data
done
# Then run:
python prepare_data.py --data-dir ./data/en --output-dir ./prepared
"""
import json
import gzip
import re
import random
import glob
import argparse
from collections import Counter, defaultdict
from pathlib import Path
REGISTER_LABELS = {
"IN": "Informational",
"NA": "Narrative",
"OP": "Opinion",
"IP": "Persuasion",
"HI": "HowTo",
"ID": "Discussion",
"SP": "Spoken",
"LY": "Lyrical",
}
def clean_text(text: str, max_words: int = 500) -> str:
"""Collapse whitespace and truncate to max_words."""
text = re.sub(r"\s+", " ", text).strip()
words = text.split()[:max_words]
return " ".join(words)
def main():
parser = argparse.ArgumentParser(description="Prepare balanced FastText training data")
parser.add_argument("--data-dir", default="./data/en", help="Directory with .jsonl.gz shards")
parser.add_argument("--output-dir", default="./prepared", help="Output directory for train/test files")
parser.add_argument("--max-words", type=int, default=500, help="Max words per document")
parser.add_argument("--min-text-len", type=int, default=50, help="Min character length to keep")
parser.add_argument("--test-ratio", type=float, default=0.1, help="Fraction held out for test")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Collect all labeled docs grouped by primary label
by_label = defaultdict(list)
total = 0
skipped_nolabel = 0
skipped_short = 0
shard_files = sorted(glob.glob(f"{args.data_dir}/*.jsonl.gz"))
if not shard_files:
raise FileNotFoundError(f"No .jsonl.gz files found in {args.data_dir}")
print(f"Found {len(shard_files)} shard(s)")
for shard_file in shard_files:
print(f" Processing {Path(shard_file).name}...")
with gzip.open(shard_file, "rt") as f:
for line in f:
d = json.loads(line)
labels = d.get("labels", [])
text = d.get("text", "")
if not labels:
skipped_nolabel += 1
continue
if len(text) < args.min_text_len:
skipped_short += 1
continue
cleaned = clean_text(text, args.max_words)
if not cleaned:
continue
label_str = " ".join(f"__label__{l}" for l in labels)
ft_line = f"{label_str} {cleaned}\n"
primary = labels[0]
by_label[primary].append(ft_line)
total += 1
print(f"\nTotal labeled docs: {total}")
print(f"Skipped (no label): {skipped_nolabel}")
print(f"Skipped (too short): {skipped_short}")
# Raw distribution
print("\nRaw distribution:")
for label in sorted(by_label.keys()):
name = REGISTER_LABELS.get(label, label)
print(f" {label} ({name}): {len(by_label[label])}")
# Balance: oversample minority to median, undersample majority to median
sizes = {k: len(v) for k, v in by_label.items()}
sorted_sizes = sorted(sizes.values())
median_size = sorted_sizes[len(sorted_sizes) // 2]
target = median_size
print(f"\nBalancing target (median): {target}")
train_lines = []
test_lines = []
for label, lines in by_label.items():
random.shuffle(lines)
n_test = max(len(lines) // 10, 50)
test_pool = lines[:n_test]
train_pool = lines[n_test:]
test_lines.extend(test_pool)
n_train = len(train_pool)
if n_train >= target:
sampled = random.sample(train_pool, target)
train_lines.extend(sampled)
print(f" {label}: {n_train} -> {target} (undersampled)")
else:
train_lines.extend(train_pool)
n_needed = target - n_train
oversampled = random.choices(train_pool, k=n_needed)
train_lines.extend(oversampled)
print(f" {label}: {n_train} -> {target} (oversampled +{n_needed})")
random.shuffle(train_lines)
random.shuffle(test_lines)
train_path = output_dir / "train.txt"
test_path = output_dir / "test.txt"
with open(train_path, "w") as f:
f.writelines(train_lines)
with open(test_path, "w") as f:
f.writelines(test_lines)
print(f"\nTrain: {len(train_lines)} -> {train_path}")
print(f"Test: {len(test_lines)} -> {test_path}")
# Verify balance
c = Counter()
for line in train_lines:
for tok in line.split():
if tok.startswith("__label__"):
c[tok] += 1
print("\nFinal train label distribution:")
for l, cnt in c.most_common():
name = REGISTER_LABELS.get(l.replace("__label__", ""), l)
print(f" {l} ({name}): {cnt}")
if __name__ == "__main__":
main()
|