Spaces:
Sleeping
Sleeping
File size: 3,296 Bytes
25faba3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""
Helper script to intelligently sample a large dataset for training on M2 Mac.
This creates balanced subsets for quick iteration.
"""
import pandas as pd
import argparse
from pathlib import Path
def sample_dataset(input_path: str, output_path: str, n_samples: int, stratify: bool = True):
"""
Sample a dataset while maintaining class balance.
Args:
input_path: Path to input CSV/JSONL
output_path: Path to save sampled dataset
n_samples: Number of samples to keep
stratify: If True, maintain class balance
"""
print(f"π Loading dataset from {input_path}...")
# Load dataset
if str(input_path).endswith(".csv"):
df = pd.read_csv(input_path)
elif str(input_path).endswith(".jsonl") or str(input_path).endswith(".json"):
df = pd.read_json(input_path, lines=str(input_path).endswith(".jsonl"))
else:
raise ValueError(f"Unsupported format: {input_path}")
print(f"π Original dataset size: {len(df):,} samples")
# Find label column
label_col = None
for col in ["label", "target", "class", "is_ai"]:
if col in df.columns:
label_col = col
break
if label_col:
print(f"π Class distribution:")
print(df[label_col].value_counts())
# Sample
if stratify and label_col:
# Stratified sampling to maintain balance
sampled = df.groupby(label_col, group_keys=False).apply(
lambda x: x.sample(min(len(x), n_samples // 2), random_state=42)
)
# If we need more samples, take randomly
if len(sampled) < n_samples:
remaining = df[~df.index.isin(sampled.index)]
needed = n_samples - len(sampled)
if len(remaining) > 0:
additional = remaining.sample(min(len(remaining), needed), random_state=42)
sampled = pd.concat([sampled, additional])
else:
sampled = df.sample(min(len(df), n_samples), random_state=42)
print(f"β
Sampled dataset size: {len(sampled):,} samples")
if label_col:
print(f"π Sampled class distribution:")
print(sampled[label_col].value_counts())
# Save
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
if str(output_path).endswith(".csv"):
sampled.to_csv(output_path, index=False)
elif str(output_path).endswith(".jsonl"):
sampled.to_json(output_path, orient="records", lines=True)
else:
sampled.to_csv(output_path, index=False)
print(f"πΎ Saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sample a dataset for training")
parser.add_argument("input", help="Input dataset path")
parser.add_argument("output", help="Output dataset path")
parser.add_argument("-n", "--n-samples", type=int, default=10000,
help="Number of samples (default: 10000)")
parser.add_argument("--no-stratify", action="store_true",
help="Don't maintain class balance")
args = parser.parse_args()
sample_dataset(
args.input,
args.output,
args.n_samples,
stratify=not args.no_stratify
)
|