Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| try: | |
| from torch.utils.tensorboard import SummaryWriter | |
| except ModuleNotFoundError: | |
| SummaryWriter = None | |
| try: | |
| import wandb | |
| except ModuleNotFoundError: | |
| wandb = None | |
| plt.ioff() | |
| PLOT_FILE_NAME = "metrics_plot.png" | |
| def safe_index(arr, idx): | |
| return arr[idx] if 0 <= idx < len(arr) else None | |
| class MetricsPlotSink: | |
| """ | |
| The MetricsPlotSink class records training metrics and saves them to a plot. | |
| Args: | |
| output_dir (str): Directory where the plot will be saved. | |
| """ | |
| def __init__(self, output_dir: str): | |
| self.output_dir = output_dir | |
| self.history = [] | |
| def update(self, values: dict): | |
| self.history.append(values) | |
| def save(self): | |
| if not self.history: | |
| print("No data to plot.") | |
| return | |
| def get_array(key): | |
| return np.array([h[key] for h in self.history if key in h]) | |
| epochs = get_array('epoch') | |
| train_loss = get_array('train_loss') | |
| test_loss = get_array('test_loss') | |
| test_coco_eval = [h['test_coco_eval_bbox'] for h in self.history if 'test_coco_eval_bbox' in h] | |
| ap50_90 = np.array([safe_index(x, 0) for x in test_coco_eval if x is not None], dtype=np.float32) | |
| ap50 = np.array([safe_index(x, 1) for x in test_coco_eval if x is not None], dtype=np.float32) | |
| ar50_90 = np.array([safe_index(x, 8) for x in test_coco_eval if x is not None], dtype=np.float32) | |
| ema_coco_eval = [h['ema_test_coco_eval_bbox'] for h in self.history if 'ema_test_coco_eval_bbox' in h] | |
| ema_ap50_90 = np.array([safe_index(x, 0) for x in ema_coco_eval if x is not None], dtype=np.float32) | |
| ema_ap50 = np.array([safe_index(x, 1) for x in ema_coco_eval if x is not None], dtype=np.float32) | |
| ema_ar50_90 = np.array([safe_index(x, 8) for x in ema_coco_eval if x is not None], dtype=np.float32) | |
| fig, axes = plt.subplots(2, 2, figsize=(18, 12)) | |
| # Subplot (0,0): Training and Validation Loss | |
| if len(epochs) > 0: | |
| if len(train_loss): | |
| axes[0][0].plot(epochs, train_loss, label='Training Loss', marker='o', linestyle='-') | |
| if len(test_loss): | |
| axes[0][0].plot(epochs, test_loss, label='Validation Loss', marker='o', linestyle='--') | |
| axes[0][0].set_title('Training and Validation Loss') | |
| axes[0][0].set_xlabel('Epoch Number') | |
| axes[0][0].set_ylabel('Loss Value') | |
| axes[0][0].legend() | |
| axes[0][0].grid(True) | |
| # Subplot (0,1): Average Precision @0.50 | |
| if ap50.size > 0 or ema_ap50.size > 0: | |
| if ap50.size > 0: | |
| axes[0][1].plot(epochs[:len(ap50)], ap50, marker='o', linestyle='-', label='Base Model') | |
| if ema_ap50.size > 0: | |
| axes[0][1].plot(epochs[:len(ema_ap50)], ema_ap50, marker='o', linestyle='--', label='EMA Model') | |
| axes[0][1].set_title('Average Precision @0.50') | |
| axes[0][1].set_xlabel('Epoch Number') | |
| axes[0][1].set_ylabel('AP50') | |
| axes[0][1].legend() | |
| axes[0][1].grid(True) | |
| # Subplot (1,0): Average Precision @0.50:0.95 | |
| if ap50_90.size > 0 or ema_ap50_90.size > 0: | |
| if ap50_90.size > 0: | |
| axes[1][0].plot(epochs[:len(ap50_90)], ap50_90, marker='o', linestyle='-', label='Base Model') | |
| if ema_ap50_90.size > 0: | |
| axes[1][0].plot(epochs[:len(ema_ap50_90)], ema_ap50_90, marker='o', linestyle='--', label='EMA Model') | |
| axes[1][0].set_title('Average Precision @0.50:0.95') | |
| axes[1][0].set_xlabel('Epoch Number') | |
| axes[1][0].set_ylabel('AP') | |
| axes[1][0].legend() | |
| axes[1][0].grid(True) | |
| # Subplot (1,1): Average Recall @0.50:0.95 | |
| if ar50_90.size > 0 or ema_ar50_90.size > 0: | |
| if ar50_90.size > 0: | |
| axes[1][1].plot(epochs[:len(ar50_90)], ar50_90, marker='o', linestyle='-', label='Base Model') | |
| if ema_ar50_90.size > 0: | |
| axes[1][1].plot(epochs[:len(ema_ar50_90)], ema_ar50_90, marker='o', linestyle='--', label='EMA Model') | |
| axes[1][1].set_title('Average Recall @0.50:0.95') | |
| axes[1][1].set_xlabel('Epoch Number') | |
| axes[1][1].set_ylabel('AR') | |
| axes[1][1].legend() | |
| axes[1][1].grid(True) | |
| plt.tight_layout() | |
| plt.savefig(f"{self.output_dir}/{PLOT_FILE_NAME}") | |
| plt.close(fig) | |
| print(f"Results saved to {self.output_dir}/{PLOT_FILE_NAME}") | |
| class MetricsTensorBoardSink: | |
| """ | |
| Training metrics via TensorBoard. | |
| Args: | |
| output_dir (str): Directory where TensorBoard logs will be written. | |
| """ | |
| def __init__(self, output_dir: str): | |
| if SummaryWriter: | |
| self.writer = SummaryWriter(log_dir=output_dir) | |
| print(f"TensorBoard logging initialized. To monitor logs, use 'tensorboard --logdir {output_dir}' and open http://localhost:6006/ in browser.") | |
| else: | |
| self.writer = None | |
| print("Unable to initialize TensorBoard. Logging is turned off for this session. Run 'pip install tensorboard' to enable logging.") | |
| def update(self, values: dict): | |
| if not self.writer: | |
| return | |
| epoch = values['epoch'] | |
| if 'train_loss' in values: | |
| self.writer.add_scalar("Loss/Train", values['train_loss'], epoch) | |
| if 'test_loss' in values: | |
| self.writer.add_scalar("Loss/Test", values['test_loss'], epoch) | |
| if 'test_coco_eval_bbox' in values: | |
| coco_eval = values['test_coco_eval_bbox'] | |
| ap50_90 = safe_index(coco_eval, 0) | |
| ap50 = safe_index(coco_eval, 1) | |
| ar50_90 = safe_index(coco_eval, 8) | |
| if ap50_90 is not None: | |
| self.writer.add_scalar("Metrics/Base/AP50_90", ap50_90, epoch) | |
| if ap50 is not None: | |
| self.writer.add_scalar("Metrics/Base/AP50", ap50, epoch) | |
| if ar50_90 is not None: | |
| self.writer.add_scalar("Metrics/Base/AR50_90", ar50_90, epoch) | |
| if 'ema_test_coco_eval_bbox' in values: | |
| ema_coco_eval = values['ema_test_coco_eval_bbox'] | |
| ema_ap50_90 = safe_index(ema_coco_eval, 0) | |
| ema_ap50 = safe_index(ema_coco_eval, 1) | |
| ema_ar50_90 = safe_index(ema_coco_eval, 8) | |
| if ema_ap50_90 is not None: | |
| self.writer.add_scalar("Metrics/EMA/AP50_90", ema_ap50_90, epoch) | |
| if ema_ap50 is not None: | |
| self.writer.add_scalar("Metrics/EMA/AP50", ema_ap50, epoch) | |
| if ema_ar50_90 is not None: | |
| self.writer.add_scalar("Metrics/EMA/AR50_90", ema_ar50_90, epoch) | |
| self.writer.flush() | |
| def close(self): | |
| if not self.writer: | |
| return | |
| self.writer.close() | |
| class MetricsWandBSink: | |
| """ | |
| Training metrics via W&B. | |
| Args: | |
| output_dir (str): Directory where W&B logs will be written locally. | |
| project (str, optional): Associate this training run with a W&B project. If None, W&B will generate a name based on the git repo name. | |
| run (str, optional): W&B run name. If None, W&B will generate a random name. | |
| config (dict, optional): Input parameters, like hyperparameters or data preprocessing settings for the run for later comparison. | |
| """ | |
| def __init__(self, output_dir: str, project: Optional[str] = None, run: Optional[str] = None, config: Optional[dict] = None): | |
| self.output_dir = output_dir | |
| if wandb: | |
| self.run = wandb.init( | |
| project=project, | |
| name=run, | |
| config=config, | |
| dir=output_dir | |
| ) | |
| print(f"W&B logging initialized. To monitor logs, open {wandb.run.url}.") | |
| else: | |
| self.run = None | |
| print("Unable to initialize W&B. Logging is turned off for this session. Run 'pip install wandb' to enable logging.") | |
| def update(self, values: dict): | |
| if not wandb or not self.run: | |
| return | |
| epoch = values['epoch'] | |
| log_dict = {"epoch": epoch} | |
| if 'train_loss' in values: | |
| log_dict["Loss/Train"] = values['train_loss'] | |
| if 'test_loss' in values: | |
| log_dict["Loss/Test"] = values['test_loss'] | |
| if 'test_coco_eval_bbox' in values: | |
| coco_eval = values['test_coco_eval_bbox'] | |
| ap50_90 = safe_index(coco_eval, 0) | |
| ap50 = safe_index(coco_eval, 1) | |
| ar50_90 = safe_index(coco_eval, 8) | |
| if ap50_90 is not None: | |
| log_dict["Metrics/Base/AP50_90"] = ap50_90 | |
| if ap50 is not None: | |
| log_dict["Metrics/Base/AP50"] = ap50 | |
| if ar50_90 is not None: | |
| log_dict["Metrics/Base/AR50_90"] = ar50_90 | |
| if 'ema_test_coco_eval_bbox' in values: | |
| ema_coco_eval = values['ema_test_coco_eval_bbox'] | |
| ema_ap50_90 = safe_index(ema_coco_eval, 0) | |
| ema_ap50 = safe_index(ema_coco_eval, 1) | |
| ema_ar50_90 = safe_index(ema_coco_eval, 8) | |
| if ema_ap50_90 is not None: | |
| log_dict["Metrics/EMA/AP50_90"] = ema_ap50_90 | |
| if ema_ap50 is not None: | |
| log_dict["Metrics/EMA/AP50"] = ema_ap50 | |
| if ema_ar50_90 is not None: | |
| log_dict["Metrics/EMA/AR50_90"] = ema_ar50_90 | |
| wandb.log(log_dict) | |
| def close(self): | |
| if not wandb or not self.run: | |
| return | |
| self.run.finish() |