File size: 6,738 Bytes
54c5666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
"""Data quality checks and validation"""
from typing import Dict, Any, Optional, List
import logging
import re
logger = logging.getLogger(__name__)
class DataValidator:
"""Validate training data quality"""
def __init__(
self,
min_length: int = 10,
max_length: int = 100000,
check_duplicates: bool = True,
check_special_chars: bool = True
):
self.min_length = min_length
self.max_length = max_length
self.check_duplicates = check_duplicates
self.check_special_chars = check_special_chars
self.stats = {
'total_samples': 0,
'filtered_too_short': 0,
'filtered_too_long': 0,
'filtered_invalid': 0,
'filtered_duplicate': 0,
'filtered_special_chars': 0,
'valid_samples': 0
}
self.seen_hashes = set() if check_duplicates else None
def validate_sample(self, sample: Dict[str, Any]) -> bool:
"""
Validate a single sample
Args:
sample: Dictionary containing sample data
Returns:
True if sample is valid, False otherwise
"""
self.stats['total_samples'] += 1
# Check required fields
text = None
if 'text' in sample:
text = sample['text']
elif 'input_ids' in sample:
# Already tokenized, assume valid
self.stats['valid_samples'] += 1
return True
else:
self.stats['filtered_invalid'] += 1
logger.debug("Sample missing 'text' or 'input_ids' field")
return False
# Check if text is string
if not isinstance(text, str):
self.stats['filtered_invalid'] += 1
logger.debug(f"Text is not string: {type(text)}")
return False
# Check length
text_len = len(text)
if text_len < self.min_length:
self.stats['filtered_too_short'] += 1
return False
if text_len > self.max_length:
self.stats['filtered_too_long'] += 1
logger.debug(f"Sample too long: {text_len} chars (max: {self.max_length})")
return False
# Check for duplicates
if self.check_duplicates:
text_hash = hash(text)
if text_hash in self.seen_hashes:
self.stats['filtered_duplicate'] += 1
return False
self.seen_hashes.add(text_hash)
# Check special characters ratio
if self.check_special_chars:
if not self._check_special_chars(text):
self.stats['filtered_special_chars'] += 1
return False
self.stats['valid_samples'] += 1
return True
def _check_special_chars(self, text: str) -> bool:
"""Check if text has too many special characters"""
if not text:
return False
# Count alphanumeric characters
alphanumeric = sum(c.isalnum() or c.isspace() for c in text)
ratio = alphanumeric / len(text)
# Text should be at least 50% alphanumeric + spaces
return ratio >= 0.5
def validate_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Validate a batch of samples
Args:
batch: List of sample dictionaries
Returns:
List of valid samples
"""
valid_samples = []
for sample in batch:
if self.validate_sample(sample):
valid_samples.append(sample)
return valid_samples
def print_stats(self):
"""Print validation statistics"""
logger.info("=" * 60)
logger.info("Data Validation Statistics")
logger.info("=" * 60)
logger.info(f"Total samples processed: {self.stats['total_samples']}")
logger.info(f"Valid samples: {self.stats['valid_samples']}")
logger.info(f"Filtered (too short): {self.stats['filtered_too_short']}")
logger.info(f"Filtered (too long): {self.stats['filtered_too_long']}")
logger.info(f"Filtered (invalid format): {self.stats['filtered_invalid']}")
if self.check_duplicates:
logger.info(f"Filtered (duplicates): {self.stats['filtered_duplicate']}")
if self.check_special_chars:
logger.info(f"Filtered (special chars): {self.stats['filtered_special_chars']}")
if self.stats['total_samples'] > 0:
valid_pct = 100 * self.stats['valid_samples'] / self.stats['total_samples']
logger.info(f"Validation rate: {valid_pct:.2f}%")
logger.info("=" * 60)
def get_stats(self) -> Dict[str, int]:
"""Get validation statistics"""
return self.stats.copy()
def reset_stats(self):
"""Reset validation statistics"""
for key in self.stats:
self.stats[key] = 0
if self.seen_hashes is not None:
self.seen_hashes.clear()
class TokenValidator:
"""Validate tokenized data"""
def __init__(self, vocab_size: int, pad_token_id: int = 0):
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
def validate_tokens(self, input_ids: List[int]) -> bool:
"""Validate token IDs are within vocabulary"""
if not input_ids:
return False
# Check all tokens are in vocabulary
for token_id in input_ids:
if not (0 <= token_id < self.vocab_size):
logger.warning(f"Invalid token ID: {token_id} (vocab_size: {self.vocab_size})")
return False
# Check not all padding
if all(t == self.pad_token_id for t in input_ids):
return False
return True
def get_token_stats(self, input_ids: List[int]) -> Dict[str, Any]:
"""Get statistics about tokens"""
if not input_ids:
return {}
unique_tokens = len(set(input_ids))
pad_tokens = sum(1 for t in input_ids if t == self.pad_token_id)
return {
'total_tokens': len(input_ids),
'unique_tokens': unique_tokens,
'pad_tokens': pad_tokens,
'vocab_coverage': unique_tokens / self.vocab_size,
'pad_ratio': pad_tokens / len(input_ids)
}
|