import yaml import os import pathlib import tempfile import torch import re import numpy as np import mlflow import mlflow.pytorch import utils.loggers from utils.metrics import backward_transfer, forward_transfer, forgetting class MLFlowLogger(utils.loggers.Logger): def __init__(self, setting_str: str, dataset_str: str, model_str: str, experiment_name='Default', parent_run_id=None, run_name=None): super().__init__(setting_str, dataset_str, model_str) self.experiment_name = experiment_name client = mlflow.tracking.MlflowClient() self.experiment = client.get_experiment_by_name(experiment_name) if self.experiment is None: new_id = self.find_last_exp_id(client) + 1 artifact_location = repo_dir() / 'mlruns' / str(new_id) id = mlflow.create_experiment(experiment_name, artifact_location=str(artifact_location)) self.experiment = client.get_experiment(id) self.experiment_id = self.experiment.experiment_id self.parent_run_id = parent_run_id self.run_name = run_name self.run_id = None self.metrics_steps = dict() def create_run(): active_run = mlflow.active_run() self.run_id = active_run.info.run_id self.activate_run(create_run) def activate_run(self, function=None): if self.parent_run_id != None: with mlflow.start_run(run_id=self.parent_run_id): with mlflow.start_run(run_id=self.run_id, experiment_id=self.experiment_id, run_name=self.run_name, nested=True): function() else: with mlflow.start_run(run_id=self.run_id, experiment_id=self.experiment_id, run_name=self.run_name, nested=False): function() def find_last_exp_id(self, client): last_id = -1 for i in range(100): try: client.get_experiment(str(i)) except: break last_id = i return last_id def log(self, mean_acc: np.ndarray) -> None: """ Logs a mean accuracy value. :param mean_acc: mean accuracy value """ if self.setting == 'general-continual': self.accs.append(mean_acc) self.log_metric('mean_acc', mean_acc) elif self.setting == 'domain-il': mean_acc, _ = mean_acc self.accs.append(mean_acc) self.log_metric('mean_acc', mean_acc) else: mean_acc_class_il, mean_acc_task_il = mean_acc self.accs.append(mean_acc_class_il) self.accs_mask_classes.append(mean_acc_task_il) self.log_metric('mean_acc_class_il', mean_acc_class_il) self.log_metric('mean_acc_task_il', mean_acc_task_il) def log_fullacc(self, accs): if self.setting == 'class-il': acc_class_il, acc_task_il = accs self.fullaccs.append(acc_class_il) self.fullaccs_mask_classes.append(acc_task_il) for t, acc in enumerate(acc_class_il): self.log_metric(f'acc_class_il_task_{t}', acc) for t, acc in enumerate(acc_task_il): self.log_metric(f'acc_task_il_task_{t}', acc) def add_fwt(self, results, accs, results_mask_classes, accs_mask_classes): self.fwt = forward_transfer(results, accs) self.log_metric('fwt', self.fwt) if self.setting == 'class-il': self.fwt_mask_classes = forward_transfer(results_mask_classes, accs_mask_classes) self.log_metric('fwt_mask_classes', self.fwt_mask_classes) def add_bwt(self, results, results_mask_classes): self.bwt = backward_transfer(results) self.log_metric('bwt', self.bwt) self.bwt_mask_classes = backward_transfer(results_mask_classes) self.log_metric('bwt_mask_classes', self.bwt_mask_classes) def add_forgetting(self, results, results_mask_classes): self.forgetting = forgetting(results) self.log_metric('forgetting', self.forgetting) self.forgetting_mask_classes = forgetting(results_mask_classes) self.log_metric('forgetting_mask_classes', self.forgetting_mask_classes) def log_metric(self, metric_name, value): def log_value(): if metric_name not in self.metrics_steps: self.metrics_steps[metric_name] = -1 self.metrics_steps[metric_name] += 1 mlflow.log_metric(metric_name, value, step=self.metrics_steps[metric_name]) self.activate_run(log_value) def log_args(self, args: dict): self.activate_run(lambda: mlflow.log_params(args)) def log_artifact(self, artifact_path, name): with SwapArtifactUri(self.experiment_id, self.run_id): active_run = mlflow.active_run() if active_run is not None and active_run.info.run_id == self.run_id: mlflow.log_artifact(artifact_path, name) else: self.activate_run(lambda: mlflow.log_artifact(artifact_path, name)) def log_model(self, model: torch.nn.Module, weight_name): with SwapArtifactUri(self.experiment_id, self.run_id): with tempfile.TemporaryDirectory() as tmpdir: model_path = pathlib.Path(tmpdir) / f'{weight_name}.pt' torch.save(model, model_path) self.activate_run(lambda: mlflow.log_artifact(model_path, weight_name)) def log_avrg_accuracy(self): client = mlflow.tracking.MlflowClient() run = client.get_run(self.run_id) run_metrics = run.data.metrics test_accs = [acc for name, acc in run_metrics.items() if name.startswith('test_accuracy_task_')] test_avrg_acc = sum(test_accs) / len(test_accs) client.log_metric(self.run_id, 'avrg_test_acc', test_avrg_acc) class SwapArtifactUri: def __init__(self, experiment_id, run_id): self.experiment_id = experiment_id self.run_id = run_id self.artifact_uri = None def __enter__(self): repo_path = repo_dir() meta_path = repo_path / 'mlruns' / f'{self.experiment_id}' / f'{self.run_id}' / 'meta.yaml' run_meta = self.load_meta(meta_path) self.artifact_uri = run_meta['artifact_uri'] run_meta['artifact_uri'] = f'file://{repo_path}/mlruns/{self.experiment_id}/{self.run_id}/artifacts' with open(meta_path, 'w') as file: yaml.safe_dump(run_meta, file) def __exit__(self, exc_type, exc_value, exc_tb): repo_path = repo_dir() meta_path = repo_path / 'mlruns' / f'{self.experiment_id}' / f'{self.run_id}' / 'meta.yaml' run_meta = self.load_meta(meta_path) run_meta['artifact_uri'] = self.artifact_uri with open(meta_path, 'w') as file: yaml.safe_dump(run_meta, file) def load_meta(self, meta_path): with open(meta_path, 'r') as file: run_meta = yaml.safe_load(file) return run_meta def repo_dir(): repo_dir = os.path.dirname(os.path.abspath(__file__)) repo_dir = pathlib.Path(repo_dir) if type(repo_dir) == pathlib.WindowsPath: repo_dir = pathlib.Path(*repo_dir.parts[1:]).as_posix() return repo_dir.parent