Spaces:
Running
Running
| """ | |
| slim_dataset.py — Reduce goEmotionDataset.csv to only the rows actually used | |
| by emotion_analyzer.py, eliminating the Git LFS requirement. | |
| emotion_analyzer.py caps centroid building at MAX_SAMPLES=300 per emotion (seed 42). | |
| This script pre-applies that same sampling and takes the union of all selected rows, | |
| so the CSV never needs more than ~8,400 rows instead of 57,732. | |
| Run once from the repo root: | |
| python scripts/slim_dataset.py | |
| """ | |
| import os | |
| import random | |
| import pandas as pd | |
| MAX_SAMPLES = 300 | |
| SEED = 42 | |
| EMOTION_LABELS = [ | |
| "admiration", "amusement", "anger", "annoyance", "approval", "caring", | |
| "confusion", "curiosity", "desire", "disappointment", "disapproval", | |
| "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief", | |
| "joy", "love", "nervousness", "optimism", "pride", "realization", | |
| "relief", "remorse", "sadness", "surprise", "neutral", | |
| ] | |
| root = os.path.join(os.path.dirname(__file__), "..") | |
| csv_path = os.path.join(root, "goEmotionDataset.csv") | |
| print(f"Reading {csv_path} ...") | |
| df = pd.read_csv(csv_path) | |
| original_rows = len(df) | |
| print(f"Original: {original_rows} rows, {os.path.getsize(csv_path) / 1e6:.2f} MB") | |
| keep_indices: set[int] = set() | |
| random.seed(SEED) | |
| for emotion in EMOTION_LABELS: | |
| if emotion not in df.columns: | |
| print(f" [warn] column '{emotion}' not found — skipping") | |
| continue | |
| indices = df.index[df[emotion] == 1].tolist() | |
| if len(indices) > MAX_SAMPLES: | |
| indices = random.sample(indices, MAX_SAMPLES) | |
| keep_indices.update(indices) | |
| print(f" {emotion}: kept {len(indices)} rows") | |
| slim = df.loc[sorted(keep_indices)].reset_index(drop=True) | |
| slim.to_csv(csv_path, index=False) | |
| new_size = os.path.getsize(csv_path) | |
| print(f"\nSlim dataset: {len(slim)} rows (was {original_rows}, " | |
| f"{len(slim)/original_rows*100:.1f}% kept)") | |
| print(f"File size: {new_size / 1e6:.2f} MB") | |
| if new_size < 10 * 1e6: | |
| print("OK: Under 10 MB - safe to commit without Git LFS") | |
| else: | |
| print("WARN: Still over 10 MB - increase MAX_SAMPLES cutoff or drop metadata columns") | |