|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import logging |
|
|
import numpy as np |
|
|
import io |
|
|
import math |
|
|
import random |
|
|
from PIL import Image |
|
|
from matplotlib import pyplot as plt |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from torchvision import transforms |
|
|
from .functions import REG_FUNCTION_MAP |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def register( |
|
|
flags: dict, |
|
|
tensor: float | torch.Tensor, |
|
|
valid_mask: torch.Tensor, |
|
|
epoch: int, |
|
|
writer: SummaryWriter, |
|
|
logger: logging.Logger, |
|
|
tensorboard_required: bool, |
|
|
parameter_name: str = '' |
|
|
): |
|
|
""" |
|
|
Registers a parameter according to the register flags (DEFAULT_WATCHER style). |
|
|
|
|
|
:param flags: A specific watch flag. |
|
|
:param tensor: The tensor to register. |
|
|
:param valid_mask: The valid mask to apply. |
|
|
:param epoch: The current epoch. |
|
|
:param writer: The tensorboard writer. |
|
|
:param logger: The logger. |
|
|
:param tensorboard_required: Whether the tensorboard writer is required. |
|
|
:param parameter_name: The name of the parameter. |
|
|
:return: |
|
|
""" |
|
|
|
|
|
if isinstance(tensor, torch.nn.Parameter): |
|
|
flag_type = 'parameters' |
|
|
elif isinstance(tensor, torch.Tensor): |
|
|
|
|
|
flag_type = 'activations' |
|
|
elif isinstance(tensor, float): |
|
|
flag_type = 'train' |
|
|
else: |
|
|
raise ValueError(f"{type(tensor)} is not a torch.nn.Parameter or torch.Tensor.") |
|
|
|
|
|
|
|
|
safe_names = list() |
|
|
|
|
|
if flag_type == 'parameters': |
|
|
for flag_key, flag_value in flags['parameters'].items(): |
|
|
|
|
|
if flag_value: |
|
|
safe_names.append((f'{flag_type}/{flag_key}/{parameter_name}/', flag_key)) |
|
|
else: |
|
|
safe_names.append((f'{flag_type}/{parameter_name}/', '')) |
|
|
|
|
|
|
|
|
|
|
|
for name, flag_key in safe_names: |
|
|
|
|
|
transformation = None |
|
|
if isinstance(tensor, torch.nn.Parameter): |
|
|
if tensor.grad is not None and 'grad' in flag_key: |
|
|
transformation = REG_FUNCTION_MAP[flag_key](tensor, valid_mask) |
|
|
else: |
|
|
transformation = float(tensor) if tensor is not None else None |
|
|
|
|
|
if transformation is not None: |
|
|
write_tensorboard( |
|
|
name=name, |
|
|
value=transformation, |
|
|
epoch=epoch, |
|
|
writer=writer, |
|
|
logger=logger, |
|
|
tensorboard_required=tensorboard_required, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def register_replay( |
|
|
predicted: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
epoch: int, |
|
|
writer: SummaryWriter, |
|
|
logger: logging.Logger, |
|
|
valid_mask: torch.Tensor = Ellipsis, |
|
|
element: int = None, |
|
|
tensorboard_required: bool = True, |
|
|
) -> plt.Figure: |
|
|
""" |
|
|
Registers a replay as an image. |
|
|
:param predicted: The predicted value (prediction). |
|
|
:param target: The expected value (labels). |
|
|
:param epoch: The current epoch. |
|
|
:param writer: The tensorboard writer. |
|
|
:param logger: The logger. |
|
|
:param valid_mask: A valid mask tensor of same shape. False positions are ignored (valid mask). |
|
|
:param element: The element to register, None chooses a random batch element. |
|
|
:param tensorboard_required: Whether the tensorboard writer is required. |
|
|
:return: A matplotlib figure. |
|
|
""" |
|
|
|
|
|
if element is None: |
|
|
element = random.randint(0, len(predicted) - 1) |
|
|
else: |
|
|
element = min(len(predicted) - 1, max(0, element)) |
|
|
|
|
|
|
|
|
predicted_np = predicted[element].detach().cpu().numpy() |
|
|
target_np = target[element].detach().cpu().numpy() |
|
|
|
|
|
|
|
|
if not target_np.shape: |
|
|
target_np_aux = np.zeros_like(predicted_np) |
|
|
target_np_aux[target_np] = 1. |
|
|
target_np = target_np_aux |
|
|
del target_np_aux |
|
|
|
|
|
|
|
|
if valid_mask is not None: |
|
|
mask_np = valid_mask[element].detach().cpu().numpy().astype(bool) |
|
|
else: |
|
|
mask_np = np.ones_like(predicted_np, dtype=bool) |
|
|
|
|
|
|
|
|
predicted_flat = predicted_np[mask_np].flatten() |
|
|
target_flat = target_np[mask_np].flatten() |
|
|
|
|
|
|
|
|
s = predicted_flat.shape[0] |
|
|
b = math.ceil(math.sqrt(s)) |
|
|
total = b * b |
|
|
pad = total - s |
|
|
|
|
|
|
|
|
predicted_padded = np.pad(predicted_flat, (0, pad), constant_values=0.0).reshape(b, b) |
|
|
target_padded = np.pad(target_flat, (0, pad), constant_values=0.0).reshape(b, b) |
|
|
|
|
|
|
|
|
fig, axs = plt.subplots(1, 2, figsize=(10, 5)) |
|
|
plot_with_values(axs[0], predicted_padded, "Predicted (y_hat)") |
|
|
plot_with_values(axs[1], target_padded, "Target (y)") |
|
|
plt.tight_layout() |
|
|
write_tensorboard( |
|
|
'replay/', |
|
|
fig, |
|
|
epoch=epoch, |
|
|
writer=writer, |
|
|
logger=logger, |
|
|
tensorboard_required=tensorboard_required, |
|
|
) |
|
|
return fig |
|
|
|
|
|
def plot_with_values(ax, data, title): |
|
|
""" |
|
|
Plots data with values and title. |
|
|
:param ax: A matplotlib axes. |
|
|
:param data: A numpy array. |
|
|
:param title: The title of the plot. |
|
|
:return: |
|
|
""" |
|
|
ax.imshow(data, cmap='viridis', interpolation='nearest') |
|
|
ax.set_title(title) |
|
|
ax.axis('off') |
|
|
for i in range(data.shape[0]): |
|
|
for j in range(data.shape[1]): |
|
|
text_color = "white" if data[i, j] < 0.5 else "black" |
|
|
ax.text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color=text_color, fontsize=8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_tensorboard( |
|
|
name: str, |
|
|
value: int | float | plt.Figure | np.ndarray | torch.Tensor, |
|
|
epoch: int, |
|
|
writer: SummaryWriter, |
|
|
logger: logging.Logger, |
|
|
tensorboard_required: bool = True, |
|
|
) -> None: |
|
|
""" |
|
|
Write to tensorboard. |
|
|
:param name: The name of the tensorboard. |
|
|
:param value: The value to write. |
|
|
:param epoch: The current epoch. |
|
|
:param writer: The tensorboard writer. |
|
|
:param logger: The logger. |
|
|
:param tensorboard_required: Whether the tensorboard writer is required. |
|
|
""" |
|
|
|
|
|
if writer is None: |
|
|
if tensorboard_required: |
|
|
logger.warning("Writer is None. Please set the writer first.") |
|
|
return |
|
|
|
|
|
if value is None: |
|
|
logger.warning("Value is None. Please set the value first.") |
|
|
return |
|
|
|
|
|
if name is None: |
|
|
logger.warning("Name is None. Please set the name first.") |
|
|
return |
|
|
|
|
|
|
|
|
if isinstance(value, int): |
|
|
writer.add_scalar(name, float(value), epoch) |
|
|
elif isinstance(value, float): |
|
|
writer.add_scalar(name, value, epoch) |
|
|
elif isinstance(value, torch.Tensor): |
|
|
value = value.detach().cpu().numpy() |
|
|
writer.add_histogram(name, value, epoch) |
|
|
elif isinstance(value, list): |
|
|
value = np.array(value) |
|
|
writer.add_histogram(name, value, epoch) |
|
|
elif isinstance(value, np.ndarray): |
|
|
writer.add_histogram(name, value, epoch) |
|
|
elif isinstance(value, str): |
|
|
writer.add_text(name, value, epoch) |
|
|
elif isinstance(value, bytes): |
|
|
image = Image.open(io.BytesIO(value)) |
|
|
transform = transforms.ToTensor() |
|
|
value = transform(image) |
|
|
writer.add_image(name, value, epoch) |
|
|
elif isinstance(value, plt.Figure): |
|
|
buf = io.BytesIO() |
|
|
value.savefig(buf, format='png') |
|
|
buf.seek(0) |
|
|
image = Image.open(buf) |
|
|
image = transforms.ToTensor()(image) |
|
|
writer.add_image(name, image, epoch) |
|
|
plt.close() |
|
|
else: |
|
|
raise ValueError(f"Type {type(value)} not supported.") |
|
|
|
|
|
|
|
|
|
|
|
|