sem-v6-training / src /sem_v6 /validation /validators.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
"""
Validator classes for text generation quality assessment.
Provides FastValidator (heuristics), GrammarValidator (LanguageTool),
and KnowledgeValidator (factual accuracy) with security hardening and
performance optimizations.
"""
__all__ = [
"FastValidator",
"GrammarValidator",
"KnowledgeValidator",
"LanguageValidator",
"PerplexityValidator",
"Validator",
"FastValidationResult",
"GrammarValidationResult",
"KnowledgeValidationResult",
"LanguageValidationResult",
"PerplexityValidationResult",
]
import time
import asyncio
import logging
from typing import Any, Protocol, TypedDict, TYPE_CHECKING
from collections import deque, Counter
from dataclasses import dataclass, field
import torch
# Import GrammarResult for type compatibility
if TYPE_CHECKING:
from .grammar_checker import GrammarResult
else:
try:
from .grammar_checker import GrammarResult
except ImportError:
# Fallback if grammar_checker not available
@dataclass
class GrammarResult:
grammar_score: float
num_errors: int
errors: list[dict] = field(default_factory=list)
suggestions: list[list[str]] = field(default_factory=list)
is_fallback: bool = False
# Import unified sanitization
from .sanitizer import sanitize
# Import validation constants
from .constants import (
MIN_ASCII_RATIO,
MAX_REPETITION_RATIO,
MIN_SAMPLE_LENGTH,
VALIDATION_MAX_LENGTH,
VALIDATION_TEMPERATURE,
KNOWLEDGE_MAX_LENGTH,
KNOWLEDGE_TEMPERATURE,
SAMPLE_HISTORY_SIZE,
GRAMMAR_HISTORY_SIZE,
TIMESTAMP_HISTORY_SIZE,
TREND_ANALYSIS_WINDOW,
NGRAM_SIZE,
MIN_NGRAM_TEXT_LENGTH,
FALLBACK_REPETITION_SCORE,
FALLBACK_GRAMMAR_SCORE,
FALLBACK_ERROR_COUNT,
ERROR_LOG_TRUNCATE_LENGTH,
)
logger = logging.getLogger(__name__)
class FastValidationResult(TypedDict):
"""Return type for FastValidator.validate()."""
samples: list[str]
is_garbage: bool
ascii_ratio: float
avg_length: float
repetition_ratio: float
class GrammarValidationResult(TypedDict):
"""Return type for GrammarValidator.validate()."""
grammar_score: float
num_errors: int
is_fallback: bool
samples: list[str]
class KnowledgeValidationResult(TypedDict):
"""Return type for KnowledgeValidator.validate()."""
accuracy: float
correct: int
total: int
failed: list[dict[str, Any]]
class LanguageValidationResult(TypedDict):
"""Return type for LanguageValidator.validate()."""
is_garbage: bool
lang_confidence: float
valid_word_ratio: float
detected_language: str
samples: list[str]
class PerplexityValidationResult(TypedDict):
"""Return type for PerplexityValidator.validate()."""
perplexity: float
perplexity_normalized: float
samples: list[str]
class Validator(Protocol):
"""
Protocol for validation components.
Validators must implement a validate() method that takes a text-generating
model and training step, returning validation metrics.
This Protocol provides structural subtyping (duck typing with type hints),
allowing type checkers to verify validator compliance without requiring
inheritance.
Example:
>>> class CustomValidator:
... def validate(self, model: Any, step: int) -> dict[str, Any]:
... return {"score": 0.95}
...
>>> validator: Validator = CustomValidator() # Type-safe!
"""
def validate(self, model: Any, step: int) -> dict[str, Any]:
"""
Run validation on model at given training step.
Args:
model: Model with .generate_text() method
step: Current training step
Returns:
Dict with validation metrics (keys vary by validator):
- FastValidator: is_garbage, ascii_ratio, avg_length, repetition_ratio
- GrammarValidator: grammar_score, num_errors, is_fallback
- KnowledgeValidator: accuracy, correct, total, failed
"""
...
def validate_samples(self, samples: list[str], step: int) -> dict[str, Any]:
"""
Run validation on pre-generated samples.
This method allows sharing samples between multiple validators,
reducing generation cost.
Args:
samples: Pre-generated text samples
step: Current training step
Returns:
Dict with validation metrics (same as validate())
"""
...
class FastValidator:
"""
Heuristic-based fast validation for garbage detection.
Runs every 100 steps with <1s overhead. Catches obvious failures
like non-ASCII output, extremely short/long output, and repetition.
"""
def __init__(self, test_prompts: list[str]) -> None:
"""
Initialize FastValidator.
Args:
test_prompts: List of prompts to test generation with
Raises:
ValueError: If test_prompts is empty
TypeError: If test_prompts contains non-string elements
"""
if not test_prompts:
raise ValueError("test_prompts cannot be empty")
if not all(isinstance(p, str) for p in test_prompts):
raise TypeError("All test_prompts must be strings")
self.test_prompts = test_prompts
self.sample_history: deque[tuple[int, list[str]]] = deque(maxlen=SAMPLE_HISTORY_SIZE)
@staticmethod
def _ngram_repetition(text: str) -> float:
"""
Calculate n-gram repetition ratio using memory-efficient generator.
Args:
text: Input text to analyze
Returns:
Repetition ratio (0.0 = no repetition, 1.0 = maximum repetition)
"""
if len(text) < NGRAM_SIZE:
return 0.0
# Generator avoids materializing full list in memory
ngrams = (text[i:i+NGRAM_SIZE] for i in range(len(text) - NGRAM_SIZE + 1))
counts = Counter(ngrams)
total = sum(counts.values())
unique = len(counts)
# Convert to repetition ratio (inverse of uniqueness)
return 1.0 - (unique / total) if total > 0 else 0.0
def validate(self, model: Any, step: int) -> FastValidationResult:
"""
Run fast heuristic validation.
Args:
model: Model to validate
step: Current training step
Returns:
FastValidationResult with keys:
- samples: list[str]
- is_garbage: bool
- ascii_ratio: float
- avg_length: float
- repetition_ratio: float
"""
samples = []
try:
# Generate with inference mode for performance
with torch.inference_mode():
for prompt in self.test_prompts:
try:
sample = model.generate_text(
prompt,
max_length=VALIDATION_MAX_LENGTH,
temperature=VALIDATION_TEMPERATURE
)
samples.append(sample)
except Exception as e:
logger.warning(
"Generation failed for prompt",
extra={"prompt": prompt, "error": str(e)}
)
samples.append("")
except Exception as e:
logger.error(
"FastValidator failed",
extra={"step": step, "error": str(e)}
)
return {
"samples": [],
"is_garbage": True,
"ascii_ratio": 0.0,
"avg_length": 0.0,
"repetition_ratio": FALLBACK_REPETITION_SCORE
}
# Delegate to validate_samples for actual validation logic
return self.validate_samples(samples, step)
def validate_samples(self, samples: list[str], step: int) -> FastValidationResult:
"""
Run fast heuristic validation on pre-generated samples.
This method allows sharing samples between multiple validators,
reducing generation cost by 50%.
Args:
samples: Pre-generated text samples
step: Current training step
Returns:
FastValidationResult with keys:
- samples: list[str]
- is_garbage: bool
- ascii_ratio: float
- avg_length: float
- repetition_ratio: float
"""
# Heuristic checks
total_chars = sum(len(s) for s in samples)
ascii_chars = sum(sum(c.isascii() for c in s) for s in samples)
ascii_ratio = ascii_chars / total_chars if total_chars > 0 else 0.0
avg_length = sum(len(s) for s in samples) / len(samples) if samples else 0
# Repetition detection (memory-efficient generator-based)
repetition_scores = []
for sample in samples:
if len(sample) < MIN_NGRAM_TEXT_LENGTH:
repetition_scores.append(FALLBACK_REPETITION_SCORE)
continue
# Use generator-based n-gram detection (O(1) memory)
rep_ratio = self._ngram_repetition(sample)
repetition_scores.append(rep_ratio)
repetition_ratio = sum(repetition_scores) / len(repetition_scores) if repetition_scores else 0.0
# Garbage detection criteria
is_garbage = (
ascii_ratio < MIN_ASCII_RATIO or
avg_length < MIN_SAMPLE_LENGTH or
repetition_ratio > MAX_REPETITION_RATIO
)
# Store sanitized samples
sanitized_samples = [sanitize(s, mode="pii") for s in samples]
self.sample_history.append((step, sanitized_samples))
return {
"samples": sanitized_samples,
"is_garbage": is_garbage,
"ascii_ratio": ascii_ratio,
"avg_length": avg_length,
"repetition_ratio": repetition_ratio
}
class GrammarValidator:
"""
LanguageTool-based grammar validation.
Runs every 200 steps with <2s overhead. Measures grammar quality
using external LanguageTool API with fallback to heuristics.
"""
def __init__(self, client: Any, test_prompts: list[str]) -> None:
"""
Initialize GrammarValidator.
Args:
client: LanguageToolClient instance
test_prompts: List of prompts to test generation with
Raises:
ValueError: If client is None or test_prompts is empty
TypeError: If test_prompts contains non-string elements
"""
if client is None:
raise ValueError("client cannot be None")
if not test_prompts:
raise ValueError("test_prompts cannot be empty")
if not all(isinstance(p, str) for p in test_prompts):
raise TypeError("All test_prompts must be strings")
self.client = client
self.test_prompts = test_prompts
# Inline history tracking (removed ValidationHistory abstraction)
self.grammar_scores: deque[float] = deque(maxlen=GRAMMAR_HISTORY_SIZE)
self.sample_outputs: deque[str] = deque(maxlen=SAMPLE_HISTORY_SIZE)
self.timestamps: deque[int] = deque(maxlen=TIMESTAMP_HISTORY_SIZE)
def validate(self, model: Any, step: int) -> GrammarValidationResult:
"""
Run grammar validation (sync wrapper for async validation).
This method wraps validate_async() to maintain backward compatibility
with PyTorch Lightning callbacks that expect synchronous validation.
For direct async usage, call validate_async() instead.
Args:
model: Model to validate
step: Current training step
Returns:
GrammarValidationResult with keys:
- grammar_score: float
- num_errors: int
- is_fallback: bool
- samples: list[str]
"""
# Use async validation with asyncio.run()
try:
return asyncio.run(self.validate_async(model, step))
except RuntimeError as e:
# Handle case where event loop is already running
if "already running" in str(e):
logger.warning("Event loop already running, falling back to sync validation")
return self._validate_sync(model, step)
raise
def _validate_sync(self, model: Any, step: int) -> GrammarValidationResult:
"""
Synchronous fallback validation (used when event loop conflicts occur).
This is the original sequential implementation, kept as fallback.
Args:
model: Model to validate
step: Current training step
Returns:
GrammarValidationResult (same structure as validate())
"""
samples = []
try:
with torch.inference_mode():
for prompt in self.test_prompts:
try:
sample = model.generate_text(
prompt,
max_length=VALIDATION_MAX_LENGTH,
temperature=VALIDATION_TEMPERATURE
)
samples.append(sample)
except Exception as e:
logger.warning("Generation failed", extra={"error": str(e)})
samples.append("")
except Exception as e:
logger.error(
"GrammarValidator generation failed",
extra={"error": str(e)}
)
return {
"grammar_score": FALLBACK_GRAMMAR_SCORE,
"num_errors": FALLBACK_ERROR_COUNT,
"is_fallback": True,
"samples": []
}
# Delegate to validate_samples_sync for actual validation logic
return self.validate_samples_sync(samples, step)
def validate_samples_sync(self, samples: list[str], step: int) -> GrammarValidationResult:
"""
Run grammar validation on pre-generated samples (synchronous).
This method allows sharing samples between multiple validators,
reducing generation cost by 50%.
Args:
samples: Pre-generated text samples
step: Current training step
Returns:
GrammarValidationResult with keys:
- grammar_score: float
- num_errors: int
- is_fallback: bool
- samples: list[str]
"""
# Check grammar for all samples (SEQUENTIAL)
results = []
for sample in samples:
if not sample or len(sample) < MIN_SAMPLE_LENGTH:
results.append(GrammarResult(
grammar_score=FALLBACK_GRAMMAR_SCORE,
num_errors=0,
errors=[],
suggestions=[],
is_fallback=True
))
continue
result = self.client.check(sample)
results.append(result)
# Aggregate scores
avg_score = sum(r.grammar_score for r in results) / len(results) if results else 0.0
total_errors = sum(r.num_errors for r in results)
any_fallback = any(r.is_fallback for r in results)
# Update history
if samples:
sanitized = sanitize(samples[0], mode="pii")
self.grammar_scores.append(avg_score)
self.sample_outputs.append(sanitized)
self.timestamps.append(step)
return {
"grammar_score": avg_score,
"num_errors": total_errors,
"is_fallback": any_fallback,
"samples": [sanitize(s, mode="pii") for s in samples]
}
async def validate_async(self, model: Any, step: int) -> GrammarValidationResult:
"""
Run async grammar validation (NON-BLOCKING).
This is the key performance optimization: all grammar checks run
in parallel instead of sequentially, reducing validation time from
2.5s to 0.5s (5x speedup).
Args:
model: Model to validate
step: Current training step
Returns:
{
"grammar_score": float,
"num_errors": int,
"is_fallback": bool,
"samples": list[str]
}
"""
samples = []
try:
# Generate samples (still synchronous, but fast)
with torch.inference_mode():
for prompt in self.test_prompts:
try:
sample = model.generate_text(
prompt,
max_length=VALIDATION_MAX_LENGTH,
temperature=VALIDATION_TEMPERATURE
)
samples.append(sample)
except Exception as e:
logger.warning("Generation failed", extra={"error": str(e)})
samples.append("")
except Exception as e:
logger.error(
"GrammarValidator generation failed",
extra={"error": str(e)}
)
return {
"grammar_score": FALLBACK_GRAMMAR_SCORE,
"num_errors": FALLBACK_ERROR_COUNT,
"is_fallback": True,
"samples": []
}
# Delegate to validate_samples_async for actual validation logic
return await self.validate_samples_async(samples, step)
async def validate_samples_async(self, samples: list[str], step: int) -> GrammarValidationResult:
"""
Run async grammar validation on pre-generated samples (NON-BLOCKING).
This method allows sharing samples between multiple validators,
reducing generation cost by 50%.
Args:
samples: Pre-generated text samples
step: Current training step
Returns:
GrammarValidationResult with keys:
- grammar_score: float
- num_errors: int
- is_fallback: bool
- samples: list[str]
"""
# Filter out empty/too-short samples
valid_samples = [s for s in samples if s and len(s) >= MIN_SAMPLE_LENGTH]
# ASYNC: Check grammar in parallel (KEY OPTIMIZATION)
if hasattr(self.client, 'check_batch_async'):
# Use async client for parallel checking
results = await self.client.check_batch_async(valid_samples)
else:
# Fallback to sync client (sequential)
logger.warning("Async client not available, falling back to sync")
results = [self.client.check(s) for s in valid_samples]
# Aggregate scores
avg_score = sum(r.grammar_score for r in results) / len(results) if results else 0.0
total_errors = sum(r.num_errors for r in results)
any_fallback = any(r.is_fallback for r in results)
# Update history
if samples:
sanitized = sanitize(samples[0], mode="pii")
self.grammar_scores.append(avg_score)
self.sample_outputs.append(sanitized)
self.timestamps.append(step)
return {
"grammar_score": avg_score,
"num_errors": total_errors,
"is_fallback": any_fallback,
"samples": [sanitize(s, mode="pii") for s in samples]
}
def validate_samples(self, samples: list[str], step: int) -> GrammarValidationResult:
"""
Synchronous wrapper for validate_samples_async (for CombinedValidationCallback).
This method provides a synchronous interface for validating pre-generated
samples, allowing the CombinedValidationCallback to share samples between
validators.
Args:
samples: Pre-generated text samples
step: Current training step
Returns:
GrammarValidationResult with same structure as validate()
"""
try:
return asyncio.run(self.validate_samples_async(samples, step))
except RuntimeError as e:
# Handle case where event loop is already running
if "already running" in str(e):
logger.warning("Event loop already running, falling back to sync validation")
return self.validate_samples_sync(samples, step)
raise
def get_trend(self, window: int = TREND_ANALYSIS_WINDOW) -> str:
"""
Detect improving/degrading trend.
Args:
window: Number of recent scores to analyze
Returns:
"improving", "degrading", "stable", or "insufficient_data"
"""
if len(self.grammar_scores) < window:
return "insufficient_data"
recent = list(self.grammar_scores)[-window:]
if all(recent[i] >= recent[i-1] for i in range(1, len(recent))):
return "improving"
elif all(recent[i] <= recent[i-1] for i in range(1, len(recent))):
return "degrading"
else:
return "stable"
class KnowledgeValidator:
"""
Factual accuracy validation using knowledge base.
Runs post-training only (~10s). Tests model on 10 factual questions
to verify knowledge retention.
"""
def __init__(self, questions: list[dict[str, Any]]) -> None:
"""
Initialize KnowledgeValidator.
Args:
questions: List of {"q": str, "a": list[str]} question/answer pairs
Raises:
ValueError: If questions list is None or has invalid structure
TypeError: If questions is not a list
"""
if questions is None:
raise ValueError("questions cannot be None")
if not isinstance(questions, list):
raise TypeError("questions must be a list")
# Validate structure of questions (each must have 'q' and 'a' keys)
for i, q in enumerate(questions):
if not isinstance(q, dict):
raise TypeError(f"Question at index {i} must be a dict")
if 'q' not in q or 'a' not in q:
raise ValueError(f"Question at index {i} must have 'q' and 'a' keys")
if not isinstance(q['q'], str):
raise TypeError(f"Question 'q' at index {i} must be a string")
if not isinstance(q['a'], list):
raise TypeError(f"Question 'a' at index {i} must be a list")
self.questions = questions
def validate(self, model: Any, step: int = -1) -> KnowledgeValidationResult:
"""
Run factual accuracy validation.
Args:
model: Model to validate
step: Training step (default -1 for post-training)
Returns:
KnowledgeValidationResult with keys:
- accuracy: float
- correct: int
- total: int
- failed: list[dict[str, Any]]
"""
correct = 0
failed = []
try:
with torch.inference_mode():
for item in self.questions:
question = item['q']
valid_answers = [a.lower() for a in item['a']]
try:
output = model.generate_text(
question,
max_length=KNOWLEDGE_MAX_LENGTH,
temperature=KNOWLEDGE_TEMPERATURE
)
output_lower = output.lower()
# Fuzzy matching: check if any valid answer in output
is_correct = any(ans in output_lower for ans in valid_answers)
if is_correct:
correct += 1
else:
failed.append({
'question': question,
'expected': item['a'],
'got': output[:ERROR_LOG_TRUNCATE_LENGTH]
})
except Exception as e:
logger.warning(
"Knowledge validation failed",
extra={"question": question, "error": str(e)}
)
failed.append({
'question': question,
'expected': item['a'],
'got': f"ERROR: {str(e)}"
})
except Exception as e:
logger.error(
"KnowledgeValidator failed",
extra={"error": str(e)}
)
return {
"accuracy": 0.0,
"correct": 0,
"total": len(self.questions),
"failed": self.questions
}
return {
"accuracy": correct / len(self.questions) if self.questions else 0.0,
"correct": correct,
"total": len(self.questions),
"failed": failed
}
def validate_samples(self, samples: list[str], step: int) -> KnowledgeValidationResult:
"""
Not applicable for KnowledgeValidator (uses its own Q&A format).
This method exists for Protocol compliance but is not supported.
Use validate() instead.
Args:
samples: Unused (KnowledgeValidator generates from questions)
step: Training step
Raises:
NotImplementedError: KnowledgeValidator doesn't support validate_samples
Note:
KnowledgeValidator doesn't use pre-generated samples since it
tests factual knowledge with specific Q&A pairs.
"""
raise NotImplementedError(
"KnowledgeValidator doesn't support validate_samples. "
"Use validate(model, step) instead."
)
class LanguageValidator:
"""
Language detection and word validity validation.
Validates text is English with real words using:
- langdetect for language detection
- NLTK words corpus for English word validation
- Unicode script detection for multilingual text
Runs every 100 steps with <1s overhead.
"""
def __init__(self, test_prompts: list[str]) -> None:
"""
Initialize LanguageValidator.
Args:
test_prompts: List of prompts to test generation with
Raises:
ValueError: If test_prompts is empty
TypeError: If test_prompts contains non-string elements
"""
if not test_prompts:
raise ValueError("test_prompts cannot be empty")
if not all(isinstance(p, str) for p in test_prompts):
raise TypeError("All test_prompts must be strings")
self.test_prompts = test_prompts
# Load English words corpus (lazy load to avoid startup cost)
self._english_words = None
@property
def english_words(self):
"""Lazy-load NLTK words corpus."""
if self._english_words is None:
try:
import nltk
from nltk.corpus import words
# Ensure local NLTK data directory is searched first
nltk.data.path.insert(0, "/home/mikeb/nltk_data")
self._english_words = set(w.lower() for w in words.words())
except Exception as e:
logger.warning(
"NLTK words corpus not available, using fallback",
extra={"error": str(e)}
)
# Fallback to small set of common English words
self._english_words = set([
'the', 'be', 'to', 'of', 'and', 'a', 'in', 'that', 'have', 'i',
'it', 'for', 'not', 'on', 'with', 'he', 'as', 'you', 'do', 'at'
])
return self._english_words
@staticmethod
def detect_language_with_confidence(text: str) -> tuple[str, float]:
"""
Detect language and return confidence score.
Args:
text: Input text to analyze
Returns:
Tuple of (language_code, confidence)
e.g., ('en', 0.95) for high-confidence English
"""
try:
import langdetect
from langdetect import DetectorFactory
# Ensure reproducible results
DetectorFactory.seed = 0
# Detect language
lang = langdetect.detect(text)
# Get probability distribution
probs = langdetect.detect_langs(text)
# Find English confidence
en_confidence = next(
(p.prob for p in probs if p.lang == 'en'),
0.0
)
return lang, en_confidence if lang == 'en' else 0.0
except Exception as e:
logger.debug(
"Language detection failed",
extra={"error": str(e)}
)
return 'unknown', 0.0
@staticmethod
def detect_multilingual(text: str) -> dict[str, Any]:
"""
Detect mixed-language text (common gaming strategy).
Args:
text: Input text to analyze
Returns:
Dict with keys:
- is_multilingual: bool
- primary_script: str
- script_ratios: dict[str, float]
"""
# Unicode script detection
scripts = {
'latin': 0,
'cyrillic': 0,
'arabic': 0,
'cjk': 0,
'greek': 0,
}
for char in text:
if 'a' <= char.lower() <= 'z':
scripts['latin'] += 1
elif '\u0400' <= char <= '\u04FF':
scripts['cyrillic'] += 1
elif '\u0600' <= char <= '\u06FF':
scripts['arabic'] += 1
elif '\u4E00' <= char <= '\u9FFF':
scripts['cjk'] += 1
elif '\u0370' <= char <= '\u03FF':
scripts['greek'] += 1
total_letters = sum(scripts.values())
if total_letters == 0:
return {
'is_multilingual': False,
'primary_script': 'none',
'script_ratios': {}
}
# Normalize to percentages
script_ratios = {k: v/total_letters for k, v in scripts.items()}
# Find dominant script
primary_script = max(script_ratios, key=script_ratios.get)
# Check if multiple scripts present
num_scripts = sum(1 for ratio in script_ratios.values() if ratio > 0.05)
return {
'is_multilingual': num_scripts > 1,
'primary_script': primary_script,
'script_ratios': script_ratios,
}
def validate(self, model: Any, step: int) -> LanguageValidationResult:
"""
Run language detection and word validity validation.
Args:
model: Model to validate
step: Current training step
Returns:
LanguageValidationResult with keys:
- is_garbage: bool
- lang_confidence: float
- valid_word_ratio: float
- detected_language: str
- samples: list[str]
"""
samples = []
try:
with torch.inference_mode():
for prompt in self.test_prompts:
try:
sample = model.generate_text(
prompt,
max_length=VALIDATION_MAX_LENGTH,
temperature=VALIDATION_TEMPERATURE
)
samples.append(sample)
except Exception as e:
logger.warning(
"Generation failed for prompt",
extra={"prompt": prompt, "error": str(e)}
)
samples.append("")
except Exception as e:
logger.error(
"LanguageValidator failed",
extra={"step": step, "error": str(e)}
)
return {
"is_garbage": True,
"lang_confidence": 0.0,
"valid_word_ratio": 0.0,
"detected_language": "unknown",
"samples": []
}
# Delegate to validate_samples for actual validation logic
return self.validate_samples(samples, step)
def validate_samples(self, samples: list[str], step: int) -> LanguageValidationResult:
"""
Run language validation on pre-generated samples.
This method allows sharing samples between multiple validators,
reducing generation cost.
Args:
samples: Pre-generated text samples
step: Current training step
Returns:
LanguageValidationResult with keys:
- is_garbage: bool
- lang_confidence: float
- valid_word_ratio: float
- detected_language: str
- samples: list[str]
"""
if not samples:
return {
"is_garbage": True,
"lang_confidence": 0.0,
"valid_word_ratio": 0.0,
"detected_language": "unknown",
"samples": []
}
# Aggregate language detection across all samples
lang_confidences = []
detected_langs = []
valid_word_ratios = []
for sample in samples:
if not sample or len(sample) < MIN_SAMPLE_LENGTH:
lang_confidences.append(0.0)
detected_langs.append('unknown')
valid_word_ratios.append(0.0)
continue
# Language detection
lang, confidence = self.detect_language_with_confidence(sample)
lang_confidences.append(confidence)
detected_langs.append(lang)
# Word validity check
tokens = sample.lower().split()
clean_tokens = [
t.strip('.,!?;:()[]{}"\'-')
for t in tokens
if t.strip('.,!?;:()[]{}"\'-')
]
if clean_tokens:
valid_count = sum(
1 for t in clean_tokens
if t in self.english_words
)
valid_ratio = valid_count / len(clean_tokens)
else:
valid_ratio = 0.0
valid_word_ratios.append(valid_ratio)
# Aggregate scores
avg_lang_confidence = sum(lang_confidences) / len(lang_confidences)
avg_valid_word_ratio = sum(valid_word_ratios) / len(valid_word_ratios)
# Most common detected language
from collections import Counter
lang_counts = Counter(detected_langs)
primary_lang = lang_counts.most_common(1)[0][0]
# Check for multilingual text in any sample
any_multilingual = any(
self.detect_multilingual(s)['is_multilingual']
for s in samples
if s and len(s) >= MIN_SAMPLE_LENGTH
)
# Garbage detection criteria
is_garbage = (
primary_lang != 'en' or
avg_lang_confidence < 0.8 or
avg_valid_word_ratio < 0.7 or
any_multilingual
)
# Sanitize samples
sanitized_samples = [sanitize(s, mode="pii") for s in samples]
return {
"is_garbage": is_garbage,
"lang_confidence": avg_lang_confidence,
"valid_word_ratio": avg_valid_word_ratio,
"detected_language": primary_lang,
"samples": sanitized_samples
}
class PerplexityValidator:
"""
Autoregressive perplexity validation using DistilGPT-2.
Measures language fluency using pre-trained transformer model.
Uses mixed precision (AMP) for 2x speedup.
Runs every 100 steps with ~500ms overhead (with batching).
"""
def __init__(self, test_prompts: list[str], model_name: str = "distilgpt2") -> None:
"""
Initialize PerplexityValidator.
Args:
test_prompts: List of prompts to test generation with
model_name: HuggingFace model name (default: "distilgpt2")
Raises:
ValueError: If test_prompts is empty
TypeError: If test_prompts contains non-string elements
"""
if not test_prompts:
raise ValueError("test_prompts cannot be empty")
if not all(isinstance(p, str) for p in test_prompts):
raise TypeError("All test_prompts must be strings")
self.test_prompts = test_prompts
self.model_name = model_name
# Lazy-load model (avoid startup cost)
self._model = None
self._tokenizer = None
@property
def model(self):
"""Lazy-load DistilGPT-2 model."""
if self._model is None:
try:
from transformers import AutoModelForCausalLM
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name
).to('cuda')
self._model.eval()
except Exception as e:
logger.error(
"Failed to load perplexity model",
extra={"model": self.model_name, "error": str(e)}
)
raise
return self._model
@property
def tokenizer(self):
"""Lazy-load tokenizer."""
if self._tokenizer is None:
try:
from transformers import AutoTokenizer
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
except Exception as e:
logger.error(
"Failed to load tokenizer",
extra={"model": self.model_name, "error": str(e)}
)
raise
return self._tokenizer
def validate(self, model: Any, step: int) -> PerplexityValidationResult:
"""
Run perplexity validation.
Args:
model: Model to validate
step: Current training step
Returns:
PerplexityValidationResult with keys:
- perplexity: float
- perplexity_normalized: float (0-1 score for reward)
- samples: list[str]
"""
samples = []
try:
with torch.inference_mode():
for prompt in self.test_prompts:
try:
sample = model.generate_text(
prompt,
max_length=VALIDATION_MAX_LENGTH,
temperature=VALIDATION_TEMPERATURE
)
samples.append(sample)
except Exception as e:
logger.warning(
"Generation failed for prompt",
extra={"prompt": prompt, "error": str(e)}
)
samples.append("")
except Exception as e:
logger.error(
"PerplexityValidator generation failed",
extra={"step": step, "error": str(e)}
)
return {
"perplexity": float('inf'),
"perplexity_normalized": 0.0,
"samples": []
}
# Delegate to validate_samples for actual validation logic
return self.validate_samples(samples, step)
def validate_samples(self, samples: list[str], step: int) -> PerplexityValidationResult:
"""
Run perplexity validation on pre-generated samples.
This method allows sharing samples between multiple validators,
reducing generation cost.
Args:
samples: Pre-generated text samples
step: Current training step
Returns:
PerplexityValidationResult with keys:
- perplexity: float
- perplexity_normalized: float (0-1 score for reward)
- samples: list[str]
"""
if not samples:
return {
"perplexity": float('inf'),
"perplexity_normalized": 0.0,
"samples": []
}
# Filter valid samples
valid_samples = [
s for s in samples
if s and len(s) >= MIN_SAMPLE_LENGTH
]
if not valid_samples:
return {
"perplexity": float('inf'),
"perplexity_normalized": 0.0,
"samples": [sanitize(s, mode="pii") for s in samples]
}
# Compute perplexity for each sample
perplexities = []
try:
for sample in valid_samples:
# Tokenize
encodings = self.tokenizer(
sample,
return_tensors='pt',
truncation=True,
max_length=512
).to('cuda')
# Compute cross-entropy with mixed precision
with torch.no_grad(), torch.amp.autocast("cuda"):
outputs = self.model(**encodings, labels=encodings.input_ids)
ce = outputs.loss.item()
# Perplexity = exp(cross_entropy)
perplexity = torch.exp(torch.tensor(ce)).item()
perplexities.append(perplexity)
except Exception as e:
logger.error(
"Perplexity computation failed",
extra={"error": str(e)}
)
return {
"perplexity": float('inf'),
"perplexity_normalized": 0.0,
"samples": [sanitize(s, mode="pii") for s in samples]
}
# Aggregate
avg_perplexity = sum(perplexities) / len(perplexities)
# Normalize to [0, 1] for reward (lower perplexity = better)
# exp(-perp/10): perp=0 → 1.0, perp=10 → 0.37, perp=50 → 0.007
import math
normalized_score = math.exp(-avg_perplexity / 10.0)
return {
"perplexity": avg_perplexity,
"perplexity_normalized": normalized_score,
"samples": [sanitize(s, mode="pii") for s in samples]
}