File size: 3,296 Bytes
25faba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Helper script to intelligently sample a large dataset for training on M2 Mac.
This creates balanced subsets for quick iteration.
"""
import pandas as pd
import argparse
from pathlib import Path

def sample_dataset(input_path: str, output_path: str, n_samples: int, stratify: bool = True):
    """
    Sample a dataset while maintaining class balance.
    
    Args:
        input_path: Path to input CSV/JSONL
        output_path: Path to save sampled dataset
        n_samples: Number of samples to keep
        stratify: If True, maintain class balance
    """
    print(f"πŸ“– Loading dataset from {input_path}...")
    
    # Load dataset
    if str(input_path).endswith(".csv"):
        df = pd.read_csv(input_path)
    elif str(input_path).endswith(".jsonl") or str(input_path).endswith(".json"):
        df = pd.read_json(input_path, lines=str(input_path).endswith(".jsonl"))
    else:
        raise ValueError(f"Unsupported format: {input_path}")
    
    print(f"πŸ“Š Original dataset size: {len(df):,} samples")
    
    # Find label column
    label_col = None
    for col in ["label", "target", "class", "is_ai"]:
        if col in df.columns:
            label_col = col
            break
    
    if label_col:
        print(f"πŸ“ˆ Class distribution:")
        print(df[label_col].value_counts())
    
    # Sample
    if stratify and label_col:
        # Stratified sampling to maintain balance
        sampled = df.groupby(label_col, group_keys=False).apply(
            lambda x: x.sample(min(len(x), n_samples // 2), random_state=42)
        )
        # If we need more samples, take randomly
        if len(sampled) < n_samples:
            remaining = df[~df.index.isin(sampled.index)]
            needed = n_samples - len(sampled)
            if len(remaining) > 0:
                additional = remaining.sample(min(len(remaining), needed), random_state=42)
                sampled = pd.concat([sampled, additional])
    else:
        sampled = df.sample(min(len(df), n_samples), random_state=42)
    
    print(f"βœ… Sampled dataset size: {len(sampled):,} samples")
    if label_col:
        print(f"πŸ“ˆ Sampled class distribution:")
        print(sampled[label_col].value_counts())
    
    # Save
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    if str(output_path).endswith(".csv"):
        sampled.to_csv(output_path, index=False)
    elif str(output_path).endswith(".jsonl"):
        sampled.to_json(output_path, orient="records", lines=True)
    else:
        sampled.to_csv(output_path, index=False)
    
    print(f"πŸ’Ύ Saved to {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Sample a dataset for training")
    parser.add_argument("input", help="Input dataset path")
    parser.add_argument("output", help="Output dataset path")
    parser.add_argument("-n", "--n-samples", type=int, default=10000,
                       help="Number of samples (default: 10000)")
    parser.add_argument("--no-stratify", action="store_true",
                       help="Don't maintain class balance")
    
    args = parser.parse_args()
    
    sample_dataset(
        args.input,
        args.output,
        args.n_samples,
        stratify=not args.no_stratify
    )