File size: 2,996 Bytes
62a3be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Seed and split dataset for training and evaluation.

- Reads: data/samples/training_data.json
- Writes:
    - data/samples/train.json
    - data/samples/eval.json

This script enforces:
- Stratified split by label
- Deterministic output (fixed random seed)
- Basic data validation
"""

import json
import random
from pathlib import Path
from collections import defaultdict

# -------------------------
# Configuration
# -------------------------
RANDOM_SEED = 42
TRAIN_RATIO = 0.7

BASE_DIR = Path(__file__).resolve().parent.parent
SAMPLES_DIR = BASE_DIR / "data" / "samples"

SOURCE_FILE = SAMPLES_DIR / "training_data.json"
TRAIN_FILE = SAMPLES_DIR / "train.json"
EVAL_FILE = SAMPLES_DIR / "eval.json"


def main():
    if not SOURCE_FILE.exists():
        raise FileNotFoundError(f"Source dataset not found: {SOURCE_FILE}")

    with open(SOURCE_FILE, "r", encoding="utf-8") as f:
        data = json.load(f)

    if not isinstance(data, list) or len(data) == 0:
        raise ValueError("Dataset must be a non-empty list")

    # -------------------------
    # Basic validation
    # -------------------------
    for i, item in enumerate(data):
        if "text" not in item or "label" not in item:
            raise ValueError(f"Invalid sample at index {i}: {item}")

    # -------------------------
    # Stratified split
    # -------------------------
    random.seed(RANDOM_SEED)

    by_label = defaultdict(list)
    for item in data:
        by_label[item["label"]].append(item)

    train_data = []
    eval_data = []

    for label, items in by_label.items():
        random.shuffle(items)

        split_idx = max(1, int(len(items) * TRAIN_RATIO))

        train_data.extend(items[:split_idx])
        eval_data.extend(items[split_idx:])

    # Final shuffle (important)
    random.shuffle(train_data)
    random.shuffle(eval_data)

    # -------------------------
    # Write outputs
    # -------------------------
    SAMPLES_DIR.mkdir(parents=True, exist_ok=True)

    with open(TRAIN_FILE, "w", encoding="utf-8") as f:
        json.dump(train_data, f, indent=2, ensure_ascii=False)

    with open(EVAL_FILE, "w", encoding="utf-8") as f:
        json.dump(eval_data, f, indent=2, ensure_ascii=False)

    # -------------------------
    # Summary
    # -------------------------
    print("====================================")
    print("Dataset seeding completed")
    print("====================================")
    print(f"Total samples : {len(data)}")
    print(f"Train samples : {len(train_data)}")
    print(f"Eval samples  : {len(eval_data)}")
    print()

    print("Label distribution (train):")
    _print_distribution(train_data)

    print("\nLabel distribution (eval):")
    _print_distribution(eval_data)


def _print_distribution(dataset):
    dist = defaultdict(int)
    for item in dataset:
        dist[item["label"]] += 1

    for label, count in sorted(dist.items()):
        print(f"  {label:<20} {count}")


if __name__ == "__main__":
    main()