| | from pathlib import Path |
| | from typing import Dict, List, Optional |
| | import matplotlib.pyplot as plt |
| | from datetime import datetime |
| |
|
| | class Plotter: |
| | def __init__(self, save_dir: Optional[Path] = None): |
| | self.save_dir = save_dir |
| | if save_dir: |
| | self.save_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"): |
| | """Plot and save training metrics history |
| | Args: |
| | history: Dict with training metrics |
| | title: Plot title |
| | """ |
| | |
| | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) |
| | |
| | |
| | ax1.plot(history['train_loss'], label='Train Loss') |
| | ax1.plot(history['val_loss'], label='Validation Loss') |
| | ax1.set_xlabel('Epoch') |
| | ax1.set_ylabel('Loss') |
| | ax1.set_title('Training and Validation Loss') |
| | ax1.legend() |
| | ax1.grid(True) |
| | |
| | |
| | if 'learning_rate' in history: |
| | ax2.plot(history['learning_rate'], label='Learning Rate') |
| | ax2.set_xlabel('Step') |
| | ax2.set_ylabel('Learning Rate') |
| | ax2.set_title('Learning Rate Schedule') |
| | ax2.legend() |
| | ax2.grid(True) |
| | |
| | plt.suptitle(title) |
| | plt.tight_layout() |
| | |
| | |
| | if self.save_dir: |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | save_path = self.save_dir / f'training_history_{timestamp}.png' |
| | plt.savefig(save_path) |
| | |
| | plt.show() |
| | |
| | def plot_validation_metrics(self, metrics: Dict[str, float]): |
| | """Plot validation metrics as a bar chart |
| | Args: |
| | metrics: Dictionary of validation metrics. Can handle nested dictionaries. |
| | """ |
| | |
| | |
| | flat_metrics = {} |
| | for key, value in metrics.items(): |
| | if key == 'num_queries_tested': |
| | continue |
| | |
| | |
| | if isinstance(value, dict): |
| | for subkey, subvalue in value.items(): |
| | if isinstance(subvalue, (int, float)): |
| | flat_metrics[f"{key}_{subkey}"] = subvalue |
| | elif isinstance(value, (int, float)): |
| | flat_metrics[key] = value |
| | |
| | if not flat_metrics: |
| | return |
| | |
| | plt.figure(figsize=(12, 6)) |
| | |
| | |
| | metric_names = list(flat_metrics.keys()) |
| | values = list(flat_metrics.values()) |
| | |
| | |
| | bars = plt.bar(range(len(metric_names)), values) |
| | |
| | |
| | plt.title('Validation Metrics') |
| | plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right') |
| | plt.ylabel('Value') |
| | |
| | |
| | for bar in bars: |
| | height = bar.get_height() |
| | plt.text(bar.get_x() + bar.get_width()/2., height, |
| | f'{height:.3f}', |
| | ha='center', va='bottom') |
| | |
| | |
| | plt.ylim(0, 1.1) |
| | plt.tight_layout() |
| | |
| | |
| | if self.save_dir: |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | save_path = self.save_dir / f'validation_metrics_{timestamp}.png' |
| | plt.savefig(save_path) |
| | |
| | plt.show() |
| | |