File size: 9,592 Bytes
69800cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""vi

Analyze STDP weight changes between checkpoints to evaluate learning progress.



Usage:

    python -m STDP_Communicator.analyze_stdp_weights --early 5 --final 20 --visualize



This script loads checkpoint files from different epochs, compares synaptic weights,

and provides statistics and optional visualizations of the weight changes.

"""
import os
import sys
import torch
import argparse
import logging
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

def load_checkpoint(checkpoint_path: str) -> Dict:
    """

    Load a checkpoint file and return its contents.

    

    Args:

        checkpoint_path: Path to the checkpoint file

        

    Returns:

        Dictionary containing checkpoint data

    """
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        
    try:
        logger.info(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        return checkpoint
    except Exception as e:
        logger.error(f"Error loading checkpoint: {e}")
        raise

def extract_synaptic_weights(checkpoint: Dict) -> torch.Tensor:
    """

    Extract synaptic weights from a checkpoint.

    

    Args:

        checkpoint: Loaded checkpoint data

        

    Returns:

        Tensor containing synaptic weights

    """
    if "synaptic_weights" in checkpoint:
        logger.info("Found direct synaptic_weights key")
        return checkpoint["synaptic_weights"]
    
    # Try to find weights in model state dict
    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
        weight_keys = [k for k in state_dict.keys() if "weight" in k.lower()]
        
        if weight_keys:
            logger.info(f"Using weights from key: {weight_keys[0]}")
            return state_dict[weight_keys[0]]
    
    # Look for weights in different formats
    for key in checkpoint.keys():
        if isinstance(checkpoint[key], dict) and "weights" in checkpoint[key]:
            logger.info(f"Found weights in nested dictionary: {key}.weights")
            return checkpoint[key]["weights"]
            
    raise ValueError("Could not find synaptic weights in checkpoint")

def analyze_weight_changes(early_weights: torch.Tensor, final_weights: torch.Tensor) -> Dict:
    """

    Analyze changes between two sets of weights.

    

    Args:

        early_weights: Weights from earlier epoch

        final_weights: Weights from later epoch

        

    Returns:

        Dictionary of statistics about weight changes

    """
    # Ensure tensors are on the same device
    if early_weights.device != final_weights.device:
        final_weights = final_weights.to(early_weights.device)
        
    # Calculate absolute differences
    diff = torch.abs(final_weights - early_weights)
    
    # Calculate statistics
    stats = {
        "mean_change": diff.mean().item(),
        "max_change": diff.max().item(),
        "std_change": diff.std().item(),
        "percent_changed": (diff > 0.0001).float().mean().item() * 100,  # Percentage of weights changed by >0.0001
        "early_weights_mean": early_weights.mean().item(),
        "early_weights_std": early_weights.std().item(),
        "final_weights_mean": final_weights.mean().item(),
        "final_weights_std": final_weights.std().item()
    }
    
    return stats

def visualize_weights(

    early_weights: torch.Tensor, 

    final_weights: torch.Tensor,

    output_dir: str = "weight_analysis"

) -> None:
    """

    Create visualizations of weight changes and save to files.

    

    Args:

        early_weights: Weights from earlier epoch

        final_weights: Weights from later epoch

        output_dir: Directory to save visualizations

    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert to numpy for easier plotting
    early_np = early_weights.flatten().detach().numpy()
    final_np = final_weights.flatten().detach().numpy()
    diff_np = np.abs(final_np - early_np)
    
    # Create plots
    plt.figure(figsize=(15, 10))
    
    # 1. Weight distribution histograms
    plt.subplot(2, 2, 1)
    plt.hist(early_np, bins=50, alpha=0.5, label="Early Weights")
    plt.hist(final_np, bins=50, alpha=0.5, label="Final Weights")
    plt.title("Weight Distribution Comparison")
    plt.xlabel("Weight Value")
    plt.ylabel("Count")
    plt.legend()
    
    # 2. Changes histogram
    plt.subplot(2, 2, 2)
    plt.hist(diff_np, bins=50, color='green')
    plt.title("Weight Changes")
    plt.xlabel("Absolute Change")
    plt.ylabel("Count")
    
    # 3. Scatter plot of early vs final weights
    plt.subplot(2, 2, 3)
    # If there are too many weights, sample a subset
    max_points = 10000
    if len(early_np) > max_points:
        indices = np.random.choice(len(early_np), max_points, replace=False)
        early_sample = early_np[indices]
        final_sample = final_np[indices]
    else:
        early_sample = early_np
        final_sample = final_np
        
    plt.scatter(early_sample, final_sample, alpha=0.1)
    plt.plot([early_np.min(), early_np.max()], [early_np.min(), early_np.max()], 'r--')  # y=x line
    plt.title("Early vs Final Weights")
    plt.xlabel("Early Weights")
    plt.ylabel("Final Weights")
    
    # 4. Weight change heatmap (if weights are 2D)
    plt.subplot(2, 2, 4)
    if len(early_weights.shape) == 2:
        diff_2d = torch.abs(final_weights - early_weights).detach().numpy()
        plt.imshow(diff_2d, cmap='hot', interpolation='nearest')
        plt.colorbar()
        plt.title("Weight Change Heatmap")
    else:
        # If weights aren't 2D, show a different visualization
        sorted_early = np.sort(early_np)
        sorted_final = np.sort(final_np)
        plt.plot(sorted_early, label="Early Weights (Sorted)")
        plt.plot(sorted_final, label="Final Weights (Sorted)")
        plt.title("Sorted Weight Comparison")
        plt.legend()
    
    # Save and show the figure
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "weight_analysis.png"), dpi=300)
    plt.close()
    
    logger.info(f"Visualizations saved to {output_dir}/weight_analysis.png")

