File size: 1,983 Bytes
e9f9fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"A `Callback` that saves tracked metrics and notebook file into MLflow server."
from ..torch_core import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback
#This is an optional dependency in fastai.  Must install separately.
try: import mlflow
except: print("To use this tracker, please run 'pip install mlflow'")

class MLFlowTracker(LearnerCallback):
    "A `TrackerCallback` that tracks the loss and metrics into MLFlow"
    def __init__(self, learn:Learner, exp_name: str, params: dict, nb_path: str, uri: str = "http://localhost:5000"):
        super().__init__(learn)
        self.learn,self.exp_name,self.params,self.nb_path,self.uri = learn,exp_name,params,nb_path,uri
        self.metrics_names = ['train_loss', 'valid_loss'] + [o.__name__ for o in learn.metrics]

    def on_train_begin(self, **kwargs: Any) -> None:
        "Prepare MLflow experiment and log params"
        self.client = mlflow.tracking.MlflowClient(self.uri)
        exp = self.client.get_experiment_by_name(self.exp_name)
        self.exp_id = self.client.create_experiment(self.exp_name) if exp is None else exp.experiment_id
        run = self.client.create_run(experiment_id=self.exp_id)
        self.run = run.info.run_uuid
        for k,v in self.params.items():
            self.client.log_param(run_id=self.run, key=k, value=v)

    def on_epoch_end(self, epoch, **kwargs:Any)->None:
        "Send loss and metrics values to MLFlow after each epoch"
        if kwargs['smooth_loss'] is None or kwargs["last_metrics"] is None: return
        metrics = [kwargs['smooth_loss']] + kwargs["last_metrics"]
        for name, val in zip(self.metrics_names, metrics):
            self.client.log_metric(self.run, name, np.float(val), step=epoch)
        
    def on_train_end(self, **kwargs: Any) -> None:  
        "Store the notebook and stop run"
        self.client.log_artifact(run_id=self.run, local_path=self.nb_path)
        self.client.set_terminated(run_id=self.run)