File size: 3,724 Bytes
dbbc4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Shared gradient visualization utilities for DP-SGD trainers.

This module provides consistent gradient norm generation and clipping
visualization across all trainer implementations.
"""

import numpy as np
from typing import List, Dict


def generate_gradient_norms(clipping_norm: float, num_points: int = 100) -> List[Dict[str, float]]:
    """
    Generate realistic gradient norms following a log-normal distribution.
    
    In real DP-SGD training, gradient norms typically follow a log-normal
    distribution, with most gradients being smaller than the clipping threshold
    and some exceeding it.
    
    Args:
        clipping_norm: The clipping threshold (C)
        num_points: Number of gradient samples to generate
        
    Returns:
        List of dicts with 'x' (gradient norm) and 'y' (density) keys,
        sorted by x value for smooth visualization
    """
    gradients = []
    
    # Parameters for log-normal distribution
    # Center around clipping_norm with some spread
    mu = np.log(clipping_norm) - 0.5
    sigma = 0.8
    
    for _ in range(num_points):
        # Generate log-normal distributed gradient norms using Box-Muller
        u1, u2 = np.random.random(2)
        z = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
        norm = np.exp(mu + sigma * z)
        
        # Calculate density using kernel density estimation
        density = np.exp(-(np.power(np.log(norm) - mu, 2) / (2 * sigma * sigma))) / \
                  (norm * sigma * np.sqrt(2 * np.pi))
        
        # Normalize and add some randomness for visual effect
        density = 0.2 + 0.8 * (density / 0.8) + 0.1 * (np.random.random() - 0.5)
        
        gradients.append({'x': float(norm), 'y': float(max(0.01, density))})
    
    return sorted(gradients, key=lambda x: x['x'])


def generate_clipped_gradients(
    clipping_norm: float, 
    original_gradients: List[Dict[str, float]] = None,
    num_points: int = 100
) -> List[Dict[str, float]]:
    """
    Generate clipped versions of gradient norms.
    
    Demonstrates how gradient clipping limits the maximum gradient norm,
    creating a "pile-up" effect at the clipping threshold.
    
    Args:
        clipping_norm: The clipping threshold (C)
        original_gradients: Optional pre-generated gradients to clip.
                           If None, generates new gradients first.
        num_points: Number of points if generating new gradients
        
    Returns:
        List of dicts with 'x' (clipped gradient norm) and 'y' (density) keys,
        sorted by x value
    """
    if original_gradients is None:
        original_gradients = generate_gradient_norms(clipping_norm, num_points)
    
    clipped = [
        {'x': min(g['x'], clipping_norm), 'y': g['y']} 
        for g in original_gradients
    ]
    
    return sorted(clipped, key=lambda x: x['x'])


def generate_gradient_info(clipping_norm: float, num_points: int = 100) -> Dict[str, List[Dict[str, float]]]:
    """
    Generate complete gradient information for visualization.
    
    This is a convenience function that generates both before and after
    clipping gradient distributions for use in training results.
    
    Args:
        clipping_norm: The clipping threshold (C)
        num_points: Number of gradient samples to generate
        
    Returns:
        Dict with 'before_clipping' and 'after_clipping' keys,
        each containing a list of gradient samples
    """
    before_clipping = generate_gradient_norms(clipping_norm, num_points)
    after_clipping = generate_clipped_gradients(clipping_norm, before_clipping)
    
    return {
        'before_clipping': before_clipping,
        'after_clipping': after_clipping
    }