Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import torch | |
| from typing import List, Optional, Tuple | |
| import os | |
| class Visualizer: | |
| def __init__(self, style: str = 'seaborn-v0_8-darkgrid'): | |
| try: | |
| plt.style.use(style) | |
| except: | |
| plt.style.use('ggplot') | |
| sns.set_palette('husl') | |
| def plot_training_history( | |
| history: dict, | |
| save_path: Optional[str] = None, | |
| figsize: Tuple[int, int] = (15, 5) | |
| ): | |
| fig, axes = plt.subplots(1, 3, figsize=figsize) | |
| epochs = history['epoch'] | |
| axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2) | |
| axes[0].plot(epochs, history['test_loss'], 'r-', label='Test Loss', linewidth=2) | |
| axes[0].set_xlabel('Época') | |
| axes[0].set_ylabel('Loss') | |
| axes[0].set_title('Loss durante Entrenamiento') | |
| axes[0].legend() | |
| axes[0].grid(True, alpha=0.3) | |
| axes[1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2) | |
| axes[1].plot(epochs, history['test_acc'], 'r-', label='Test Acc', linewidth=2) | |
| axes[1].set_xlabel('Época') | |
| axes[1].set_ylabel('Accuracy (%)') | |
| axes[1].set_title('Accuracy durante Entrenamiento') | |
| axes[1].legend() | |
| axes[1].grid(True, alpha=0.3) | |
| axes[2].plot(epochs, history['learning_rate'], 'g-', linewidth=2) | |
| axes[2].set_xlabel('Época') | |
| axes[2].set_ylabel('Learning Rate') | |
| axes[2].set_title('Learning Rate Schedule') | |
| axes[2].set_yscale('log') | |
| axes[2].grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f'Gráfico guardado en: {save_path}') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_sample_predictions( | |
| images: torch.Tensor, | |
| predictions: np.ndarray, | |
| targets: np.ndarray, | |
| classes: List[str], | |
| num_samples: int = 16, | |
| save_path: Optional[str] = None | |
| ): | |
| num_samples = min(num_samples, len(images)) | |
| rows = int(np.sqrt(num_samples)) | |
| cols = (num_samples + rows - 1) // rows | |
| fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.5, rows * 2.5)) | |
| axes = axes.flatten() if num_samples > 1 else [axes] | |
| for i in range(num_samples): | |
| img = images[i].cpu().numpy() | |
| if img.shape[0] == 1: | |
| img = img.squeeze() | |
| axes[i].imshow(img, cmap='gray') | |
| else: | |
| img = np.transpose(img, (1, 2, 0)) | |
| img = (img - img.min()) / (img.max() - img.min()) | |
| axes[i].imshow(img) | |
| pred_class = classes[predictions[i]] | |
| true_class = classes[targets[i]] | |
| color = 'green' if predictions[i] == targets[i] else 'red' | |
| axes[i].set_title(f'P: {pred_class}\nR: {true_class}', color=color, fontsize=9) | |
| axes[i].axis('off') | |
| for i in range(num_samples, len(axes)): | |
| axes[i].axis('off') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_class_distribution( | |
| targets: np.ndarray, | |
| classes: List[str], | |
| title: str = 'Distribución de Clases', | |
| save_path: Optional[str] = None | |
| ): | |
| unique, counts = np.unique(targets, return_counts=True) | |
| plt.figure(figsize=(12, 6)) | |
| bars = plt.bar(range(len(classes)), [0] * len(classes), color='lightgray') | |
| for i, u in enumerate(unique): | |
| if u < len(classes): | |
| bars[u].set_height(counts[i]) | |
| bars[u].set_color('steelblue') | |
| plt.xlabel('Clase') | |
| plt.ylabel('Cantidad') | |
| plt.title(title) | |
| plt.xticks(range(len(classes)), [classes[i] for i in range(len(classes))], rotation=45, ha='right') | |
| plt.grid(axis='y', alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_loss_comparison( | |
| results: dict, | |
| save_path: Optional[str] = None | |
| ): | |
| datasets = list(results.keys()) | |
| train_losses = [results[d]['train_loss'] for d in datasets] | |
| test_losses = [results[d]['test_loss'] for d in datasets] | |
| x = np.arange(len(datasets)) | |
| width = 0.35 | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| bars1 = ax.bar(x - width/2, train_losses, width, label='Train Loss', color='steelblue') | |
| bars2 = ax.bar(x + width/2, test_losses, width, label='Test Loss', color='coral') | |
| ax.set_xlabel('Dataset') | |
| ax.set_ylabel('Loss') | |
| ax.set_title('Comparación de Loss entre Datasets') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(datasets) | |
| ax.legend() | |
| ax.grid(axis='y', alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_accuracy_comparison( | |
| results: dict, | |
| save_path: Optional[str] = None | |
| ): | |
| datasets = list(results.keys()) | |
| train_accs = [results[d]['train_acc'] for d in datasets] | |
| test_accs = [results[d]['test_acc'] for d in datasets] | |
| x = np.arange(len(datasets)) | |
| width = 0.35 | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| bars1 = ax.bar(x - width/2, train_accs, width, label='Train Acc', color='steelblue') | |
| bars2 = ax.bar(x + width/2, test_accs, width, label='Test Acc', color='coral') | |
| for bar in bars1: | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height + 1, | |
| f'{height:.1f}%', ha='center', va='bottom', fontsize=8) | |
| for bar in bars2: | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height + 1, | |
| f'{height:.1f}%', ha='center', va='bottom', fontsize=8) | |
| ax.set_xlabel('Dataset') | |
| ax.set_ylabel('Accuracy (%)') | |
| ax.set_title('Comparación de Accuracy entre Datasets') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(datasets) | |
| ax.legend() | |
| ax.grid(axis='y', alpha=0.3) | |
| ax.set_ylim(0, 110) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_filters(model: torch.nn.Module, layer_name: str = 'features.0', save_path: Optional[str] = None): | |
| try: | |
| layer = dict(model.named_modules())[layer_name] | |
| if isinstance(layer, torch.nn.Conv2d): | |
| filters = layer.weight.data.cpu().numpy() | |
| num_filters = min(32, filters.shape[0]) | |
| fig, axes = plt.subplots(4, 8, figsize=(16, 8)) | |
| axes = axes.flatten() | |
| for i in range(num_filters): | |
| if filters.shape[1] == 3: | |
| f = filters[i].transpose(1, 2, 0) | |
| f = (f - f.min()) / (f.max() - f.min()) | |
| else: | |
| f = filters[i, 0] | |
| axes[i].imshow(f, cmap='gray') | |
| axes[i].axis('off') | |
| for i in range(num_filters, len(axes)): | |
| axes[i].axis('off') | |
| plt.suptitle(f'Filtros de la capa: {layer_name}') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| except Exception as e: | |
| print(f'No se pudieron visualizar los filtros: {e}') | |