File size: 6,738 Bytes
54c5666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""Data quality checks and validation"""
from typing import Dict, Any, Optional, List
import logging
import re

logger = logging.getLogger(__name__)


class DataValidator:
    """Validate training data quality"""
    
    def __init__(

        self,

        min_length: int = 10,

        max_length: int = 100000,

        check_duplicates: bool = True,

        check_special_chars: bool = True

    ):
        self.min_length = min_length
        self.max_length = max_length
        self.check_duplicates = check_duplicates
        self.check_special_chars = check_special_chars
        
        self.stats = {
            'total_samples': 0,
            'filtered_too_short': 0,
            'filtered_too_long': 0,
            'filtered_invalid': 0,
            'filtered_duplicate': 0,
            'filtered_special_chars': 0,
            'valid_samples': 0
        }
        
        self.seen_hashes = set() if check_duplicates else None
    
    def validate_sample(self, sample: Dict[str, Any]) -> bool:
        """

        Validate a single sample

        

        Args:

            sample: Dictionary containing sample data

        

        Returns:

            True if sample is valid, False otherwise

        """
        self.stats['total_samples'] += 1
        
        # Check required fields
        text = None
        if 'text' in sample:
            text = sample['text']
        elif 'input_ids' in sample:
            # Already tokenized, assume valid
            self.stats['valid_samples'] += 1
            return True
        else:
            self.stats['filtered_invalid'] += 1
            logger.debug("Sample missing 'text' or 'input_ids' field")
            return False
        
        # Check if text is string
        if not isinstance(text, str):
            self.stats['filtered_invalid'] += 1
            logger.debug(f"Text is not string: {type(text)}")
            return False
        
        # Check length
        text_len = len(text)
        if text_len < self.min_length:
            self.stats['filtered_too_short'] += 1
            return False
        
        if text_len > self.max_length:
            self.stats['filtered_too_long'] += 1
            logger.debug(f"Sample too long: {text_len} chars (max: {self.max_length})")
            return False
        
        # Check for duplicates
        if self.check_duplicates:
            text_hash = hash(text)
            if text_hash in self.seen_hashes:
                self.stats['filtered_duplicate'] += 1
                return False
            self.seen_hashes.add(text_hash)
        
        # Check special characters ratio
        if self.check_special_chars:
            if not self._check_special_chars(text):
                self.stats['filtered_special_chars'] += 1
                return False
        
        self.stats['valid_samples'] += 1
        return True
    
    def _check_special_chars(self, text: str) -> bool:
        """Check if text has too many special characters"""
        if not text:
            return False
        
        # Count alphanumeric characters
        alphanumeric = sum(c.isalnum() or c.isspace() for c in text)
        ratio = alphanumeric / len(text)
        
        # Text should be at least 50% alphanumeric + spaces
        return ratio >= 0.5
    
    def validate_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """

        Validate a batch of samples

        

        Args:

            batch: List of sample dictionaries

        

        Returns:

            List of valid samples

        """
        valid_samples = []
        for sample in batch:
            if self.validate_sample(sample):
                valid_samples.append(sample)
        return valid_samples
    
    def print_stats(self):
        """Print validation statistics"""
        logger.info("=" * 60)
        logger.info("Data Validation Statistics")
        logger.info("=" * 60)
        logger.info(f"Total samples processed: {self.stats['total_samples']}")
        logger.info(f"Valid samples: {self.stats['valid_samples']}")
        logger.info(f"Filtered (too short): {self.stats['filtered_too_short']}")
        logger.info(f"Filtered (too long): {self.stats['filtered_too_long']}")
        logger.info(f"Filtered (invalid format): {self.stats['filtered_invalid']}")
        
        if self.check_duplicates:
            logger.info(f"Filtered (duplicates): {self.stats['filtered_duplicate']}")
        
        if self.check_special_chars:
            logger.info(f"Filtered (special chars): {self.stats['filtered_special_chars']}")
        
        if self.stats['total_samples'] > 0:
            valid_pct = 100 * self.stats['valid_samples'] / self.stats['total_samples']
            logger.info(f"Validation rate: {valid_pct:.2f}%")
        
        logger.info("=" * 60)
    
    def get_stats(self) -> Dict[str, int]:
        """Get validation statistics"""
        return self.stats.copy()
    
    def reset_stats(self):
        """Reset validation statistics"""
        for key in self.stats:
            self.stats[key] = 0
        if self.seen_hashes is not None:
            self.seen_hashes.clear()


class TokenValidator:
    """Validate tokenized data"""
    
    def __init__(self, vocab_size: int, pad_token_id: int = 0):
        self.vocab_size = vocab_size
        self.pad_token_id = pad_token_id
    
    def validate_tokens(self, input_ids: List[int]) -> bool:
        """Validate token IDs are within vocabulary"""
        if not input_ids:
            return False
        
        # Check all tokens are in vocabulary
        for token_id in input_ids:
            if not (0 <= token_id < self.vocab_size):
                logger.warning(f"Invalid token ID: {token_id} (vocab_size: {self.vocab_size})")
                return False
        
        # Check not all padding
        if all(t == self.pad_token_id for t in input_ids):
            return False
        
        return True
    
    def get_token_stats(self, input_ids: List[int]) -> Dict[str, Any]:
        """Get statistics about tokens"""
        if not input_ids:
            return {}
        
        unique_tokens = len(set(input_ids))
        pad_tokens = sum(1 for t in input_ids if t == self.pad_token_id)
        
        return {
            'total_tokens': len(input_ids),
            'unique_tokens': unique_tokens,
            'pad_tokens': pad_tokens,
            'vocab_coverage': unique_tokens / self.vocab_size,
            'pad_ratio': pad_tokens / len(input_ids)
        }