""" PyTorch Lightning callbacks for validation during training. Provides non-blocking validation callbacks that integrate with TensorBoard and respect the 100s training budget. Key optimization: CUDA streams for validation prefetch to overlap with training. """ __all__ = [ "ValidationCallback", "CombinedValidationCallback", ] # Standard library import logging from typing import Any, Optional # Third-party import pytorch_lightning as pl import torch from torch.cuda import Stream # Local from .validators import Validator from .constants import ( VALIDATION_MAX_LENGTH, VALIDATION_TEMPERATURE, LOW_GRAMMAR_SCORE_THRESHOLD, TARGET_GRAMMAR_SCORE, MAX_TENSORBOARD_SAMPLES, GRAMMAR_VALIDATION_FREQUENCY, ) logger = logging.getLogger(__name__) class ValidationCallback(pl.Callback): """ Generic validation callback that delegates to a validator. Runs validation at specified frequency with proper GPU memory management and non-blocking execution. CUDA streams optimization: Validation runs in a separate stream to overlap with training, reducing validation overhead by 10-20%. """ def __init__(self, validator: Validator, frequency: int, name: str) -> None: """ Initialize ValidationCallback. Args: validator: Validator instance (FastValidator, GrammarValidator, etc.) frequency: Run validation every N steps name: Name for logging (e.g., "fast", "grammar") """ super().__init__() self.validator = validator self.frequency = frequency self.name = name # CUDA stream for async validation (if GPU available) self.stream: Optional[Stream] if torch.cuda.is_available(): self.stream = torch.cuda.Stream() # type: ignore[no-untyped-call] logger.info(f"ValidationCallback[{name}]: CUDA stream created for async validation") else: self.stream = None logger.warning(f"ValidationCallback[{name}]: CUDA not available, stream disabled") def on_train_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int ) -> None: """ Run validation at specified frequency. Uses CUDA streams to overlap validation with training: 1. Launch validation in separate stream 2. Training continues in default stream 3. Sync before logging to ensure results are ready """ if pl_module.global_step % self.frequency != 0: return if pl_module.global_step == 0: return # Skip step 0 try: # Run validation in separate CUDA stream if available if self.stream is not None: with torch.cuda.stream(self.stream): results = self.validator.validate(pl_module.model, pl_module.global_step) # Training continues in default stream while validation runs # Sync before logging to ensure validation completed torch.cuda.current_stream().wait_stream(self.stream) else: # CPU fallback (no stream) results = self.validator.validate(pl_module.model, pl_module.global_step) # Log to TensorBoard self._log_results(pl_module, results) # Alert if quality issues detected self._check_alerts(pl_module.global_step, results) except Exception as e: logger.error( "Validation failed", extra={ "validator": self.name, "step": pl_module.global_step, "error": str(e) } ) def _log_results(self, pl_module: pl.LightningModule, results: dict[str, Any]) -> None: """Log validation results to TensorBoard.""" # Log scalar metrics for key, value in results.items(): if isinstance(value, (int, float, bool)): pl_module.log(f"{self.name}_{key}", float(value)) # Log text samples if available if "samples" in results and results["samples"]: try: sample_text = "\n\n".join( f"**Sample {i+1}:** {sample}" for i, sample in enumerate(results["samples"][:MAX_TENSORBOARD_SAMPLES]) ) if pl_module.logger is not None: pl_module.logger.experiment.add_text( # type: ignore[attr-defined] f"{self.name}_samples", sample_text, pl_module.global_step ) except Exception as e: logger.warning( "Failed to log samples", extra={"error": str(e)} ) def _check_alerts(self, step: int, results: dict[str, Any]) -> None: """Check for quality issues and alert.""" if self.name == "fast" and results.get("is_garbage"): logger.warning( "GARBAGE OUTPUT detected", extra={ "step": step, "ascii_ratio": results.get('ascii_ratio', 0), "avg_length": results.get('avg_length', 0), "repetition_ratio": results.get('repetition_ratio', 0) } ) if self.name == "grammar": score = results.get("grammar_score", 0.0) if score < LOW_GRAMMAR_SCORE_THRESHOLD: logger.warning( "LOW GRAMMAR SCORE", extra={ "step": step, "score": score, "target": TARGET_GRAMMAR_SCORE, "is_fallback": results.get('is_fallback', False) } ) # Check for degrading trend if hasattr(self.validator, 'get_trend'): trend = self.validator.get_trend() if trend == "degrading": logger.warning( "GRAMMAR DEGRADING", extra={"step": step, "trend": trend} ) class CombinedValidationCallback(pl.Callback): """ Combined validation callback that shares samples between validators. Generates samples once and passes them to both FastValidator and GrammarValidator, reducing generation cost by 50%. Runs at the frequency of the slower validator (grammar every 200 steps). CUDA streams optimization: Sample generation runs in separate stream to overlap with training. """ def __init__( self, fast_validator: Validator, grammar_validator: Validator, test_prompts: list[str], frequency: int = GRAMMAR_VALIDATION_FREQUENCY ) -> None: """ Initialize CombinedValidationCallback. Args: fast_validator: FastValidator instance grammar_validator: GrammarValidator instance test_prompts: List of prompts to generate samples from frequency: Run validation every N steps (default: 200) """ super().__init__() self.fast_validator = fast_validator self.grammar_validator = grammar_validator self.test_prompts = test_prompts self.frequency = frequency # CUDA stream for async validation (if GPU available) self.stream: Optional[Stream] if torch.cuda.is_available(): self.stream = torch.cuda.Stream() # type: ignore[no-untyped-call] logger.info("CombinedValidationCallback: CUDA stream created for async validation") else: self.stream = None logger.warning("CombinedValidationCallback: CUDA not available, stream disabled") def on_train_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int ) -> None: """ Run combined validation at specified frequency. Uses CUDA streams to overlap validation with training: 1. Launch sample generation + validation in separate stream 2. Training continues in default stream 3. Sync before logging to ensure results are ready """ if pl_module.global_step % self.frequency != 0: return if pl_module.global_step == 0: return # Skip step 0 try: # Run validation in separate CUDA stream if available if self.stream is not None: with torch.cuda.stream(self.stream): samples = self._generate_samples(pl_module) fast_results = self.fast_validator.validate_samples( samples, pl_module.global_step ) grammar_results = self.grammar_validator.validate_samples( samples, pl_module.global_step ) # Training continues in default stream while validation runs # Sync before logging to ensure validation completed torch.cuda.current_stream().wait_stream(self.stream) else: # CPU fallback (no stream) samples = self._generate_samples(pl_module) fast_results = self.fast_validator.validate_samples( samples, pl_module.global_step ) grammar_results = self.grammar_validator.validate_samples( samples, pl_module.global_step ) # Log results for both validators self._log_results(pl_module, "fast", fast_results) self._log_results(pl_module, "grammar", grammar_results) # Check alerts for both self._check_fast_alerts(pl_module.global_step, fast_results) self._check_grammar_alerts(pl_module.global_step, grammar_results) except Exception as e: logger.error( "Combined validation failed", extra={ "step": pl_module.global_step, "error": str(e) } ) def _generate_samples(self, pl_module: pl.LightningModule) -> list[str]: """ Generate samples for validation. Args: pl_module: LightningModule with model Returns: List of generated text samples """ samples = [] with torch.inference_mode(): for prompt in self.test_prompts: try: sample = pl_module.model.generate_text( prompt, max_length=VALIDATION_MAX_LENGTH, temperature=VALIDATION_TEMPERATURE ) samples.append(sample) except Exception as e: logger.warning( "Generation failed for prompt", extra={"prompt": prompt, "error": str(e)} ) samples.append("") return samples def _log_results(self, pl_module: pl.LightningModule, name: str, results: dict[str, Any]) -> None: """Log validation results to TensorBoard.""" # Log scalar metrics for key, value in results.items(): if isinstance(value, (int, float, bool)): pl_module.log(f"{name}_{key}", float(value)) # Log text samples if available if "samples" in results and results["samples"]: try: sample_text = "\n\n".join( f"**Sample {i+1}:** {sample}" for i, sample in enumerate(results["samples"][:MAX_TENSORBOARD_SAMPLES]) ) if pl_module.logger is not None: pl_module.logger.experiment.add_text( # type: ignore[attr-defined] f"{name}_samples", sample_text, pl_module.global_step ) except Exception as e: logger.warning( "Failed to log samples", extra={"error": str(e)} ) def _check_fast_alerts(self, step: int, results: dict[str, Any]) -> None: """Check for fast validation quality issues.""" if results.get("is_garbage"): logger.warning( "GARBAGE OUTPUT detected", extra={ "step": step, "ascii_ratio": results.get('ascii_ratio', 0), "avg_length": results.get('avg_length', 0), "repetition_ratio": results.get('repetition_ratio', 0) } ) def _check_grammar_alerts(self, step: int, results: dict[str, Any]) -> None: """Check for grammar validation quality issues.""" score = results.get("grammar_score", 0.0) if score < LOW_GRAMMAR_SCORE_THRESHOLD: logger.warning( "LOW GRAMMAR SCORE", extra={ "step": step, "score": score, "target": TARGET_GRAMMAR_SCORE, "is_fallback": results.get('is_fallback', False) } ) # Check for degrading trend if hasattr(self.grammar_validator, 'get_trend'): trend = self.grammar_validator.get_trend() if trend == "degrading": logger.warning( "GRAMMAR DEGRADING", extra={"step": step, "trend": trend} )