| import os |
| import torch |
|
|
| |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' |
| os.environ['TF_CPU_ENABLE_AVX2'] = '1' |
| os.environ['TF_CPU_ENABLE_AVX512F'] = '1' |
| os.environ['TF_CPU_ENABLE_AVX512_VNNI'] = '1' |
| os.environ['TF_CPU_ENABLE_FMA'] = '1' |
| os.environ['MKL_NUM_THREADS'] = '80' |
| os.environ['OMP_NUM_THREADS'] = '80' |
|
|
| |
| torch.set_num_threads(80) |
| torch.set_num_interop_threads(10) |
|
|
| |
| import pandas as pd |
| import numpy as np |
| from pathlib import Path |
| import logging |
| from datetime import datetime |
| import sys |
| from toxic_augment import ToxicAugmenter |
| import json |
|
|
| |
| log_dir = Path("logs") |
| log_dir.mkdir(exist_ok=True) |
|
|
| timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") |
| log_file = log_dir / f"balance_english_{timestamp}.log" |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s | %(message)s', |
| handlers=[ |
| logging.StreamHandler(sys.stdout), |
| logging.FileHandler(log_file) |
| ] |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def analyze_label_distribution(df, lang='en'): |
| """Analyze label distribution for a specific language""" |
| labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
| lang_df = df[df['lang'] == lang] |
| total = len(lang_df) |
| |
| if total == 0: |
| logger.warning(f"No samples found for language {lang.upper()}.") |
| return {} |
| |
| logger.info(f"\nLabel Distribution for {lang.upper()}:") |
| logger.info("-" * 50) |
| dist = {} |
| for label in labels: |
| count = lang_df[label].sum() |
| percentage = (count / total) * 100 |
| dist[label] = {'count': int(count), 'percentage': percentage} |
| logger.info(f"{label}: {count:,} ({percentage:.2f}%)") |
| return dist |
|
|
| def analyze_language_distribution(df): |
| """Analyze current language distribution""" |
| lang_dist = df['lang'].value_counts() |
| logger.info("\nCurrent Language Distribution:") |
| logger.info("-" * 50) |
| for lang, count in lang_dist.items(): |
| logger.info(f"{lang}: {count:,} comments ({count/len(df)*100:.2f}%)") |
| return lang_dist |
|
|
| def calculate_required_samples(df): |
| """Calculate how many English samples we need to generate""" |
| lang_counts = df['lang'].value_counts() |
| target_count = lang_counts.max() |
| en_count = lang_counts.get('en', 0) |
| required_samples = target_count - en_count |
| |
| logger.info(f"\nTarget count per language: {target_count:,}") |
| logger.info(f"Current English count: {en_count:,}") |
| logger.info(f"Required additional English samples: {required_samples:,}") |
| |
| return required_samples |
|
|
| def generate_balanced_samples(df, required_samples): |
| """Generate samples maintaining original class distribution ratios""" |
| logger.info("\nGenerating balanced samples...") |
| |
| |
| en_df = df[df['lang'] == 'en'] |
| labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
| |
| |
| target_counts = {} |
| for label in labels: |
| count = en_df[label].sum() |
| ratio = count / len(en_df) |
| target_count = int(ratio * required_samples) |
| target_counts[label] = target_count |
| logger.info(f"Target count for {label}: {target_count:,}") |
| |
| augmented_samples = [] |
| augmenter = ToxicAugmenter() |
| total_generated = 0 |
| |
| |
| for label, target_count in target_counts.items(): |
| if target_count == 0: |
| continue |
| |
| logger.info(f"\nGenerating {target_count:,} samples for {label}") |
| |
| |
| seed_texts = en_df[en_df[label] == 1]['comment_text'].tolist() |
| |
| if not seed_texts: |
| logger.warning(f"No seed texts found for {label}, skipping...") |
| continue |
| |
| |
| new_samples = augmenter.augment_dataset( |
| target_samples=target_count, |
| label=label, |
| seed_texts=seed_texts, |
| timeout_minutes=5 |
| ) |
| |
| if new_samples is not None and not new_samples.empty: |
| augmented_samples.append(new_samples) |
| total_generated += len(new_samples) |
| |
| |
| logger.info(f"✓ Generated {len(new_samples):,} samples") |
| logger.info(f"Progress: {total_generated:,}/{required_samples:,}") |
| |
| |
| if total_generated >= required_samples: |
| logger.info("Reached required sample count, stopping generation") |
| break |
| |
| |
| if augmented_samples: |
| augmented_df = pd.concat(augmented_samples, ignore_index=True) |
| augmented_df['lang'] = 'en' |
| |
| |
| if len(augmented_df) > required_samples: |
| logger.info(f"Trimming excess samples from {len(augmented_df):,} to {required_samples:,}") |
| augmented_df = augmented_df.head(required_samples) |
| |
| |
| logger.info("\nFinal class distribution in generated samples:") |
| for label in labels: |
| count = augmented_df[label].sum() |
| percentage = (count / len(augmented_df)) * 100 |
| logger.info(f"{label}: {count:,} ({percentage:.2f}%)") |
| |
| |
| clean_count = len(augmented_df[augmented_df[labels].sum(axis=1) == 0]) |
| clean_percentage = (clean_count / len(augmented_df)) * 100 |
| logger.info(f"Clean samples: {clean_count:,} ({clean_percentage:.2f}%)") |
| |
| return augmented_df |
| else: |
| raise Exception("Failed to generate any valid samples") |
|
|
| def balance_english_data(): |
| """Main function to balance English data with other languages""" |
| try: |
| |
| input_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv' |
| logger.info(f"Loading dataset from {input_file}") |
| df = pd.read_csv(input_file) |
| |
| |
| logger.info("\nAnalyzing current distribution...") |
| initial_dist = analyze_language_distribution(df) |
| initial_label_dist = analyze_label_distribution(df, 'en') |
| |
| |
| required_samples = calculate_required_samples(df) |
| |
| if required_samples <= 0: |
| logger.info("English data is already balanced. No augmentation needed.") |
| return |
| |
| |
| augmented_df = generate_balanced_samples(df, required_samples) |
| |
| |
| logger.info("\nMerging datasets...") |
| output_file = 'dataset/processed/MULTILINGUAL_TOXIC_DATASET_BALANCED.csv' |
| |
| |
| combined_df = pd.concat([df, augmented_df], ignore_index=True) |
| |
| |
| combined_df.to_csv(output_file, index=False) |
| logger.info(f"\nSaved balanced dataset to {output_file}") |
| |
| |
| logger.info("\nFinal distribution after balancing:") |
| final_dist = analyze_language_distribution(combined_df) |
| final_label_dist = analyze_label_distribution(combined_df, 'en') |
| |
| |
| stats = { |
| 'timestamp': timestamp, |
| 'initial_distribution': { |
| 'languages': initial_dist.to_dict(), |
| 'english_labels': initial_label_dist |
| }, |
| 'final_distribution': { |
| 'languages': final_dist.to_dict(), |
| 'english_labels': final_label_dist |
| }, |
| 'samples_generated': len(augmented_df), |
| 'total_samples': len(combined_df) |
| } |
| |
| stats_file = f'logs/balance_stats_{timestamp}.json' |
| with open(stats_file, 'w') as f: |
| json.dump(stats, f, indent=2) |
| logger.info(f"\nSaved balancing statistics to {stats_file}") |
| |
| except Exception as e: |
| logger.error(f"Error during balancing: {str(e)}") |
| raise |
|
|
| def main(): |
| balance_english_data() |
|
|
| if __name__ == "__main__": |
| logger.info("Starting English data balancing process...") |
| main() |