Cascade / cascade /torch_hook.py
tostido's picture
Initial commit - cascade-lattice 0.5.4
77bcbf1
"""
Cascade PyTorch Hook - Deep Neural Network Instrumentation.
This is the missing piece: direct integration with PyTorch training loops
to capture what stdout never shows:
- Per-layer gradient norms
- Weight statistics
- Activation patterns
- Attention maps
- Memory allocation
Usage:
from cascade.torch_hook import CascadeHook
model = YourModel()
hook = CascadeHook(model, monitor)
# Training loop
for batch in dataloader:
loss = model(batch)
loss.backward() # Hook automatically captures gradients
optimizer.step()
# Hook auto-logs per-layer stats to monitor
"""
from typing import Dict, Any, Optional, List, Callable
from dataclasses import dataclass
import weakref
try:
import torch
import torch.nn as nn
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
torch = None
nn = None
@dataclass
class LayerStats:
"""Statistics for a single layer."""
name: str
param_count: int
grad_norm: Optional[float] = None
grad_mean: Optional[float] = None
grad_std: Optional[float] = None
weight_norm: Optional[float] = None
weight_mean: Optional[float] = None
weight_std: Optional[float] = None
activation_norm: Optional[float] = None
activation_mean: Optional[float] = None
class CascadeHook:
"""
PyTorch hook for deep instrumentation.
Captures per-layer metrics that stdout logging misses:
- Gradient flow through each layer
- Weight evolution
- Activation statistics
- Memory usage
Example:
>>> from cascade import Monitor
>>> from cascade.torch_hook import CascadeHook
>>>
>>> monitor = Monitor()
>>> model = nn.Sequential(...)
>>> hook = CascadeHook(model, monitor)
>>>
>>> # Training happens...
>>> # Hook automatically captures:
>>> # - grad_norm/layer_0, grad_norm/layer_1, ...
>>> # - weight_norm/layer_0, ...
>>> # - activation_mean/layer_0, ...
"""
def __init__(
self,
model: "nn.Module",
monitor: Optional[Any] = None,
track_gradients: bool = True,
track_weights: bool = True,
track_activations: bool = False, # Can be expensive
layer_filter: Optional[Callable[[str, "nn.Module"], bool]] = None,
):
if not TORCH_AVAILABLE:
raise ImportError("PyTorch required for CascadeHook. pip install torch")
self.model = model
self.monitor = monitor
self.track_gradients = track_gradients
self.track_weights = track_weights
self.track_activations = track_activations
self.layer_filter = layer_filter or self._default_filter
self._handles: List[Any] = []
self._layer_stats: Dict[str, LayerStats] = {}
self._step = 0
# Register hooks
self._register_hooks()
def _default_filter(self, name: str, module: "nn.Module") -> bool:
"""Default: track Linear, Conv, and Attention layers."""
return isinstance(module, (
nn.Linear,
nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.MultiheadAttention,
nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d,
nn.Embedding,
))
def _register_hooks(self):
"""Register forward and backward hooks on tracked layers."""
for name, module in self.model.named_modules():
if self.layer_filter(name, module):
# Count params
param_count = sum(p.numel() for p in module.parameters())
self._layer_stats[name] = LayerStats(name=name, param_count=param_count)
# Gradient hook
if self.track_gradients:
handle = module.register_full_backward_hook(
self._make_grad_hook(name)
)
self._handles.append(handle)
# Activation hook
if self.track_activations:
handle = module.register_forward_hook(
self._make_activation_hook(name)
)
self._handles.append(handle)
def _make_grad_hook(self, layer_name: str):
"""Create gradient hook for a specific layer."""
def hook(module, grad_input, grad_output):
stats = self._layer_stats[layer_name]
# Get gradient from output
if grad_output and grad_output[0] is not None:
grad = grad_output[0]
if grad.numel() > 0:
stats.grad_norm = grad.norm().item()
stats.grad_mean = grad.mean().item()
stats.grad_std = grad.std().item()
return hook
def _make_activation_hook(self, layer_name: str):
"""Create activation hook for a specific layer."""
def hook(module, input, output):
stats = self._layer_stats[layer_name]
if isinstance(output, torch.Tensor):
stats.activation_norm = output.norm().item()
stats.activation_mean = output.mean().item()
return hook
def capture_weights(self):
"""Capture current weight statistics."""
for name, module in self.model.named_modules():
if name in self._layer_stats:
stats = self._layer_stats[name]
# Get weight tensor
if hasattr(module, 'weight') and module.weight is not None:
w = module.weight.data
stats.weight_norm = w.norm().item()
stats.weight_mean = w.mean().item()
stats.weight_std = w.std().item()
def step(self, extra_data: Optional[Dict[str, Any]] = None):
"""
Call after each training step to log metrics.
Args:
extra_data: Additional data to include (loss, lr, etc.)
"""
self._step += 1
if self.track_weights:
self.capture_weights()
# Build event data
data = {"step": self._step}
if extra_data:
data.update(extra_data)
# Add per-layer stats
for layer_name, stats in self._layer_stats.items():
prefix = layer_name.replace(".", "_")
if stats.grad_norm is not None:
data[f"grad_norm/{prefix}"] = stats.grad_norm
if stats.grad_mean is not None:
data[f"grad_mean/{prefix}"] = stats.grad_mean
if stats.weight_norm is not None:
data[f"weight_norm/{prefix}"] = stats.weight_norm
if stats.activation_mean is not None:
data[f"activation_mean/{prefix}"] = stats.activation_mean
# Aggregate metrics
grad_norms = [s.grad_norm for s in self._layer_stats.values() if s.grad_norm is not None]
if grad_norms:
data["grad_norm_min"] = min(grad_norms)
data["grad_norm_max"] = max(grad_norms)
data["grad_norm_mean"] = sum(grad_norms) / len(grad_norms)
weight_norms = [s.weight_norm for s in self._layer_stats.values() if s.weight_norm is not None]
if weight_norms:
data["weight_norm_total"] = sum(weight_norms)
# Log to monitor
if self.monitor:
self.monitor.observe(data, event_type="training_step", component="torch_hook")
return data
def get_layer_report(self) -> Dict[str, Dict[str, Any]]:
"""Get current stats for all tracked layers."""
return {
name: {
"param_count": stats.param_count,
"grad_norm": stats.grad_norm,
"weight_norm": stats.weight_norm,
"activation_mean": stats.activation_mean,
}
for name, stats in self._layer_stats.items()
}
def detect_issues(self) -> List[str]:
"""Quick check for common issues."""
issues = []
for name, stats in self._layer_stats.items():
# Vanishing gradients
if stats.grad_norm is not None and stats.grad_norm < 1e-7:
issues.append(f"Vanishing gradient in {name}: {stats.grad_norm:.2e}")
# Exploding gradients
if stats.grad_norm is not None and stats.grad_norm > 100:
issues.append(f"Exploding gradient in {name}: {stats.grad_norm:.2f}")
# Dead layer (no gradient flow)
if stats.grad_norm == 0:
issues.append(f"Dead layer (zero gradient): {name}")
# Weight explosion
if stats.weight_norm is not None and stats.weight_norm > 1000:
issues.append(f"Large weights in {name}: {stats.weight_norm:.2f}")
return issues
def remove(self):
"""Remove all hooks."""
for handle in self._handles:
handle.remove()
self._handles.clear()
def __del__(self):
self.remove()
@property
def tracked_layers(self) -> List[str]:
"""List of tracked layer names."""
return list(self._layer_stats.keys())
@property
def total_params(self) -> int:
"""Total parameters in tracked layers."""
return sum(s.param_count for s in self._layer_stats.values())
# =========================================================================
# BRANCHING: From observation to understanding to action
# =========================================================================
def trace_anomaly(self, metric_name: str = "loss") -> Dict[str, Any]:
"""
BACKWARD BRANCH: When something goes wrong, which layer caused it?
Correlates metric anomaly with per-layer gradient behavior.
Returns the likely culprit layer(s).
"""
if not self.monitor or not self.monitor.metrics:
return {"culprit": None, "reason": "No monitor data"}
# Get anomalies from the metric
anomalies = self.monitor.metrics.recent_anomalies
if not anomalies:
return {"culprit": None, "reason": "No anomalies detected"}
# Find layers with extreme gradients at the time of anomaly
suspects = []
for name, stats in self._layer_stats.items():
if stats.grad_norm is not None:
if stats.grad_norm < 1e-7:
suspects.append({"layer": name, "issue": "vanishing", "grad_norm": stats.grad_norm})
elif stats.grad_norm > 50:
suspects.append({"layer": name, "issue": "exploding", "grad_norm": stats.grad_norm})
if suspects:
# Sort by severity
suspects.sort(key=lambda x: abs(x["grad_norm"]) if x["issue"] == "exploding" else -x["grad_norm"], reverse=True)
return {
"culprit": suspects[0]["layer"],
"issue": suspects[0]["issue"],
"all_suspects": suspects,
"recommendation": self._recommend_fix(suspects[0])
}
return {"culprit": None, "reason": "No layer anomalies found"}
def _recommend_fix(self, suspect: Dict[str, Any]) -> str:
"""Generate actionable recommendation."""
if suspect["issue"] == "exploding":
return f"Gradient explosion in {suspect['layer']}. Try: lower LR, add gradient clipping, check for NaN inputs."
elif suspect["issue"] == "vanishing":
return f"Vanishing gradient in {suspect['layer']}. Try: residual connections, different activation, layer norm."
return "Unknown issue"
def predict_failure(self, lookahead: int = 5) -> Dict[str, Any]:
"""
FORWARD BRANCH: Predict if training is about to fail.
Uses gradient trends to predict explosion/vanishing before it happens.
"""
if not self.monitor or not self.monitor.metrics:
return {"risk": "unknown", "reason": "No history"}
warnings = []
for name, stats in self._layer_stats.items():
# Check gradient history via monitor
grad_key = f"grad_norm/{name.replace('.', '_')}"
series = self.monitor.metrics.get_metric(grad_key)
if series and series.count >= 5:
trend = series.trend()
roc = series.rate_of_change()
if trend == "rising" and roc and roc > 0:
# Project forward
projected = series.current + (roc * lookahead)
if projected > 100:
warnings.append({
"layer": name,
"prediction": "explosion",
"current": series.current,
"projected": projected,
"steps_until": int(100 / roc) if roc > 0 else None
})
elif trend == "falling" and series.current < 0.001:
warnings.append({
"layer": name,
"prediction": "vanishing",
"current": series.current,
"trend": "falling"
})
if warnings:
return {
"risk": "high",
"warnings": warnings,
"action": "Consider intervention now"
}
return {"risk": "low", "warnings": [], "action": "Continue monitoring"}
def suggest_intervention(self) -> Optional[Dict[str, Any]]:
"""
FORWARD BRANCH: Suggest specific parameter changes.
Based on current state, recommend concrete actions.
"""
prediction = self.predict_failure()
if prediction["risk"] != "high":
return None
interventions = []
for warning in prediction["warnings"]:
if warning["prediction"] == "explosion":
interventions.append({
"action": "reduce_lr",
"factor": 0.5,
"reason": f"Gradient explosion predicted in {warning['layer']}"
})
interventions.append({
"action": "add_grad_clip",
"value": 1.0,
"reason": "Prevent gradient explosion"
})
elif warning["prediction"] == "vanishing":
interventions.append({
"action": "increase_lr",
"factor": 1.5,
"reason": f"Vanishing gradient in {warning['layer']}"
})
return {
"interventions": interventions,
"urgency": "high" if len(interventions) > 1 else "medium"
}
def get_attention_pattern(self, layer_name: str) -> Optional[Dict[str, Any]]:
"""
DEEP BRANCH: Extract attention patterns (for transformer layers).
Returns attention entropy, sparsity, positional bias.
"""
for name, module in self.model.named_modules():
if name == layer_name and isinstance(module, nn.MultiheadAttention):
# Would need forward hook with attention weights
# This is a stub showing the branch exists
return {
"layer": layer_name,
"type": "attention",
"note": "Full implementation requires attention weight capture"
}
return None
def find_dead_neurons(self, threshold: float = 0.01) -> List[Dict[str, Any]]:
"""
DEEP BRANCH: Find neurons that never activate.
Dead neurons = wasted parameters = pruning candidates.
"""
dead = []
for name, stats in self._layer_stats.items():
if stats.activation_mean is not None:
if abs(stats.activation_mean) < threshold:
dead.append({
"layer": name,
"activation_mean": stats.activation_mean,
"recommendation": "Consider pruning or reinitializing"
})
return dead
def branch_report(self) -> Dict[str, Any]:
"""
Full branch analysis: backward, forward, and deep.
"""
return {
"backward": {
"anomaly_trace": self.trace_anomaly(),
},
"forward": {
"failure_prediction": self.predict_failure(),
"suggested_interventions": self.suggest_intervention(),
},
"deep": {
"dead_neurons": self.find_dead_neurons(),
"layer_health": self.detect_issues(),
},
"meta": {
"tracked_layers": len(self._layer_stats),
"total_params": self.total_params,
"step": self._step,
}
}
# Convenience function
def instrument(model: "nn.Module", monitor=None, **kwargs) -> CascadeHook:
"""
Quick instrumentation of a PyTorch model.
Usage:
hook = cascade.torch_hook.instrument(model, monitor)
"""
return CascadeHook(model, monitor, **kwargs)