tynerox / src /logger /logger.py
Ubuntu
Re-adiciona model.safetensors via LFS
58d9159
from typing import Dict, Optional
import os
from zoneinfo import ZoneInfo
import mlflow
import pandas as pd
import torch
import torch.nn as nn
from datetime import datetime, date
class TrainerLogger:
def __init__(
self,
tracking_uri: str,
experiment: str,
total_params: int,
model_name: str = None,
run_name: str = None,
tags: Dict[str, str] = None,
):
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment)
# Ativar autologging para PyTorch
mlflow.pytorch.autolog(log_models=True) # Desativamos log automático de modelos para controle manual
# Iniciar run com contexto
self.run = mlflow.start_run(run_name=run_name)
self.run_id = self.run.info.run_id
self.experiment = experiment
self.model_name = model_name
self.total_params = total_params
# Registrar tags para melhor organização
default_tags = {"model_type": self.model_name}
if tags:
default_tags.update(tags)
mlflow.set_tags(default_tags)
# Registrar parâmetros
base_params = {"model_name": self.model_name, "total_params": self.total_params}
self.log_parameters(base_params)
def log_parameters(self, parameters: dict):
mlflow.log_params(parameters) # Mais eficiente que log_param individual
def log_metrics(self, metrics: dict, step: Optional[int] = None):
mlflow.log_metrics(metrics, step)
def log_checkpoint_table(self, current_lr:float, loss:float, perplexity: float, last_batch:int) -> None:
"""
Log a checkpoint record (month, day, hour, perplexity) to MLflow as a table artifact.
Perplexity is rounded to 4 decimal places.
Parameters
----------
perplexity : float
The perplexity metric to log (rounded to 4 decimal places).
:param current_lr:
:param loss:
:param perplexity:
:param last_batch:
"""
# Define artifact directory and ensure it exists
artifact_dir = f"checkpoint_table/model"
os.makedirs(artifact_dir, exist_ok=True)
# Capture current timestamp
now = datetime.now(ZoneInfo("America/Sao_Paulo"))
record = {
"month": now.month,
"day": now.day,
"hour": f"{now.hour:02d}:{now.minute:02d}",
"last_batch": last_batch,
"current_lr": round(current_lr, 7),
"perplexity": round(perplexity, 4),
"loss": round(loss, 4),
}
df_record = pd.DataFrame([record])
# Define artifact file path (relative POSIX path)
artifact_file = f"{artifact_dir}/checkpoint_table.json"
# Log the table to MLflow Tracking
mlflow.log_table(
data=df_record,
artifact_file=artifact_file
)
def checkpoint_model(self, model: nn.Module):
# Criar diretório local para checkpoint
step = 1
checkpoint_dir = f"checkpoints/model_{step}"
os.makedirs(checkpoint_dir, exist_ok=True)
# Salvar estado do modelo localmente
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(model.state_dict(), checkpoint_path)
# Registrar artefato no MLflow
mlflow.log_artifact(checkpoint_path, f"model_checkpoints/epoch_{step}")
input_example = torch.zeros(1, 128, dtype=torch.long) # Ajuste as dimensões conforme seu modelo
# input_example_numpy = input_example.cpu().numpy()
# Registrar modelo no registro de modelos MLflow
if self.model_name:
registered_model_name = f"{self.model_name}"
mlflow.pytorch.log_model(
pytorch_model=model,
artifact_path=f"models/epoch_{step}",
registered_model_name=registered_model_name,
pip_requirements=["torch>=1.9.0"],
code_paths=["tynerox/"], # Inclui código-fonte relevante
# input_example=input_example_numpy, # Exemplo de entrada
signature=None # Adicione assinatura do modelo se possível
)
table_dict = {
"entrada": ["Pergunta A", "Pergunta B"],
"saida": ["Resposta A", "Resposta B"],
"nota": [0.75, 0.40],
}
def log_html(self, html: str, step: Optional[int] = None):
file_path = f"visualizations/sample.html"
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
f.write(html)
mlflow.log_artifact(file_path)
def finish(self):
"""Finaliza a execução do MLflow run"""
mlflow.end_run()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.finish()