icarus112's picture
Upload folder using huggingface_hub
518db7a verified
"""
PyTorch Lightning callbacks for validation during training.
Provides non-blocking validation callbacks that integrate with TensorBoard
and respect the 100s training budget.
Key optimization: CUDA streams for validation prefetch to overlap with training.
"""
__all__ = [
"ValidationCallback",
"CombinedValidationCallback",
]
# Standard library
import logging
from typing import Any, Optional
# Third-party
import pytorch_lightning as pl
import torch
from torch.cuda import Stream
# Local
from .validators import Validator
from .constants import (
VALIDATION_MAX_LENGTH,
VALIDATION_TEMPERATURE,
LOW_GRAMMAR_SCORE_THRESHOLD,
TARGET_GRAMMAR_SCORE,
MAX_TENSORBOARD_SAMPLES,
GRAMMAR_VALIDATION_FREQUENCY,
)
logger = logging.getLogger(__name__)
class ValidationCallback(pl.Callback):
"""
Generic validation callback that delegates to a validator.
Runs validation at specified frequency with proper GPU memory management
and non-blocking execution.
CUDA streams optimization: Validation runs in a separate stream to overlap
with training, reducing validation overhead by 10-20%.
"""
def __init__(self, validator: Validator, frequency: int, name: str) -> None:
"""
Initialize ValidationCallback.
Args:
validator: Validator instance (FastValidator, GrammarValidator, etc.)
frequency: Run validation every N steps
name: Name for logging (e.g., "fast", "grammar")
"""
super().__init__()
self.validator = validator
self.frequency = frequency
self.name = name
# CUDA stream for async validation (if GPU available)
self.stream: Optional[Stream]
if torch.cuda.is_available():
self.stream = torch.cuda.Stream() # type: ignore[no-untyped-call]
logger.info(f"ValidationCallback[{name}]: CUDA stream created for async validation")
else:
self.stream = None
logger.warning(f"ValidationCallback[{name}]: CUDA not available, stream disabled")
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int
) -> None:
"""
Run validation at specified frequency.
Uses CUDA streams to overlap validation with training:
1. Launch validation in separate stream
2. Training continues in default stream
3. Sync before logging to ensure results are ready
"""
if pl_module.global_step % self.frequency != 0:
return
if pl_module.global_step == 0:
return # Skip step 0
try:
# Run validation in separate CUDA stream if available
if self.stream is not None:
with torch.cuda.stream(self.stream):
results = self.validator.validate(pl_module.model, pl_module.global_step)
# Training continues in default stream while validation runs
# Sync before logging to ensure validation completed
torch.cuda.current_stream().wait_stream(self.stream)
else:
# CPU fallback (no stream)
results = self.validator.validate(pl_module.model, pl_module.global_step)
# Log to TensorBoard
self._log_results(pl_module, results)
# Alert if quality issues detected
self._check_alerts(pl_module.global_step, results)
except Exception as e:
logger.error(
"Validation failed",
extra={
"validator": self.name,
"step": pl_module.global_step,
"error": str(e)
}
)
def _log_results(self, pl_module: pl.LightningModule, results: dict[str, Any]) -> None:
"""Log validation results to TensorBoard."""
# Log scalar metrics
for key, value in results.items():
if isinstance(value, (int, float, bool)):
pl_module.log(f"{self.name}_{key}", float(value))
# Log text samples if available
if "samples" in results and results["samples"]:
try:
sample_text = "\n\n".join(
f"**Sample {i+1}:** {sample}"
for i, sample in enumerate(results["samples"][:MAX_TENSORBOARD_SAMPLES])
)
if pl_module.logger is not None:
pl_module.logger.experiment.add_text( # type: ignore[attr-defined]
f"{self.name}_samples",
sample_text,
pl_module.global_step
)
except Exception as e:
logger.warning(
"Failed to log samples",
extra={"error": str(e)}
)
def _check_alerts(self, step: int, results: dict[str, Any]) -> None:
"""Check for quality issues and alert."""
if self.name == "fast" and results.get("is_garbage"):
logger.warning(
"GARBAGE OUTPUT detected",
extra={
"step": step,
"ascii_ratio": results.get('ascii_ratio', 0),
"avg_length": results.get('avg_length', 0),
"repetition_ratio": results.get('repetition_ratio', 0)
}
)
if self.name == "grammar":
score = results.get("grammar_score", 0.0)
if score < LOW_GRAMMAR_SCORE_THRESHOLD:
logger.warning(
"LOW GRAMMAR SCORE",
extra={
"step": step,
"score": score,
"target": TARGET_GRAMMAR_SCORE,
"is_fallback": results.get('is_fallback', False)
}
)
# Check for degrading trend
if hasattr(self.validator, 'get_trend'):
trend = self.validator.get_trend()
if trend == "degrading":
logger.warning(
"GRAMMAR DEGRADING",
extra={"step": step, "trend": trend}
)
class CombinedValidationCallback(pl.Callback):
"""
Combined validation callback that shares samples between validators.
Generates samples once and passes them to both FastValidator and
GrammarValidator, reducing generation cost by 50%.
Runs at the frequency of the slower validator (grammar every 200 steps).
CUDA streams optimization: Sample generation runs in separate stream to
overlap with training.
"""
def __init__(
self,
fast_validator: Validator,
grammar_validator: Validator,
test_prompts: list[str],
frequency: int = GRAMMAR_VALIDATION_FREQUENCY
) -> None:
"""
Initialize CombinedValidationCallback.
Args:
fast_validator: FastValidator instance
grammar_validator: GrammarValidator instance
test_prompts: List of prompts to generate samples from
frequency: Run validation every N steps (default: 200)
"""
super().__init__()
self.fast_validator = fast_validator
self.grammar_validator = grammar_validator
self.test_prompts = test_prompts
self.frequency = frequency
# CUDA stream for async validation (if GPU available)
self.stream: Optional[Stream]
if torch.cuda.is_available():
self.stream = torch.cuda.Stream() # type: ignore[no-untyped-call]
logger.info("CombinedValidationCallback: CUDA stream created for async validation")
else:
self.stream = None
logger.warning("CombinedValidationCallback: CUDA not available, stream disabled")
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int
) -> None:
"""
Run combined validation at specified frequency.
Uses CUDA streams to overlap validation with training:
1. Launch sample generation + validation in separate stream
2. Training continues in default stream
3. Sync before logging to ensure results are ready
"""
if pl_module.global_step % self.frequency != 0:
return
if pl_module.global_step == 0:
return # Skip step 0
try:
# Run validation in separate CUDA stream if available
if self.stream is not None:
with torch.cuda.stream(self.stream):
samples = self._generate_samples(pl_module)
fast_results = self.fast_validator.validate_samples(
samples, pl_module.global_step
)
grammar_results = self.grammar_validator.validate_samples(
samples, pl_module.global_step
)
# Training continues in default stream while validation runs
# Sync before logging to ensure validation completed
torch.cuda.current_stream().wait_stream(self.stream)
else:
# CPU fallback (no stream)
samples = self._generate_samples(pl_module)
fast_results = self.fast_validator.validate_samples(
samples, pl_module.global_step
)
grammar_results = self.grammar_validator.validate_samples(
samples, pl_module.global_step
)
# Log results for both validators
self._log_results(pl_module, "fast", fast_results)
self._log_results(pl_module, "grammar", grammar_results)
# Check alerts for both
self._check_fast_alerts(pl_module.global_step, fast_results)
self._check_grammar_alerts(pl_module.global_step, grammar_results)
except Exception as e:
logger.error(
"Combined validation failed",
extra={
"step": pl_module.global_step,
"error": str(e)
}
)
def _generate_samples(self, pl_module: pl.LightningModule) -> list[str]:
"""
Generate samples for validation.
Args:
pl_module: LightningModule with model
Returns:
List of generated text samples
"""
samples = []
with torch.inference_mode():
for prompt in self.test_prompts:
try:
sample = pl_module.model.generate_text(
prompt,
max_length=VALIDATION_MAX_LENGTH,
temperature=VALIDATION_TEMPERATURE
)
samples.append(sample)
except Exception as e:
logger.warning(
"Generation failed for prompt",
extra={"prompt": prompt, "error": str(e)}
)
samples.append("")
return samples
def _log_results(self, pl_module: pl.LightningModule, name: str, results: dict[str, Any]) -> None:
"""Log validation results to TensorBoard."""
# Log scalar metrics
for key, value in results.items():
if isinstance(value, (int, float, bool)):
pl_module.log(f"{name}_{key}", float(value))
# Log text samples if available
if "samples" in results and results["samples"]:
try:
sample_text = "\n\n".join(
f"**Sample {i+1}:** {sample}"
for i, sample in enumerate(results["samples"][:MAX_TENSORBOARD_SAMPLES])
)
if pl_module.logger is not None:
pl_module.logger.experiment.add_text( # type: ignore[attr-defined]
f"{name}_samples",
sample_text,
pl_module.global_step
)
except Exception as e:
logger.warning(
"Failed to log samples",
extra={"error": str(e)}
)
def _check_fast_alerts(self, step: int, results: dict[str, Any]) -> None:
"""Check for fast validation quality issues."""
if results.get("is_garbage"):
logger.warning(
"GARBAGE OUTPUT detected",
extra={
"step": step,
"ascii_ratio": results.get('ascii_ratio', 0),
"avg_length": results.get('avg_length', 0),
"repetition_ratio": results.get('repetition_ratio', 0)
}
)
def _check_grammar_alerts(self, step: int, results: dict[str, Any]) -> None:
"""Check for grammar validation quality issues."""
score = results.get("grammar_score", 0.0)
if score < LOW_GRAMMAR_SCORE_THRESHOLD:
logger.warning(
"LOW GRAMMAR SCORE",
extra={
"step": step,
"score": score,
"target": TARGET_GRAMMAR_SCORE,
"is_fallback": results.get('is_fallback', False)
}
)
# Check for degrading trend
if hasattr(self.grammar_validator, 'get_trend'):
trend = self.grammar_validator.get_trend()
if trend == "degrading":
logger.warning(
"GRAMMAR DEGRADING",
extra={"step": step, "trend": trend}
)