bbc-document-classifier / src /data /split_data.py
pearlll's picture
Deploy document classifier app
492754f
Raw
History Blame Contribute Delete
870 Bytes
import os
import pandas as pd
from sklearn.model_selection import train_test_split
RANDOM_STATE = 42
df = pd.read_csv("data/processed/processed_bbc.csv")
train_df, temp_df = train_test_split(
df,
test_size=0.30,
random_state=RANDOM_STATE,
stratify=df["label_text"]
)
val_df, test_df = train_test_split(
temp_df,
test_size=0.50,
random_state=RANDOM_STATE,
stratify=temp_df["label_text"]
)
os.makedirs("data/splits", exist_ok=True)
train_df.to_csv("data/splits/train.csv", index=False)
val_df.to_csv("data/splits/val.csv", index=False)
test_df.to_csv("data/splits/test.csv", index=False)
print("Train/Validation/Test split completed.")
print("Train shape:", train_df.shape)
print("Validation shape:", val_df.shape)
print("Test shape:", test_df.shape)
print("\nTrain class distribution:")
print(train_df["label_text"].value_counts())