from typing import Dict, Optional import os from zoneinfo import ZoneInfo import mlflow import pandas as pd import torch import torch.nn as nn from datetime import datetime, date class TrainerLogger: def __init__( self, tracking_uri: str, experiment: str, total_params: int, model_name: str = None, run_name: str = None, tags: Dict[str, str] = None, ): mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(experiment) # Ativar autologging para PyTorch mlflow.pytorch.autolog(log_models=True) # Desativamos log automático de modelos para controle manual # Iniciar run com contexto self.run = mlflow.start_run(run_name=run_name) self.run_id = self.run.info.run_id self.experiment = experiment self.model_name = model_name self.total_params = total_params # Registrar tags para melhor organização default_tags = {"model_type": self.model_name} if tags: default_tags.update(tags) mlflow.set_tags(default_tags) # Registrar parâmetros base_params = {"model_name": self.model_name, "total_params": self.total_params} self.log_parameters(base_params) def log_parameters(self, parameters: dict): mlflow.log_params(parameters) # Mais eficiente que log_param individual def log_metrics(self, metrics: dict, step: Optional[int] = None): mlflow.log_metrics(metrics, step) def log_checkpoint_table(self, current_lr:float, loss:float, perplexity: float, last_batch:int) -> None: """ Log a checkpoint record (month, day, hour, perplexity) to MLflow as a table artifact. Perplexity is rounded to 4 decimal places. Parameters ---------- perplexity : float The perplexity metric to log (rounded to 4 decimal places). :param current_lr: :param loss: :param perplexity: :param last_batch: """ # Define artifact directory and ensure it exists artifact_dir = f"checkpoint_table/model" os.makedirs(artifact_dir, exist_ok=True) # Capture current timestamp now = datetime.now(ZoneInfo("America/Sao_Paulo")) record = { "month": now.month, "day": now.day, "hour": f"{now.hour:02d}:{now.minute:02d}", "last_batch": last_batch, "current_lr": round(current_lr, 7), "perplexity": round(perplexity, 4), "loss": round(loss, 4), } df_record = pd.DataFrame([record]) # Define artifact file path (relative POSIX path) artifact_file = f"{artifact_dir}/checkpoint_table.json" # Log the table to MLflow Tracking mlflow.log_table( data=df_record, artifact_file=artifact_file ) def checkpoint_model(self, model: nn.Module): # Criar diretório local para checkpoint step = 1 checkpoint_dir = f"checkpoints/model_{step}" os.makedirs(checkpoint_dir, exist_ok=True) # Salvar estado do modelo localmente checkpoint_path = os.path.join(checkpoint_dir, "model.pth") torch.save(model.state_dict(), checkpoint_path) # Registrar artefato no MLflow mlflow.log_artifact(checkpoint_path, f"model_checkpoints/epoch_{step}") input_example = torch.zeros(1, 128, dtype=torch.long) # Ajuste as dimensões conforme seu modelo # input_example_numpy = input_example.cpu().numpy() # Registrar modelo no registro de modelos MLflow if self.model_name: registered_model_name = f"{self.model_name}" mlflow.pytorch.log_model( pytorch_model=model, artifact_path=f"models/epoch_{step}", registered_model_name=registered_model_name, pip_requirements=["torch>=1.9.0"], code_paths=["tynerox/"], # Inclui código-fonte relevante # input_example=input_example_numpy, # Exemplo de entrada signature=None # Adicione assinatura do modelo se possível ) table_dict = { "entrada": ["Pergunta A", "Pergunta B"], "saida": ["Resposta A", "Resposta B"], "nota": [0.75, 0.40], } def log_html(self, html: str, step: Optional[int] = None): file_path = f"visualizations/sample.html" os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w") as f: f.write(html) mlflow.log_artifact(file_path) def finish(self): """Finaliza a execução do MLflow run""" mlflow.end_run() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.finish()