gansta / tasks /image_classification /analysis /run_imagenet_analysis.py
Elliotasdasdasfasas's picture
Deploy CTM Codebase bypass FUSE 503
ed89628
# --- Core Libraries ---
import torch
import numpy as np
import os
import argparse
from tqdm.auto import tqdm
import torch.nn.functional as F # Used for interpolate
# --- Plotting & Visualization ---
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.use('Agg')
import seaborn as sns
sns.set_style('darkgrid')
from matplotlib import patheffects
import seaborn as sns
import imageio
import cv2
from scipy.special import softmax
from tasks.image_classification.plotting import save_frames_to_mp4
# --- Data Handling & Model ---
from torchvision import transforms
from torchvision import datasets # Only used for CIFAR100 in debug mode
from scipy import ndimage # Used in find_island_centers
from data.custom_datasets import ImageNet
from models.ctm import ContinuousThoughtMachine
from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
from tasks.image_classification.plotting import plot_neural_dynamics
# --- Global Settings ---
np.seterr(divide='ignore')
mpl.use('Agg')
sns.set_style('darkgrid')
# --- Helper Functions ---
def find_island_centers(array_2d, threshold):
"""
Finds the center of mass of each island (connected component > threshold)
in a 2D array, weighted by the array's values.
Returns list of (y, x) centers and list of areas.
"""
binary_image = array_2d > threshold
labeled_image, num_labels = ndimage.label(binary_image)
centers = []
areas = []
# Calculate center of mass for each labeled island (label 0 is background)
for i in range(1, num_labels + 1):
island_mask = (labeled_image == i)
total_mass = np.sum(array_2d[island_mask])
if total_mass > 0:
# Get coordinates for this island
y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
# Calculate weighted average for center
x_center = np.average(x_coords[island_mask], weights=array_2d[island_mask])
y_center = np.average(y_coords[island_mask], weights=array_2d[island_mask])
centers.append((round(y_center, 4), round(x_center, 4)))
areas.append(np.sum(island_mask)) # Area is the count of pixels in the island
return centers, areas
def parse_args():
"""Parses command-line arguments."""
# Note: Original had two ArgumentParser instances, using the second one.
parser = argparse.ArgumentParser(description="Visualize Continuous Thought Machine Attention")
parser.add_argument('--actions', type=str, nargs='+', default=['videos'], choices=['plots', 'videos', 'demo'], help="Actions to take. Plots=results plots; videos=gifs/mp4s to watch attention; demo: last frame of internal ticks")
parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
parser.add_argument('--checkpoint', type=str, default='checkpoints/imagenet/ctm_clean.pt', help="Path to ATM checkpoint")
parser.add_argument('--output_dir', type=str, default='tasks/image_classification/analysis/outputs/imagenet_viz', help="Directory for visualization outputs")
parser.add_argument('--debug', action=argparse.BooleanOptionalAction, default=True, help='Debug mode: use CIFAR100 instead of ImageNet for debugging.')
parser.add_argument('--plot_every', type=int, default=10, help="How often to plot.")
parser.add_argument('--inference_iterations', type=int, default=50, help="Iterations to use during inference.")
parser.add_argument('--data_indices', type=int, nargs='+', default=[], help="Use specific indices in validation data for demos, otherwise random.")
parser.add_argument('--N_to_viz', type=int, default=5, help="When not supplying data_indices.")
return parser.parse_args()
# --- Main Execution Block ---
if __name__=='__main__':
# --- Setup ---
args = parse_args()
if args.device[0] != -1 and torch.cuda.is_available():
device = f'cuda:{args.device[0]}'
else:
device = 'cpu'
print(f"Using device: {device}")
# --- Load Checkpoint & Model ---
print(f"Loading checkpoint: {args.checkpoint}")
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) # removed weights_only=False
model_args = checkpoint['args']
# Handle legacy arguments from checkpoint if necessary
if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
if not hasattr(model_args, 'neuron_select_type'):
model_args.neuron_select_type = 'first-last'
# Instantiate Model based on checkpoint args
print("Instantiating CTM model...")
model = ContinuousThoughtMachine(
iterations=model_args.iterations,
d_model=model_args.d_model,
d_input=model_args.d_input,
heads=model_args.heads,
n_synch_out=model_args.n_synch_out,
n_synch_action=model_args.n_synch_action,
synapse_depth=model_args.synapse_depth,
memory_length=model_args.memory_length,
deep_nlms=model_args.deep_memory,
memory_hidden_dims=model_args.memory_hidden_dims,
do_layernorm_nlm=model_args.do_normalisation,
backbone_type=model_args.backbone_type,
positional_embedding_type=model_args.positional_embedding_type,
out_dims=model_args.out_dims,
prediction_reshaper=[-1], # Kept fixed value from original code
dropout=0, # No dropout for eval
neuron_select_type=model_args.neuron_select_type,
n_random_pairing_self=model_args.n_random_pairing_self,
).to(device)
# Load weights into model
load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
model.eval() # Set model to evaluation mode
# --- Prepare Dataset ---
if args.debug:
print("Debug mode: Using CIFAR100")
# CIFAR100 specific normalization constants
dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
img_size = 256 # Resize CIFAR images for consistency
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=dataset_mean, std=dataset_std), # Normalize
])
validation_dataset = datasets.CIFAR100('data/', train=False, transform=transform, download=True)
validation_dataset_centercrop = datasets.CIFAR100('data/', train=True, transform=transform, download=True)
else:
print("Using ImageNet")
# ImageNet specific normalization constants
dataset_mean = [0.485, 0.456, 0.406]
dataset_std = [0.229, 0.224, 0.225]
img_size = 256 # Resize ImageNet images
# Note: Original comment mentioned no CenterCrop, this transform reflects that.
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
])
validation_dataset = ImageNet(which_split='validation', transform=transform)
validation_dataset_centercrop = ImageNet(which_split='train', transform=transforms.Compose([
transforms.Resize(img_size),
transforms.RandomCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
]))
class_labels = list(IMAGENET2012_CLASSES.values()) # Load actual class names
os.makedirs(f'{args.output_dir}', exist_ok=True)
interp_mode = 'nearest'
cmap_calib = sns.color_palette('viridis', as_cmap=True)
loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)
loader_crop = torch.utils.data.DataLoader(validation_dataset_centercrop, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
model.eval()
figscale = 0.85
topk = 5
mean_certainties_correct, mean_certainties_incorrect = [],[]
tracked_certainties = []
tracked_targets = []
tracked_predictions = []
if model.iterations != args.inference_iterations:
print('WARNING: you are setting inference iterations to a value not used during training!')
model.iterations = args.inference_iterations
if 'plots' in args.actions:
with torch.inference_mode(): # Disable gradient calculations
with tqdm(total=len(loader), initial=0, leave=False, position=0, dynamic_ncols=True) as pbar:
imgi = 0
for bi, (inputs, targets) in enumerate(loader):
inputs = inputs.to(device)
targets = targets.to(device)
if bi==0:
dynamics_inputs, _ = next(iter(loader_crop)) # Use this because of batching
_, _, _, _, post_activations_viz, _ = model(inputs, track=True)
plot_neural_dynamics(post_activations_viz, 15*10, args.output_dir, axis_snap=True, N_per_row=15)
predictions, certainties, synchronisation = model(inputs)
tracked_predictions.append(predictions.detach().cpu().numpy())
tracked_targets.append(targets.detach().cpu().numpy())
tracked_certainties.append(certainties.detach().cpu().numpy())
pbar.set_description(f'Processing base image of size {inputs.shape}')
pbar.update(1)
if ((bi % args.plot_every == 0) or bi == len(loader)-1) and bi!=0: #
concatenated_certainties = np.concatenate(tracked_certainties, axis=0)
concatenated_targets = np.concatenate(tracked_targets, axis=0)
concatenated_predictions = np.concatenate(tracked_predictions, axis=0)
concatenated_predictions_argsorted = np.argsort(concatenated_predictions, 1)[:,::-1]
for topk in [1, 5]:
concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
accs_instant, accs_avg, accs_certain = [], [], []
accs_avg_logits, accs_weighted_logits = [],[]
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
pbarinner.set_description('Acc types')
for stepi in np.arange(concatenated_predictions.shape[-1]):
pred_avg = softmax(concatenated_predictions, 1)[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
pred_instant = concatenated_predictions_argsorted_topk[:,:,stepi]
pred_certain = concatenated_predictions_argsorted_topk[np.arange(concatenated_predictions.shape[0]),:, concatenated_certainties[:,1,:stepi+1].argmax(1)]
pred_avg_logits = concatenated_predictions[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
pred_weighted_logits = (concatenated_predictions[:,:,:stepi+1] * concatenated_certainties[:,1:,:stepi+1]).sum(-1).argsort(1)[:, -topk:]
pbarinner.update(1)
accs_instant.append(np.any(pred_instant==concatenated_targets[...,np.newaxis], -1).mean())
accs_avg.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
accs_avg_logits.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
accs_weighted_logits.append(np.any(pred_weighted_logits==concatenated_targets[...,np.newaxis], -1).mean())
accs_certain.append(np.any(pred_avg_logits==concatenated_targets[...,np.newaxis], -1).mean())
fig = plt.figure(figsize=(10*figscale, 4*figscale))
ax = fig.add_subplot(111)
cp = sns.color_palette("bright")
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_instant), linestyle='-', color=cp[0], label='Instant')
# ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg), linestyle='--', color=cp[1], label='Based on average probability up to this step')
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_certain), linestyle=':', color=cp[2], label='Most certain')
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg_logits), linestyle='-.', color=cp[3], label='Average logits')
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_weighted_logits), linestyle='--', color=cp[4], label='Logits weighted by certainty')
ax.set_xlim([0, concatenated_predictions.shape[-1]+1])
ax.set_ylim([75, 92])
ax.set_xlabel('Internal ticks')
ax.set_ylabel(f'Top-k={topk} accuracy')
ax.legend(loc='lower right')
fig.tight_layout(pad=0.1)
fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.png', dpi=200)
fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.pdf', dpi=200)
plt.close(fig)
print(f'k={topk}. Accuracy most certain at last internal tick={100*np.array(accs_certain)[-1]:0.4f}') # Using certainty based approach
indices_over_80 = []
classes_80 = {}
corrects_80 = {}
topk = 5
concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
for certainty_threshold in [0.5, 0.8, 0.9]:
# certainty_threshold = 0.6
percentage_corrects = []
percentage_incorrects = []
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
pbarinner.set_description(f'Certainty threshold={certainty_threshold}')
for stepi in np.arange(concatenated_predictions.shape[-1]):
certainty_here = concatenated_certainties[:,1,stepi]
certainty_mask = certainty_here>=certainty_threshold
predictions_here = concatenated_predictions_argsorted_topk[:,:,stepi]
is_correct_here = np.any(predictions_here==concatenated_targets[...,np.newaxis], axis=-1)
percentage_corrects.append(is_correct_here[certainty_mask].sum()/predictions_here.shape[0])
percentage_incorrects.append((~is_correct_here)[certainty_mask].sum()/predictions_here.shape[0])
if certainty_threshold==0.8:
indices_certain = np.where(certainty_mask)[0]
for index in indices_certain:
if index not in indices_over_80:
indices_over_80.append(index)
if concatenated_targets[index] not in classes_80:
classes_80[concatenated_targets[index]] = [stepi]
corrects_80[concatenated_targets[index]] = [is_correct_here[index]]
else:
classes_80[concatenated_targets[index]] = classes_80[concatenated_targets[index]]+[stepi]
corrects_80[concatenated_targets[index]] = corrects_80[concatenated_targets[index]]+[is_correct_here[index]]
pbarinner.update(1)
fig = plt.figure(figsize=(6.5*figscale, 4*figscale))
ax = fig.add_subplot(111)
ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
percentage_corrects,
color='forestgreen',
hatch='OO',
width=0.9,
label='Positive',
alpha=0.9,
linewidth=1.0*figscale)
ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
percentage_incorrects,
bottom=percentage_corrects,
color='crimson',
hatch='xx',
width=0.9,
label='Negative',
alpha=0.9,
linewidth=1.0*figscale)
ax.set_xlim(-1, concatenated_predictions.shape[-1]+1)
ax.set_xlabel('Internal tick')
ax.set_ylabel('% of data')
ax.legend(loc='lower right')
fig.tight_layout(pad=0.1)
fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.png', dpi=200)
fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.pdf', dpi=200)
plt.close(fig)
class_list = list(classes_80.keys())
mean_steps = [np.mean(classes_80[cls]) for cls in class_list]
std_steps = [np.std(classes_80[cls]) for cls in class_list]
# Following code plots the class distribution over internal ticks
indices_to_show = np.arange(1000)
colours = cmap_diverse = plt.get_cmap('rainbow')(np.linspace(0, 1, 1000))
# np.random.shuffle(colours)
bottom = np.zeros(concatenated_predictions.shape[-1])
fig = plt.figure(figsize=(7*figscale, 4*figscale))
ax = fig.add_subplot(111)
for iii, idx in enumerate(indices_to_show):
if idx in classes_80:
steps = classes_80[idx]
colour = colours[iii]
vs, cts = np.unique(steps, return_counts=True)
bar = np.zeros(concatenated_predictions.shape[-1])
bar[vs] = cts
ax.bar(np.arange(concatenated_predictions.shape[-1])+1, bar, bottom=bottom, color=colour, width=1, edgecolor='none')
bottom += bar
ax.set_xlabel('Internal ticks')
ax.set_ylabel('Counts over 0.8 certainty')
fig.tight_layout(pad=0.1)
fig.savefig(f'{args.output_dir}/class_counts.png', dpi=200)
fig.savefig(f'{args.output_dir}/class_counts.pdf', dpi=200)
plt.close(fig)
# The following code plots calibration
probability_space = np.linspace(0, 1, 10)
fig = plt.figure(figsize=(6*figscale, 4*figscale))
ax = fig.add_subplot(111)
color_linspace = np.linspace(0, 1, concatenated_predictions.shape[-1])
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
pbarinner.set_description(f'Calibration')
for stepi in np.arange(concatenated_predictions.shape[-1]):
color = cmap_calib(color_linspace[stepi])
pred = concatenated_predictions[:,:,stepi].argmax(1)
is_correct = pred == concatenated_targets # BxT
probabilities = softmax(concatenated_predictions[:,:,:stepi+1], axis=1)[np.arange(concatenated_predictions.shape[0]),pred].mean(-1)#softmax(concatenated_predictions[:,:,stepi], axis=1).max(1)
probability_space = np.linspace(0, 1, 10)
accuracies_per_bin = []
bin_centers = []
for pi in range(len(probability_space)-1):
bin_low = probability_space[pi]
bin_high = probability_space[pi+1]
mask = ((probabilities >=bin_low) & (probabilities < bin_high)) if pi !=len(probability_space)-2 else ((probabilities >=bin_low) & (probabilities <= bin_high))
accuracies_per_bin.append(is_correct[mask].mean())
bin_centers.append(probabilities[mask].mean())
if stepi==concatenated_predictions.shape[-1]-1:
ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color='#4050f7', alpha=1, label='After all ticks')
else: ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color=color, alpha=0.65)
pbarinner.update(1)
ax.plot(probability_space, np.linspace(0, 1, len(probability_space)), 'k--')
ax.legend(loc='upper left')
ax.set_xlim([-0.01, 1.01])
ax.set_ylim([-0.01, 1.01])
sm = plt.cm.ScalarMappable(cmap=cmap_calib, norm=plt.Normalize(vmin=0, vmax=concatenated_predictions.shape[-1] - 1))
sm.set_array([]) # Empty array for colormap
cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
cbar.set_label('Internal ticks')
ax.set_xlabel('Mean predicted probabilities')
ax.set_ylabel('Ratio of positives')
fig.tight_layout(pad=0.1)
fig.savefig(f'{args.output_dir}/imagenet_calibration.png', dpi=200)
fig.savefig(f'{args.output_dir}/imagenet_calibration.pdf', dpi=200)
plt.close(fig)
if 'videos' in args.actions:
if not args.data_indices: # If list is empty
n_samples = len(validation_dataset)
num_to_sample = min(args.N_to_viz, n_samples)
replace = n_samples < num_to_sample
data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
print(f"Selected random indices: {data_indices}")
else:
data_indices = args.data_indices
print(f"Using specified indices: {data_indices}")
for di in data_indices:
print(f'\nBuilding viz for dataset index {di}.')
# --- Get Data & Run Inference ---
# inputs_norm is already normalized by the transform
inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
# Add batch dimension and send to device
inputs = inputs.to(device).unsqueeze(0)
# Run model inference
predictions, certainties, synchronisation, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
# predictions: (B, Classes, Steps), attention_tracking: (Steps*B*Heads, SeqLen)
n_steps = predictions.size(-1)
# --- Reshape Attention ---
# Infer feature map size from model internals (assuming B=1)
h_feat, w_feat = model.kv_features.shape[-2:]
n_heads = attention_tracking.shape[2]
# Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
# --- Setup for Plotting ---
step_linspace = np.linspace(0, 1, n_steps) # For step colors
# Define color maps
cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
cmap_attention = sns.color_palette('viridis', as_cmap=True)
# Create output directory for this index
index_output_dir = os.path.join(args.output_dir, str(di))
os.makedirs(index_output_dir, exist_ok=True)
frames = [] # Store frames for GIF
head_routes = {h: [] for h in range(n_heads)} # Store (y,x) path points per head
head_routes[-1] = []
route_colours_step = [] # Store colors for each step's path segments
# --- Loop Through Each Step ---
for step_i in range(n_steps):
# --- Prepare Image for Display ---
# Denormalize the input tensor for visualization
data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
data_img_denorm = data_img_tensor * std_tensor + mean_tensor
# Permute to (H, W, C) and convert to numpy, clip to [0, 1]
data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
data_img_np = np.clip(data_img_np, 0, 1)
img_h, img_w = data_img_np.shape[:2]
# --- Process Attention & Certainty ---
# Average attention over last few steps (from original code)
start_step = max(0, step_i - 5)
attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
# Get certainties up to current step
certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
# --- Calculate Attention Paths (using bilinear interp) ---
# Interpolate attention to image size using bilinear for center finding
attention_interp_bilinear = F.interpolate(
torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
size=(img_h, img_w),
mode=interp_mode,
# align_corners=False
).squeeze(0) # Remove batch dim -> (Heads, H, W)
# Normalize each head's map to [0, 1]
# Deal with mean
attn_mean = attention_interp_bilinear.mean(0)
attn_mean_min = attn_mean.min()
attn_mean_max = attn_mean.max()
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
if centers: # If islands found
largest_island_idx = np.argmax(areas)
current_center = centers[largest_island_idx] # (y, x)
head_routes[-1].append(current_center)
elif head_routes[-1]: # If no center now, repeat last known center if history exists
head_routes[-1].append(head_routes[-1][-1])
attn_min = attention_interp_bilinear.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
attn_max = attention_interp_bilinear.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
attention_interp_bilinear = (attention_interp_bilinear - attn_min) / (attn_max - attn_min + 1e-6)
# Store step color
current_colour = list(cmap_spectral(step_linspace[step_i]))
route_colours_step.append(current_colour)
# Find island center for each head
for head_i in range(n_heads):
attn_head_np = attention_interp_bilinear[head_i].detach().cpu().numpy()
# Keep threshold=0.7 based on original call
centers, areas = find_island_centers(attn_head_np, threshold=0.7)
if centers: # If islands found
largest_island_idx = np.argmax(areas)
current_center = centers[largest_island_idx] # (y, x)
head_routes[head_i].append(current_center)
elif head_routes[head_i]: # If no center now, repeat last known center if history exists
head_routes[head_i].append(head_routes[head_i][-1])
# --- Plotting Setup ---
mosaic = [['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay', 'head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay','head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
['head_8', 'head_8_overlay', 'head_9', 'head_9_overlay','head_10', 'head_10_overlay', 'head_11', 'head_11_overlay'],
['head_12', 'head_12_overlay', 'head_13', 'head_13_overlay','head_14', 'head_14_overlay', 'head_15', 'head_15_overlay'],
['probabilities', 'probabilities','probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty'],
]
img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
aspect_ratio = (8 * figscale, 9 * figscale * img_aspect) # W, H
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
for ax in axes.values():
ax.axis('off')
# --- Plot Certainty ---
ax_cert = axes['certainty']
ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
# Add background color based on prediction correctness at each step
for ii in range(len(certainties_now)):
is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
facecolor = 'limegreen' if is_correct else 'orchid'
ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
# Mark the last point
ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
ax_cert.axis('off')
ax_cert.set_ylim([0.05, 1.05])
ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
# --- Plot Probabilities ---
ax_prob = axes['probabilities']
# Get probabilities for the current step
ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
k = 15 # Top k predictions
topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
topk_indices = topk_indices.numpy()
topk_probs = topk_probs.numpy()
top_classes = np.array(class_labels)[topk_indices]
true_class_idx = ground_truth_target # Ground truth index
# Determine bar colors (green if correct, blue otherwise - consistent with original)
colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
# Plot horizontal bars (inverted range for top-down display)
ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
ax_prob.set_xlim([0, 1])
ax_prob.axis('off')
# Add text labels for top classes
for i, name_idx in enumerate(topk_indices):
name = class_labels[name_idx] # Get name from index
is_correct = name_idx == true_class_idx
fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
text_str = f'{name[:40]}' # Truncate long names
# Position text on the left side of the horizontal bars
ax_prob.text(
0.01, # Small offset from left edge
k - 1 - i, # Y-position corresponding to the bar
text_str,
#transform=ax_prob.transAxes, # Use data coordinates for Y
verticalalignment='center',
horizontalalignment='left',
fontsize=8,
color=fg_color,
alpha=0.9, # Slightly more visible than 0.5
path_effects=[
patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
patheffects.Normal()
])
# --- Plot Attention Heads & Overlays (using nearest interp) ---
# Re-interpolate attention using nearest neighbor for visual plotting
attention_interp_plot = F.interpolate(
torch.from_numpy(attention_now).unsqueeze(0).float(),
size=(img_h, img_w),
mode=interp_mode, # 'nearest'
).squeeze(0)
attn_mean = attention_interp_plot.mean(0)
attn_mean_min = attn_mean.min()
attn_mean_max = attn_mean.max()
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
# Normalize each head's map to [0, 1]
attn_min_plot = attention_interp_plot.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
attn_max_plot = attention_interp_plot.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
attention_interp_plot = (attention_interp_plot - attn_min_plot) / (attn_max_plot - attn_min_plot + 1e-6)
attention_interp_plot_np = attention_interp_plot.detach().cpu().numpy()
for head_i in list(range(n_heads)) + [-1]:
axname = f'head_{head_i}' if head_i != -1 else 'head_mean'
if axname not in axes: continue # Skip if mosaic doesn't have this head
ax = axes[axname]
ax_overlay = axes[f'{axname}_overlay']
# Plot attention heatmap
this_attn = attention_interp_plot_np[head_i] if head_i != -1 else attn_mean
img_to_plot = cmap_attention(this_attn)
ax.imshow(img_to_plot)
ax.axis('off')
# Plot overlay: image + paths
these_route_steps = head_routes[head_i]
arrow_scale = 1.5 if head_i != -1 else 3
if these_route_steps: # Only plot if path exists
# Separate y and x coordinates
y_coords, x_coords = zip(*these_route_steps)
y_coords = np.array(y_coords)
x_coords = np.array(x_coords)
# Flip y-coordinates for correct plotting (imshow origin is top-left)
# NOTE: Original flip seemed complex, simplifying to standard flip
y_coords_flipped = img_h - 1 - y_coords
# Show original image flipped vertically to match coordinate system
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
# Draw arrows for path segments
# Arrow size scaling from original
for i in range(len(these_route_steps) - 1):
dx = x_coords[i+1] - x_coords[i]
dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
# Draw white background arrow (thicker)
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
linewidth=1.6 * arrow_scale * 1.3,
head_width=1.9 * arrow_scale * 1.3,
head_length=1.4 * arrow_scale * 1.45,
fc='white', ec='white', length_includes_head=True, alpha=1)
# Draw colored foreground arrow
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
linewidth=1.6 * arrow_scale,
head_width=1.9 * arrow_scale,
head_length=1.4 * arrow_scale,
fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
length_includes_head=True)
else: # If no path yet, just show the image
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
# Set limits and turn off axes for overlay
ax_overlay.set_xlim([0, img_w - 1])
ax_overlay.set_ylim([0, img_h - 1])
ax_overlay.axis('off')
# --- Finalize and Save Frame ---
fig.tight_layout(pad=0.1) # Adjust spacing
# Render the plot to a numpy array
canvas = fig.canvas
canvas.draw()
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
frames.append(image_numpy) # Add to list for GIF
plt.close(fig) # Close figure to free memory
# --- Save GIF ---
gif_path = os.path.join(index_output_dir, f'{str(di)}_viz.gif')
print(f"Saving GIF to {gif_path}...")
imageio.mimsave(gif_path, frames, fps=15, loop=0) # loop=0 means infinite loop
save_frames_to_mp4([fm[:,:,::-1] for fm in frames], os.path.join(index_output_dir, f'{str(di)}_viz.mp4'), fps=15, gop_size=1, preset='veryslow')
if 'demo' in args.actions:
# --- Select Data Indices ---
if not args.data_indices: # If list is empty
n_samples = len(validation_dataset)
num_to_sample = min(args.N_to_viz, n_samples)
replace = n_samples < num_to_sample
data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
print(f"Selected random indices: {data_indices}")
else:
data_indices = args.data_indices
print(f"Using specified indices: {data_indices}")
for di in data_indices:
index_output_dir = os.path.join(args.output_dir, str(di))
os.makedirs(index_output_dir, exist_ok=True)
print(f'\nBuilding viz for dataset index {di}.')
inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
# Add batch dimension and send to device
inputs = inputs.to(device).unsqueeze(0)
predictions, certainties, synchronisations_over_time, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
# --- Reshape Attention ---
# Infer feature map size from model internals (assuming B=1)
h_feat, w_feat = model.kv_features.shape[-2:]
n_steps = predictions.size(-1)
n_heads = attention_tracking.shape[2]
# Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
# --- Setup for Plotting ---
step_linspace = np.linspace(0, 1, n_steps) # For step colors
# Define color maps
cmap_steps = sns.color_palette("Spectral", as_cmap=True)
cmap_attention = sns.color_palette('viridis', as_cmap=True)
# Create output directory for this index
frames = [] # Store frames for GIF
head_routes = [] # Store (y,x) path points per head
route_colours_step = [] # Store colors for each step's path segments
# --- Loop Through Each Step ---
for step_i in range(n_steps):
# Store step color
current_colour = list(cmap_steps(step_linspace[step_i]))
route_colours_step.append(current_colour)
# --- Prepare Image for Display ---
# Denormalize the input tensor for visualization
data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
data_img_denorm = data_img_tensor * std_tensor + mean_tensor
# Permute to (H, W, C) and convert to numpy, clip to [0, 1]
data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
data_img_np = np.clip(data_img_np, 0, 1)
img_h, img_w = data_img_np.shape[:2]
# --- Process Attention & Certainty ---
# Average attention over last few steps (from original code)
start_step = max(0, step_i - 5)
attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
# Get certainties up to current step
certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
# --- Calculate Attention Paths (using bilinear interp) ---
# Interpolate attention to image size using bilinear for center finding
attention_interp_bilinear = F.interpolate(
torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
size=(img_h, img_w),
mode=interp_mode,
).squeeze(0) # Remove batch dim -> (Heads, H, W)
attn_mean = attention_interp_bilinear.mean(0)
attn_mean_min = attn_mean.min()
attn_mean_max = attn_mean.max()
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
if centers: # If islands found
largest_island_idx = np.argmax(areas)
current_center = centers[largest_island_idx] # (y, x)
head_routes.append(current_center)
elif head_routes: # If no center now, repeat last known center if history exists
head_routes.append(head_routes[-1])
# --- Plotting Setup ---
# if n_heads != 8: print(f"Warning: Plotting layout assumes 8 heads, found {n_heads}. Layout may be incorrect.")
mosaic = [['head_0', 'head_1', 'head_2', 'head_3', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
['head_4', 'head_5', 'head_6', 'head_7', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
['head_8', 'head_9', 'head_10', 'head_11', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
['head_12', 'head_13', 'head_14', 'head_15', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
]
img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
aspect_ratio = (12 * figscale, 6 * figscale * img_aspect) # W, H
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
for ax in axes.values():
ax.axis('off')
# --- Plot Certainty ---
ax_cert = axes['certainty']
ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
# Add background color based on prediction correctness at each step
for ii in range(len(certainties_now)):
is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
facecolor = 'limegreen' if is_correct else 'orchid'
ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
# Mark the last point
ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
ax_cert.axis('off')
ax_cert.set_ylim([0.05, 1.05])
ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
# --- Plot Probabilities ---
ax_prob = axes['probabilities']
# Get probabilities for the current step
ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
k = 15 # Top k predictions
topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
topk_indices = topk_indices.numpy()
topk_probs = topk_probs.numpy()
top_classes = np.array(class_labels)[topk_indices]
true_class_idx = ground_truth_target # Ground truth index
# Determine bar colors (green if correct, blue otherwise - consistent with original)
colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
# Plot horizontal bars (inverted range for top-down display)
ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
ax_prob.set_xlim([0, 1])
ax_prob.axis('off')
# Add text labels for top classes
for i, name_idx in enumerate(topk_indices):
name = class_labels[name_idx] # Get name from index
is_correct = name_idx == true_class_idx
fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
text_str = f'{name[:40]}' # Truncate long names
# Position text on the left side of the horizontal bars
ax_prob.text(
0.01, # Small offset from left edge
k - 1 - i, # Y-position corresponding to the bar
text_str,
#transform=ax_prob.transAxes, # Use data coordinates for Y
verticalalignment='center',
horizontalalignment='left',
fontsize=8,
color=fg_color,
alpha=0.7, # Slightly more visible than 0.5
path_effects=[
patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
patheffects.Normal()
])
# --- Plot Attention Heads & Overlays (using nearest interp) ---
# Re-interpolate attention using nearest neighbor for visual plotting
attention_interp_plot = F.interpolate(
torch.from_numpy(attention_now).unsqueeze(0).float(),
size=(img_h, img_w),
mode=interp_mode # 'nearest'
).squeeze(0)
attn_mean = attention_interp_plot.mean(0)
attn_mean_min = attn_mean.min()
attn_mean_max = attn_mean.max()
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
img_to_plot = cmap_attention(attn_mean)
axes['head_mean'].imshow(img_to_plot)
axes['head_mean'].axis('off')
these_route_steps = head_routes
ax_overlay = axes['overlay']
if these_route_steps: # Only plot if path exists
# Separate y and x coordinates
y_coords, x_coords = zip(*these_route_steps)
y_coords = np.array(y_coords)
x_coords = np.array(x_coords)
# Flip y-coordinates for correct plotting (imshow origin is top-left)
# NOTE: Original flip seemed complex, simplifying to standard flip
y_coords_flipped = img_h - 1 - y_coords
# Show original image flipped vertically to match coordinate system
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
# Draw arrows for path segments
arrow_scale = 2 # Arrow size scaling from original
for i in range(len(these_route_steps) - 1):
dx = x_coords[i+1] - x_coords[i]
dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
# Draw white background arrow (thicker)
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
linewidth=1.6 * arrow_scale * 1.3,
head_width=1.9 * arrow_scale * 1.3,
head_length=1.4 * arrow_scale * 1.45,
fc='white', ec='white', length_includes_head=True, alpha=1)
# Draw colored foreground arrow
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
linewidth=1.6 * arrow_scale,
head_width=1.9 * arrow_scale,
head_length=1.4 * arrow_scale,
fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
length_includes_head=True)
# Set limits and turn off axes for overlay
ax_overlay.set_xlim([0, img_w - 1])
ax_overlay.set_ylim([0, img_h - 1])
ax_overlay.axis('off')
for head_i in range(n_heads):
if f'head_{head_i}' not in axes: continue # Skip if mosaic doesn't have this head
ax = axes[f'head_{head_i}']
# Plot attention heatmap
attn_up_to_now = attention_tracking[:step_i + 1, head_i].mean(0)
attn_up_to_now = (attn_up_to_now - attn_up_to_now.min())/(attn_up_to_now.max() - attn_up_to_now.min())
img_to_plot = cmap_attention(attn_up_to_now)
ax.imshow(img_to_plot)
ax.axis('off')
# --- Finalize and Save Frame ---
fig.tight_layout(pad=0.1) # Adjust spacing
# Render the plot to a numpy array
canvas = fig.canvas
canvas.draw()
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
frames.append(image_numpy) # Add to list for GIF
# Save individual frame if requested
if step_i==model.iterations-1:
fig.savefig(os.path.join(index_output_dir, f'frame_{step_i}.png'), dpi=200)
plt.close(fig) # Close figure to free memory
outfilename = os.path.join(index_output_dir, f'{di}_demo.mp4')
save_frames_to_mp4([fm[:,:,::-1] for fm in frames], outfilename, fps=15, gop_size=1, preset='veryslow')