| |
| """ |
| Thoroughly shuffle the dataset while maintaining class distributions and data integrity. |
| This script implements stratified shuffling to ensure balanced representation of classes |
| and languages in the shuffled data. |
| """ |
|
|
| import pandas as pd |
| import numpy as np |
| from pathlib import Path |
| import argparse |
| from sklearn.model_selection import StratifiedKFold |
| from collections import defaultdict |
| import logging |
| import json |
| from typing import List, Dict, Tuple |
| import sys |
| from datetime import datetime |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.StreamHandler(sys.stdout), |
| logging.FileHandler(f'logs/shuffle_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') |
| ] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| def create_stratification_label(row: pd.Series, toxicity_labels: List[str]) -> str: |
| """ |
| Create a composite label for stratification that captures the combination of |
| toxicity labels and language. |
| """ |
| |
| toxicity_str = ''.join(['1' if row[label] == 1 else '0' for label in toxicity_labels]) |
| |
| return f"{row['lang']}_{toxicity_str}" |
|
|
| def validate_data(df: pd.DataFrame, toxicity_labels: List[str]) -> bool: |
| """ |
| Validate the dataset for required columns and data integrity. |
| """ |
| try: |
| |
| required_columns = ['comment_text', 'lang'] + toxicity_labels |
| missing_columns = [col for col in required_columns if col not in df.columns] |
| if missing_columns: |
| raise ValueError(f"Missing required columns: {missing_columns}") |
| |
| |
| null_counts = df[required_columns].isnull().sum() |
| if null_counts.any(): |
| logger.warning(f"Found null values:\n{null_counts[null_counts > 0]}") |
| |
| |
| for label in toxicity_labels: |
| invalid_values = df[label][~df[label].isin([0, 1, np.nan])] |
| if not invalid_values.empty: |
| raise ValueError(f"Found non-binary values in {label}: {invalid_values.unique()}") |
| |
| |
| if df['comment_text'].str.len().min() == 0: |
| logger.warning("Found empty comments in dataset") |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"Data validation failed: {str(e)}") |
| return False |
|
|
| def analyze_distribution(df: pd.DataFrame, toxicity_labels: List[str]) -> Dict: |
| """ |
| Analyze the class distribution and language distribution in the dataset. |
| """ |
| stats = { |
| 'total_samples': len(df), |
| 'language_distribution': df['lang'].value_counts().to_dict(), |
| 'class_distribution': { |
| label: { |
| 'positive': int(df[label].sum()), |
| 'negative': int(len(df) - df[label].sum()), |
| 'ratio': float(df[label].mean()) |
| } |
| for label in toxicity_labels |
| }, |
| 'language_class_distribution': defaultdict(dict) |
| } |
| |
| |
| for lang in df['lang'].unique(): |
| lang_df = df[df['lang'] == lang] |
| stats['language_class_distribution'][lang] = { |
| label: { |
| 'positive': int(lang_df[label].sum()), |
| 'negative': int(len(lang_df) - lang_df[label].sum()), |
| 'ratio': float(lang_df[label].mean()) |
| } |
| for label in toxicity_labels |
| } |
| |
| return stats |
|
|
| def shuffle_dataset( |
| input_file: str, |
| output_file: str, |
| toxicity_labels: List[str], |
| n_splits: int = 10, |
| random_state: int = 42 |
| ) -> Tuple[bool, Dict]: |
| """ |
| Thoroughly shuffle the dataset while maintaining class distributions. |
| Uses stratified k-fold splitting for balanced shuffling. |
| """ |
| try: |
| logger.info(f"Loading dataset from {input_file}") |
| df = pd.read_csv(input_file) |
| |
| |
| if not validate_data(df, toxicity_labels): |
| return False, {} |
| |
| |
| initial_stats = analyze_distribution(df, toxicity_labels) |
| logger.info("Initial distribution stats:") |
| logger.info(json.dumps(initial_stats, indent=2)) |
| |
| |
| logger.info("Creating stratification labels") |
| df['strat_label'] = df.apply( |
| lambda row: create_stratification_label(row, toxicity_labels), |
| axis=1 |
| ) |
| |
| |
| skf = StratifiedKFold( |
| n_splits=n_splits, |
| shuffle=True, |
| random_state=random_state |
| ) |
| |
| |
| logger.info(f"Performing stratified shuffling with {n_splits} splits") |
| all_indices = [] |
| for _, fold_indices in skf.split(df, df['strat_label']): |
| all_indices.extend(fold_indices) |
| |
| |
| shuffled_df = df.iloc[all_indices].copy() |
| shuffled_df = shuffled_df.drop('strat_label', axis=1) |
| |
| |
| final_stats = analyze_distribution(shuffled_df, toxicity_labels) |
| |
| |
| logger.info(f"Saving shuffled dataset to {output_file}") |
| shuffled_df.to_csv(output_file, index=False) |
| |
| |
| stats_file = Path(output_file).parent / 'shuffle_stats.json' |
| stats = { |
| 'initial': initial_stats, |
| 'final': final_stats, |
| 'shuffle_params': { |
| 'n_splits': n_splits, |
| 'random_state': random_state |
| } |
| } |
| with open(stats_file, 'w') as f: |
| json.dump(stats, f, indent=2) |
| |
| logger.info(f"Shuffling complete. Statistics saved to {stats_file}") |
| return True, stats |
| |
| except Exception as e: |
| logger.error(f"Error shuffling dataset: {str(e)}") |
| return False, {} |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Thoroughly shuffle the dataset.') |
| parser.add_argument( |
| '--input', |
| type=str, |
| required=True, |
| help='Input CSV file path' |
| ) |
| parser.add_argument( |
| '--output', |
| type=str, |
| required=True, |
| help='Output CSV file path' |
| ) |
| parser.add_argument( |
| '--splits', |
| type=int, |
| default=10, |
| help='Number of splits for stratified shuffling (default: 10)' |
| ) |
| parser.add_argument( |
| '--seed', |
| type=int, |
| default=42, |
| help='Random seed (default: 42)' |
| ) |
| args = parser.parse_args() |
| |
| |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| Path('logs').mkdir(exist_ok=True) |
| |
| |
| toxicity_labels = [ |
| 'toxic', 'severe_toxic', 'obscene', 'threat', |
| 'insult', 'identity_hate' |
| ] |
| |
| |
| success, stats = shuffle_dataset( |
| args.input, |
| args.output, |
| toxicity_labels, |
| args.splits, |
| args.seed |
| ) |
| |
| if success: |
| logger.info("Dataset shuffling completed successfully") |
| |
| for label, dist in stats['final']['class_distribution'].items(): |
| logger.info(f"{label}: {dist['ratio']:.3f} " |
| f"(+:{dist['positive']}, -:{dist['negative']})") |
| else: |
| logger.error("Dataset shuffling failed") |
| sys.exit(1) |
|
|
| if __name__ == '__main__': |
| main() |