Spaces:
Running
Running
| """ | |
| Plotting utilities for PyTorch Playground demos. | |
| Provides consistent, styled plots for training metrics, tensors, etc. | |
| """ | |
| from typing import List, Optional, Dict, Any, Tuple | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import numpy as np | |
| # Use non-interactive backend for Gradio | |
| matplotlib.use("Agg") | |
| # Consistent style | |
| COLORS = { | |
| "primary": "#ee4c2c", # PyTorch orange | |
| "secondary": "#29b6f6", # Light blue | |
| "tertiary": "#66bb6a", # Green | |
| "quaternary": "#ab47bc", # Purple | |
| "loss": "#ef5350", # Red | |
| "accuracy": "#66bb6a", # Green | |
| "lr": "#ff9800", # Orange | |
| } | |
| def setup_plot_style(): | |
| """Apply consistent plot styling.""" | |
| plt.style.use("seaborn-v0_8-whitegrid") | |
| plt.rcParams.update( | |
| { | |
| "figure.facecolor": "white", | |
| "axes.facecolor": "white", | |
| "axes.edgecolor": "#cccccc", | |
| "axes.labelcolor": "#333333", | |
| "text.color": "#333333", | |
| "xtick.color": "#666666", | |
| "ytick.color": "#666666", | |
| "grid.color": "#eeeeee", | |
| "font.size": 10, | |
| "axes.titlesize": 12, | |
| "axes.labelsize": 10, | |
| } | |
| ) | |
| def create_loss_plot( | |
| losses: List[float], | |
| val_losses: Optional[List[float]] = None, | |
| title: str = "Training Loss", | |
| figsize: Tuple[int, int] = (8, 5), | |
| ) -> plt.Figure: | |
| """ | |
| Create a training loss plot. | |
| Args: | |
| losses: List of training loss values | |
| val_losses: Optional list of validation loss values | |
| title: Plot title | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| setup_plot_style() | |
| fig, ax = plt.subplots(figsize=figsize) | |
| epochs = range(1, len(losses) + 1) | |
| ax.plot(epochs, losses, color=COLORS["loss"], linewidth=2, label="Training Loss") | |
| if val_losses: | |
| ax.plot( | |
| epochs, | |
| val_losses, | |
| color=COLORS["secondary"], | |
| linewidth=2, | |
| linestyle="--", | |
| label="Validation Loss", | |
| ) | |
| ax.set_xlabel("Epoch") | |
| ax.set_ylabel("Loss") | |
| ax.set_title(title) | |
| ax.legend() | |
| # Set y-axis to start at 0 if all losses are positive | |
| if min(losses) >= 0: | |
| ax.set_ylim(bottom=0) | |
| plt.tight_layout() | |
| return fig | |
| def create_metrics_plot( | |
| metrics: Dict[str, List[float]], | |
| title: str = "Training Metrics", | |
| figsize: Tuple[int, int] = (10, 5), | |
| ) -> plt.Figure: | |
| """ | |
| Create a multi-metric plot. | |
| Args: | |
| metrics: Dictionary of metric name -> values | |
| title: Plot title | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| setup_plot_style() | |
| fig, axes = plt.subplots(1, len(metrics), figsize=figsize) | |
| if len(metrics) == 1: | |
| axes = [axes] | |
| colors = list(COLORS.values()) | |
| for idx, (name, values) in enumerate(metrics.items()): | |
| ax = axes[idx] | |
| epochs = range(1, len(values) + 1) | |
| ax.plot(epochs, values, color=colors[idx % len(colors)], linewidth=2) | |
| ax.set_xlabel("Epoch") | |
| ax.set_ylabel(name.capitalize()) | |
| ax.set_title(name.capitalize()) | |
| plt.suptitle(title) | |
| plt.tight_layout() | |
| return fig | |
| def create_confusion_matrix_plot( | |
| cm: np.ndarray, | |
| class_names: Optional[List[str]] = None, | |
| title: str = "Confusion Matrix", | |
| figsize: Tuple[int, int] = (8, 6), | |
| cmap: str = "Blues", | |
| ) -> plt.Figure: | |
| """ | |
| Create a confusion matrix heatmap. | |
| Args: | |
| cm: Confusion matrix array | |
| class_names: Optional list of class names | |
| title: Plot title | |
| figsize: Figure size | |
| cmap: Colormap name | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| setup_plot_style() | |
| fig, ax = plt.subplots(figsize=figsize) | |
| im = ax.imshow(cm, interpolation="nearest", cmap=cmap) | |
| ax.figure.colorbar(im, ax=ax) | |
| if class_names is None: | |
| class_names = [str(i) for i in range(len(cm))] | |
| ax.set( | |
| xticks=np.arange(len(class_names)), | |
| yticks=np.arange(len(class_names)), | |
| xticklabels=class_names, | |
| yticklabels=class_names, | |
| ylabel="True label", | |
| xlabel="Predicted label", | |
| title=title, | |
| ) | |
| plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") | |
| # Add text annotations | |
| thresh = cm.max() / 2.0 | |
| for i in range(len(class_names)): | |
| for j in range(len(class_names)): | |
| ax.text( | |
| j, | |
| i, | |
| format(cm[i, j], "d"), | |
| ha="center", | |
| va="center", | |
| color="white" if cm[i, j] > thresh else "black", | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| def create_tensor_visualization( | |
| tensor: "torch.Tensor", | |
| title: str = "Tensor Visualization", | |
| figsize: Tuple[int, int] = (8, 6), | |
| ) -> plt.Figure: | |
| """ | |
| Create a visualization of a tensor. | |
| Args: | |
| tensor: PyTorch tensor to visualize | |
| title: Plot title | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| import torch | |
| setup_plot_style() | |
| # Convert to numpy for plotting | |
| data = tensor.detach().cpu().numpy() | |
| if data.ndim == 1: | |
| # 1D: Bar chart | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.bar(range(len(data)), data, color=COLORS["primary"]) | |
| ax.set_xlabel("Index") | |
| ax.set_ylabel("Value") | |
| ax.set_title(f"{title} (1D Tensor)") | |
| elif data.ndim == 2: | |
| # 2D: Heatmap | |
| fig, ax = plt.subplots(figsize=figsize) | |
| im = ax.imshow(data, cmap="RdYlBu_r", aspect="auto") | |
| ax.figure.colorbar(im, ax=ax) | |
| ax.set_xlabel("Column") | |
| ax.set_ylabel("Row") | |
| ax.set_title(f"{title} (2D Tensor: {data.shape})") | |
| elif data.ndim == 3: | |
| # 3D: Show as image if channel-like, else show slices | |
| if data.shape[0] in [1, 3, 4]: # Likely CHW format | |
| if data.shape[0] == 1: | |
| data = data[0] # Squeeze channel dim | |
| fig, ax = plt.subplots(figsize=figsize) | |
| im = ax.imshow(data, cmap="gray") | |
| ax.set_title(f"{title} (Grayscale Image)") | |
| else: | |
| # RGB(A) image | |
| if data.shape[0] == 3: | |
| data = np.transpose(data, (1, 2, 0)) | |
| else: | |
| data = np.transpose(data, (1, 2, 0))[:, :, :3] | |
| # Normalize to 0-1 | |
| data = (data - data.min()) / (data.max() - data.min() + 1e-8) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.imshow(data) | |
| ax.set_title(f"{title} (RGB Image)") | |
| else: | |
| # Show first slice | |
| fig, ax = plt.subplots(figsize=figsize) | |
| im = ax.imshow(data[0], cmap="viridis", aspect="auto") | |
| ax.figure.colorbar(im, ax=ax) | |
| ax.set_title(f"{title} (3D Tensor slice 0 of {data.shape[0]})") | |
| elif data.ndim == 4: | |
| # 4D: Show grid of first batch items | |
| n_show = min(4, data.shape[0]) | |
| fig, axes = plt.subplots(1, n_show, figsize=(figsize[0], figsize[1] // 2)) | |
| if n_show == 1: | |
| axes = [axes] | |
| for i, ax in enumerate(axes): | |
| if data.shape[1] in [1, 3]: | |
| img = data[i] | |
| if img.shape[0] == 1: | |
| img = img[0] | |
| ax.imshow(img, cmap="gray") | |
| else: | |
| img = np.transpose(img, (1, 2, 0)) | |
| img = (img - img.min()) / (img.max() - img.min() + 1e-8) | |
| ax.imshow(img) | |
| else: | |
| ax.imshow(data[i, 0], cmap="viridis") | |
| ax.set_title(f"Sample {i}") | |
| ax.axis("off") | |
| fig.suptitle(f"{title} (4D Tensor: {data.shape})") | |
| else: | |
| # Higher dimensions: just show stats | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.text( | |
| 0.5, | |
| 0.5, | |
| f"Tensor shape: {data.shape}\n" | |
| f"Min: {data.min():.4f}\n" | |
| f"Max: {data.max():.4f}\n" | |
| f"Mean: {data.mean():.4f}\n" | |
| f"Std: {data.std():.4f}", | |
| transform=ax.transAxes, | |
| fontsize=12, | |
| verticalalignment="center", | |
| horizontalalignment="center", | |
| ) | |
| ax.set_title(f"{title} (High-dimensional Tensor)") | |
| ax.axis("off") | |
| plt.tight_layout() | |
| return fig | |
| def create_lr_schedule_plot( | |
| lrs: List[float], | |
| title: str = "Learning Rate Schedule", | |
| figsize: Tuple[int, int] = (8, 4), | |
| ) -> plt.Figure: | |
| """ | |
| Plot learning rate schedule. | |
| Args: | |
| lrs: List of learning rate values | |
| title: Plot title | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| setup_plot_style() | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.plot(range(len(lrs)), lrs, color=COLORS["lr"], linewidth=2) | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Learning Rate") | |
| ax.set_title(title) | |
| ax.set_yscale("log") | |
| plt.tight_layout() | |
| return fig | |
| def create_timing_comparison_plot( | |
| results: Dict[str, float], | |
| title: str = "Timing Comparison", | |
| figsize: Tuple[int, int] = (8, 5), | |
| ) -> plt.Figure: | |
| """ | |
| Create a bar chart comparing timing results. | |
| Args: | |
| results: Dictionary of method name -> time in ms | |
| title: Plot title | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| setup_plot_style() | |
| fig, ax = plt.subplots(figsize=figsize) | |
| names = list(results.keys()) | |
| times = list(results.values()) | |
| colors = [COLORS["primary"], COLORS["secondary"], COLORS["tertiary"]][: len(names)] | |
| bars = ax.bar(names, times, color=colors) | |
| ax.set_ylabel("Time (ms)") | |
| ax.set_title(title) | |
| # Add value labels on bars | |
| for bar, time in zip(bars, times): | |
| height = bar.get_height() | |
| ax.annotate( | |
| f"{time:.2f}ms", | |
| xy=(bar.get_x() + bar.get_width() / 2, height), | |
| xytext=(0, 3), | |
| textcoords="offset points", | |
| ha="center", | |
| va="bottom", | |
| fontsize=10, | |
| ) | |
| plt.tight_layout() | |
| return fig | |