leo-ai / scripts /slim_dataset.py
JacobBirger's picture
Deploy Leo to HuggingFace Spaces
0d443b6
"""
slim_dataset.py — Reduce goEmotionDataset.csv to only the rows actually used
by emotion_analyzer.py, eliminating the Git LFS requirement.
emotion_analyzer.py caps centroid building at MAX_SAMPLES=300 per emotion (seed 42).
This script pre-applies that same sampling and takes the union of all selected rows,
so the CSV never needs more than ~8,400 rows instead of 57,732.
Run once from the repo root:
python scripts/slim_dataset.py
"""
import os
import random
import pandas as pd
MAX_SAMPLES = 300
SEED = 42
EMOTION_LABELS = [
"admiration", "amusement", "anger", "annoyance", "approval", "caring",
"confusion", "curiosity", "desire", "disappointment", "disapproval",
"disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
"joy", "love", "nervousness", "optimism", "pride", "realization",
"relief", "remorse", "sadness", "surprise", "neutral",
]
root = os.path.join(os.path.dirname(__file__), "..")
csv_path = os.path.join(root, "goEmotionDataset.csv")
print(f"Reading {csv_path} ...")
df = pd.read_csv(csv_path)
original_rows = len(df)
print(f"Original: {original_rows} rows, {os.path.getsize(csv_path) / 1e6:.2f} MB")
keep_indices: set[int] = set()
random.seed(SEED)
for emotion in EMOTION_LABELS:
if emotion not in df.columns:
print(f" [warn] column '{emotion}' not found — skipping")
continue
indices = df.index[df[emotion] == 1].tolist()
if len(indices) > MAX_SAMPLES:
indices = random.sample(indices, MAX_SAMPLES)
keep_indices.update(indices)
print(f" {emotion}: kept {len(indices)} rows")
slim = df.loc[sorted(keep_indices)].reset_index(drop=True)
slim.to_csv(csv_path, index=False)
new_size = os.path.getsize(csv_path)
print(f"\nSlim dataset: {len(slim)} rows (was {original_rows}, "
f"{len(slim)/original_rows*100:.1f}% kept)")
print(f"File size: {new_size / 1e6:.2f} MB")
if new_size < 10 * 1e6:
print("OK: Under 10 MB - safe to commit without Git LFS")
else:
print("WARN: Still over 10 MB - increase MAX_SAMPLES cutoff or drop metadata columns")