""" Shared gradient visualization utilities for DP-SGD trainers. This module provides consistent gradient norm generation and clipping visualization across all trainer implementations. """ import numpy as np from typing import List, Dict def generate_gradient_norms(clipping_norm: float, num_points: int = 100) -> List[Dict[str, float]]: """ Generate realistic gradient norms following a log-normal distribution. In real DP-SGD training, gradient norms typically follow a log-normal distribution, with most gradients being smaller than the clipping threshold and some exceeding it. Args: clipping_norm: The clipping threshold (C) num_points: Number of gradient samples to generate Returns: List of dicts with 'x' (gradient norm) and 'y' (density) keys, sorted by x value for smooth visualization """ gradients = [] # Parameters for log-normal distribution # Center around clipping_norm with some spread mu = np.log(clipping_norm) - 0.5 sigma = 0.8 for _ in range(num_points): # Generate log-normal distributed gradient norms using Box-Muller u1, u2 = np.random.random(2) z = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2) norm = np.exp(mu + sigma * z) # Calculate density using kernel density estimation density = np.exp(-(np.power(np.log(norm) - mu, 2) / (2 * sigma * sigma))) / \ (norm * sigma * np.sqrt(2 * np.pi)) # Normalize and add some randomness for visual effect density = 0.2 + 0.8 * (density / 0.8) + 0.1 * (np.random.random() - 0.5) gradients.append({'x': float(norm), 'y': float(max(0.01, density))}) return sorted(gradients, key=lambda x: x['x']) def generate_clipped_gradients( clipping_norm: float, original_gradients: List[Dict[str, float]] = None, num_points: int = 100 ) -> List[Dict[str, float]]: """ Generate clipped versions of gradient norms. Demonstrates how gradient clipping limits the maximum gradient norm, creating a "pile-up" effect at the clipping threshold. Args: clipping_norm: The clipping threshold (C) original_gradients: Optional pre-generated gradients to clip. If None, generates new gradients first. num_points: Number of points if generating new gradients Returns: List of dicts with 'x' (clipped gradient norm) and 'y' (density) keys, sorted by x value """ if original_gradients is None: original_gradients = generate_gradient_norms(clipping_norm, num_points) clipped = [ {'x': min(g['x'], clipping_norm), 'y': g['y']} for g in original_gradients ] return sorted(clipped, key=lambda x: x['x']) def generate_gradient_info(clipping_norm: float, num_points: int = 100) -> Dict[str, List[Dict[str, float]]]: """ Generate complete gradient information for visualization. This is a convenience function that generates both before and after clipping gradient distributions for use in training results. Args: clipping_norm: The clipping threshold (C) num_points: Number of gradient samples to generate Returns: Dict with 'before_clipping' and 'after_clipping' keys, each containing a list of gradient samples """ before_clipping = generate_gradient_norms(clipping_norm, num_points) after_clipping = generate_clipped_gradients(clipping_norm, before_clipping) return { 'before_clipping': before_clipping, 'after_clipping': after_clipping }