| | from typing import Dict, Optional |
| | import os |
| | from zoneinfo import ZoneInfo |
| |
|
| | import mlflow |
| | import pandas as pd |
| | import torch |
| | import torch.nn as nn |
| | from datetime import datetime, date |
| |
|
| |
|
| | class TrainerLogger: |
| | def __init__( |
| | self, |
| | tracking_uri: str, |
| | experiment: str, |
| | total_params: int, |
| | model_name: str = None, |
| | run_name: str = None, |
| | tags: Dict[str, str] = None, |
| | ): |
| | mlflow.set_tracking_uri(tracking_uri) |
| | mlflow.set_experiment(experiment) |
| |
|
| | |
| | mlflow.pytorch.autolog(log_models=True) |
| |
|
| | |
| | self.run = mlflow.start_run(run_name=run_name) |
| | self.run_id = self.run.info.run_id |
| | self.experiment = experiment |
| | self.model_name = model_name |
| | self.total_params = total_params |
| |
|
| | |
| | default_tags = {"model_type": self.model_name} |
| | if tags: |
| | default_tags.update(tags) |
| | mlflow.set_tags(default_tags) |
| |
|
| | |
| | base_params = {"model_name": self.model_name, "total_params": self.total_params} |
| | self.log_parameters(base_params) |
| |
|
| | def log_parameters(self, parameters: dict): |
| | mlflow.log_params(parameters) |
| |
|
| | def log_metrics(self, metrics: dict, step: Optional[int] = None): |
| | mlflow.log_metrics(metrics, step) |
| |
|
| | def log_checkpoint_table(self, current_lr:float, loss:float, perplexity: float, last_batch:int) -> None: |
| | """ |
| | Log a checkpoint record (month, day, hour, perplexity) to MLflow as a table artifact. |
| | Perplexity is rounded to 4 decimal places. |
| | |
| | Parameters |
| | ---------- |
| | perplexity : float |
| | The perplexity metric to log (rounded to 4 decimal places). |
| | :param current_lr: |
| | :param loss: |
| | :param perplexity: |
| | :param last_batch: |
| | """ |
| | |
| | artifact_dir = f"checkpoint_table/model" |
| | os.makedirs(artifact_dir, exist_ok=True) |
| |
|
| | |
| | now = datetime.now(ZoneInfo("America/Sao_Paulo")) |
| | record = { |
| | "month": now.month, |
| | "day": now.day, |
| | "hour": f"{now.hour:02d}:{now.minute:02d}", |
| | "last_batch": last_batch, |
| | "current_lr": round(current_lr, 7), |
| | "perplexity": round(perplexity, 4), |
| | "loss": round(loss, 4), |
| |
|
| | } |
| | df_record = pd.DataFrame([record]) |
| |
|
| | |
| | artifact_file = f"{artifact_dir}/checkpoint_table.json" |
| |
|
| | |
| | mlflow.log_table( |
| | data=df_record, |
| | artifact_file=artifact_file |
| | ) |
| |
|
| | def checkpoint_model(self, model: nn.Module): |
| | |
| | step = 1 |
| | checkpoint_dir = f"checkpoints/model_{step}" |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| |
|
| | |
| | checkpoint_path = os.path.join(checkpoint_dir, "model.pth") |
| | torch.save(model.state_dict(), checkpoint_path) |
| |
|
| | |
| | mlflow.log_artifact(checkpoint_path, f"model_checkpoints/epoch_{step}") |
| |
|
| | input_example = torch.zeros(1, 128, dtype=torch.long) |
| | |
| |
|
| | |
| | if self.model_name: |
| | registered_model_name = f"{self.model_name}" |
| | mlflow.pytorch.log_model( |
| | pytorch_model=model, |
| | artifact_path=f"models/epoch_{step}", |
| | registered_model_name=registered_model_name, |
| | pip_requirements=["torch>=1.9.0"], |
| | code_paths=["tynerox/"], |
| | |
| | signature=None |
| | ) |
| |
|
| | table_dict = { |
| | "entrada": ["Pergunta A", "Pergunta B"], |
| | "saida": ["Resposta A", "Resposta B"], |
| | "nota": [0.75, 0.40], |
| | } |
| |
|
| | def log_html(self, html: str, step: Optional[int] = None): |
| | file_path = f"visualizations/sample.html" |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| |
|
| | with open(file_path, "w") as f: |
| | f.write(html) |
| |
|
| | mlflow.log_artifact(file_path) |
| |
|
| | def finish(self): |
| | """Finaliza a execução do MLflow run""" |
| | mlflow.end_run() |
| |
|
| | def __enter__(self): |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | self.finish() |
| |
|