def parse_args():
    parser = argparse.ArgumentParser(description="Analyze STDP weights between checkpoints")
    parser.add_argument("--early", type=int, default=5, help="Early epoch number")
    parser.add_argument("--final", type=int, default=20, help="Final epoch number")
    parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", 
                       help="Directory containing checkpoint files")
    parser.add_argument("--output-dir", type=str, default="weight_analysis",
                       help="Directory to save analysis outputs")
    parser.add_argument("--visualize", action="store_true", help="Create visualizations")
    return parser.parse_args()

def main():
    args = parse_args()
    
    try:
        # Resolve paths
        checkpoint_dir = Path(args.checkpoint_dir)
        if not checkpoint_dir.is_absolute():
            # Get project root - two directories up from this script
            project_root = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
            checkpoint_dir = project_root / checkpoint_dir
        
        # Construct checkpoint paths
        early_path = checkpoint_dir / f"stdp_model_epoch_{args.early}.pt"
        final_path = checkpoint_dir / f"stdp_model_epoch_{args.final}.pt"
        
        # Load checkpoints
        early_checkpoint = load_checkpoint(str(early_path))
        final_checkpoint = load_checkpoint(str(final_path))
        
        # Extract weights
        early_weights = extract_synaptic_weights(early_checkpoint)
        final_weights = extract_synaptic_weights(final_checkpoint)
        
        # Analyze changes
        stats = analyze_weight_changes(early_weights, final_weights)
        
        # Print results
        logger.info("Weight Change Analysis:")
        logger.info(f"Epochs {args.early}{args.final}")
        logger.info(f"Average weight change: {stats['mean_change']:.6f}")
        logger.info(f"Maximum weight change: {stats['max_change']:.6f}")
        logger.info(f"Standard deviation of changes: {stats['std_change']:.6f}")
        logger.info(f"Percentage of weights changed: {stats['percent_changed']:.2f}%")
        logger.info(f"Early weights - mean: {stats['early_weights_mean']:.6f}, std: {stats['early_weights_std']:.6f}")
        logger.info(f"Final weights - mean: {stats['final_weights_mean']:.6f}, std: {stats['final_weights_std']:.6f}")
        
        # Create visualizations if requested
        if args.visualize:
            output_dir = Path(args.output_dir)
            if not output_dir.is_absolute():
                output_dir = project_root / output_dir
            visualize_weights(early_weights, final_weights, str(output_dir))
        
    except Exception as e:
        logger.error(f"Error during analysis: {e}", exc_info=True)
        return 1
        
    return 0

if __name__ == "__main__":
    sys.exit(main())