|
|
"""vi
|
|
|
Analyze STDP weight changes between checkpoints to evaluate learning progress.
|
|
|
|
|
|
Usage:
|
|
|
python -m STDP_Communicator.analyze_stdp_weights --early 5 --final 20 --visualize
|
|
|
|
|
|
This script loads checkpoint files from different epochs, compares synaptic weights,
|
|
|
and provides statistics and optional visualizations of the weight changes.
|
|
|
"""
|
|
|
import os
|
|
|
import sys
|
|
|
import torch
|
|
|
import argparse
|
|
|
import logging
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Tuple, Optional, Union
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_checkpoint(checkpoint_path: str) -> Dict:
|
|
|
"""
|
|
|
Load a checkpoint file and return its contents.
|
|
|
|
|
|
Args:
|
|
|
checkpoint_path: Path to the checkpoint file
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing checkpoint data
|
|
|
"""
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
|
|
|
|
try:
|
|
|
logger.info(f"Loading checkpoint: {checkpoint_path}")
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
|
|
return checkpoint
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading checkpoint: {e}")
|
|
|
raise
|
|
|
|
|
|
def extract_synaptic_weights(checkpoint: Dict) -> torch.Tensor:
|
|
|
"""
|
|
|
Extract synaptic weights from a checkpoint.
|
|
|
|
|
|
Args:
|
|
|
checkpoint: Loaded checkpoint data
|
|
|
|
|
|
Returns:
|
|
|
Tensor containing synaptic weights
|
|
|
"""
|
|
|
if "synaptic_weights" in checkpoint:
|
|
|
logger.info("Found direct synaptic_weights key")
|
|
|
return checkpoint["synaptic_weights"]
|
|
|
|
|
|
|
|
|
if "model_state_dict" in checkpoint:
|
|
|
state_dict = checkpoint["model_state_dict"]
|
|
|
weight_keys = [k for k in state_dict.keys() if "weight" in k.lower()]
|
|
|
|
|
|
if weight_keys:
|
|
|
logger.info(f"Using weights from key: {weight_keys[0]}")
|
|
|
return state_dict[weight_keys[0]]
|
|
|
|
|
|
|
|
|
for key in checkpoint.keys():
|
|
|
if isinstance(checkpoint[key], dict) and "weights" in checkpoint[key]:
|
|
|
logger.info(f"Found weights in nested dictionary: {key}.weights")
|
|
|
return checkpoint[key]["weights"]
|
|
|
|
|
|
raise ValueError("Could not find synaptic weights in checkpoint")
|
|
|
|
|
|
def analyze_weight_changes(early_weights: torch.Tensor, final_weights: torch.Tensor) -> Dict:
|
|
|
"""
|
|
|
Analyze changes between two sets of weights.
|
|
|
|
|
|
Args:
|
|
|
early_weights: Weights from earlier epoch
|
|
|
final_weights: Weights from later epoch
|
|
|
|
|
|
Returns:
|
|
|
Dictionary of statistics about weight changes
|
|
|
"""
|
|
|
|
|
|
if early_weights.device != final_weights.device:
|
|
|
final_weights = final_weights.to(early_weights.device)
|
|
|
|
|
|
|
|
|
diff = torch.abs(final_weights - early_weights)
|
|
|
|
|
|
|
|
|
stats = {
|
|
|
"mean_change": diff.mean().item(),
|
|
|
"max_change": diff.max().item(),
|
|
|
"std_change": diff.std().item(),
|
|
|
"percent_changed": (diff > 0.0001).float().mean().item() * 100,
|
|
|
"early_weights_mean": early_weights.mean().item(),
|
|
|
"early_weights_std": early_weights.std().item(),
|
|
|
"final_weights_mean": final_weights.mean().item(),
|
|
|
"final_weights_std": final_weights.std().item()
|
|
|
}
|
|
|
|
|
|
return stats
|
|
|
|
|
|
def visualize_weights(
|
|
|
early_weights: torch.Tensor,
|
|
|
final_weights: torch.Tensor,
|
|
|
output_dir: str = "weight_analysis"
|
|
|
) -> None:
|
|
|
"""
|
|
|
Create visualizations of weight changes and save to files.
|
|
|
|
|
|
Args:
|
|
|
early_weights: Weights from earlier epoch
|
|
|
final_weights: Weights from later epoch
|
|
|
output_dir: Directory to save visualizations
|
|
|
"""
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
early_np = early_weights.flatten().detach().numpy()
|
|
|
final_np = final_weights.flatten().detach().numpy()
|
|
|
diff_np = np.abs(final_np - early_np)
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(15, 10))
|
|
|
|
|
|
|
|
|
plt.subplot(2, 2, 1)
|
|
|
plt.hist(early_np, bins=50, alpha=0.5, label="Early Weights")
|
|
|
plt.hist(final_np, bins=50, alpha=0.5, label="Final Weights")
|
|
|
plt.title("Weight Distribution Comparison")
|
|
|
plt.xlabel("Weight Value")
|
|
|
plt.ylabel("Count")
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
plt.subplot(2, 2, 2)
|
|
|
plt.hist(diff_np, bins=50, color='green')
|
|
|
plt.title("Weight Changes")
|
|
|
plt.xlabel("Absolute Change")
|
|
|
plt.ylabel("Count")
|
|
|
|
|
|
|
|
|
plt.subplot(2, 2, 3)
|
|
|
|
|
|
max_points = 10000
|
|
|
if len(early_np) > max_points:
|
|
|
indices = np.random.choice(len(early_np), max_points, replace=False)
|
|
|
early_sample = early_np[indices]
|
|
|
final_sample = final_np[indices]
|
|
|
else:
|
|
|
early_sample = early_np
|
|
|
final_sample = final_np
|
|
|
|
|
|
plt.scatter(early_sample, final_sample, alpha=0.1)
|
|
|
plt.plot([early_np.min(), early_np.max()], [early_np.min(), early_np.max()], 'r--')
|
|
|
plt.title("Early vs Final Weights")
|
|
|
plt.xlabel("Early Weights")
|
|
|
plt.ylabel("Final Weights")
|
|
|
|
|
|
|
|
|
plt.subplot(2, 2, 4)
|
|
|
if len(early_weights.shape) == 2:
|
|
|
diff_2d = torch.abs(final_weights - early_weights).detach().numpy()
|
|
|
plt.imshow(diff_2d, cmap='hot', interpolation='nearest')
|
|
|
plt.colorbar()
|
|
|
plt.title("Weight Change Heatmap")
|
|
|
else:
|
|
|
|
|
|
sorted_early = np.sort(early_np)
|
|
|
sorted_final = np.sort(final_np)
|
|
|
plt.plot(sorted_early, label="Early Weights (Sorted)")
|
|
|
plt.plot(sorted_final, label="Final Weights (Sorted)")
|
|
|
plt.title("Sorted Weight Comparison")
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(output_dir, "weight_analysis.png"), dpi=300)
|
|
|
plt.close()
|
|
|
|
|
|
logger.info(f"Visualizations saved to {output_dir}/weight_analysis.png")
|
|
|
|
|
|
def parse_args():
|
|
|
parser = argparse.ArgumentParser(description="Analyze STDP weights between checkpoints")
|
|
|
parser.add_argument("--early", type=int, default=5, help="Early epoch number")
|
|
|
parser.add_argument("--final", type=int, default=20, help="Final epoch number")
|
|
|
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints",
|
|
|
help="Directory containing checkpoint files")
|
|
|
parser.add_argument("--output-dir", type=str, default="weight_analysis",
|
|
|
help="Directory to save analysis outputs")
|
|
|
parser.add_argument("--visualize", action="store_true", help="Create visualizations")
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
|
|
|
try:
|
|
|
|
|
|
checkpoint_dir = Path(args.checkpoint_dir)
|
|
|
if not checkpoint_dir.is_absolute():
|
|
|
|
|
|
project_root = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
checkpoint_dir = project_root / checkpoint_dir
|
|
|
|
|
|
|
|
|
early_path = checkpoint_dir / f"stdp_model_epoch_{args.early}.pt"
|
|
|
final_path = checkpoint_dir / f"stdp_model_epoch_{args.final}.pt"
|
|
|
|
|
|
|
|
|
early_checkpoint = load_checkpoint(str(early_path))
|
|
|
final_checkpoint = load_checkpoint(str(final_path))
|
|
|
|
|
|
|
|
|
early_weights = extract_synaptic_weights(early_checkpoint)
|
|
|
final_weights = extract_synaptic_weights(final_checkpoint)
|
|
|
|
|
|
|
|
|
stats = analyze_weight_changes(early_weights, final_weights)
|
|
|
|
|
|
|
|
|
logger.info("Weight Change Analysis:")
|
|
|
logger.info(f"Epochs {args.early} → {args.final}")
|
|
|
logger.info(f"Average weight change: {stats['mean_change']:.6f}")
|
|
|
logger.info(f"Maximum weight change: {stats['max_change']:.6f}")
|
|
|
logger.info(f"Standard deviation of changes: {stats['std_change']:.6f}")
|
|
|
logger.info(f"Percentage of weights changed: {stats['percent_changed']:.2f}%")
|
|
|
logger.info(f"Early weights - mean: {stats['early_weights_mean']:.6f}, std: {stats['early_weights_std']:.6f}")
|
|
|
logger.info(f"Final weights - mean: {stats['final_weights_mean']:.6f}, std: {stats['final_weights_std']:.6f}")
|
|
|
|
|
|
|
|
|
if args.visualize:
|
|
|
output_dir = Path(args.output_dir)
|
|
|
if not output_dir.is_absolute():
|
|
|
output_dir = project_root / output_dir
|
|
|
visualize_weights(early_weights, final_weights, str(output_dir))
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during analysis: {e}", exc_info=True)
|
|
|
return 1
|
|
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
sys.exit(main())
|
|
|
|