stm32-modelzoo-app / common /utils /visualize_utils.py
FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from hydra.core.hydra_config import HydraConfig
import numpy as np
from sklearn.preprocessing import normalize
from sklearn.metrics import confusion_matrix
from PIL import Image
from typing import Dict
def vis_training_curves(history=None, output_dir: str = None) -> None:
"""
Visualizes the training curves of the model.
Args:
history: The history object returned by the model.fit() method.
output_dir (Optional[str]): The output directory to save the training curves plot.
Returns:
None
"""
# Extract the accuracy and loss values for training and validation data
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))
# Create dataframes for the training and validation data
df_val = pd.DataFrame(
{'run': 'validation', 'step': epochs_range, 'epoch_accuracy': val_acc, 'epoch_loss': val_loss})
df_train = pd.DataFrame({'run': 'train', 'step': epochs_range, 'epoch_accuracy': acc, 'epoch_loss': loss})
# Concatenate the dataframes
frames = [df_val, df_train]
df = pd.concat(frames)
df = df.reset_index()
# Plot the training curves
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
sns.lineplot(data=df, x="step", y="epoch_accuracy", hue="run").set_title("accuracy")
plt.grid()
plt.subplot(1, 2, 2)
sns.lineplot(data=df, x="step", y="epoch_loss", hue="run").set_title("loss")
plt.grid()
plt.savefig(os.path.join(output_dir, 'Training_curves.png'))
def plot_confusion_matrix(cm: np.ndarray = None,
class_names: list = None,
title: str = "f",
model_name: str = "f",
output_dir: str = None) -> None:
"""
Plots a confusion matrix using seaborn and saves it as an image.
Args:
cm (numpy.ndarray): The confusion matrix to plot.
class_names (list): A list of class names.title : str, Pre-pended to model test accuracy in the figure title
model_name (str): The name of the model.
output_dir (str): The directory where to save the image.
Returns:
None
"""
plt.figure(figsize=(14, 14))
confusion_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
axis_labels = list(class_names)
if len(class_names) > 20:
sns.set(font_scale=0.5)
plt.xticks(rotation=45, ha='right')
sns.heatmap(confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels, cmap='Blues',
annot=True, fmt='.2f', square=True)
plt.title(title, fontsize=10)
plt.tight_layout(pad=3)
plt.ylabel("True Label", fontsize=10)
plt.xlabel("Predicted Label", fontsize=10)
plt.savefig(os.path.join(output_dir, f"{model_name}.png"))
def display_figures(cfg: Dict = None):
"""
Displays all the figures created during the execution of current run stored in output_dir.
"""
if cfg.general.display_figures:
# Get a list of all the PNG files in it and display it
png_files = [f for f in os.listdir(cfg.output_dir) if f.endswith('.png')]
for png_file in png_files:
img = Image.open(os.path.join(cfg.output_dir, png_file))
img.show()
def compute_confusion_matrix2(y_true, y_pred):
'''Computes a confusion matrix for multiclass monolabel classification
Takes one-hot encoded labels and predictions as input.
Inputs
------
y_true : np.ndarray, (n_samples, n_classes) True labels. Must be one-hot encoded labels
y_pred : np.ndarray, (n_samples, n_classes) Predicted labels. Must be one-hot encoded labels.
Outputs
-------
matrix : ndarray, (n_classes, n_classes) : Confusion matrix'''
# Convert one-hot vectors to integer
y_pred = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_true, axis=1)
matrix = confusion_matrix(y_true, y_pred)
return matrix
def compute_multilabel_confusion_matrices():
"""
Compute confusion matrices for multilabel inference.
Outputs 1 2x2 matrix per class.
"""
print("this feature is not available yet!")
pass
def plot_multilabel_confusion_matrices():
"""Plots confusion matrices for multilabel inference.
Plots 1 2x2 matrix per class."""
print("this feature is not available yet!")
pass