Image Classification
English
zhanwang's picture
update
377dccd verified
import yaml
import os
import pathlib
import tempfile
import torch
import re
import numpy as np
import mlflow
import mlflow.pytorch
import utils.loggers
from utils.metrics import backward_transfer, forward_transfer, forgetting
class MLFlowLogger(utils.loggers.Logger):
def __init__(self, setting_str: str, dataset_str: str, model_str: str,
experiment_name='Default', parent_run_id=None, run_name=None):
super().__init__(setting_str, dataset_str, model_str)
self.experiment_name = experiment_name
client = mlflow.tracking.MlflowClient()
self.experiment = client.get_experiment_by_name(experiment_name)
if self.experiment is None:
new_id = self.find_last_exp_id(client) + 1
artifact_location = repo_dir() / 'mlruns' / str(new_id)
id = mlflow.create_experiment(experiment_name, artifact_location=str(artifact_location))
self.experiment = client.get_experiment(id)
self.experiment_id = self.experiment.experiment_id
self.parent_run_id = parent_run_id
self.run_name = run_name
self.run_id = None
self.metrics_steps = dict()
def create_run():
active_run = mlflow.active_run()
self.run_id = active_run.info.run_id
self.activate_run(create_run)
def activate_run(self, function=None):
if self.parent_run_id != None:
with mlflow.start_run(run_id=self.parent_run_id):
with mlflow.start_run(run_id=self.run_id, experiment_id=self.experiment_id, run_name=self.run_name, nested=True):
function()
else:
with mlflow.start_run(run_id=self.run_id, experiment_id=self.experiment_id, run_name=self.run_name, nested=False):
function()
def find_last_exp_id(self, client):
last_id = -1
for i in range(100):
try:
client.get_experiment(str(i))
except:
break
last_id = i
return last_id
def log(self, mean_acc: np.ndarray) -> None:
"""
Logs a mean accuracy value.
:param mean_acc: mean accuracy value
"""
if self.setting == 'general-continual':
self.accs.append(mean_acc)
self.log_metric('mean_acc', mean_acc)
elif self.setting == 'domain-il':
mean_acc, _ = mean_acc
self.accs.append(mean_acc)
self.log_metric('mean_acc', mean_acc)
else:
mean_acc_class_il, mean_acc_task_il = mean_acc
self.accs.append(mean_acc_class_il)
self.accs_mask_classes.append(mean_acc_task_il)
self.log_metric('mean_acc_class_il', mean_acc_class_il)
self.log_metric('mean_acc_task_il', mean_acc_task_il)
def log_fullacc(self, accs):
if self.setting == 'class-il':
acc_class_il, acc_task_il = accs
self.fullaccs.append(acc_class_il)
self.fullaccs_mask_classes.append(acc_task_il)
for t, acc in enumerate(acc_class_il):
self.log_metric(f'acc_class_il_task_{t}', acc)
for t, acc in enumerate(acc_task_il):
self.log_metric(f'acc_task_il_task_{t}', acc)
def add_fwt(self, results, accs, results_mask_classes, accs_mask_classes):
self.fwt = forward_transfer(results, accs)
self.log_metric('fwt', self.fwt)
if self.setting == 'class-il':
self.fwt_mask_classes = forward_transfer(results_mask_classes, accs_mask_classes)
self.log_metric('fwt_mask_classes', self.fwt_mask_classes)
def add_bwt(self, results, results_mask_classes):
self.bwt = backward_transfer(results)
self.log_metric('bwt', self.bwt)
self.bwt_mask_classes = backward_transfer(results_mask_classes)
self.log_metric('bwt_mask_classes', self.bwt_mask_classes)
def add_forgetting(self, results, results_mask_classes):
self.forgetting = forgetting(results)
self.log_metric('forgetting', self.forgetting)
self.forgetting_mask_classes = forgetting(results_mask_classes)
self.log_metric('forgetting_mask_classes', self.forgetting_mask_classes)
def log_metric(self, metric_name, value):
def log_value():
if metric_name not in self.metrics_steps:
self.metrics_steps[metric_name] = -1
self.metrics_steps[metric_name] += 1
mlflow.log_metric(metric_name, value, step=self.metrics_steps[metric_name])
self.activate_run(log_value)
def log_args(self, args: dict):
self.activate_run(lambda: mlflow.log_params(args))
def log_artifact(self, artifact_path, name):
with SwapArtifactUri(self.experiment_id, self.run_id):
active_run = mlflow.active_run()
if active_run is not None and active_run.info.run_id == self.run_id:
mlflow.log_artifact(artifact_path, name)
else:
self.activate_run(lambda: mlflow.log_artifact(artifact_path, name))
def log_model(self, model: torch.nn.Module, weight_name):
with SwapArtifactUri(self.experiment_id, self.run_id):
with tempfile.TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / f'{weight_name}.pt'
torch.save(model, model_path)
self.activate_run(lambda: mlflow.log_artifact(model_path, weight_name))
def log_avrg_accuracy(self):
client = mlflow.tracking.MlflowClient()
run = client.get_run(self.run_id)
run_metrics = run.data.metrics
test_accs = [acc for name, acc in run_metrics.items() if name.startswith('test_accuracy_task_')]
test_avrg_acc = sum(test_accs) / len(test_accs)
client.log_metric(self.run_id, 'avrg_test_acc', test_avrg_acc)
class SwapArtifactUri:
def __init__(self, experiment_id, run_id):
self.experiment_id = experiment_id
self.run_id = run_id
self.artifact_uri = None
def __enter__(self):
repo_path = repo_dir()
meta_path = repo_path / 'mlruns' / f'{self.experiment_id}' / f'{self.run_id}' / 'meta.yaml'
run_meta = self.load_meta(meta_path)
self.artifact_uri = run_meta['artifact_uri']
run_meta['artifact_uri'] = f'file://{repo_path}/mlruns/{self.experiment_id}/{self.run_id}/artifacts'
with open(meta_path, 'w') as file:
yaml.safe_dump(run_meta, file)
def __exit__(self, exc_type, exc_value, exc_tb):
repo_path = repo_dir()
meta_path = repo_path / 'mlruns' / f'{self.experiment_id}' / f'{self.run_id}' / 'meta.yaml'
run_meta = self.load_meta(meta_path)
run_meta['artifact_uri'] = self.artifact_uri
with open(meta_path, 'w') as file:
yaml.safe_dump(run_meta, file)
def load_meta(self, meta_path):
with open(meta_path, 'r') as file:
run_meta = yaml.safe_load(file)
return run_meta
def repo_dir():
repo_dir = os.path.dirname(os.path.abspath(__file__))
repo_dir = pathlib.Path(repo_dir)
if type(repo_dir) == pathlib.WindowsPath:
repo_dir = pathlib.Path(*repo_dir.parts[1:]).as_posix()
return repo_dir.parent