File size: 10,062 Bytes
d4398e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""

Preprocessing Pipeline Runner

================================

Central pipeline that runs all enabled preprocessing stages

sequentially and logs each step.

"""

from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple
import time
import pandas as pd

from preprocessing.text_cleaning import TextCleaningConfig, apply_text_cleaning
from preprocessing.tokenization import (
    TokenizationConfig, get_tokenizer, compute_token_stats,
    truncate_samples, split_long_samples,
)
from preprocessing.system_prompt import SystemPromptConfig
from preprocessing.dataset_balancing import BalancingConfig, balance_dataset
from preprocessing.quality_filters import QualityFilterConfig, apply_quality_filters
from preprocessing.deduplication import DeduplicationConfig, apply_deduplication
from preprocessing.train_val_split import SplitConfig, split_dataset
from preprocessing.output_formatter import OutputFormatConfig, format_dataset, export_jsonl
from preprocessing.pii_filter import PIIFilterConfig, apply_pii_filter_df
from preprocessing.augmentation import AugmentationConfig, augment_dataset


@dataclass
class PreprocessingConfig:
    """Master configuration for the entire preprocessing pipeline."""
    # Column mappings
    instruction_col: str = ""
    output_col: str = ""
    input_col: Optional[str] = None
    label_col: Optional[str] = None

    # Sub-configs
    text_cleaning: TextCleaningConfig = field(default_factory=TextCleaningConfig)
    tokenization: TokenizationConfig = field(default_factory=TokenizationConfig)
    system_prompt: SystemPromptConfig = field(default_factory=SystemPromptConfig)
    balancing: BalancingConfig = field(default_factory=BalancingConfig)
    quality_filters: QualityFilterConfig = field(default_factory=QualityFilterConfig)
    deduplication: DeduplicationConfig = field(default_factory=DeduplicationConfig)
    split: SplitConfig = field(default_factory=SplitConfig)
    output_format: OutputFormatConfig = field(default_factory=OutputFormatConfig)
    pii_filter: PIIFilterConfig = field(default_factory=PIIFilterConfig)
    augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)


@dataclass
class PipelineLog:
    """A single log entry from a pipeline stage."""
    stage: str
    description: str
    rows_before: int
    rows_after: int
    duration_ms: float

    @property
    def rows_delta(self) -> int:
        return self.rows_after - self.rows_before


class PreprocessingPipeline:
    """

    Sequential preprocessing pipeline runner.

    Applies all enabled stages and collects logs.

    """

    def __init__(self, config: PreprocessingConfig):
        self.config = config
        self.logs: List[PipelineLog] = []

    def _log(self, stage: str, desc: str, before: int, after: int, elapsed: float):
        self.logs.append(PipelineLog(
            stage=stage,
            description=desc,
            rows_before=before,
            rows_after=after,
            duration_ms=round(elapsed * 1000, 1),
        ))

    def run(

        self,

        df: pd.DataFrame,

        progress_callback=None,

    ) -> Tuple[pd.DataFrame, pd.DataFrame, List[PipelineLog]]:
        """

        Run the complete preprocessing pipeline.



        Args:

            df: Input DataFrame

            progress_callback: Optional callable(stage_name, progress_pct) for UI updates



        Returns:

            (train_df, val_df, logs)

            If split is disabled, val_df will be empty.

        """
        self.logs = []
        total_stages = 7  # text cleaning, quality, dedup, pii, balancing, augmentation, tokenization
        current_stage = 0

        def _progress(name):
            nonlocal current_stage
            current_stage += 1
            if progress_callback:
                pct = int((current_stage / total_stages) * 100)
                progress_callback(name, pct)

        cfg = self.config
        text_cols = [c for c in [cfg.instruction_col, cfg.output_col, cfg.input_col] if c and c in df.columns]

        # ── Stage 1: Text Cleaning ──
        t0 = time.time()
        before = len(df)
        any_cleaning = (
            cfg.text_cleaning.remove_html or cfg.text_cleaning.remove_urls or
            cfg.text_cleaning.remove_emojis or cfg.text_cleaning.normalize_whitespace or
            cfg.text_cleaning.lowercase or cfg.text_cleaning.remove_special_chars or
            cfg.text_cleaning.strip_extra_linebreaks
        )
        if any_cleaning:
            df = apply_text_cleaning(df, text_cols, cfg.text_cleaning)
        self._log("Text Cleaning", "Applied text cleaning filters", before, len(df), time.time() - t0)
        _progress("Text Cleaning")

        # ── Stage 2: Quality Filters ──
        t0 = time.time()
        before = len(df)
        has_quality = (
            cfg.quality_filters.min_word_count > 0 or
            cfg.quality_filters.max_word_count > 0 or
            cfg.quality_filters.profanity_filter or
            cfg.quality_filters.language_filter or
            cfg.quality_filters.remove_low_quality
        )
        if has_quality and cfg.output_col:
            df = apply_quality_filters(df, cfg.output_col, cfg.quality_filters)
        self._log("Quality Filters", "Applied quality filters", before, len(df), time.time() - t0)
        _progress("Quality Filters")

        # ── Stage 3: Deduplication ──
        t0 = time.time()
        before = len(df)
        if cfg.instruction_col and (cfg.deduplication.remove_exact or cfg.deduplication.remove_semantic):
            df = apply_deduplication(df, cfg.instruction_col, cfg.deduplication)
        self._log("Deduplication", "Removed duplicate samples", before, len(df), time.time() - t0)
        _progress("Deduplication")

        # ── Stage 4: PII Filtering ──
        t0 = time.time()
        before = len(df)
        has_pii = (
            cfg.pii_filter.filter_emails or cfg.pii_filter.filter_phones or
            cfg.pii_filter.filter_id_numbers or cfg.pii_filter.filter_api_keys or
            cfg.pii_filter.filter_addresses
        )
        if has_pii:
            df = apply_pii_filter_df(df, text_cols, cfg.pii_filter)
        self._log("PII Filtering", "Masked PII data", before, len(df), time.time() - t0)
        _progress("PII Filtering")

        # ── Stage 5: Dataset Balancing ──
        t0 = time.time()
        before = len(df)
        if cfg.balancing.enabled and cfg.balancing.label_column and cfg.balancing.strategy != "none":
            df = balance_dataset(df, cfg.balancing.label_column, cfg.balancing.strategy)
        self._log("Balancing", "Balanced dataset classes", before, len(df), time.time() - t0)
        _progress("Balancing")

        # ── Stage 6: Augmentation ──
        t0 = time.time()
        before = len(df)
        if cfg.augmentation.enabled and cfg.instruction_col:
            df = augment_dataset(df, cfg.instruction_col, cfg.augmentation)
        self._log("Augmentation", "Generated augmented samples", before, len(df), time.time() - t0)
        _progress("Augmentation")

        # ── Stage 7: Tokenization Controls ──
        t0 = time.time()
        before = len(df)
        if cfg.tokenization.truncate_long or cfg.tokenization.split_long:
            try:
                tokenizer = get_tokenizer(cfg.tokenization)
                is_tiktoken = cfg.tokenization.tokenizer_name == "tiktoken"

                for col in text_cols:
                    if cfg.tokenization.split_long:
                        df = split_long_samples(
                            df, col, cfg.tokenization.max_total_tokens,
                            tokenizer, is_tiktoken, cfg.tokenization.split_overlap,
                        )
                    elif cfg.tokenization.truncate_long:
                        df = truncate_samples(
                            df, col, cfg.tokenization.max_total_tokens,
                            tokenizer, is_tiktoken,
                        )
            except ImportError:
                pass  # tokenizer not available
        self._log("Tokenization", "Applied tokenization controls", before, len(df), time.time() - t0)
        _progress("Tokenization")

        # ── Split ──
        train_df, val_df = split_dataset(df, cfg.split)

        return train_df, val_df, self.logs


def get_safe_preset() -> PreprocessingConfig:
    """Return a sensible 'safe preset' configuration for common use cases."""
    return PreprocessingConfig(
        text_cleaning=TextCleaningConfig(
            remove_html=True,
            remove_urls=True,
            remove_emojis=False,
            normalize_whitespace=True,
            lowercase=False,
            remove_special_chars=False,
            strip_extra_linebreaks=True,
        ),
        quality_filters=QualityFilterConfig(
            min_word_count=3,
            max_word_count=0,
            profanity_filter=False,
            language_filter=False,
            remove_low_quality=True,
            min_quality_length=20,
        ),
        deduplication=DeduplicationConfig(
            remove_exact=True,
            remove_semantic=False,
        ),
        pii_filter=PIIFilterConfig(
            filter_emails=True,
            filter_phones=True,
            filter_id_numbers=True,
            filter_api_keys=True,
            filter_addresses=False,
        ),
        split=SplitConfig(
            enabled=True,
            train_ratio=0.9,
            random_seed=42,
            shuffle=True,
        ),
        output_format=OutputFormatConfig(
            format_type="openai_chat",
        ),
        system_prompt=SystemPromptConfig(
            system_prompt="You are a helpful AI assistant.",
            prepend_to_all=True,
        ),
    )