| """ |
| Validator classes for text generation quality assessment. |
| |
| Provides FastValidator (heuristics), GrammarValidator (LanguageTool), |
| and KnowledgeValidator (factual accuracy) with security hardening and |
| performance optimizations. |
| """ |
|
|
| __all__ = [ |
| "FastValidator", |
| "GrammarValidator", |
| "KnowledgeValidator", |
| "LanguageValidator", |
| "PerplexityValidator", |
| "Validator", |
| "FastValidationResult", |
| "GrammarValidationResult", |
| "KnowledgeValidationResult", |
| "LanguageValidationResult", |
| "PerplexityValidationResult", |
| ] |
|
|
| import time |
| import asyncio |
| import logging |
| from typing import Any, Protocol, TypedDict, TYPE_CHECKING |
| from collections import deque, Counter |
| from dataclasses import dataclass, field |
| import torch |
|
|
| |
| if TYPE_CHECKING: |
| from .grammar_checker import GrammarResult |
| else: |
| try: |
| from .grammar_checker import GrammarResult |
| except ImportError: |
| |
| @dataclass |
| class GrammarResult: |
| grammar_score: float |
| num_errors: int |
| errors: list[dict] = field(default_factory=list) |
| suggestions: list[list[str]] = field(default_factory=list) |
| is_fallback: bool = False |
|
|
| |
| from .sanitizer import sanitize |
|
|
| |
| from .constants import ( |
| MIN_ASCII_RATIO, |
| MAX_REPETITION_RATIO, |
| MIN_SAMPLE_LENGTH, |
| VALIDATION_MAX_LENGTH, |
| VALIDATION_TEMPERATURE, |
| KNOWLEDGE_MAX_LENGTH, |
| KNOWLEDGE_TEMPERATURE, |
| SAMPLE_HISTORY_SIZE, |
| GRAMMAR_HISTORY_SIZE, |
| TIMESTAMP_HISTORY_SIZE, |
| TREND_ANALYSIS_WINDOW, |
| NGRAM_SIZE, |
| MIN_NGRAM_TEXT_LENGTH, |
| FALLBACK_REPETITION_SCORE, |
| FALLBACK_GRAMMAR_SCORE, |
| FALLBACK_ERROR_COUNT, |
| ERROR_LOG_TRUNCATE_LENGTH, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class FastValidationResult(TypedDict): |
| """Return type for FastValidator.validate().""" |
| samples: list[str] |
| is_garbage: bool |
| ascii_ratio: float |
| avg_length: float |
| repetition_ratio: float |
|
|
| class GrammarValidationResult(TypedDict): |
| """Return type for GrammarValidator.validate().""" |
| grammar_score: float |
| num_errors: int |
| is_fallback: bool |
| samples: list[str] |
|
|
| class KnowledgeValidationResult(TypedDict): |
| """Return type for KnowledgeValidator.validate().""" |
| accuracy: float |
| correct: int |
| total: int |
| failed: list[dict[str, Any]] |
|
|
| class LanguageValidationResult(TypedDict): |
| """Return type for LanguageValidator.validate().""" |
| is_garbage: bool |
| lang_confidence: float |
| valid_word_ratio: float |
| detected_language: str |
| samples: list[str] |
|
|
| class PerplexityValidationResult(TypedDict): |
| """Return type for PerplexityValidator.validate().""" |
| perplexity: float |
| perplexity_normalized: float |
| samples: list[str] |
|
|
| class Validator(Protocol): |
| """ |
| Protocol for validation components. |
| |
| Validators must implement a validate() method that takes a text-generating |
| model and training step, returning validation metrics. |
| |
| This Protocol provides structural subtyping (duck typing with type hints), |
| allowing type checkers to verify validator compliance without requiring |
| inheritance. |
| |
| Example: |
| >>> class CustomValidator: |
| ... def validate(self, model: Any, step: int) -> dict[str, Any]: |
| ... return {"score": 0.95} |
| ... |
| >>> validator: Validator = CustomValidator() # Type-safe! |
| """ |
|
|
| def validate(self, model: Any, step: int) -> dict[str, Any]: |
| """ |
| Run validation on model at given training step. |
| |
| Args: |
| model: Model with .generate_text() method |
| step: Current training step |
| |
| Returns: |
| Dict with validation metrics (keys vary by validator): |
| - FastValidator: is_garbage, ascii_ratio, avg_length, repetition_ratio |
| - GrammarValidator: grammar_score, num_errors, is_fallback |
| - KnowledgeValidator: accuracy, correct, total, failed |
| """ |
| ... |
|
|
| def validate_samples(self, samples: list[str], step: int) -> dict[str, Any]: |
| """ |
| Run validation on pre-generated samples. |
| |
| This method allows sharing samples between multiple validators, |
| reducing generation cost. |
| |
| Args: |
| samples: Pre-generated text samples |
| step: Current training step |
| |
| Returns: |
| Dict with validation metrics (same as validate()) |
| """ |
| ... |
|
|
|
|
| class FastValidator: |
| """ |
| Heuristic-based fast validation for garbage detection. |
| |
| Runs every 100 steps with <1s overhead. Catches obvious failures |
| like non-ASCII output, extremely short/long output, and repetition. |
| """ |
|
|
| def __init__(self, test_prompts: list[str]) -> None: |
| """ |
| Initialize FastValidator. |
| |
| Args: |
| test_prompts: List of prompts to test generation with |
| |
| Raises: |
| ValueError: If test_prompts is empty |
| TypeError: If test_prompts contains non-string elements |
| """ |
| if not test_prompts: |
| raise ValueError("test_prompts cannot be empty") |
| if not all(isinstance(p, str) for p in test_prompts): |
| raise TypeError("All test_prompts must be strings") |
|
|
| self.test_prompts = test_prompts |
| self.sample_history: deque[tuple[int, list[str]]] = deque(maxlen=SAMPLE_HISTORY_SIZE) |
|
|
| @staticmethod |
| def _ngram_repetition(text: str) -> float: |
| """ |
| Calculate n-gram repetition ratio using memory-efficient generator. |
| |
| Args: |
| text: Input text to analyze |
| |
| Returns: |
| Repetition ratio (0.0 = no repetition, 1.0 = maximum repetition) |
| """ |
| if len(text) < NGRAM_SIZE: |
| return 0.0 |
|
|
| |
| ngrams = (text[i:i+NGRAM_SIZE] for i in range(len(text) - NGRAM_SIZE + 1)) |
| counts = Counter(ngrams) |
| total = sum(counts.values()) |
| unique = len(counts) |
|
|
| |
| return 1.0 - (unique / total) if total > 0 else 0.0 |
|
|
| def validate(self, model: Any, step: int) -> FastValidationResult: |
| """ |
| Run fast heuristic validation. |
| |
| Args: |
| model: Model to validate |
| step: Current training step |
| |
| Returns: |
| FastValidationResult with keys: |
| - samples: list[str] |
| - is_garbage: bool |
| - ascii_ratio: float |
| - avg_length: float |
| - repetition_ratio: float |
| """ |
| samples = [] |
|
|
| try: |
| |
| with torch.inference_mode(): |
| for prompt in self.test_prompts: |
| try: |
| sample = model.generate_text( |
| prompt, |
| max_length=VALIDATION_MAX_LENGTH, |
| temperature=VALIDATION_TEMPERATURE |
| ) |
| samples.append(sample) |
| except Exception as e: |
| logger.warning( |
| "Generation failed for prompt", |
| extra={"prompt": prompt, "error": str(e)} |
| ) |
| samples.append("") |
|
|
| except Exception as e: |
| logger.error( |
| "FastValidator failed", |
| extra={"step": step, "error": str(e)} |
| ) |
| return { |
| "samples": [], |
| "is_garbage": True, |
| "ascii_ratio": 0.0, |
| "avg_length": 0.0, |
| "repetition_ratio": FALLBACK_REPETITION_SCORE |
| } |
|
|
| |
| return self.validate_samples(samples, step) |
|
|
| def validate_samples(self, samples: list[str], step: int) -> FastValidationResult: |
| """ |
| Run fast heuristic validation on pre-generated samples. |
| |
| This method allows sharing samples between multiple validators, |
| reducing generation cost by 50%. |
| |
| Args: |
| samples: Pre-generated text samples |
| step: Current training step |
| |
| Returns: |
| FastValidationResult with keys: |
| - samples: list[str] |
| - is_garbage: bool |
| - ascii_ratio: float |
| - avg_length: float |
| - repetition_ratio: float |
| """ |
| |
| total_chars = sum(len(s) for s in samples) |
| ascii_chars = sum(sum(c.isascii() for c in s) for s in samples) |
| ascii_ratio = ascii_chars / total_chars if total_chars > 0 else 0.0 |
|
|
| avg_length = sum(len(s) for s in samples) / len(samples) if samples else 0 |
|
|
| |
| repetition_scores = [] |
| for sample in samples: |
| if len(sample) < MIN_NGRAM_TEXT_LENGTH: |
| repetition_scores.append(FALLBACK_REPETITION_SCORE) |
| continue |
| |
| rep_ratio = self._ngram_repetition(sample) |
| repetition_scores.append(rep_ratio) |
|
|
| repetition_ratio = sum(repetition_scores) / len(repetition_scores) if repetition_scores else 0.0 |
|
|
| |
| is_garbage = ( |
| ascii_ratio < MIN_ASCII_RATIO or |
| avg_length < MIN_SAMPLE_LENGTH or |
| repetition_ratio > MAX_REPETITION_RATIO |
| ) |
|
|
| |
| sanitized_samples = [sanitize(s, mode="pii") for s in samples] |
| self.sample_history.append((step, sanitized_samples)) |
|
|
| return { |
| "samples": sanitized_samples, |
| "is_garbage": is_garbage, |
| "ascii_ratio": ascii_ratio, |
| "avg_length": avg_length, |
| "repetition_ratio": repetition_ratio |
| } |
|
|
| class GrammarValidator: |
| """ |
| LanguageTool-based grammar validation. |
| |
| Runs every 200 steps with <2s overhead. Measures grammar quality |
| using external LanguageTool API with fallback to heuristics. |
| """ |
|
|
| def __init__(self, client: Any, test_prompts: list[str]) -> None: |
| """ |
| Initialize GrammarValidator. |
| |
| Args: |
| client: LanguageToolClient instance |
| test_prompts: List of prompts to test generation with |
| |
| Raises: |
| ValueError: If client is None or test_prompts is empty |
| TypeError: If test_prompts contains non-string elements |
| """ |
| if client is None: |
| raise ValueError("client cannot be None") |
| if not test_prompts: |
| raise ValueError("test_prompts cannot be empty") |
| if not all(isinstance(p, str) for p in test_prompts): |
| raise TypeError("All test_prompts must be strings") |
|
|
| self.client = client |
| self.test_prompts = test_prompts |
| |
| self.grammar_scores: deque[float] = deque(maxlen=GRAMMAR_HISTORY_SIZE) |
| self.sample_outputs: deque[str] = deque(maxlen=SAMPLE_HISTORY_SIZE) |
| self.timestamps: deque[int] = deque(maxlen=TIMESTAMP_HISTORY_SIZE) |
|
|
| def validate(self, model: Any, step: int) -> GrammarValidationResult: |
| """ |
| Run grammar validation (sync wrapper for async validation). |
| |
| This method wraps validate_async() to maintain backward compatibility |
| with PyTorch Lightning callbacks that expect synchronous validation. |
| |
| For direct async usage, call validate_async() instead. |
| |
| Args: |
| model: Model to validate |
| step: Current training step |
| |
| Returns: |
| GrammarValidationResult with keys: |
| - grammar_score: float |
| - num_errors: int |
| - is_fallback: bool |
| - samples: list[str] |
| """ |
| |
| try: |
| return asyncio.run(self.validate_async(model, step)) |
| except RuntimeError as e: |
| |
| if "already running" in str(e): |
| logger.warning("Event loop already running, falling back to sync validation") |
| return self._validate_sync(model, step) |
| raise |
|
|
| def _validate_sync(self, model: Any, step: int) -> GrammarValidationResult: |
| """ |
| Synchronous fallback validation (used when event loop conflicts occur). |
| |
| This is the original sequential implementation, kept as fallback. |
| |
| Args: |
| model: Model to validate |
| step: Current training step |
| |
| Returns: |
| GrammarValidationResult (same structure as validate()) |
| """ |
| samples = [] |
|
|
| try: |
| with torch.inference_mode(): |
| for prompt in self.test_prompts: |
| try: |
| sample = model.generate_text( |
| prompt, |
| max_length=VALIDATION_MAX_LENGTH, |
| temperature=VALIDATION_TEMPERATURE |
| ) |
| samples.append(sample) |
| except Exception as e: |
| logger.warning("Generation failed", extra={"error": str(e)}) |
| samples.append("") |
|
|
| except Exception as e: |
| logger.error( |
| "GrammarValidator generation failed", |
| extra={"error": str(e)} |
| ) |
| return { |
| "grammar_score": FALLBACK_GRAMMAR_SCORE, |
| "num_errors": FALLBACK_ERROR_COUNT, |
| "is_fallback": True, |
| "samples": [] |
| } |
|
|
| |
| return self.validate_samples_sync(samples, step) |
|
|
| def validate_samples_sync(self, samples: list[str], step: int) -> GrammarValidationResult: |
| """ |
| Run grammar validation on pre-generated samples (synchronous). |
| |
| This method allows sharing samples between multiple validators, |
| reducing generation cost by 50%. |
| |
| Args: |
| samples: Pre-generated text samples |
| step: Current training step |
| |
| Returns: |
| GrammarValidationResult with keys: |
| - grammar_score: float |
| - num_errors: int |
| - is_fallback: bool |
| - samples: list[str] |
| """ |
| |
| results = [] |
| for sample in samples: |
| if not sample or len(sample) < MIN_SAMPLE_LENGTH: |
| results.append(GrammarResult( |
| grammar_score=FALLBACK_GRAMMAR_SCORE, |
| num_errors=0, |
| errors=[], |
| suggestions=[], |
| is_fallback=True |
| )) |
| continue |
|
|
| result = self.client.check(sample) |
| results.append(result) |
|
|
| |
| avg_score = sum(r.grammar_score for r in results) / len(results) if results else 0.0 |
| total_errors = sum(r.num_errors for r in results) |
| any_fallback = any(r.is_fallback for r in results) |
|
|
| |
| if samples: |
| sanitized = sanitize(samples[0], mode="pii") |
| self.grammar_scores.append(avg_score) |
| self.sample_outputs.append(sanitized) |
| self.timestamps.append(step) |
|
|
| return { |
| "grammar_score": avg_score, |
| "num_errors": total_errors, |
| "is_fallback": any_fallback, |
| "samples": [sanitize(s, mode="pii") for s in samples] |
| } |
|
|
| async def validate_async(self, model: Any, step: int) -> GrammarValidationResult: |
| """ |
| Run async grammar validation (NON-BLOCKING). |
| |
| This is the key performance optimization: all grammar checks run |
| in parallel instead of sequentially, reducing validation time from |
| 2.5s to 0.5s (5x speedup). |
| |
| Args: |
| model: Model to validate |
| step: Current training step |
| |
| Returns: |
| { |
| "grammar_score": float, |
| "num_errors": int, |
| "is_fallback": bool, |
| "samples": list[str] |
| } |
| """ |
| samples = [] |
|
|
| try: |
| |
| with torch.inference_mode(): |
| for prompt in self.test_prompts: |
| try: |
| sample = model.generate_text( |
| prompt, |
| max_length=VALIDATION_MAX_LENGTH, |
| temperature=VALIDATION_TEMPERATURE |
| ) |
| samples.append(sample) |
| except Exception as e: |
| logger.warning("Generation failed", extra={"error": str(e)}) |
| samples.append("") |
|
|
| except Exception as e: |
| logger.error( |
| "GrammarValidator generation failed", |
| extra={"error": str(e)} |
| ) |
| return { |
| "grammar_score": FALLBACK_GRAMMAR_SCORE, |
| "num_errors": FALLBACK_ERROR_COUNT, |
| "is_fallback": True, |
| "samples": [] |
| } |
|
|
| |
| return await self.validate_samples_async(samples, step) |
|
|
| async def validate_samples_async(self, samples: list[str], step: int) -> GrammarValidationResult: |
| """ |
| Run async grammar validation on pre-generated samples (NON-BLOCKING). |
| |
| This method allows sharing samples between multiple validators, |
| reducing generation cost by 50%. |
| |
| Args: |
| samples: Pre-generated text samples |
| step: Current training step |
| |
| Returns: |
| GrammarValidationResult with keys: |
| - grammar_score: float |
| - num_errors: int |
| - is_fallback: bool |
| - samples: list[str] |
| """ |
| |
| valid_samples = [s for s in samples if s and len(s) >= MIN_SAMPLE_LENGTH] |
|
|
| |
| if hasattr(self.client, 'check_batch_async'): |
| |
| results = await self.client.check_batch_async(valid_samples) |
| else: |
| |
| logger.warning("Async client not available, falling back to sync") |
| results = [self.client.check(s) for s in valid_samples] |
|
|
| |
| avg_score = sum(r.grammar_score for r in results) / len(results) if results else 0.0 |
| total_errors = sum(r.num_errors for r in results) |
| any_fallback = any(r.is_fallback for r in results) |
|
|
| |
| if samples: |
| sanitized = sanitize(samples[0], mode="pii") |
| self.grammar_scores.append(avg_score) |
| self.sample_outputs.append(sanitized) |
| self.timestamps.append(step) |
|
|
| return { |
| "grammar_score": avg_score, |
| "num_errors": total_errors, |
| "is_fallback": any_fallback, |
| "samples": [sanitize(s, mode="pii") for s in samples] |
| } |
|
|
| def validate_samples(self, samples: list[str], step: int) -> GrammarValidationResult: |
| """ |
| Synchronous wrapper for validate_samples_async (for CombinedValidationCallback). |
| |
| This method provides a synchronous interface for validating pre-generated |
| samples, allowing the CombinedValidationCallback to share samples between |
| validators. |
| |
| Args: |
| samples: Pre-generated text samples |
| step: Current training step |
| |
| Returns: |
| GrammarValidationResult with same structure as validate() |
| """ |
| try: |
| return asyncio.run(self.validate_samples_async(samples, step)) |
| except RuntimeError as e: |
| |
| if "already running" in str(e): |
| logger.warning("Event loop already running, falling back to sync validation") |
| return self.validate_samples_sync(samples, step) |
| raise |
|
|
| def get_trend(self, window: int = TREND_ANALYSIS_WINDOW) -> str: |
| """ |
| Detect improving/degrading trend. |
| |
| Args: |
| window: Number of recent scores to analyze |
| |
| Returns: |
| "improving", "degrading", "stable", or "insufficient_data" |
| """ |
| if len(self.grammar_scores) < window: |
| return "insufficient_data" |
|
|
| recent = list(self.grammar_scores)[-window:] |
| if all(recent[i] >= recent[i-1] for i in range(1, len(recent))): |
| return "improving" |
| elif all(recent[i] <= recent[i-1] for i in range(1, len(recent))): |
| return "degrading" |
| else: |
| return "stable" |
|
|
| class KnowledgeValidator: |
| """ |
| Factual accuracy validation using knowledge base. |
| |
| Runs post-training only (~10s). Tests model on 10 factual questions |
| to verify knowledge retention. |
| """ |
|
|
| def __init__(self, questions: list[dict[str, Any]]) -> None: |
| """ |
| Initialize KnowledgeValidator. |
| |
| Args: |
| questions: List of {"q": str, "a": list[str]} question/answer pairs |
| |
| Raises: |
| ValueError: If questions list is None or has invalid structure |
| TypeError: If questions is not a list |
| """ |
| if questions is None: |
| raise ValueError("questions cannot be None") |
| if not isinstance(questions, list): |
| raise TypeError("questions must be a list") |
| |
| for i, q in enumerate(questions): |
| if not isinstance(q, dict): |
| raise TypeError(f"Question at index {i} must be a dict") |
| if 'q' not in q or 'a' not in q: |
| raise ValueError(f"Question at index {i} must have 'q' and 'a' keys") |
| if not isinstance(q['q'], str): |
| raise TypeError(f"Question 'q' at index {i} must be a string") |
| if not isinstance(q['a'], list): |
| raise TypeError(f"Question 'a' at index {i} must be a list") |
|
|
| self.questions = questions |
|
|
| def validate(self, model: Any, step: int = -1) -> KnowledgeValidationResult: |
| """ |
| Run factual accuracy validation. |
| |
| Args: |
| model: Model to validate |
| step: Training step (default -1 for post-training) |
| |
| Returns: |
| KnowledgeValidationResult with keys: |
| - accuracy: float |
| - correct: int |
| - total: int |
| - failed: list[dict[str, Any]] |
| """ |
| correct = 0 |
| failed = [] |
|
|
| try: |
| with torch.inference_mode(): |
| for item in self.questions: |
| question = item['q'] |
| valid_answers = [a.lower() for a in item['a']] |
|
|
| try: |
| output = model.generate_text( |
| question, |
| max_length=KNOWLEDGE_MAX_LENGTH, |
| temperature=KNOWLEDGE_TEMPERATURE |
| ) |
| output_lower = output.lower() |
|
|
| |
| is_correct = any(ans in output_lower for ans in valid_answers) |
|
|
| if is_correct: |
| correct += 1 |
| else: |
| failed.append({ |
| 'question': question, |
| 'expected': item['a'], |
| 'got': output[:ERROR_LOG_TRUNCATE_LENGTH] |
| }) |
|
|
| except Exception as e: |
| logger.warning( |
| "Knowledge validation failed", |
| extra={"question": question, "error": str(e)} |
| ) |
| failed.append({ |
| 'question': question, |
| 'expected': item['a'], |
| 'got': f"ERROR: {str(e)}" |
| }) |
|
|
| except Exception as e: |
| logger.error( |
| "KnowledgeValidator failed", |
| extra={"error": str(e)} |
| ) |
| return { |
| "accuracy": 0.0, |
| "correct": 0, |
| "total": len(self.questions), |
| "failed": self.questions |
| } |
|
|
| return { |
| "accuracy": correct / len(self.questions) if self.questions else 0.0, |
| "correct": correct, |
| "total": len(self.questions), |
| "failed": failed |
| } |
|
|
| def validate_samples(self, samples: list[str], step: int) -> KnowledgeValidationResult: |
| """ |
| Not applicable for KnowledgeValidator (uses its own Q&A format). |
| |
| This method exists for Protocol compliance but is not supported. |
| Use validate() instead. |
| |
| Args: |
| samples: Unused (KnowledgeValidator generates from questions) |
| step: Training step |
| |
| Raises: |
| NotImplementedError: KnowledgeValidator doesn't support validate_samples |
| |
| Note: |
| KnowledgeValidator doesn't use pre-generated samples since it |
| tests factual knowledge with specific Q&A pairs. |
| """ |
| raise NotImplementedError( |
| "KnowledgeValidator doesn't support validate_samples. " |
| "Use validate(model, step) instead." |
| ) |
|
|
|
|
| class LanguageValidator: |
| """ |
| Language detection and word validity validation. |
| |
| Validates text is English with real words using: |
| - langdetect for language detection |
| - NLTK words corpus for English word validation |
| - Unicode script detection for multilingual text |
| |
| Runs every 100 steps with <1s overhead. |
| """ |
|
|
| def __init__(self, test_prompts: list[str]) -> None: |
| """ |
| Initialize LanguageValidator. |
| |
| Args: |
| test_prompts: List of prompts to test generation with |
| |
| Raises: |
| ValueError: If test_prompts is empty |
| TypeError: If test_prompts contains non-string elements |
| """ |
| if not test_prompts: |
| raise ValueError("test_prompts cannot be empty") |
| if not all(isinstance(p, str) for p in test_prompts): |
| raise TypeError("All test_prompts must be strings") |
|
|
| self.test_prompts = test_prompts |
|
|
| |
| self._english_words = None |
|
|
| @property |
| def english_words(self): |
| """Lazy-load NLTK words corpus.""" |
| if self._english_words is None: |
| try: |
| import nltk |
| from nltk.corpus import words |
| |
| nltk.data.path.insert(0, "/home/mikeb/nltk_data") |
| self._english_words = set(w.lower() for w in words.words()) |
| except Exception as e: |
| logger.warning( |
| "NLTK words corpus not available, using fallback", |
| extra={"error": str(e)} |
| ) |
| |
| self._english_words = set([ |
| 'the', 'be', 'to', 'of', 'and', 'a', 'in', 'that', 'have', 'i', |
| 'it', 'for', 'not', 'on', 'with', 'he', 'as', 'you', 'do', 'at' |
| ]) |
| return self._english_words |
|
|
| @staticmethod |
| def detect_language_with_confidence(text: str) -> tuple[str, float]: |
| """ |
| Detect language and return confidence score. |
| |
| Args: |
| text: Input text to analyze |
| |
| Returns: |
| Tuple of (language_code, confidence) |
| e.g., ('en', 0.95) for high-confidence English |
| """ |
| try: |
| import langdetect |
| from langdetect import DetectorFactory |
|
|
| |
| DetectorFactory.seed = 0 |
|
|
| |
| lang = langdetect.detect(text) |
|
|
| |
| probs = langdetect.detect_langs(text) |
|
|
| |
| en_confidence = next( |
| (p.prob for p in probs if p.lang == 'en'), |
| 0.0 |
| ) |
|
|
| return lang, en_confidence if lang == 'en' else 0.0 |
|
|
| except Exception as e: |
| logger.debug( |
| "Language detection failed", |
| extra={"error": str(e)} |
| ) |
| return 'unknown', 0.0 |
|
|
| @staticmethod |
| def detect_multilingual(text: str) -> dict[str, Any]: |
| """ |
| Detect mixed-language text (common gaming strategy). |
| |
| Args: |
| text: Input text to analyze |
| |
| Returns: |
| Dict with keys: |
| - is_multilingual: bool |
| - primary_script: str |
| - script_ratios: dict[str, float] |
| """ |
| |
| scripts = { |
| 'latin': 0, |
| 'cyrillic': 0, |
| 'arabic': 0, |
| 'cjk': 0, |
| 'greek': 0, |
| } |
|
|
| for char in text: |
| if 'a' <= char.lower() <= 'z': |
| scripts['latin'] += 1 |
| elif '\u0400' <= char <= '\u04FF': |
| scripts['cyrillic'] += 1 |
| elif '\u0600' <= char <= '\u06FF': |
| scripts['arabic'] += 1 |
| elif '\u4E00' <= char <= '\u9FFF': |
| scripts['cjk'] += 1 |
| elif '\u0370' <= char <= '\u03FF': |
| scripts['greek'] += 1 |
|
|
| total_letters = sum(scripts.values()) |
| if total_letters == 0: |
| return { |
| 'is_multilingual': False, |
| 'primary_script': 'none', |
| 'script_ratios': {} |
| } |
|
|
| |
| script_ratios = {k: v/total_letters for k, v in scripts.items()} |
|
|
| |
| primary_script = max(script_ratios, key=script_ratios.get) |
|
|
| |
| num_scripts = sum(1 for ratio in script_ratios.values() if ratio > 0.05) |
|
|
| return { |
| 'is_multilingual': num_scripts > 1, |
| 'primary_script': primary_script, |
| 'script_ratios': script_ratios, |
| } |
|
|
| def validate(self, model: Any, step: int) -> LanguageValidationResult: |
| """ |
| Run language detection and word validity validation. |
| |
| Args: |
| model: Model to validate |
| step: Current training step |
| |
| Returns: |
| LanguageValidationResult with keys: |
| - is_garbage: bool |
| - lang_confidence: float |
| - valid_word_ratio: float |
| - detected_language: str |
| - samples: list[str] |
| """ |
| samples = [] |
|
|
| try: |
| with torch.inference_mode(): |
| for prompt in self.test_prompts: |
| try: |
| sample = model.generate_text( |
| prompt, |
| max_length=VALIDATION_MAX_LENGTH, |
| temperature=VALIDATION_TEMPERATURE |
| ) |
| samples.append(sample) |
| except Exception as e: |
| logger.warning( |
| "Generation failed for prompt", |
| extra={"prompt": prompt, "error": str(e)} |
| ) |
| samples.append("") |
|
|
| except Exception as e: |
| logger.error( |
| "LanguageValidator failed", |
| extra={"step": step, "error": str(e)} |
| ) |
| return { |
| "is_garbage": True, |
| "lang_confidence": 0.0, |
| "valid_word_ratio": 0.0, |
| "detected_language": "unknown", |
| "samples": [] |
| } |
|
|
| |
| return self.validate_samples(samples, step) |
|
|
| def validate_samples(self, samples: list[str], step: int) -> LanguageValidationResult: |
| """ |
| Run language validation on pre-generated samples. |
| |
| This method allows sharing samples between multiple validators, |
| reducing generation cost. |
| |
| Args: |
| samples: Pre-generated text samples |
| step: Current training step |
| |
| Returns: |
| LanguageValidationResult with keys: |
| - is_garbage: bool |
| - lang_confidence: float |
| - valid_word_ratio: float |
| - detected_language: str |
| - samples: list[str] |
| """ |
| if not samples: |
| return { |
| "is_garbage": True, |
| "lang_confidence": 0.0, |
| "valid_word_ratio": 0.0, |
| "detected_language": "unknown", |
| "samples": [] |
| } |
|
|
| |
| lang_confidences = [] |
| detected_langs = [] |
| valid_word_ratios = [] |
|
|
| for sample in samples: |
| if not sample or len(sample) < MIN_SAMPLE_LENGTH: |
| lang_confidences.append(0.0) |
| detected_langs.append('unknown') |
| valid_word_ratios.append(0.0) |
| continue |
|
|
| |
| lang, confidence = self.detect_language_with_confidence(sample) |
| lang_confidences.append(confidence) |
| detected_langs.append(lang) |
|
|
| |
| tokens = sample.lower().split() |
| clean_tokens = [ |
| t.strip('.,!?;:()[]{}"\'-') |
| for t in tokens |
| if t.strip('.,!?;:()[]{}"\'-') |
| ] |
|
|
| if clean_tokens: |
| valid_count = sum( |
| 1 for t in clean_tokens |
| if t in self.english_words |
| ) |
| valid_ratio = valid_count / len(clean_tokens) |
| else: |
| valid_ratio = 0.0 |
|
|
| valid_word_ratios.append(valid_ratio) |
|
|
| |
| avg_lang_confidence = sum(lang_confidences) / len(lang_confidences) |
| avg_valid_word_ratio = sum(valid_word_ratios) / len(valid_word_ratios) |
|
|
| |
| from collections import Counter |
| lang_counts = Counter(detected_langs) |
| primary_lang = lang_counts.most_common(1)[0][0] |
|
|
| |
| any_multilingual = any( |
| self.detect_multilingual(s)['is_multilingual'] |
| for s in samples |
| if s and len(s) >= MIN_SAMPLE_LENGTH |
| ) |
|
|
| |
| is_garbage = ( |
| primary_lang != 'en' or |
| avg_lang_confidence < 0.8 or |
| avg_valid_word_ratio < 0.7 or |
| any_multilingual |
| ) |
|
|
| |
| sanitized_samples = [sanitize(s, mode="pii") for s in samples] |
|
|
| return { |
| "is_garbage": is_garbage, |
| "lang_confidence": avg_lang_confidence, |
| "valid_word_ratio": avg_valid_word_ratio, |
| "detected_language": primary_lang, |
| "samples": sanitized_samples |
| } |
|
|
|
|
| class PerplexityValidator: |
| """ |
| Autoregressive perplexity validation using DistilGPT-2. |
| |
| Measures language fluency using pre-trained transformer model. |
| Uses mixed precision (AMP) for 2x speedup. |
| |
| Runs every 100 steps with ~500ms overhead (with batching). |
| """ |
|
|
| def __init__(self, test_prompts: list[str], model_name: str = "distilgpt2") -> None: |
| """ |
| Initialize PerplexityValidator. |
| |
| Args: |
| test_prompts: List of prompts to test generation with |
| model_name: HuggingFace model name (default: "distilgpt2") |
| |
| Raises: |
| ValueError: If test_prompts is empty |
| TypeError: If test_prompts contains non-string elements |
| """ |
| if not test_prompts: |
| raise ValueError("test_prompts cannot be empty") |
| if not all(isinstance(p, str) for p in test_prompts): |
| raise TypeError("All test_prompts must be strings") |
|
|
| self.test_prompts = test_prompts |
| self.model_name = model_name |
|
|
| |
| self._model = None |
| self._tokenizer = None |
|
|
| @property |
| def model(self): |
| """Lazy-load DistilGPT-2 model.""" |
| if self._model is None: |
| try: |
| from transformers import AutoModelForCausalLM |
| self._model = AutoModelForCausalLM.from_pretrained( |
| self.model_name |
| ).to('cuda') |
| self._model.eval() |
| except Exception as e: |
| logger.error( |
| "Failed to load perplexity model", |
| extra={"model": self.model_name, "error": str(e)} |
| ) |
| raise |
| return self._model |
|
|
| @property |
| def tokenizer(self): |
| """Lazy-load tokenizer.""" |
| if self._tokenizer is None: |
| try: |
| from transformers import AutoTokenizer |
| self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| except Exception as e: |
| logger.error( |
| "Failed to load tokenizer", |
| extra={"model": self.model_name, "error": str(e)} |
| ) |
| raise |
| return self._tokenizer |
|
|
| def validate(self, model: Any, step: int) -> PerplexityValidationResult: |
| """ |
| Run perplexity validation. |
| |
| Args: |
| model: Model to validate |
| step: Current training step |
| |
| Returns: |
| PerplexityValidationResult with keys: |
| - perplexity: float |
| - perplexity_normalized: float (0-1 score for reward) |
| - samples: list[str] |
| """ |
| samples = [] |
|
|
| try: |
| with torch.inference_mode(): |
| for prompt in self.test_prompts: |
| try: |
| sample = model.generate_text( |
| prompt, |
| max_length=VALIDATION_MAX_LENGTH, |
| temperature=VALIDATION_TEMPERATURE |
| ) |
| samples.append(sample) |
| except Exception as e: |
| logger.warning( |
| "Generation failed for prompt", |
| extra={"prompt": prompt, "error": str(e)} |
| ) |
| samples.append("") |
|
|
| except Exception as e: |
| logger.error( |
| "PerplexityValidator generation failed", |
| extra={"step": step, "error": str(e)} |
| ) |
| return { |
| "perplexity": float('inf'), |
| "perplexity_normalized": 0.0, |
| "samples": [] |
| } |
|
|
| |
| return self.validate_samples(samples, step) |
|
|
| def validate_samples(self, samples: list[str], step: int) -> PerplexityValidationResult: |
| """ |
| Run perplexity validation on pre-generated samples. |
| |
| This method allows sharing samples between multiple validators, |
| reducing generation cost. |
| |
| Args: |
| samples: Pre-generated text samples |
| step: Current training step |
| |
| Returns: |
| PerplexityValidationResult with keys: |
| - perplexity: float |
| - perplexity_normalized: float (0-1 score for reward) |
| - samples: list[str] |
| """ |
| if not samples: |
| return { |
| "perplexity": float('inf'), |
| "perplexity_normalized": 0.0, |
| "samples": [] |
| } |
|
|
| |
| valid_samples = [ |
| s for s in samples |
| if s and len(s) >= MIN_SAMPLE_LENGTH |
| ] |
|
|
| if not valid_samples: |
| return { |
| "perplexity": float('inf'), |
| "perplexity_normalized": 0.0, |
| "samples": [sanitize(s, mode="pii") for s in samples] |
| } |
|
|
| |
| perplexities = [] |
|
|
| try: |
| for sample in valid_samples: |
| |
| encodings = self.tokenizer( |
| sample, |
| return_tensors='pt', |
| truncation=True, |
| max_length=512 |
| ).to('cuda') |
|
|
| |
| with torch.no_grad(), torch.amp.autocast("cuda"): |
| outputs = self.model(**encodings, labels=encodings.input_ids) |
| ce = outputs.loss.item() |
|
|
| |
| perplexity = torch.exp(torch.tensor(ce)).item() |
| perplexities.append(perplexity) |
|
|
| except Exception as e: |
| logger.error( |
| "Perplexity computation failed", |
| extra={"error": str(e)} |
| ) |
| return { |
| "perplexity": float('inf'), |
| "perplexity_normalized": 0.0, |
| "samples": [sanitize(s, mode="pii") for s in samples] |
| } |
|
|
| |
| avg_perplexity = sum(perplexities) / len(perplexities) |
|
|
| |
| |
| import math |
| normalized_score = math.exp(-avg_perplexity / 10.0) |
|
|
| return { |
| "perplexity": avg_perplexity, |
| "perplexity_normalized": normalized_score, |
| "samples": [sanitize(s, mode="pii") for s in samples] |
| } |
|
|