|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import torch |
|
|
import os |
|
|
import glob |
|
|
import json |
|
|
import matplotlib.pyplot as plt |
|
|
from .logger import get_logger |
|
|
from .tensorboard import get_writer |
|
|
from .seeds import get_seed |
|
|
from .device import get_device |
|
|
from .clear import clear_logs |
|
|
from .marker import register_replay, register |
|
|
from .watchers import DEFAULT_WATCHER, S_WATCHER, A_WATCHER, B_WATCHER, C_WATCHER, CNN_WATCHER, AEN_WATCHER, TRA_WATCHER |
|
|
from dataclasses import asdict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Setup: |
|
|
def __init__( |
|
|
self, |
|
|
path: str, |
|
|
device: int = 0, |
|
|
seed: int = None, |
|
|
save_each: int = 1, |
|
|
reload_state: bool = False, |
|
|
tensorboard: int | bool = 6006, |
|
|
autoscaler: bool = True, |
|
|
replay_element: tuple = (-1, None) |
|
|
): |
|
|
""" |
|
|
This class is used to set up the environment for an AI experiment. It saves |
|
|
the model checkpoints, logs, and tensorboard files. It also sets the device |
|
|
and seed for reproducibility. |
|
|
|
|
|
Usage: |
|
|
|
|
|
>>> from *** import Setup |
|
|
>>> setup = Setup(path='logs', device=0, seed=42, save_each=10) |
|
|
|
|
|
Inside the train loop: |
|
|
|
|
|
>>> model: torch.Model |
|
|
>>> loss_value: torch.Tensor |
|
|
>>> y: torch.Tensor |
|
|
>>> y_hat: torch.Tensor |
|
|
|
|
|
>>> setup.check(model) |
|
|
>>> setup.register('loss', loss_value) |
|
|
>>> setup.register_replay(y, y_hat) |
|
|
|
|
|
In case you want to reload latest checkpoint: |
|
|
|
|
|
>>> setup.reload(model) |
|
|
|
|
|
|
|
|
:param path: The path to the logs. |
|
|
:param device: The device to use. |
|
|
:param seed: The seed to use. |
|
|
:param save_each: The number of epochs to save the model. |
|
|
:param reload_state: Whether to reload the latest checkpoint. |
|
|
:param tensorboard: Whether to use tensorboard. |
|
|
:param autoscaler: Whether to use autoscaler for training. |
|
|
:param replay_element: The element to replay. |
|
|
""" |
|
|
|
|
|
self.path = path |
|
|
self.save_each = save_each |
|
|
self.tensorboard_required = tensorboard |
|
|
self.replay_id = replay_element |
|
|
self.__epoch_count = 0 |
|
|
|
|
|
if not reload_state: |
|
|
self.clear(path) |
|
|
|
|
|
self.logger = self.set_logger(path) |
|
|
self.writer, self.ch_path = self.set_writer(path, tensorboard) if tensorboard else (None, os.path.join(path, 'checkpoints')) |
|
|
self.seed = self.set_seed(seed) |
|
|
self.device = self.set_device(device) |
|
|
self.log_setup_info() |
|
|
|
|
|
self.watcher = DEFAULT_WATCHER |
|
|
self.autoscaler = torch.amp.GradScaler(enabled=self.device.type == 'cuda') if autoscaler else None |
|
|
|
|
|
def log_setup_info(self): |
|
|
""" |
|
|
Log the setup information. |
|
|
""" |
|
|
self.logger.info("Setup information:") |
|
|
self.logger.info(f"- Setup path: {self.path}") |
|
|
self.logger.info(f"- Setup checkpoints path: {self.ch_path}") |
|
|
self.logger.info(f"- Setup device: {self.device}") |
|
|
self.logger.info(f"- Setup seed: {self.seed}") |
|
|
self.logger.info(f"- Setup logger: {self.logger}") |
|
|
self.logger.info(f"- Setup writer: {self.writer}") |
|
|
self.logger.info(f"- Setup save each: {self.save_each}") |
|
|
|
|
|
def check( |
|
|
self, |
|
|
model: torch.nn.Module, |
|
|
optimizer: torch.optim.Optimizer | None = None, |
|
|
learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None |
|
|
) -> bool: |
|
|
""" |
|
|
Check the model and save it if the epoch count is a multiple of save_each. |
|
|
:param model: The model to checkpoint and save. |
|
|
:param optimizer: The optimizer to save. |
|
|
:param learning_rate: The learning rate scheduler to save. |
|
|
:return: If the model is checkpointed. |
|
|
""" |
|
|
self.__epoch_count += 1 |
|
|
if self.save_each is not None and self.__epoch_count % self.save_each == 0: |
|
|
self.logger.info(f"Checkpointing model at epoch {self.__epoch_count}") |
|
|
self.save_model( |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
learning_rate=learning_rate |
|
|
) |
|
|
self.logger.info(f"Model checkpointed at epoch {self.__epoch_count}") |
|
|
return True |
|
|
return False |
|
|
|
|
|
def save_model( |
|
|
self, |
|
|
model: torch.nn.Module, |
|
|
optimizer: torch.optim.Optimizer | None = None, |
|
|
learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None |
|
|
): |
|
|
""" |
|
|
Saves the model. |
|
|
:param model: The model to save. |
|
|
:param optimizer: The optimizer to save. |
|
|
:param learning_rate: The learning rate scheduler to save. |
|
|
:return: Nothing. |
|
|
""" |
|
|
torch_state = { |
|
|
'epoch': self.__epoch_count, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict() if optimizer else None, |
|
|
'scheduler_state_dict': learning_rate.state_dict() if learning_rate else None, |
|
|
'seed': self.seed |
|
|
} |
|
|
torch.save(torch_state, self.ch_path + f'/model_epoch_{self.__epoch_count}.pt') |
|
|
|
|
|
def reload( |
|
|
self, |
|
|
model: torch.nn.Module, |
|
|
optimizer: torch.optim.Optimizer | None = None, |
|
|
learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None |
|
|
) -> None: |
|
|
""" |
|
|
Reloads the latest checkpoint into the given model. |
|
|
|
|
|
:param model: The PyTorch model to reload the state into. |
|
|
:param optimizer: The optimizer to reload the state into. |
|
|
:param learning_rate: The learning rate scheduler to reload the state into. |
|
|
""" |
|
|
|
|
|
checkpoints = glob.glob(os.path.join(self.ch_path, 'model_epoch_*.pt')) |
|
|
if not checkpoints: |
|
|
self.logger.warning("No checkpoint files found.") |
|
|
else: |
|
|
|
|
|
checkpoints.sort(key=os.path.getmtime) |
|
|
latest_checkpoint = checkpoints[-1] |
|
|
|
|
|
try: |
|
|
state_dict = torch.load(latest_checkpoint, map_location=self.device) |
|
|
|
|
|
model.load_state_dict(state_dict['model_state_dict']) |
|
|
model.to(self.device) |
|
|
self.__epoch_count = state_dict['epoch'] |
|
|
self.seed = state_dict['seed'] |
|
|
self.logger.info(f"Model reloaded from {latest_checkpoint} at epoch {self.__epoch_count} and " |
|
|
f"seed {self.seed}") |
|
|
|
|
|
|
|
|
if optimizer and state_dict['optimizer_state_dict'] is not None: |
|
|
optimizer.load_state_dict(state_dict['optimizer_state_dict']) |
|
|
self.logger.info(f"Optimizer state_dict loaded from {latest_checkpoint}") |
|
|
if learning_rate and state_dict['scheduler_state_dict'] is not None: |
|
|
learning_rate.load_state_dict(state_dict['scheduler_state_dict']) |
|
|
self.logger.info(f"Scheduler state_dict loaded from {latest_checkpoint}") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to reload model from {latest_checkpoint}: {e}") |
|
|
raise RuntimeError(f"Failed to reload model from {latest_checkpoint}: {e}") |
|
|
|
|
|
def set_watcher(self, flag_names: str | list[tuple], deactivate: bool = False) -> None: |
|
|
""" |
|
|
Sets up the parameter watcher to the tensorboard. |
|
|
:param flag_names: The names of the flags to watch as a tuple of strings. |
|
|
:param deactivate: Whether to deactivate the watcher. |
|
|
:return: Nothing |
|
|
""" |
|
|
if isinstance(flag_names, str): |
|
|
if flag_names == 'S': |
|
|
flag_names = S_WATCHER |
|
|
elif flag_names == 'A': |
|
|
flag_names = A_WATCHER + S_WATCHER |
|
|
elif flag_names == 'B': |
|
|
flag_names = S_WATCHER + A_WATCHER + B_WATCHER |
|
|
elif flag_names == 'C': |
|
|
flag_names = S_WATCHER + A_WATCHER + B_WATCHER + C_WATCHER |
|
|
elif flag_names == 'cnn': |
|
|
flag_names = CNN_WATCHER |
|
|
elif flag_names == 'transformer': |
|
|
flag_names = TRA_WATCHER |
|
|
elif flag_names == 'ae': |
|
|
flag_names = AEN_WATCHER |
|
|
else: |
|
|
self.logger.error(f"[WATCHER] Unknown flag name '{flag_names}'") |
|
|
raise ValueError(f"[WATCHER] Unknown flag tier '{flag_names}'") |
|
|
|
|
|
for top_name, low_name in flag_names: |
|
|
if top_name not in self.watcher: |
|
|
self.logger.error(f"Watcher {top_name} not found in watcher.") |
|
|
raise ValueError(f"Watcher {top_name} not found in watcher.") |
|
|
elif low_name not in self.watcher[top_name]: |
|
|
self.logger.error(f"Watcher {low_name} not found in {top_name}.") |
|
|
raise ValueError(f"Watcher {low_name} not found in {top_name}.") |
|
|
else: |
|
|
self.watcher[top_name][low_name] = not deactivate |
|
|
|
|
|
def register_replay(self, predicted: torch.Tensor, target: torch.Tensor, mask: torch.Tensor = None) -> plt.Figure: |
|
|
""" |
|
|
Visualizes predicted vs. target outputs with an optional mask. |
|
|
Only positions where mask == True are shown. Each cell displays its value with two decimal places. |
|
|
|
|
|
:param predicted: Tensor of shape (S) or (S, Y) representing the model's output. |
|
|
:param target: Tensor of same shape as predicted. |
|
|
:param mask: Optional boolean tensor of same shape. False positions are ignored (valid mask). |
|
|
""" |
|
|
return register_replay( |
|
|
predicted=predicted, |
|
|
target=target, |
|
|
valid_mask=mask, |
|
|
element=self.replay_id[1], |
|
|
epoch=self.__epoch_count, |
|
|
writer=self.writer, |
|
|
logger=self.logger, |
|
|
tensorboard_required=self.tensorboard_required, |
|
|
) |
|
|
|
|
|
def register(self, name: str, parameter: float | torch.Tensor, mask: torch.Tensor = Ellipsis) -> None: |
|
|
""" |
|
|
Registers a named parameter into the tensorboard. |
|
|
:param name: The name of the parameter. |
|
|
:param parameter: The parameter to register. |
|
|
:param mask: The optional boolean tensor of same shape as parameter. |
|
|
:return: Nothing. |
|
|
""" |
|
|
if isinstance(parameter, torch.Tensor) and mask is Ellipsis: |
|
|
mask = torch.ones_like(parameter).bool() |
|
|
elif isinstance(parameter, float): |
|
|
mask = Ellipsis |
|
|
|
|
|
register( |
|
|
flags=self.watcher, |
|
|
tensor=parameter, |
|
|
valid_mask=mask, |
|
|
epoch=self.__epoch_count, |
|
|
writer=self.writer, |
|
|
logger=self.logger, |
|
|
tensorboard_required=self.tensorboard_required, |
|
|
parameter_name=name |
|
|
) |
|
|
|
|
|
def save_config(self, configuration): |
|
|
""" |
|
|
Saves the configuration to a file. |
|
|
:param configuration: A dataclasses configuration object. |
|
|
:return: Nothing. |
|
|
""" |
|
|
config_path = os.path.join(self.path, "config.json") |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(asdict(configuration), f, indent=4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def clear(path: str) -> None: |
|
|
""" |
|
|
Clear the logs. |
|
|
:param path: The path to the logs. |
|
|
""" |
|
|
clear_logs(path) |
|
|
|
|
|
@staticmethod |
|
|
def set_logger(path: str) -> logging.Logger: |
|
|
""" |
|
|
Set the logger. |
|
|
:param path: The path to the logs. |
|
|
:return: The logger. |
|
|
""" |
|
|
return get_logger(path) |
|
|
|
|
|
def set_writer(self, path: str, tensorboard_port: int | bool) -> tuple: |
|
|
""" |
|
|
Get the writer. |
|
|
:param path: The path to the logs. |
|
|
:param tensorboard_port: The port to use for tensorboard. |
|
|
:return: The writer. |
|
|
""" |
|
|
return get_writer(path, tensorboard_port, self.logger) |
|
|
|
|
|
def set_device(self, device: int) -> torch.device: |
|
|
""" |
|
|
Get the device. |
|
|
:param device: The device to use. |
|
|
:return: The device. |
|
|
""" |
|
|
return get_device(device, self.logger) |
|
|
|
|
|
def set_seed(self, seed: int) -> int: |
|
|
""" |
|
|
Get the seed. |
|
|
:param seed: The seed to use. |
|
|
:return: The seed. |
|
|
""" |
|
|
return get_seed(seed, self.logger) |
|
|
|
|
|
@property |
|
|
def epoch(self): |
|
|
""" |
|
|
Get the current epoch. |
|
|
:return: The current epoch. |
|
|
""" |
|
|
return self.__epoch_count |
|
|
|
|
|
def __enter__(self): |
|
|
return self |
|
|
|
|
|
def __exit__(self, *exc): |
|
|
if self.writer: |
|
|
self.writer.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|