Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""Data quality checks and validation"""
from typing import Dict, Any, Optional, List
import logging
import re
logger = logging.getLogger(__name__)
class DataValidator:
"""Validate training data quality"""
def __init__(
self,
min_length: int = 10,
max_length: int = 100000,
check_duplicates: bool = True,
check_special_chars: bool = True
):
self.min_length = min_length
self.max_length = max_length
self.check_duplicates = check_duplicates
self.check_special_chars = check_special_chars
self.stats = {
'total_samples': 0,
'filtered_too_short': 0,
'filtered_too_long': 0,
'filtered_invalid': 0,
'filtered_duplicate': 0,
'filtered_special_chars': 0,
'valid_samples': 0
}
self.seen_hashes = set() if check_duplicates else None
def validate_sample(self, sample: Dict[str, Any]) -> bool:
"""
Validate a single sample
Args:
sample: Dictionary containing sample data
Returns:
True if sample is valid, False otherwise
"""
self.stats['total_samples'] += 1
# Check required fields
text = None
if 'text' in sample:
text = sample['text']
elif 'input_ids' in sample:
# Already tokenized, assume valid
self.stats['valid_samples'] += 1
return True
else:
self.stats['filtered_invalid'] += 1
logger.debug("Sample missing 'text' or 'input_ids' field")
return False
# Check if text is string
if not isinstance(text, str):
self.stats['filtered_invalid'] += 1
logger.debug(f"Text is not string: {type(text)}")
return False
# Check length
text_len = len(text)
if text_len < self.min_length:
self.stats['filtered_too_short'] += 1
return False
if text_len > self.max_length:
self.stats['filtered_too_long'] += 1
logger.debug(f"Sample too long: {text_len} chars (max: {self.max_length})")
return False
# Check for duplicates
if self.check_duplicates:
text_hash = hash(text)
if text_hash in self.seen_hashes:
self.stats['filtered_duplicate'] += 1
return False
self.seen_hashes.add(text_hash)
# Check special characters ratio
if self.check_special_chars:
if not self._check_special_chars(text):
self.stats['filtered_special_chars'] += 1
return False
self.stats['valid_samples'] += 1
return True
def _check_special_chars(self, text: str) -> bool:
"""Check if text has too many special characters"""
if not text:
return False
# Count alphanumeric characters
alphanumeric = sum(c.isalnum() or c.isspace() for c in text)
ratio = alphanumeric / len(text)
# Text should be at least 50% alphanumeric + spaces
return ratio >= 0.5
def validate_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Validate a batch of samples
Args:
batch: List of sample dictionaries
Returns:
List of valid samples
"""
valid_samples = []
for sample in batch:
if self.validate_sample(sample):
valid_samples.append(sample)
return valid_samples
def print_stats(self):
"""Print validation statistics"""
logger.info("=" * 60)
logger.info("Data Validation Statistics")
logger.info("=" * 60)
logger.info(f"Total samples processed: {self.stats['total_samples']}")
logger.info(f"Valid samples: {self.stats['valid_samples']}")
logger.info(f"Filtered (too short): {self.stats['filtered_too_short']}")
logger.info(f"Filtered (too long): {self.stats['filtered_too_long']}")
logger.info(f"Filtered (invalid format): {self.stats['filtered_invalid']}")
if self.check_duplicates:
logger.info(f"Filtered (duplicates): {self.stats['filtered_duplicate']}")
if self.check_special_chars:
logger.info(f"Filtered (special chars): {self.stats['filtered_special_chars']}")
if self.stats['total_samples'] > 0:
valid_pct = 100 * self.stats['valid_samples'] / self.stats['total_samples']
logger.info(f"Validation rate: {valid_pct:.2f}%")
logger.info("=" * 60)
def get_stats(self) -> Dict[str, int]:
"""Get validation statistics"""
return self.stats.copy()
def reset_stats(self):
"""Reset validation statistics"""
for key in self.stats:
self.stats[key] = 0
if self.seen_hashes is not None:
self.seen_hashes.clear()
class TokenValidator:
"""Validate tokenized data"""
def __init__(self, vocab_size: int, pad_token_id: int = 0):
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
def validate_tokens(self, input_ids: List[int]) -> bool:
"""Validate token IDs are within vocabulary"""
if not input_ids:
return False
# Check all tokens are in vocabulary
for token_id in input_ids:
if not (0 <= token_id < self.vocab_size):
logger.warning(f"Invalid token ID: {token_id} (vocab_size: {self.vocab_size})")
return False
# Check not all padding
if all(t == self.pad_token_id for t in input_ids):
return False
return True
def get_token_stats(self, input_ids: List[int]) -> Dict[str, Any]:
"""Get statistics about tokens"""
if not input_ids:
return {}
unique_tokens = len(set(input_ids))
pad_tokens = sum(1 for t in input_ids if t == self.pad_token_id)
return {
'total_tokens': len(input_ids),
'unique_tokens': unique_tokens,
'pad_tokens': pad_tokens,
'vocab_coverage': unique_tokens / self.vocab_size,
'pad_ratio': pad_tokens / len(input_ids)
}