|
|
"""Data quality checks and validation"""
|
|
|
from typing import Dict, Any, Optional, List
|
|
|
import logging
|
|
|
import re
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class DataValidator:
|
|
|
"""Validate training data quality"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
min_length: int = 10,
|
|
|
max_length: int = 100000,
|
|
|
check_duplicates: bool = True,
|
|
|
check_special_chars: bool = True
|
|
|
):
|
|
|
self.min_length = min_length
|
|
|
self.max_length = max_length
|
|
|
self.check_duplicates = check_duplicates
|
|
|
self.check_special_chars = check_special_chars
|
|
|
|
|
|
self.stats = {
|
|
|
'total_samples': 0,
|
|
|
'filtered_too_short': 0,
|
|
|
'filtered_too_long': 0,
|
|
|
'filtered_invalid': 0,
|
|
|
'filtered_duplicate': 0,
|
|
|
'filtered_special_chars': 0,
|
|
|
'valid_samples': 0
|
|
|
}
|
|
|
|
|
|
self.seen_hashes = set() if check_duplicates else None
|
|
|
|
|
|
def validate_sample(self, sample: Dict[str, Any]) -> bool:
|
|
|
"""
|
|
|
Validate a single sample
|
|
|
|
|
|
Args:
|
|
|
sample: Dictionary containing sample data
|
|
|
|
|
|
Returns:
|
|
|
True if sample is valid, False otherwise
|
|
|
"""
|
|
|
self.stats['total_samples'] += 1
|
|
|
|
|
|
|
|
|
text = None
|
|
|
if 'text' in sample:
|
|
|
text = sample['text']
|
|
|
elif 'input_ids' in sample:
|
|
|
|
|
|
self.stats['valid_samples'] += 1
|
|
|
return True
|
|
|
else:
|
|
|
self.stats['filtered_invalid'] += 1
|
|
|
logger.debug("Sample missing 'text' or 'input_ids' field")
|
|
|
return False
|
|
|
|
|
|
|
|
|
if not isinstance(text, str):
|
|
|
self.stats['filtered_invalid'] += 1
|
|
|
logger.debug(f"Text is not string: {type(text)}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
text_len = len(text)
|
|
|
if text_len < self.min_length:
|
|
|
self.stats['filtered_too_short'] += 1
|
|
|
return False
|
|
|
|
|
|
if text_len > self.max_length:
|
|
|
self.stats['filtered_too_long'] += 1
|
|
|
logger.debug(f"Sample too long: {text_len} chars (max: {self.max_length})")
|
|
|
return False
|
|
|
|
|
|
|
|
|
if self.check_duplicates:
|
|
|
text_hash = hash(text)
|
|
|
if text_hash in self.seen_hashes:
|
|
|
self.stats['filtered_duplicate'] += 1
|
|
|
return False
|
|
|
self.seen_hashes.add(text_hash)
|
|
|
|
|
|
|
|
|
if self.check_special_chars:
|
|
|
if not self._check_special_chars(text):
|
|
|
self.stats['filtered_special_chars'] += 1
|
|
|
return False
|
|
|
|
|
|
self.stats['valid_samples'] += 1
|
|
|
return True
|
|
|
|
|
|
def _check_special_chars(self, text: str) -> bool:
|
|
|
"""Check if text has too many special characters"""
|
|
|
if not text:
|
|
|
return False
|
|
|
|
|
|
|
|
|
alphanumeric = sum(c.isalnum() or c.isspace() for c in text)
|
|
|
ratio = alphanumeric / len(text)
|
|
|
|
|
|
|
|
|
return ratio >= 0.5
|
|
|
|
|
|
def validate_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Validate a batch of samples
|
|
|
|
|
|
Args:
|
|
|
batch: List of sample dictionaries
|
|
|
|
|
|
Returns:
|
|
|
List of valid samples
|
|
|
"""
|
|
|
valid_samples = []
|
|
|
for sample in batch:
|
|
|
if self.validate_sample(sample):
|
|
|
valid_samples.append(sample)
|
|
|
return valid_samples
|
|
|
|
|
|
def print_stats(self):
|
|
|
"""Print validation statistics"""
|
|
|
logger.info("=" * 60)
|
|
|
logger.info("Data Validation Statistics")
|
|
|
logger.info("=" * 60)
|
|
|
logger.info(f"Total samples processed: {self.stats['total_samples']}")
|
|
|
logger.info(f"Valid samples: {self.stats['valid_samples']}")
|
|
|
logger.info(f"Filtered (too short): {self.stats['filtered_too_short']}")
|
|
|
logger.info(f"Filtered (too long): {self.stats['filtered_too_long']}")
|
|
|
logger.info(f"Filtered (invalid format): {self.stats['filtered_invalid']}")
|
|
|
|
|
|
if self.check_duplicates:
|
|
|
logger.info(f"Filtered (duplicates): {self.stats['filtered_duplicate']}")
|
|
|
|
|
|
if self.check_special_chars:
|
|
|
logger.info(f"Filtered (special chars): {self.stats['filtered_special_chars']}")
|
|
|
|
|
|
if self.stats['total_samples'] > 0:
|
|
|
valid_pct = 100 * self.stats['valid_samples'] / self.stats['total_samples']
|
|
|
logger.info(f"Validation rate: {valid_pct:.2f}%")
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
def get_stats(self) -> Dict[str, int]:
|
|
|
"""Get validation statistics"""
|
|
|
return self.stats.copy()
|
|
|
|
|
|
def reset_stats(self):
|
|
|
"""Reset validation statistics"""
|
|
|
for key in self.stats:
|
|
|
self.stats[key] = 0
|
|
|
if self.seen_hashes is not None:
|
|
|
self.seen_hashes.clear()
|
|
|
|
|
|
|
|
|
class TokenValidator:
|
|
|
"""Validate tokenized data"""
|
|
|
|
|
|
def __init__(self, vocab_size: int, pad_token_id: int = 0):
|
|
|
self.vocab_size = vocab_size
|
|
|
self.pad_token_id = pad_token_id
|
|
|
|
|
|
def validate_tokens(self, input_ids: List[int]) -> bool:
|
|
|
"""Validate token IDs are within vocabulary"""
|
|
|
if not input_ids:
|
|
|
return False
|
|
|
|
|
|
|
|
|
for token_id in input_ids:
|
|
|
if not (0 <= token_id < self.vocab_size):
|
|
|
logger.warning(f"Invalid token ID: {token_id} (vocab_size: {self.vocab_size})")
|
|
|
return False
|
|
|
|
|
|
|
|
|
if all(t == self.pad_token_id for t in input_ids):
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
def get_token_stats(self, input_ids: List[int]) -> Dict[str, Any]:
|
|
|
"""Get statistics about tokens"""
|
|
|
if not input_ids:
|
|
|
return {}
|
|
|
|
|
|
unique_tokens = len(set(input_ids))
|
|
|
pad_tokens = sum(1 for t in input_ids if t == self.pad_token_id)
|
|
|
|
|
|
return {
|
|
|
'total_tokens': len(input_ids),
|
|
|
'unique_tokens': unique_tokens,
|
|
|
'pad_tokens': pad_tokens,
|
|
|
'vocab_coverage': unique_tokens / self.vocab_size,
|
|
|
'pad_ratio': pad_tokens / len(input_ids)
|
|
|
}
|
|
|
|