|
|
""" |
|
|
Seed and split dataset for training and evaluation. |
|
|
|
|
|
- Reads: data/samples/training_data.json |
|
|
- Writes: |
|
|
- data/samples/train.json |
|
|
- data/samples/eval.json |
|
|
|
|
|
This script enforces: |
|
|
- Stratified split by label |
|
|
- Deterministic output (fixed random seed) |
|
|
- Basic data validation |
|
|
""" |
|
|
|
|
|
import json |
|
|
import random |
|
|
from pathlib import Path |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RANDOM_SEED = 42 |
|
|
TRAIN_RATIO = 0.7 |
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent.parent |
|
|
SAMPLES_DIR = BASE_DIR / "data" / "samples" |
|
|
|
|
|
SOURCE_FILE = SAMPLES_DIR / "training_data.json" |
|
|
TRAIN_FILE = SAMPLES_DIR / "train.json" |
|
|
EVAL_FILE = SAMPLES_DIR / "eval.json" |
|
|
|
|
|
|
|
|
def main(): |
|
|
if not SOURCE_FILE.exists(): |
|
|
raise FileNotFoundError(f"Source dataset not found: {SOURCE_FILE}") |
|
|
|
|
|
with open(SOURCE_FILE, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if not isinstance(data, list) or len(data) == 0: |
|
|
raise ValueError("Dataset must be a non-empty list") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, item in enumerate(data): |
|
|
if "text" not in item or "label" not in item: |
|
|
raise ValueError(f"Invalid sample at index {i}: {item}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
random.seed(RANDOM_SEED) |
|
|
|
|
|
by_label = defaultdict(list) |
|
|
for item in data: |
|
|
by_label[item["label"]].append(item) |
|
|
|
|
|
train_data = [] |
|
|
eval_data = [] |
|
|
|
|
|
for label, items in by_label.items(): |
|
|
random.shuffle(items) |
|
|
|
|
|
split_idx = max(1, int(len(items) * TRAIN_RATIO)) |
|
|
|
|
|
train_data.extend(items[:split_idx]) |
|
|
eval_data.extend(items[split_idx:]) |
|
|
|
|
|
|
|
|
random.shuffle(train_data) |
|
|
random.shuffle(eval_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAMPLES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(TRAIN_FILE, "w", encoding="utf-8") as f: |
|
|
json.dump(train_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
with open(EVAL_FILE, "w", encoding="utf-8") as f: |
|
|
json.dump(eval_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("====================================") |
|
|
print("Dataset seeding completed") |
|
|
print("====================================") |
|
|
print(f"Total samples : {len(data)}") |
|
|
print(f"Train samples : {len(train_data)}") |
|
|
print(f"Eval samples : {len(eval_data)}") |
|
|
print() |
|
|
|
|
|
print("Label distribution (train):") |
|
|
_print_distribution(train_data) |
|
|
|
|
|
print("\nLabel distribution (eval):") |
|
|
_print_distribution(eval_data) |
|
|
|
|
|
|
|
|
def _print_distribution(dataset): |
|
|
dist = defaultdict(int) |
|
|
for item in dataset: |
|
|
dist[item["label"]] += 1 |
|
|
|
|
|
for label, count in sorted(dist.items()): |
|
|
print(f" {label:<20} {count}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|