taller_CNN / utils /visualization.py
NICOMOSHE's picture
Upload 74 files
b4b8733 verified
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from typing import List, Optional, Tuple
import os
class Visualizer:
def __init__(self, style: str = 'seaborn-v0_8-darkgrid'):
try:
plt.style.use(style)
except:
plt.style.use('ggplot')
sns.set_palette('husl')
@staticmethod
def plot_training_history(
history: dict,
save_path: Optional[str] = None,
figsize: Tuple[int, int] = (15, 5)
):
fig, axes = plt.subplots(1, 3, figsize=figsize)
epochs = history['epoch']
axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs, history['test_loss'], 'r-', label='Test Loss', linewidth=2)
axes[0].set_xlabel('Época')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss durante Entrenamiento')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
axes[1].plot(epochs, history['test_acc'], 'r-', label='Test Acc', linewidth=2)
axes[1].set_xlabel('Época')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Accuracy durante Entrenamiento')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[2].plot(epochs, history['learning_rate'], 'g-', linewidth=2)
axes[2].set_xlabel('Época')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f'Gráfico guardado en: {save_path}')
else:
plt.show()
plt.close()
@staticmethod
def plot_sample_predictions(
images: torch.Tensor,
predictions: np.ndarray,
targets: np.ndarray,
classes: List[str],
num_samples: int = 16,
save_path: Optional[str] = None
):
num_samples = min(num_samples, len(images))
rows = int(np.sqrt(num_samples))
cols = (num_samples + rows - 1) // rows
fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.5, rows * 2.5))
axes = axes.flatten() if num_samples > 1 else [axes]
for i in range(num_samples):
img = images[i].cpu().numpy()
if img.shape[0] == 1:
img = img.squeeze()
axes[i].imshow(img, cmap='gray')
else:
img = np.transpose(img, (1, 2, 0))
img = (img - img.min()) / (img.max() - img.min())
axes[i].imshow(img)
pred_class = classes[predictions[i]]
true_class = classes[targets[i]]
color = 'green' if predictions[i] == targets[i] else 'red'
axes[i].set_title(f'P: {pred_class}\nR: {true_class}', color=color, fontsize=9)
axes[i].axis('off')
for i in range(num_samples, len(axes)):
axes[i].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
else:
plt.show()
plt.close()
@staticmethod
def plot_class_distribution(
targets: np.ndarray,
classes: List[str],
title: str = 'Distribución de Clases',
save_path: Optional[str] = None
):
unique, counts = np.unique(targets, return_counts=True)
plt.figure(figsize=(12, 6))
bars = plt.bar(range(len(classes)), [0] * len(classes), color='lightgray')
for i, u in enumerate(unique):
if u < len(classes):
bars[u].set_height(counts[i])
bars[u].set_color('steelblue')
plt.xlabel('Clase')
plt.ylabel('Cantidad')
plt.title(title)
plt.xticks(range(len(classes)), [classes[i] for i in range(len(classes))], rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
else:
plt.show()
plt.close()
@staticmethod
def plot_loss_comparison(
results: dict,
save_path: Optional[str] = None
):
datasets = list(results.keys())
train_losses = [results[d]['train_loss'] for d in datasets]
test_losses = [results[d]['test_loss'] for d in datasets]
x = np.arange(len(datasets))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width/2, train_losses, width, label='Train Loss', color='steelblue')
bars2 = ax.bar(x + width/2, test_losses, width, label='Test Loss', color='coral')
ax.set_xlabel('Dataset')
ax.set_ylabel('Loss')
ax.set_title('Comparación de Loss entre Datasets')
ax.set_xticks(x)
ax.set_xticklabels(datasets)
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
else:
plt.show()
plt.close()
@staticmethod
def plot_accuracy_comparison(
results: dict,
save_path: Optional[str] = None
):
datasets = list(results.keys())
train_accs = [results[d]['train_acc'] for d in datasets]
test_accs = [results[d]['test_acc'] for d in datasets]
x = np.arange(len(datasets))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width/2, train_accs, width, label='Train Acc', color='steelblue')
bars2 = ax.bar(x + width/2, test_accs, width, label='Test Acc', color='coral')
for bar in bars1:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 1,
f'{height:.1f}%', ha='center', va='bottom', fontsize=8)
for bar in bars2:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 1,
f'{height:.1f}%', ha='center', va='bottom', fontsize=8)
ax.set_xlabel('Dataset')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Comparación de Accuracy entre Datasets')
ax.set_xticks(x)
ax.set_xticklabels(datasets)
ax.legend()
ax.grid(axis='y', alpha=0.3)
ax.set_ylim(0, 110)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
else:
plt.show()
plt.close()
@staticmethod
def plot_filters(model: torch.nn.Module, layer_name: str = 'features.0', save_path: Optional[str] = None):
try:
layer = dict(model.named_modules())[layer_name]
if isinstance(layer, torch.nn.Conv2d):
filters = layer.weight.data.cpu().numpy()
num_filters = min(32, filters.shape[0])
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
axes = axes.flatten()
for i in range(num_filters):
if filters.shape[1] == 3:
f = filters[i].transpose(1, 2, 0)
f = (f - f.min()) / (f.max() - f.min())
else:
f = filters[i, 0]
axes[i].imshow(f, cmap='gray')
axes[i].axis('off')
for i in range(num_filters, len(axes)):
axes[i].axis('off')
plt.suptitle(f'Filtros de la capa: {layer_name}')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
else:
plt.show()
plt.close()
except Exception as e:
print(f'No se pudieron visualizar los filtros: {e}')