architectural-style-classifier / src /utils\visualization.py
fxxkingusername's picture
Upload src/utils\visualization.py with huggingface_hub
99ca773 verified
"""
Visualization utilities for architectural style classification.
"""
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from typing import Dict, List, Optional, Tuple
def plot_attention_weights(attention_weights: torch.Tensor,
image: torch.Tensor,
title: str = "Attention Weights",
save_path: Optional[str] = None) -> None:
"""
Plot attention weights overlaid on the image.
Args:
attention_weights: Attention weights tensor [H, W]
image: Input image tensor [C, H, W]
title: Plot title
save_path: Path to save the plot
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
# Plot original image
img_np = image.permute(1, 2, 0).numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
ax1.imshow(img_np)
ax1.set_title("Original Image")
ax1.axis('off')
# Plot attention weights
attention_np = attention_weights.numpy()
im = ax2.imshow(attention_np, cmap='hot', alpha=0.7)
ax2.imshow(img_np, alpha=0.3)
ax2.set_title(title)
ax2.axis('off')
plt.colorbar(im, ax=ax2)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_training_curves(train_losses: List[float],
val_losses: List[float],
train_accuracies: Optional[List[float]] = None,
val_accuracies: Optional[List[float]] = None,
save_path: Optional[str] = None) -> None:
"""
Plot training and validation curves.
Args:
train_losses: Training losses per epoch
val_losses: Validation losses per epoch
train_accuracies: Training accuracies per epoch
val_accuracies: Validation accuracies per epoch
save_path: Path to save the plot
"""
epochs = range(1, len(train_losses) + 1)
if train_accuracies and val_accuracies:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Plot losses
ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)
# Plot accuracies
ax2.plot(epochs, train_accuracies, 'b-', label='Training Accuracy')
ax2.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True)
else:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.plot(epochs, train_losses, 'b-', label='Training Loss')
ax.plot(epochs, val_losses, 'r-', label='Validation Loss')
ax.set_title('Training and Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
ax.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_confusion_matrix(confusion_matrix: np.ndarray,
class_names: List[str],
title: str = "Confusion Matrix",
save_path: Optional[str] = None) -> None:
"""
Plot confusion matrix.
Args:
confusion_matrix: Confusion matrix array
class_names: List of class names
title: Plot title
save_path: Path to save the plot
"""
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title(title)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_class_distribution(class_counts: Dict[str, int],
title: str = "Class Distribution",
save_path: Optional[str] = None) -> None:
"""
Plot class distribution.
Args:
class_counts: Dictionary mapping class names to counts
title: Plot title
save_path: Path to save the plot
"""
classes = list(class_counts.keys())
counts = list(class_counts.values())
plt.figure(figsize=(12, 6))
bars = plt.bar(classes, counts)
plt.title(title)
plt.xlabel('Class')
plt.ylabel('Count')
plt.xticks(rotation=45)
# Add count labels on bars
for bar, count in zip(bars, counts):
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
str(count), ha='center', va='bottom')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_feature_maps(feature_maps: torch.Tensor,
title: str = "Feature Maps",
num_maps: int = 16,
save_path: Optional[str] = None) -> None:
"""
Plot feature maps from a convolutional layer.
Args:
feature_maps: Feature maps tensor [B, C, H, W]
title: Plot title
num_maps: Number of feature maps to display
save_path: Path to save the plot
"""
# Take first batch and select first num_maps channels
maps = feature_maps[0, :num_maps].detach().cpu()
# Calculate grid size
grid_size = int(np.ceil(np.sqrt(num_maps)))
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
axes = axes.flatten()
for i in range(num_maps):
if i < maps.shape[0]:
im = axes[i].imshow(maps[i], cmap='viridis')
axes[i].set_title(f'Feature Map {i+1}')
axes[i].axis('off')
else:
axes[i].axis('off')
plt.suptitle(title)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_learning_rate_schedule(learning_rates: List[float],
steps: List[int],
title: str = "Learning Rate Schedule",
save_path: Optional[str] = None) -> None:
"""
Plot learning rate schedule.
Args:
learning_rates: List of learning rates
steps: List of step numbers
title: Plot title
save_path: Path to save the plot
"""
plt.figure(figsize=(10, 6))
plt.plot(steps, learning_rates, 'b-', linewidth=2)
plt.title(title)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.yscale('log')
plt.grid(True)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()