Spaces:
Configuration error
Configuration error
| """ | |
| Preprocessing Pipeline Runner | |
| ================================ | |
| Central pipeline that runs all enabled preprocessing stages | |
| sequentially and logs each step. | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import time | |
| import pandas as pd | |
| from preprocessing.text_cleaning import TextCleaningConfig, apply_text_cleaning | |
| from preprocessing.tokenization import ( | |
| TokenizationConfig, get_tokenizer, compute_token_stats, | |
| truncate_samples, split_long_samples, | |
| ) | |
| from preprocessing.system_prompt import SystemPromptConfig | |
| from preprocessing.dataset_balancing import BalancingConfig, balance_dataset | |
| from preprocessing.quality_filters import QualityFilterConfig, apply_quality_filters | |
| from preprocessing.deduplication import DeduplicationConfig, apply_deduplication | |
| from preprocessing.train_val_split import SplitConfig, split_dataset | |
| from preprocessing.output_formatter import OutputFormatConfig, format_dataset, export_jsonl | |
| from preprocessing.pii_filter import PIIFilterConfig, apply_pii_filter_df | |
| from preprocessing.augmentation import AugmentationConfig, augment_dataset | |
| class PreprocessingConfig: | |
| """Master configuration for the entire preprocessing pipeline.""" | |
| # Column mappings | |
| instruction_col: str = "" | |
| output_col: str = "" | |
| input_col: Optional[str] = None | |
| label_col: Optional[str] = None | |
| # Sub-configs | |
| text_cleaning: TextCleaningConfig = field(default_factory=TextCleaningConfig) | |
| tokenization: TokenizationConfig = field(default_factory=TokenizationConfig) | |
| system_prompt: SystemPromptConfig = field(default_factory=SystemPromptConfig) | |
| balancing: BalancingConfig = field(default_factory=BalancingConfig) | |
| quality_filters: QualityFilterConfig = field(default_factory=QualityFilterConfig) | |
| deduplication: DeduplicationConfig = field(default_factory=DeduplicationConfig) | |
| split: SplitConfig = field(default_factory=SplitConfig) | |
| output_format: OutputFormatConfig = field(default_factory=OutputFormatConfig) | |
| pii_filter: PIIFilterConfig = field(default_factory=PIIFilterConfig) | |
| augmentation: AugmentationConfig = field(default_factory=AugmentationConfig) | |
| class PipelineLog: | |
| """A single log entry from a pipeline stage.""" | |
| stage: str | |
| description: str | |
| rows_before: int | |
| rows_after: int | |
| duration_ms: float | |
| def rows_delta(self) -> int: | |
| return self.rows_after - self.rows_before | |
| class PreprocessingPipeline: | |
| """ | |
| Sequential preprocessing pipeline runner. | |
| Applies all enabled stages and collects logs. | |
| """ | |
| def __init__(self, config: PreprocessingConfig): | |
| self.config = config | |
| self.logs: List[PipelineLog] = [] | |
| def _log(self, stage: str, desc: str, before: int, after: int, elapsed: float): | |
| self.logs.append(PipelineLog( | |
| stage=stage, | |
| description=desc, | |
| rows_before=before, | |
| rows_after=after, | |
| duration_ms=round(elapsed * 1000, 1), | |
| )) | |
| def run( | |
| self, | |
| df: pd.DataFrame, | |
| progress_callback=None, | |
| ) -> Tuple[pd.DataFrame, pd.DataFrame, List[PipelineLog]]: | |
| """ | |
| Run the complete preprocessing pipeline. | |
| Args: | |
| df: Input DataFrame | |
| progress_callback: Optional callable(stage_name, progress_pct) for UI updates | |
| Returns: | |
| (train_df, val_df, logs) | |
| If split is disabled, val_df will be empty. | |
| """ | |
| self.logs = [] | |
| total_stages = 7 # text cleaning, quality, dedup, pii, balancing, augmentation, tokenization | |
| current_stage = 0 | |
| def _progress(name): | |
| nonlocal current_stage | |
| current_stage += 1 | |
| if progress_callback: | |
| pct = int((current_stage / total_stages) * 100) | |
| progress_callback(name, pct) | |
| cfg = self.config | |
| text_cols = [c for c in [cfg.instruction_col, cfg.output_col, cfg.input_col] if c and c in df.columns] | |
| # ββ Stage 1: Text Cleaning ββ | |
| t0 = time.time() | |
| before = len(df) | |
| any_cleaning = ( | |
| cfg.text_cleaning.remove_html or cfg.text_cleaning.remove_urls or | |
| cfg.text_cleaning.remove_emojis or cfg.text_cleaning.normalize_whitespace or | |
| cfg.text_cleaning.lowercase or cfg.text_cleaning.remove_special_chars or | |
| cfg.text_cleaning.strip_extra_linebreaks | |
| ) | |
| if any_cleaning: | |
| df = apply_text_cleaning(df, text_cols, cfg.text_cleaning) | |
| self._log("Text Cleaning", "Applied text cleaning filters", before, len(df), time.time() - t0) | |
| _progress("Text Cleaning") | |
| # ββ Stage 2: Quality Filters ββ | |
| t0 = time.time() | |
| before = len(df) | |
| has_quality = ( | |
| cfg.quality_filters.min_word_count > 0 or | |
| cfg.quality_filters.max_word_count > 0 or | |
| cfg.quality_filters.profanity_filter or | |
| cfg.quality_filters.language_filter or | |
| cfg.quality_filters.remove_low_quality | |
| ) | |
| if has_quality and cfg.output_col: | |
| df = apply_quality_filters(df, cfg.output_col, cfg.quality_filters) | |
| self._log("Quality Filters", "Applied quality filters", before, len(df), time.time() - t0) | |
| _progress("Quality Filters") | |
| # ββ Stage 3: Deduplication ββ | |
| t0 = time.time() | |
| before = len(df) | |
| if cfg.instruction_col and (cfg.deduplication.remove_exact or cfg.deduplication.remove_semantic): | |
| df = apply_deduplication(df, cfg.instruction_col, cfg.deduplication) | |
| self._log("Deduplication", "Removed duplicate samples", before, len(df), time.time() - t0) | |
| _progress("Deduplication") | |
| # ββ Stage 4: PII Filtering ββ | |
| t0 = time.time() | |
| before = len(df) | |
| has_pii = ( | |
| cfg.pii_filter.filter_emails or cfg.pii_filter.filter_phones or | |
| cfg.pii_filter.filter_id_numbers or cfg.pii_filter.filter_api_keys or | |
| cfg.pii_filter.filter_addresses | |
| ) | |
| if has_pii: | |
| df = apply_pii_filter_df(df, text_cols, cfg.pii_filter) | |
| self._log("PII Filtering", "Masked PII data", before, len(df), time.time() - t0) | |
| _progress("PII Filtering") | |
| # ββ Stage 5: Dataset Balancing ββ | |
| t0 = time.time() | |
| before = len(df) | |
| if cfg.balancing.enabled and cfg.balancing.label_column and cfg.balancing.strategy != "none": | |
| df = balance_dataset(df, cfg.balancing.label_column, cfg.balancing.strategy) | |
| self._log("Balancing", "Balanced dataset classes", before, len(df), time.time() - t0) | |
| _progress("Balancing") | |
| # ββ Stage 6: Augmentation ββ | |
| t0 = time.time() | |
| before = len(df) | |
| if cfg.augmentation.enabled and cfg.instruction_col: | |
| df = augment_dataset(df, cfg.instruction_col, cfg.augmentation) | |
| self._log("Augmentation", "Generated augmented samples", before, len(df), time.time() - t0) | |
| _progress("Augmentation") | |
| # ββ Stage 7: Tokenization Controls ββ | |
| t0 = time.time() | |
| before = len(df) | |
| if cfg.tokenization.truncate_long or cfg.tokenization.split_long: | |
| try: | |
| tokenizer = get_tokenizer(cfg.tokenization) | |
| is_tiktoken = cfg.tokenization.tokenizer_name == "tiktoken" | |
| for col in text_cols: | |
| if cfg.tokenization.split_long: | |
| df = split_long_samples( | |
| df, col, cfg.tokenization.max_total_tokens, | |
| tokenizer, is_tiktoken, cfg.tokenization.split_overlap, | |
| ) | |
| elif cfg.tokenization.truncate_long: | |
| df = truncate_samples( | |
| df, col, cfg.tokenization.max_total_tokens, | |
| tokenizer, is_tiktoken, | |
| ) | |
| except ImportError: | |
| pass # tokenizer not available | |
| self._log("Tokenization", "Applied tokenization controls", before, len(df), time.time() - t0) | |
| _progress("Tokenization") | |
| # ββ Split ββ | |
| train_df, val_df = split_dataset(df, cfg.split) | |
| return train_df, val_df, self.logs | |
| def get_safe_preset() -> PreprocessingConfig: | |
| """Return a sensible 'safe preset' configuration for common use cases.""" | |
| return PreprocessingConfig( | |
| text_cleaning=TextCleaningConfig( | |
| remove_html=True, | |
| remove_urls=True, | |
| remove_emojis=False, | |
| normalize_whitespace=True, | |
| lowercase=False, | |
| remove_special_chars=False, | |
| strip_extra_linebreaks=True, | |
| ), | |
| quality_filters=QualityFilterConfig( | |
| min_word_count=3, | |
| max_word_count=0, | |
| profanity_filter=False, | |
| language_filter=False, | |
| remove_low_quality=True, | |
| min_quality_length=20, | |
| ), | |
| deduplication=DeduplicationConfig( | |
| remove_exact=True, | |
| remove_semantic=False, | |
| ), | |
| pii_filter=PIIFilterConfig( | |
| filter_emails=True, | |
| filter_phones=True, | |
| filter_id_numbers=True, | |
| filter_api_keys=True, | |
| filter_addresses=False, | |
| ), | |
| split=SplitConfig( | |
| enabled=True, | |
| train_ratio=0.9, | |
| random_seed=42, | |
| shuffle=True, | |
| ), | |
| output_format=OutputFormatConfig( | |
| format_type="openai_chat", | |
| ), | |
| system_prompt=SystemPromptConfig( | |
| system_prompt="You are a helpful AI assistant.", | |
| prepend_to_all=True, | |
| ), | |
| ) | |