import logging import sys import torch import wandb from .plot_color import Plot from .metrics import Metric class Logger: def __init__(self, args, mds): self.molecule = args.molecule self.save_dir = args.save_dir self.wandb = args.wandb self.plot = Plot(args, mds) self.metric = Metric(args, mds) self.rmsd = float("inf") def __call__(self, loss, rollout, policy): metrics = self.metric() if self.rmsd > metrics["rmsd"]: self.rmsd = metrics["rmsd"] torch.save(policy.state_dict(), f"{self.save_dir}/policy.pt") if self.wandb: if metrics["ets"] is not None: wandb.log({ "rmsd": metrics["rmsd"], "rmsd_std": metrics["rmsd_std"], "thp": metrics["thp"], "ets": metrics["ets"], "ets_std": metrics["ets_std"], "loss": loss }) else: wandb.log({ "rmsd": metrics["rmsd"], "rmsd_std": metrics["rmsd_std"], "thp": metrics["thp"], "loss": loss })