File size: 4,909 Bytes
58d9159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Dict, Optional
import os
from zoneinfo import ZoneInfo

import mlflow
import pandas as pd
import torch
import torch.nn as nn
from datetime import datetime, date


class TrainerLogger:
    def __init__(
        self,
        tracking_uri: str,
        experiment: str,
        total_params: int,
        model_name: str = None,
        run_name: str = None,
        tags: Dict[str, str] = None,
    ):
        mlflow.set_tracking_uri(tracking_uri)
        mlflow.set_experiment(experiment)

        # Ativar autologging para PyTorch
        mlflow.pytorch.autolog(log_models=True)  # Desativamos log automático de modelos para controle manual

        # Iniciar run com contexto
        self.run = mlflow.start_run(run_name=run_name)
        self.run_id = self.run.info.run_id
        self.experiment = experiment
        self.model_name = model_name
        self.total_params = total_params

        # Registrar tags para melhor organização
        default_tags = {"model_type": self.model_name}
        if tags:
            default_tags.update(tags)
        mlflow.set_tags(default_tags)

        # Registrar parâmetros
        base_params = {"model_name": self.model_name, "total_params": self.total_params}
        self.log_parameters(base_params)

    def log_parameters(self, parameters: dict):
        mlflow.log_params(parameters)  # Mais eficiente que log_param individual

    def log_metrics(self, metrics: dict, step: Optional[int] = None):
        mlflow.log_metrics(metrics, step)

    def log_checkpoint_table(self, current_lr:float, loss:float, perplexity: float, last_batch:int) -> None:
        """
        Log a checkpoint record (month, day, hour, perplexity) to MLflow as a table artifact.
        Perplexity is rounded to 4 decimal places.

        Parameters
        ----------
        perplexity : float
            The perplexity metric to log (rounded to 4 decimal places).
            :param current_lr:
            :param loss:
            :param perplexity:
            :param last_batch:
        """
        # Define artifact directory and ensure it exists
        artifact_dir = f"checkpoint_table/model"
        os.makedirs(artifact_dir, exist_ok=True)

        # Capture current timestamp
        now = datetime.now(ZoneInfo("America/Sao_Paulo"))
        record = {
            "month": now.month,
            "day": now.day,
            "hour": f"{now.hour:02d}:{now.minute:02d}",
            "last_batch": last_batch,
            "current_lr": round(current_lr, 7),
            "perplexity": round(perplexity, 4),
            "loss": round(loss, 4),

        }
        df_record = pd.DataFrame([record])

        # Define artifact file path (relative POSIX path)
        artifact_file = f"{artifact_dir}/checkpoint_table.json"

        # Log the table to MLflow Tracking
        mlflow.log_table(
            data=df_record,
            artifact_file=artifact_file
        )

    def checkpoint_model(self, model: nn.Module):
        # Criar diretório local para checkpoint
        step = 1
        checkpoint_dir = f"checkpoints/model_{step}"
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Salvar estado do modelo localmente
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save(model.state_dict(), checkpoint_path)

        # Registrar artefato no MLflow
        mlflow.log_artifact(checkpoint_path, f"model_checkpoints/epoch_{step}")

        input_example = torch.zeros(1, 128, dtype=torch.long)  # Ajuste as dimensões conforme seu modelo
        # input_example_numpy = input_example.cpu().numpy()

        # Registrar modelo no registro de modelos MLflow
        if self.model_name:
            registered_model_name = f"{self.model_name}"
            mlflow.pytorch.log_model(
                pytorch_model=model,
                artifact_path=f"models/epoch_{step}",
                registered_model_name=registered_model_name,
                pip_requirements=["torch>=1.9.0"],
                code_paths=["tynerox/"],  # Inclui código-fonte relevante
                # input_example=input_example_numpy,  # Exemplo de entrada
                signature=None  # Adicione assinatura do modelo se possível
            )

        table_dict = {
            "entrada": ["Pergunta A", "Pergunta B"],
            "saida": ["Resposta A", "Resposta B"],
            "nota": [0.75, 0.40],
        }

    def log_html(self, html: str, step: Optional[int] = None):
        file_path = f"visualizations/sample.html"
        os.makedirs(os.path.dirname(file_path), exist_ok=True)

        with open(file_path, "w") as f:
            f.write(html)

        mlflow.log_artifact(file_path)

    def finish(self):
        """Finaliza a execução do MLflow run"""
        mlflow.end_run()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.finish()