| """ |
| PyTorch Lightning callbacks for validation during training. |
| |
| Provides non-blocking validation callbacks that integrate with TensorBoard |
| and respect the 100s training budget. |
| |
| Key optimization: CUDA streams for validation prefetch to overlap with training. |
| """ |
|
|
| __all__ = [ |
| "ValidationCallback", |
| "CombinedValidationCallback", |
| ] |
|
|
| |
| import logging |
| from typing import Any, Optional |
|
|
| |
| import pytorch_lightning as pl |
| import torch |
| from torch.cuda import Stream |
|
|
| |
| from .validators import Validator |
| from .constants import ( |
| VALIDATION_MAX_LENGTH, |
| VALIDATION_TEMPERATURE, |
| LOW_GRAMMAR_SCORE_THRESHOLD, |
| TARGET_GRAMMAR_SCORE, |
| MAX_TENSORBOARD_SAMPLES, |
| GRAMMAR_VALIDATION_FREQUENCY, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ValidationCallback(pl.Callback): |
| """ |
| Generic validation callback that delegates to a validator. |
| |
| Runs validation at specified frequency with proper GPU memory management |
| and non-blocking execution. |
| |
| CUDA streams optimization: Validation runs in a separate stream to overlap |
| with training, reducing validation overhead by 10-20%. |
| """ |
|
|
| def __init__(self, validator: Validator, frequency: int, name: str) -> None: |
| """ |
| Initialize ValidationCallback. |
| |
| Args: |
| validator: Validator instance (FastValidator, GrammarValidator, etc.) |
| frequency: Run validation every N steps |
| name: Name for logging (e.g., "fast", "grammar") |
| """ |
| super().__init__() |
| self.validator = validator |
| self.frequency = frequency |
| self.name = name |
|
|
| |
| self.stream: Optional[Stream] |
| if torch.cuda.is_available(): |
| self.stream = torch.cuda.Stream() |
| logger.info(f"ValidationCallback[{name}]: CUDA stream created for async validation") |
| else: |
| self.stream = None |
| logger.warning(f"ValidationCallback[{name}]: CUDA not available, stream disabled") |
|
|
| def on_train_batch_end( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: Any, |
| batch: Any, |
| batch_idx: int |
| ) -> None: |
| """ |
| Run validation at specified frequency. |
| |
| Uses CUDA streams to overlap validation with training: |
| 1. Launch validation in separate stream |
| 2. Training continues in default stream |
| 3. Sync before logging to ensure results are ready |
| """ |
| if pl_module.global_step % self.frequency != 0: |
| return |
|
|
| if pl_module.global_step == 0: |
| return |
|
|
| try: |
| |
| if self.stream is not None: |
| with torch.cuda.stream(self.stream): |
| results = self.validator.validate(pl_module.model, pl_module.global_step) |
|
|
| |
| |
| torch.cuda.current_stream().wait_stream(self.stream) |
| else: |
| |
| results = self.validator.validate(pl_module.model, pl_module.global_step) |
|
|
| |
| self._log_results(pl_module, results) |
|
|
| |
| self._check_alerts(pl_module.global_step, results) |
|
|
| except Exception as e: |
| logger.error( |
| "Validation failed", |
| extra={ |
| "validator": self.name, |
| "step": pl_module.global_step, |
| "error": str(e) |
| } |
| ) |
|
|
| def _log_results(self, pl_module: pl.LightningModule, results: dict[str, Any]) -> None: |
| """Log validation results to TensorBoard.""" |
| |
| for key, value in results.items(): |
| if isinstance(value, (int, float, bool)): |
| pl_module.log(f"{self.name}_{key}", float(value)) |
|
|
| |
| if "samples" in results and results["samples"]: |
| try: |
| sample_text = "\n\n".join( |
| f"**Sample {i+1}:** {sample}" |
| for i, sample in enumerate(results["samples"][:MAX_TENSORBOARD_SAMPLES]) |
| ) |
| if pl_module.logger is not None: |
| pl_module.logger.experiment.add_text( |
| f"{self.name}_samples", |
| sample_text, |
| pl_module.global_step |
| ) |
| except Exception as e: |
| logger.warning( |
| "Failed to log samples", |
| extra={"error": str(e)} |
| ) |
|
|
| def _check_alerts(self, step: int, results: dict[str, Any]) -> None: |
| """Check for quality issues and alert.""" |
| if self.name == "fast" and results.get("is_garbage"): |
| logger.warning( |
| "GARBAGE OUTPUT detected", |
| extra={ |
| "step": step, |
| "ascii_ratio": results.get('ascii_ratio', 0), |
| "avg_length": results.get('avg_length', 0), |
| "repetition_ratio": results.get('repetition_ratio', 0) |
| } |
| ) |
|
|
| if self.name == "grammar": |
| score = results.get("grammar_score", 0.0) |
| if score < LOW_GRAMMAR_SCORE_THRESHOLD: |
| logger.warning( |
| "LOW GRAMMAR SCORE", |
| extra={ |
| "step": step, |
| "score": score, |
| "target": TARGET_GRAMMAR_SCORE, |
| "is_fallback": results.get('is_fallback', False) |
| } |
| ) |
|
|
| |
| if hasattr(self.validator, 'get_trend'): |
| trend = self.validator.get_trend() |
| if trend == "degrading": |
| logger.warning( |
| "GRAMMAR DEGRADING", |
| extra={"step": step, "trend": trend} |
| ) |
|
|
| class CombinedValidationCallback(pl.Callback): |
| """ |
| Combined validation callback that shares samples between validators. |
| |
| Generates samples once and passes them to both FastValidator and |
| GrammarValidator, reducing generation cost by 50%. |
| |
| Runs at the frequency of the slower validator (grammar every 200 steps). |
| |
| CUDA streams optimization: Sample generation runs in separate stream to |
| overlap with training. |
| """ |
|
|
| def __init__( |
| self, |
| fast_validator: Validator, |
| grammar_validator: Validator, |
| test_prompts: list[str], |
| frequency: int = GRAMMAR_VALIDATION_FREQUENCY |
| ) -> None: |
| """ |
| Initialize CombinedValidationCallback. |
| |
| Args: |
| fast_validator: FastValidator instance |
| grammar_validator: GrammarValidator instance |
| test_prompts: List of prompts to generate samples from |
| frequency: Run validation every N steps (default: 200) |
| """ |
| super().__init__() |
| self.fast_validator = fast_validator |
| self.grammar_validator = grammar_validator |
| self.test_prompts = test_prompts |
| self.frequency = frequency |
|
|
| |
| self.stream: Optional[Stream] |
| if torch.cuda.is_available(): |
| self.stream = torch.cuda.Stream() |
| logger.info("CombinedValidationCallback: CUDA stream created for async validation") |
| else: |
| self.stream = None |
| logger.warning("CombinedValidationCallback: CUDA not available, stream disabled") |
|
|
| def on_train_batch_end( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: Any, |
| batch: Any, |
| batch_idx: int |
| ) -> None: |
| """ |
| Run combined validation at specified frequency. |
| |
| Uses CUDA streams to overlap validation with training: |
| 1. Launch sample generation + validation in separate stream |
| 2. Training continues in default stream |
| 3. Sync before logging to ensure results are ready |
| """ |
| if pl_module.global_step % self.frequency != 0: |
| return |
|
|
| if pl_module.global_step == 0: |
| return |
|
|
| try: |
| |
| if self.stream is not None: |
| with torch.cuda.stream(self.stream): |
| samples = self._generate_samples(pl_module) |
| fast_results = self.fast_validator.validate_samples( |
| samples, pl_module.global_step |
| ) |
| grammar_results = self.grammar_validator.validate_samples( |
| samples, pl_module.global_step |
| ) |
|
|
| |
| |
| torch.cuda.current_stream().wait_stream(self.stream) |
| else: |
| |
| samples = self._generate_samples(pl_module) |
| fast_results = self.fast_validator.validate_samples( |
| samples, pl_module.global_step |
| ) |
| grammar_results = self.grammar_validator.validate_samples( |
| samples, pl_module.global_step |
| ) |
|
|
| |
| self._log_results(pl_module, "fast", fast_results) |
| self._log_results(pl_module, "grammar", grammar_results) |
|
|
| |
| self._check_fast_alerts(pl_module.global_step, fast_results) |
| self._check_grammar_alerts(pl_module.global_step, grammar_results) |
|
|
| except Exception as e: |
| logger.error( |
| "Combined validation failed", |
| extra={ |
| "step": pl_module.global_step, |
| "error": str(e) |
| } |
| ) |
|
|
| def _generate_samples(self, pl_module: pl.LightningModule) -> list[str]: |
| """ |
| Generate samples for validation. |
| |
| Args: |
| pl_module: LightningModule with model |
| |
| Returns: |
| List of generated text samples |
| """ |
| samples = [] |
| with torch.inference_mode(): |
| for prompt in self.test_prompts: |
| try: |
| sample = pl_module.model.generate_text( |
| prompt, |
| max_length=VALIDATION_MAX_LENGTH, |
| temperature=VALIDATION_TEMPERATURE |
| ) |
| samples.append(sample) |
| except Exception as e: |
| logger.warning( |
| "Generation failed for prompt", |
| extra={"prompt": prompt, "error": str(e)} |
| ) |
| samples.append("") |
| return samples |
|
|
| def _log_results(self, pl_module: pl.LightningModule, name: str, results: dict[str, Any]) -> None: |
| """Log validation results to TensorBoard.""" |
| |
| for key, value in results.items(): |
| if isinstance(value, (int, float, bool)): |
| pl_module.log(f"{name}_{key}", float(value)) |
|
|
| |
| if "samples" in results and results["samples"]: |
| try: |
| sample_text = "\n\n".join( |
| f"**Sample {i+1}:** {sample}" |
| for i, sample in enumerate(results["samples"][:MAX_TENSORBOARD_SAMPLES]) |
| ) |
| if pl_module.logger is not None: |
| pl_module.logger.experiment.add_text( |
| f"{name}_samples", |
| sample_text, |
| pl_module.global_step |
| ) |
| except Exception as e: |
| logger.warning( |
| "Failed to log samples", |
| extra={"error": str(e)} |
| ) |
|
|
| def _check_fast_alerts(self, step: int, results: dict[str, Any]) -> None: |
| """Check for fast validation quality issues.""" |
| if results.get("is_garbage"): |
| logger.warning( |
| "GARBAGE OUTPUT detected", |
| extra={ |
| "step": step, |
| "ascii_ratio": results.get('ascii_ratio', 0), |
| "avg_length": results.get('avg_length', 0), |
| "repetition_ratio": results.get('repetition_ratio', 0) |
| } |
| ) |
|
|
| def _check_grammar_alerts(self, step: int, results: dict[str, Any]) -> None: |
| """Check for grammar validation quality issues.""" |
| score = results.get("grammar_score", 0.0) |
| if score < LOW_GRAMMAR_SCORE_THRESHOLD: |
| logger.warning( |
| "LOW GRAMMAR SCORE", |
| extra={ |
| "step": step, |
| "score": score, |
| "target": TARGET_GRAMMAR_SCORE, |
| "is_fallback": results.get('is_fallback', False) |
| } |
| ) |
|
|
| |
| if hasattr(self.grammar_validator, 'get_trend'): |
| trend = self.grammar_validator.get_trend() |
| if trend == "degrading": |
| logger.warning( |
| "GRAMMAR DEGRADING", |
| extra={"step": step, "trend": trend} |
| ) |
|
|