Spaces:
Configuration error
Configuration error
| """ | |
| Cascade PyTorch Hook - Deep Neural Network Instrumentation. | |
| This is the missing piece: direct integration with PyTorch training loops | |
| to capture what stdout never shows: | |
| - Per-layer gradient norms | |
| - Weight statistics | |
| - Activation patterns | |
| - Attention maps | |
| - Memory allocation | |
| Usage: | |
| from cascade.torch_hook import CascadeHook | |
| model = YourModel() | |
| hook = CascadeHook(model, monitor) | |
| # Training loop | |
| for batch in dataloader: | |
| loss = model(batch) | |
| loss.backward() # Hook automatically captures gradients | |
| optimizer.step() | |
| # Hook auto-logs per-layer stats to monitor | |
| """ | |
| from typing import Dict, Any, Optional, List, Callable | |
| from dataclasses import dataclass | |
| import weakref | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| TORCH_AVAILABLE = True | |
| except ImportError: | |
| TORCH_AVAILABLE = False | |
| torch = None | |
| nn = None | |
| class LayerStats: | |
| """Statistics for a single layer.""" | |
| name: str | |
| param_count: int | |
| grad_norm: Optional[float] = None | |
| grad_mean: Optional[float] = None | |
| grad_std: Optional[float] = None | |
| weight_norm: Optional[float] = None | |
| weight_mean: Optional[float] = None | |
| weight_std: Optional[float] = None | |
| activation_norm: Optional[float] = None | |
| activation_mean: Optional[float] = None | |
| class CascadeHook: | |
| """ | |
| PyTorch hook for deep instrumentation. | |
| Captures per-layer metrics that stdout logging misses: | |
| - Gradient flow through each layer | |
| - Weight evolution | |
| - Activation statistics | |
| - Memory usage | |
| Example: | |
| >>> from cascade import Monitor | |
| >>> from cascade.torch_hook import CascadeHook | |
| >>> | |
| >>> monitor = Monitor() | |
| >>> model = nn.Sequential(...) | |
| >>> hook = CascadeHook(model, monitor) | |
| >>> | |
| >>> # Training happens... | |
| >>> # Hook automatically captures: | |
| >>> # - grad_norm/layer_0, grad_norm/layer_1, ... | |
| >>> # - weight_norm/layer_0, ... | |
| >>> # - activation_mean/layer_0, ... | |
| """ | |
| def __init__( | |
| self, | |
| model: "nn.Module", | |
| monitor: Optional[Any] = None, | |
| track_gradients: bool = True, | |
| track_weights: bool = True, | |
| track_activations: bool = False, # Can be expensive | |
| layer_filter: Optional[Callable[[str, "nn.Module"], bool]] = None, | |
| ): | |
| if not TORCH_AVAILABLE: | |
| raise ImportError("PyTorch required for CascadeHook. pip install torch") | |
| self.model = model | |
| self.monitor = monitor | |
| self.track_gradients = track_gradients | |
| self.track_weights = track_weights | |
| self.track_activations = track_activations | |
| self.layer_filter = layer_filter or self._default_filter | |
| self._handles: List[Any] = [] | |
| self._layer_stats: Dict[str, LayerStats] = {} | |
| self._step = 0 | |
| # Register hooks | |
| self._register_hooks() | |
| def _default_filter(self, name: str, module: "nn.Module") -> bool: | |
| """Default: track Linear, Conv, and Attention layers.""" | |
| return isinstance(module, ( | |
| nn.Linear, | |
| nn.Conv1d, nn.Conv2d, nn.Conv3d, | |
| nn.MultiheadAttention, | |
| nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, | |
| nn.Embedding, | |
| )) | |
| def _register_hooks(self): | |
| """Register forward and backward hooks on tracked layers.""" | |
| for name, module in self.model.named_modules(): | |
| if self.layer_filter(name, module): | |
| # Count params | |
| param_count = sum(p.numel() for p in module.parameters()) | |
| self._layer_stats[name] = LayerStats(name=name, param_count=param_count) | |
| # Gradient hook | |
| if self.track_gradients: | |
| handle = module.register_full_backward_hook( | |
| self._make_grad_hook(name) | |
| ) | |
| self._handles.append(handle) | |
| # Activation hook | |
| if self.track_activations: | |
| handle = module.register_forward_hook( | |
| self._make_activation_hook(name) | |
| ) | |
| self._handles.append(handle) | |
| def _make_grad_hook(self, layer_name: str): | |
| """Create gradient hook for a specific layer.""" | |
| def hook(module, grad_input, grad_output): | |
| stats = self._layer_stats[layer_name] | |
| # Get gradient from output | |
| if grad_output and grad_output[0] is not None: | |
| grad = grad_output[0] | |
| if grad.numel() > 0: | |
| stats.grad_norm = grad.norm().item() | |
| stats.grad_mean = grad.mean().item() | |
| stats.grad_std = grad.std().item() | |
| return hook | |
| def _make_activation_hook(self, layer_name: str): | |
| """Create activation hook for a specific layer.""" | |
| def hook(module, input, output): | |
| stats = self._layer_stats[layer_name] | |
| if isinstance(output, torch.Tensor): | |
| stats.activation_norm = output.norm().item() | |
| stats.activation_mean = output.mean().item() | |
| return hook | |
| def capture_weights(self): | |
| """Capture current weight statistics.""" | |
| for name, module in self.model.named_modules(): | |
| if name in self._layer_stats: | |
| stats = self._layer_stats[name] | |
| # Get weight tensor | |
| if hasattr(module, 'weight') and module.weight is not None: | |
| w = module.weight.data | |
| stats.weight_norm = w.norm().item() | |
| stats.weight_mean = w.mean().item() | |
| stats.weight_std = w.std().item() | |
| def step(self, extra_data: Optional[Dict[str, Any]] = None): | |
| """ | |
| Call after each training step to log metrics. | |
| Args: | |
| extra_data: Additional data to include (loss, lr, etc.) | |
| """ | |
| self._step += 1 | |
| if self.track_weights: | |
| self.capture_weights() | |
| # Build event data | |
| data = {"step": self._step} | |
| if extra_data: | |
| data.update(extra_data) | |
| # Add per-layer stats | |
| for layer_name, stats in self._layer_stats.items(): | |
| prefix = layer_name.replace(".", "_") | |
| if stats.grad_norm is not None: | |
| data[f"grad_norm/{prefix}"] = stats.grad_norm | |
| if stats.grad_mean is not None: | |
| data[f"grad_mean/{prefix}"] = stats.grad_mean | |
| if stats.weight_norm is not None: | |
| data[f"weight_norm/{prefix}"] = stats.weight_norm | |
| if stats.activation_mean is not None: | |
| data[f"activation_mean/{prefix}"] = stats.activation_mean | |
| # Aggregate metrics | |
| grad_norms = [s.grad_norm for s in self._layer_stats.values() if s.grad_norm is not None] | |
| if grad_norms: | |
| data["grad_norm_min"] = min(grad_norms) | |
| data["grad_norm_max"] = max(grad_norms) | |
| data["grad_norm_mean"] = sum(grad_norms) / len(grad_norms) | |
| weight_norms = [s.weight_norm for s in self._layer_stats.values() if s.weight_norm is not None] | |
| if weight_norms: | |
| data["weight_norm_total"] = sum(weight_norms) | |
| # Log to monitor | |
| if self.monitor: | |
| self.monitor.observe(data, event_type="training_step", component="torch_hook") | |
| return data | |
| def get_layer_report(self) -> Dict[str, Dict[str, Any]]: | |
| """Get current stats for all tracked layers.""" | |
| return { | |
| name: { | |
| "param_count": stats.param_count, | |
| "grad_norm": stats.grad_norm, | |
| "weight_norm": stats.weight_norm, | |
| "activation_mean": stats.activation_mean, | |
| } | |
| for name, stats in self._layer_stats.items() | |
| } | |
| def detect_issues(self) -> List[str]: | |
| """Quick check for common issues.""" | |
| issues = [] | |
| for name, stats in self._layer_stats.items(): | |
| # Vanishing gradients | |
| if stats.grad_norm is not None and stats.grad_norm < 1e-7: | |
| issues.append(f"Vanishing gradient in {name}: {stats.grad_norm:.2e}") | |
| # Exploding gradients | |
| if stats.grad_norm is not None and stats.grad_norm > 100: | |
| issues.append(f"Exploding gradient in {name}: {stats.grad_norm:.2f}") | |
| # Dead layer (no gradient flow) | |
| if stats.grad_norm == 0: | |
| issues.append(f"Dead layer (zero gradient): {name}") | |
| # Weight explosion | |
| if stats.weight_norm is not None and stats.weight_norm > 1000: | |
| issues.append(f"Large weights in {name}: {stats.weight_norm:.2f}") | |
| return issues | |
| def remove(self): | |
| """Remove all hooks.""" | |
| for handle in self._handles: | |
| handle.remove() | |
| self._handles.clear() | |
| def __del__(self): | |
| self.remove() | |
| def tracked_layers(self) -> List[str]: | |
| """List of tracked layer names.""" | |
| return list(self._layer_stats.keys()) | |
| def total_params(self) -> int: | |
| """Total parameters in tracked layers.""" | |
| return sum(s.param_count for s in self._layer_stats.values()) | |
| # ========================================================================= | |
| # BRANCHING: From observation to understanding to action | |
| # ========================================================================= | |
| def trace_anomaly(self, metric_name: str = "loss") -> Dict[str, Any]: | |
| """ | |
| BACKWARD BRANCH: When something goes wrong, which layer caused it? | |
| Correlates metric anomaly with per-layer gradient behavior. | |
| Returns the likely culprit layer(s). | |
| """ | |
| if not self.monitor or not self.monitor.metrics: | |
| return {"culprit": None, "reason": "No monitor data"} | |
| # Get anomalies from the metric | |
| anomalies = self.monitor.metrics.recent_anomalies | |
| if not anomalies: | |
| return {"culprit": None, "reason": "No anomalies detected"} | |
| # Find layers with extreme gradients at the time of anomaly | |
| suspects = [] | |
| for name, stats in self._layer_stats.items(): | |
| if stats.grad_norm is not None: | |
| if stats.grad_norm < 1e-7: | |
| suspects.append({"layer": name, "issue": "vanishing", "grad_norm": stats.grad_norm}) | |
| elif stats.grad_norm > 50: | |
| suspects.append({"layer": name, "issue": "exploding", "grad_norm": stats.grad_norm}) | |
| if suspects: | |
| # Sort by severity | |
| suspects.sort(key=lambda x: abs(x["grad_norm"]) if x["issue"] == "exploding" else -x["grad_norm"], reverse=True) | |
| return { | |
| "culprit": suspects[0]["layer"], | |
| "issue": suspects[0]["issue"], | |
| "all_suspects": suspects, | |
| "recommendation": self._recommend_fix(suspects[0]) | |
| } | |
| return {"culprit": None, "reason": "No layer anomalies found"} | |
| def _recommend_fix(self, suspect: Dict[str, Any]) -> str: | |
| """Generate actionable recommendation.""" | |
| if suspect["issue"] == "exploding": | |
| return f"Gradient explosion in {suspect['layer']}. Try: lower LR, add gradient clipping, check for NaN inputs." | |
| elif suspect["issue"] == "vanishing": | |
| return f"Vanishing gradient in {suspect['layer']}. Try: residual connections, different activation, layer norm." | |
| return "Unknown issue" | |
| def predict_failure(self, lookahead: int = 5) -> Dict[str, Any]: | |
| """ | |
| FORWARD BRANCH: Predict if training is about to fail. | |
| Uses gradient trends to predict explosion/vanishing before it happens. | |
| """ | |
| if not self.monitor or not self.monitor.metrics: | |
| return {"risk": "unknown", "reason": "No history"} | |
| warnings = [] | |
| for name, stats in self._layer_stats.items(): | |
| # Check gradient history via monitor | |
| grad_key = f"grad_norm/{name.replace('.', '_')}" | |
| series = self.monitor.metrics.get_metric(grad_key) | |
| if series and series.count >= 5: | |
| trend = series.trend() | |
| roc = series.rate_of_change() | |
| if trend == "rising" and roc and roc > 0: | |
| # Project forward | |
| projected = series.current + (roc * lookahead) | |
| if projected > 100: | |
| warnings.append({ | |
| "layer": name, | |
| "prediction": "explosion", | |
| "current": series.current, | |
| "projected": projected, | |
| "steps_until": int(100 / roc) if roc > 0 else None | |
| }) | |
| elif trend == "falling" and series.current < 0.001: | |
| warnings.append({ | |
| "layer": name, | |
| "prediction": "vanishing", | |
| "current": series.current, | |
| "trend": "falling" | |
| }) | |
| if warnings: | |
| return { | |
| "risk": "high", | |
| "warnings": warnings, | |
| "action": "Consider intervention now" | |
| } | |
| return {"risk": "low", "warnings": [], "action": "Continue monitoring"} | |
| def suggest_intervention(self) -> Optional[Dict[str, Any]]: | |
| """ | |
| FORWARD BRANCH: Suggest specific parameter changes. | |
| Based on current state, recommend concrete actions. | |
| """ | |
| prediction = self.predict_failure() | |
| if prediction["risk"] != "high": | |
| return None | |
| interventions = [] | |
| for warning in prediction["warnings"]: | |
| if warning["prediction"] == "explosion": | |
| interventions.append({ | |
| "action": "reduce_lr", | |
| "factor": 0.5, | |
| "reason": f"Gradient explosion predicted in {warning['layer']}" | |
| }) | |
| interventions.append({ | |
| "action": "add_grad_clip", | |
| "value": 1.0, | |
| "reason": "Prevent gradient explosion" | |
| }) | |
| elif warning["prediction"] == "vanishing": | |
| interventions.append({ | |
| "action": "increase_lr", | |
| "factor": 1.5, | |
| "reason": f"Vanishing gradient in {warning['layer']}" | |
| }) | |
| return { | |
| "interventions": interventions, | |
| "urgency": "high" if len(interventions) > 1 else "medium" | |
| } | |
| def get_attention_pattern(self, layer_name: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| DEEP BRANCH: Extract attention patterns (for transformer layers). | |
| Returns attention entropy, sparsity, positional bias. | |
| """ | |
| for name, module in self.model.named_modules(): | |
| if name == layer_name and isinstance(module, nn.MultiheadAttention): | |
| # Would need forward hook with attention weights | |
| # This is a stub showing the branch exists | |
| return { | |
| "layer": layer_name, | |
| "type": "attention", | |
| "note": "Full implementation requires attention weight capture" | |
| } | |
| return None | |
| def find_dead_neurons(self, threshold: float = 0.01) -> List[Dict[str, Any]]: | |
| """ | |
| DEEP BRANCH: Find neurons that never activate. | |
| Dead neurons = wasted parameters = pruning candidates. | |
| """ | |
| dead = [] | |
| for name, stats in self._layer_stats.items(): | |
| if stats.activation_mean is not None: | |
| if abs(stats.activation_mean) < threshold: | |
| dead.append({ | |
| "layer": name, | |
| "activation_mean": stats.activation_mean, | |
| "recommendation": "Consider pruning or reinitializing" | |
| }) | |
| return dead | |
| def branch_report(self) -> Dict[str, Any]: | |
| """ | |
| Full branch analysis: backward, forward, and deep. | |
| """ | |
| return { | |
| "backward": { | |
| "anomaly_trace": self.trace_anomaly(), | |
| }, | |
| "forward": { | |
| "failure_prediction": self.predict_failure(), | |
| "suggested_interventions": self.suggest_intervention(), | |
| }, | |
| "deep": { | |
| "dead_neurons": self.find_dead_neurons(), | |
| "layer_health": self.detect_issues(), | |
| }, | |
| "meta": { | |
| "tracked_layers": len(self._layer_stats), | |
| "total_params": self.total_params, | |
| "step": self._step, | |
| } | |
| } | |
| # Convenience function | |
| def instrument(model: "nn.Module", monitor=None, **kwargs) -> CascadeHook: | |
| """ | |
| Quick instrumentation of a PyTorch model. | |
| Usage: | |
| hook = cascade.torch_hook.instrument(model, monitor) | |
| """ | |
| return CascadeHook(model, monitor, **kwargs) | |