File size: 5,505 Bytes
3dea709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Prepare balanced FastText training data from TurkuNLP/register_oscar dataset.

Downloads English shards, extracts labeled documents, and creates a balanced
training set by oversampling minority classes and undersampling majority classes
to the median class size.

Requirements:
    pip install huggingface_hub

Usage:
    # Download shards first:
    for i in $(seq 0 9); do
        hf download TurkuNLP/register_oscar \
            $(printf "en/en_%05d.jsonl.gz" $i) \
            --repo-type dataset --local-dir ./data
    done

    # Then run:
    python prepare_data.py --data-dir ./data/en --output-dir ./prepared
"""

import json
import gzip
import re
import random
import glob
import argparse
from collections import Counter, defaultdict
from pathlib import Path


REGISTER_LABELS = {
    "IN": "Informational",
    "NA": "Narrative",
    "OP": "Opinion",
    "IP": "Persuasion",
    "HI": "HowTo",
    "ID": "Discussion",
    "SP": "Spoken",
    "LY": "Lyrical",
}


def clean_text(text: str, max_words: int = 500) -> str:
    """Collapse whitespace and truncate to max_words."""
    text = re.sub(r"\s+", " ", text).strip()
    words = text.split()[:max_words]
    return " ".join(words)


def main():
    parser = argparse.ArgumentParser(description="Prepare balanced FastText training data")
    parser.add_argument("--data-dir", default="./data/en", help="Directory with .jsonl.gz shards")
    parser.add_argument("--output-dir", default="./prepared", help="Output directory for train/test files")
    parser.add_argument("--max-words", type=int, default=500, help="Max words per document")
    parser.add_argument("--min-text-len", type=int, default=50, help="Min character length to keep")
    parser.add_argument("--test-ratio", type=float, default=0.1, help="Fraction held out for test")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    random.seed(args.seed)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Collect all labeled docs grouped by primary label
    by_label = defaultdict(list)
    total = 0
    skipped_nolabel = 0
    skipped_short = 0

    shard_files = sorted(glob.glob(f"{args.data_dir}/*.jsonl.gz"))
    if not shard_files:
        raise FileNotFoundError(f"No .jsonl.gz files found in {args.data_dir}")

    print(f"Found {len(shard_files)} shard(s)")

    for shard_file in shard_files:
        print(f"  Processing {Path(shard_file).name}...")
        with gzip.open(shard_file, "rt") as f:
            for line in f:
                d = json.loads(line)
                labels = d.get("labels", [])
                text = d.get("text", "")

                if not labels:
                    skipped_nolabel += 1
                    continue
                if len(text) < args.min_text_len:
                    skipped_short += 1
                    continue

                cleaned = clean_text(text, args.max_words)
                if not cleaned:
                    continue

                label_str = " ".join(f"__label__{l}" for l in labels)
                ft_line = f"{label_str} {cleaned}\n"

                primary = labels[0]
                by_label[primary].append(ft_line)
                total += 1

    print(f"\nTotal labeled docs: {total}")
    print(f"Skipped (no label): {skipped_nolabel}")
    print(f"Skipped (too short): {skipped_short}")

    # Raw distribution
    print("\nRaw distribution:")
    for label in sorted(by_label.keys()):
        name = REGISTER_LABELS.get(label, label)
        print(f"  {label} ({name}): {len(by_label[label])}")

    # Balance: oversample minority to median, undersample majority to median
    sizes = {k: len(v) for k, v in by_label.items()}
    sorted_sizes = sorted(sizes.values())
    median_size = sorted_sizes[len(sorted_sizes) // 2]
    target = median_size

    print(f"\nBalancing target (median): {target}")

    train_lines = []
    test_lines = []

    for label, lines in by_label.items():
        random.shuffle(lines)

        n_test = max(len(lines) // 10, 50)
        test_pool = lines[:n_test]
        train_pool = lines[n_test:]

        test_lines.extend(test_pool)
        n_train = len(train_pool)

        if n_train >= target:
            sampled = random.sample(train_pool, target)
            train_lines.extend(sampled)
            print(f"  {label}: {n_train} -> {target} (undersampled)")
        else:
            train_lines.extend(train_pool)
            n_needed = target - n_train
            oversampled = random.choices(train_pool, k=n_needed)
            train_lines.extend(oversampled)
            print(f"  {label}: {n_train} -> {target} (oversampled +{n_needed})")

    random.shuffle(train_lines)
    random.shuffle(test_lines)

    train_path = output_dir / "train.txt"
    test_path = output_dir / "test.txt"

    with open(train_path, "w") as f:
        f.writelines(train_lines)
    with open(test_path, "w") as f:
        f.writelines(test_lines)

    print(f"\nTrain: {len(train_lines)} -> {train_path}")
    print(f"Test:  {len(test_lines)} -> {test_path}")

    # Verify balance
    c = Counter()
    for line in train_lines:
        for tok in line.split():
            if tok.startswith("__label__"):
                c[tok] += 1
    print("\nFinal train label distribution:")
    for l, cnt in c.most_common():
        name = REGISTER_LABELS.get(l.replace("__label__", ""), l)
        print(f"  {l} ({name}): {cnt}")


if __name__ == "__main__":
    main()