|
|
""" |
|
|
日志记录器模块 |
|
|
Logger module for training and evaluation logging |
|
|
|
|
|
该模块实现了一个完整的日志记录系统,包含: |
|
|
- 控制台、文件和远程日志输出 |
|
|
- 训练指标记录和可视化 |
|
|
- TensorBoard和WandB集成 |
|
|
- 实验跟踪和结果保存 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import logging |
|
|
import logging.handlers |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any, Optional, Union |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from datetime import datetime |
|
|
import pickle |
|
|
from collections import defaultdict |
|
|
import warnings |
|
|
|
|
|
|
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
TENSORBOARD_AVAILABLE = True |
|
|
except ImportError: |
|
|
TENSORBOARD_AVAILABLE = False |
|
|
warnings.warn("TensorBoard不可用,请安装tensorboard: pip install tensorboard") |
|
|
|
|
|
try: |
|
|
import wandb |
|
|
WANDB_AVAILABLE = True |
|
|
except ImportError: |
|
|
WANDB_AVAILABLE = False |
|
|
warnings.warn("WandB不可用,请安装wandb: pip install wandb") |
|
|
|
|
|
try: |
|
|
import mlflow |
|
|
MLFLOW_AVAILABLE = True |
|
|
except ImportError: |
|
|
MLFLOW_AVAILABLE = False |
|
|
warnings.warn("MLflow不可用,请安装mlflow: pip install mlflow") |
|
|
|
|
|
|
|
|
class TrainingLogger: |
|
|
""" |
|
|
训练日志记录器类 |
|
|
|
|
|
功能特性: |
|
|
- 多级别日志记录(DEBUG, INFO, WARNING, ERROR) |
|
|
- 控制台和文件输出 |
|
|
- TensorBoard集成 |
|
|
- WandB集成 |
|
|
- MLflow集成 |
|
|
- 训练指标可视化 |
|
|
- 实验结果保存 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
config: Dict[str, Any], |
|
|
experiment_name: Optional[str] = None, |
|
|
log_dir: Optional[str] = None): |
|
|
""" |
|
|
初始化训练日志记录器 |
|
|
|
|
|
Args: |
|
|
config: 日志配置 |
|
|
experiment_name: 实验名称 |
|
|
log_dir: 日志目录 |
|
|
""" |
|
|
self.config = config |
|
|
self.experiment_name = experiment_name or config.get('training_info', {}).get('experiment_name', 'default_experiment') |
|
|
|
|
|
|
|
|
self.log_dir = Path(log_dir or config.get('logging', {}).get('log_dir', 'logs')) |
|
|
self.log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
self.experiment_dir = self.log_dir / f"{self.experiment_name}_{timestamp}" |
|
|
self.experiment_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.logger = self._setup_logger() |
|
|
|
|
|
|
|
|
self.tensorboard_writer = None |
|
|
self.wandb_run = None |
|
|
self.mlflow_experiment = None |
|
|
|
|
|
self._setup_visualization_tools() |
|
|
|
|
|
|
|
|
self.metrics_history = defaultdict(list) |
|
|
self.config_history = {} |
|
|
|
|
|
self.logger.info(f"训练日志记录器初始化完成") |
|
|
self.logger.info(f"实验目录: {self.experiment_dir}") |
|
|
|
|
|
def _setup_logger(self) -> logging.Logger: |
|
|
"""设置日志记录器""" |
|
|
logger = logging.getLogger(f"training_{self.experiment_name}") |
|
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
|
logger.handlers.clear() |
|
|
|
|
|
|
|
|
log_level = self.config.get('logging', {}).get('level', 'INFO') |
|
|
logger.setLevel(getattr(logging, log_level.upper())) |
|
|
|
|
|
|
|
|
formatter = logging.Formatter( |
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
|
) |
|
|
|
|
|
|
|
|
console_handler = logging.StreamHandler(sys.stdout) |
|
|
console_handler.setLevel(logging.INFO) |
|
|
console_handler.setFormatter(formatter) |
|
|
logger.addHandler(console_handler) |
|
|
|
|
|
|
|
|
log_file = self.experiment_dir / 'training.log' |
|
|
file_handler = logging.FileHandler(log_file, encoding='utf-8') |
|
|
file_handler.setLevel(logging.DEBUG) |
|
|
file_handler.setFormatter(formatter) |
|
|
logger.addHandler(file_handler) |
|
|
|
|
|
|
|
|
error_log_file = self.experiment_dir / 'errors.log' |
|
|
error_handler = logging.FileHandler(error_log_file, encoding='utf-8') |
|
|
error_handler.setLevel(logging.ERROR) |
|
|
error_handler.setFormatter(formatter) |
|
|
logger.addHandler(error_handler) |
|
|
|
|
|
return logger |
|
|
|
|
|
def _setup_visualization_tools(self): |
|
|
"""设置可视化工具""" |
|
|
|
|
|
tensorboard_config = self.config.get('logging', {}).get('tensorboard', {}) |
|
|
if tensorboard_config.get('enabled', False) and TENSORBOARD_AVAILABLE: |
|
|
tb_log_dir = self.experiment_dir / 'tensorboard' |
|
|
self.tensorboard_writer = SummaryWriter( |
|
|
log_dir=str(tb_log_dir), |
|
|
comment=tensorboard_config.get('comment', '') |
|
|
) |
|
|
self.logger.info(f"TensorBoard已启用,日志目录: {tb_log_dir}") |
|
|
|
|
|
|
|
|
experiment_tracking = self.config.get('experiment_tracking', {}) |
|
|
if experiment_tracking.get('enabled', False) and WANDB_AVAILABLE: |
|
|
wandb_config = experiment_tracking.get('wandb', {}) |
|
|
try: |
|
|
self.wandb_run = wandb.init( |
|
|
project=wandb_config.get('experiment_name', self.experiment_name), |
|
|
name=wandb_config.get('run_name', f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"), |
|
|
config=self.config, |
|
|
tags=wandb_config.get('tags', []), |
|
|
reinit=True |
|
|
) |
|
|
self.logger.info("WandB已启用") |
|
|
except Exception as e: |
|
|
self.logger.warning(f"WandB初始化失败: {e}") |
|
|
|
|
|
|
|
|
if experiment_tracking.get('enabled', False) and MLFLOW_AVAILABLE: |
|
|
mlflow_config = experiment_tracking.get('mlflow', {}) |
|
|
try: |
|
|
mlflow.set_tracking_uri(mlflow_config.get('tracking_uri', 'http://localhost:5000')) |
|
|
mlflow.set_experiment(mlflow_config.get('experiment_name', self.experiment_name)) |
|
|
self.mlflow_experiment = True |
|
|
self.logger.info("MLflow已启用") |
|
|
except Exception as e: |
|
|
self.logger.warning(f"MLflow初始化失败: {e}") |
|
|
|
|
|
def log_metrics(self, metrics: Dict[str, Union[float, int]], step: Optional[int] = None, prefix: str = ''): |
|
|
""" |
|
|
记录训练指标 |
|
|
|
|
|
Args: |
|
|
metrics: 指标字典 |
|
|
step: 训练步数 |
|
|
prefix: 指标前缀 |
|
|
""" |
|
|
|
|
|
for key, value in metrics.items(): |
|
|
full_key = f"{prefix}_{key}" if prefix else key |
|
|
self.metrics_history[full_key].append((step, value)) |
|
|
|
|
|
|
|
|
metrics_str = ", ".join([f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}" for k, v in metrics.items()]) |
|
|
step_str = f" (step {step})" if step is not None else "" |
|
|
self.logger.info(f"记录指标{step_str}: {metrics_str}") |
|
|
|
|
|
|
|
|
if self.tensorboard_writer is not None: |
|
|
for key, value in metrics.items(): |
|
|
full_key = f"{prefix}/{key}" if prefix else key |
|
|
self.tensorboard_writer.add_scalar(full_key, value, step) |
|
|
|
|
|
|
|
|
if self.wandb_run is not None: |
|
|
wandb_metrics = {f"{prefix}/{key}" if prefix else key: value for key, value in metrics.items()} |
|
|
self.wandb_run.log(wandb_metrics, step=step) |
|
|
|
|
|
|
|
|
if self.mlflow_experiment: |
|
|
try: |
|
|
mlflow.log_metrics(metrics, step=step) |
|
|
except Exception as e: |
|
|
self.logger.warning(f"MLflow记录指标失败: {e}") |
|
|
|
|
|
def log_config(self, config: Dict[str, Any], name: str = 'config'): |
|
|
""" |
|
|
记录配置信息 |
|
|
|
|
|
Args: |
|
|
config: 配置字典 |
|
|
name: 配置名称 |
|
|
""" |
|
|
|
|
|
config_file = self.experiment_dir / f"{name}.json" |
|
|
with open(config_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
self.config_history[name] = config |
|
|
|
|
|
|
|
|
if self.wandb_run is not None: |
|
|
self.wandb_run.config.update(config) |
|
|
|
|
|
|
|
|
if self.mlflow_experiment: |
|
|
try: |
|
|
mlflow.log_params({f"{name}_{k}": v for k, v in config.items()}) |
|
|
except Exception as e: |
|
|
self.logger.warning(f"MLflow记录配置失败: {e}") |
|
|
|
|
|
self.logger.info(f"配置已保存: {config_file}") |
|
|
|
|
|
def log_model_info(self, model_info: Dict[str, Any]): |
|
|
""" |
|
|
记录模型信息 |
|
|
|
|
|
Args: |
|
|
model_info: 模型信息字典 |
|
|
""" |
|
|
model_info_file = self.experiment_dir / 'model_info.json' |
|
|
with open(model_info_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(model_info, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
self.logger.info(f"模型信息已保存: {model_info_file}") |
|
|
|
|
|
def log_figure(self, figure, name: str, step: Optional[int] = None): |
|
|
""" |
|
|
记录图表 |
|
|
|
|
|
Args: |
|
|
figure: matplotlib图表对象 |
|
|
name: 图表名称 |
|
|
step: 训练步数 |
|
|
""" |
|
|
|
|
|
figure_file = self.experiment_dir / f"{name}_{step if step is not None else 'final'}.png" |
|
|
figure.savefig(figure_file, dpi=300, bbox_inches='tight') |
|
|
|
|
|
|
|
|
if self.tensorboard_writer is not None: |
|
|
self.tensorboard_writer.add_figure(name, figure, step) |
|
|
|
|
|
|
|
|
if self.wandb_run is not None: |
|
|
self.wandb_run.log({name: wandb.Image(figure_file)}, step=step) |
|
|
|
|
|
self.logger.info(f"图表已保存: {figure_file}") |
|
|
|
|
|
def plot_training_curves(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
绘制训练曲线 |
|
|
|
|
|
Args: |
|
|
save_path: 保存路径 |
|
|
""" |
|
|
if not self.metrics_history: |
|
|
self.logger.warning("No training metrics data available for plotting") |
|
|
return |
|
|
|
|
|
|
|
|
plt.style.use('seaborn-v0_8') |
|
|
fig, axes = plt.subplots(2, 2, figsize=(15, 10)) |
|
|
fig.suptitle('Training Curves', fontsize=16) |
|
|
|
|
|
|
|
|
loss_keys = [k for k in self.metrics_history.keys() if 'loss' in k.lower()] |
|
|
if loss_keys: |
|
|
ax = axes[0, 0] |
|
|
for key in loss_keys: |
|
|
steps, values = zip(*self.metrics_history[key]) |
|
|
ax.plot(steps, values, label=key, linewidth=2) |
|
|
ax.set_title('Loss Curves') |
|
|
ax.set_xlabel('Epoch') |
|
|
ax.set_ylabel('Loss') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
mae_keys = [k for k in self.metrics_history.keys() if 'mae' in k.lower()] |
|
|
if mae_keys: |
|
|
ax = axes[0, 1] |
|
|
for key in mae_keys: |
|
|
steps, values = zip(*self.metrics_history[key]) |
|
|
ax.plot(steps, values, label=key, linewidth=2) |
|
|
ax.set_title('MAE Curves') |
|
|
ax.set_xlabel('Epoch') |
|
|
ax.set_ylabel('MAE') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
r2_keys = [k for k in self.metrics_history.keys() if 'r2' in k.lower()] |
|
|
if r2_keys: |
|
|
ax = axes[1, 0] |
|
|
for key in r2_keys: |
|
|
steps, values = zip(*self.metrics_history[key]) |
|
|
ax.plot(steps, values, label=key, linewidth=2) |
|
|
ax.set_title('R² Curves') |
|
|
ax.set_xlabel('Epoch') |
|
|
ax.set_ylabel('R²') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
lr_keys = [k for k in self.metrics_history.keys() if 'lr' in k.lower()] |
|
|
if lr_keys: |
|
|
ax = axes[1, 1] |
|
|
for key in lr_keys: |
|
|
steps, values = zip(*self.metrics_history[key]) |
|
|
ax.plot(steps, values, label=key, linewidth=2) |
|
|
ax.set_title('Learning Rate Curves') |
|
|
ax.set_xlabel('Epoch') |
|
|
ax.set_ylabel('Learning Rate') |
|
|
ax.set_yscale('log') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
if save_path is None: |
|
|
save_path = self.experiment_dir / 'training_curves.png' |
|
|
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
self.log_figure(plt.gcf(), 'training_curves') |
|
|
|
|
|
self.logger.info(f"Training curves saved: {save_path}") |
|
|
plt.show() |
|
|
|
|
|
def plot_metric_comparison(self, metric_name: str, save_path: Optional[str] = None): |
|
|
""" |
|
|
绘制指标比较图 |
|
|
|
|
|
Args: |
|
|
metric_name: 指标名称 |
|
|
save_path: 保存路径 |
|
|
""" |
|
|
relevant_keys = [k for k in self.metrics_history.keys() if metric_name.lower() in k.lower()] |
|
|
|
|
|
if not relevant_keys: |
|
|
self.logger.warning(f"No metrics found containing '{metric_name}'") |
|
|
return |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
|
|
for key in relevant_keys: |
|
|
steps, values = zip(*self.metrics_history[key]) |
|
|
plt.plot(steps, values, label=key, linewidth=2, marker='o', markersize=3) |
|
|
|
|
|
plt.title(f'{metric_name} 指标比较', fontsize=16) |
|
|
plt.xlabel('Epoch', fontsize=12) |
|
|
plt.ylabel(metric_name, fontsize=12) |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if save_path is None: |
|
|
save_path = self.experiment_dir / f'{metric_name}_comparison.png' |
|
|
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
self.logger.info(f"指标比较图已保存: {save_path}") |
|
|
|
|
|
def save_metrics_history(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
Save training metrics history |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the metrics history |
|
|
""" |
|
|
if save_path is None: |
|
|
save_path = self.experiment_dir / 'metrics_history.pkl' |
|
|
|
|
|
with open(save_path, 'wb') as f: |
|
|
pickle.dump(dict(self.metrics_history), f) |
|
|
|
|
|
|
|
|
json_save_path = save_path.with_suffix('.json') |
|
|
json_data = {} |
|
|
for key, values in self.metrics_history.items(): |
|
|
json_data[key] = [{'step': step, 'value': value} for step, value in values] |
|
|
|
|
|
with open(json_save_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(json_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
self.logger.info(f"Training metrics history saved: {save_path}") |
|
|
|
|
|
def load_metrics_history(self, load_path: str): |
|
|
""" |
|
|
Load training metrics history |
|
|
|
|
|
Args: |
|
|
load_path: Path to load the metrics history from |
|
|
""" |
|
|
with open(load_path, 'rb') as f: |
|
|
self.metrics_history = defaultdict(list, pickle.load(f)) |
|
|
|
|
|
self.logger.info(f"Training metrics history loaded: {load_path}") |
|
|
|
|
|
def log_experiment_summary(self, summary: Dict[str, Any]): |
|
|
""" |
|
|
记录实验总结 |
|
|
|
|
|
Args: |
|
|
summary: 实验总结字典 |
|
|
""" |
|
|
summary_file = self.experiment_dir / 'experiment_summary.json' |
|
|
with open(summary_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(summary, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
if self.wandb_run is not None: |
|
|
self.wandb_run.summary.update(summary) |
|
|
|
|
|
self.logger.info(f"实验总结已保存: {summary_file}") |
|
|
|
|
|
def close(self): |
|
|
"""关闭日志记录器""" |
|
|
|
|
|
if self.tensorboard_writer is not None: |
|
|
self.tensorboard_writer.close() |
|
|
|
|
|
|
|
|
if self.wandb_run is not None: |
|
|
self.wandb_run.finish() |
|
|
|
|
|
|
|
|
if self.mlflow_experiment: |
|
|
try: |
|
|
mlflow.end_run() |
|
|
except Exception as e: |
|
|
self.logger.warning(f"MLflow结束运行失败: {e}") |
|
|
|
|
|
self.logger.info("日志记录器已关闭") |
|
|
|
|
|
def __enter__(self): |
|
|
"""上下文管理器入口""" |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
"""上下文管理器出口""" |
|
|
self.close() |
|
|
|
|
|
|
|
|
class ProgressLogger: |
|
|
"""进度记录器类""" |
|
|
|
|
|
def __init__(self, total_steps: int, log_frequency: int = 10): |
|
|
""" |
|
|
初始化进度记录器 |
|
|
|
|
|
Args: |
|
|
total_steps: 总步数 |
|
|
log_frequency: 日志记录频率 |
|
|
""" |
|
|
self.total_steps = total_steps |
|
|
self.log_frequency = log_frequency |
|
|
self.current_step = 0 |
|
|
self.start_time = None |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
def start(self): |
|
|
"""开始记录""" |
|
|
self.start_time = datetime.now() |
|
|
self.logger.info(f"开始训练,总步数: {self.total_steps}") |
|
|
|
|
|
def update(self, step: int = 1, metrics: Optional[Dict[str, float]] = None): |
|
|
""" |
|
|
更新进度 |
|
|
|
|
|
Args: |
|
|
step: 步数增量 |
|
|
metrics: 当前步数的指标 |
|
|
""" |
|
|
self.current_step += step |
|
|
|
|
|
if self.current_step % self.log_frequency == 0: |
|
|
progress = self.current_step / self.total_steps |
|
|
elapsed_time = datetime.now() - self.start_time |
|
|
|
|
|
|
|
|
if progress > 0: |
|
|
total_estimated_time = elapsed_time / progress |
|
|
remaining_time = total_estimated_time - elapsed_time |
|
|
remaining_str = str(remaining_time).split('.')[0] |
|
|
else: |
|
|
remaining_str = "未知" |
|
|
|
|
|
log_msg = (f"进度: {self.current_step}/{self.total_steps} " |
|
|
f"({progress:.1%}) | 已用时间: {elapsed_time} | " |
|
|
f"剩余时间: {remaining_str}") |
|
|
|
|
|
if metrics: |
|
|
metrics_str = ", ".join([f"{k}: {v:.6f}" for k, v in metrics.items()]) |
|
|
log_msg += f" | {metrics_str}" |
|
|
|
|
|
self.logger.info(log_msg) |
|
|
|
|
|
def finish(self): |
|
|
"""完成记录""" |
|
|
elapsed_time = datetime.now() - self.start_time |
|
|
self.logger.info(f"训练完成,总耗时: {elapsed_time}") |
|
|
|
|
|
|
|
|
def setup_logger(level: str = "INFO", |
|
|
log_file: Optional[str] = None) -> None: |
|
|
""" |
|
|
配置全局日志记录器 |
|
|
|
|
|
Args: |
|
|
level: 日志级别 |
|
|
log_file: 日志文件路径 |
|
|
""" |
|
|
|
|
|
numeric_level = getattr(logging, level.upper(), None) |
|
|
if not isinstance(numeric_level, int): |
|
|
numeric_level = logging.INFO |
|
|
|
|
|
|
|
|
handlers = [logging.StreamHandler(sys.stdout)] |
|
|
if log_file: |
|
|
log_file_path = Path(log_file) |
|
|
log_file_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
handlers.append(logging.FileHandler(log_file)) |
|
|
|
|
|
logging.basicConfig( |
|
|
level=numeric_level, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=handlers, |
|
|
force=True |
|
|
) |
|
|
|
|
|
|
|
|
def create_logger(config: Dict[str, Any], |
|
|
experiment_name: Optional[str] = None, |
|
|
log_dir: Optional[str] = None) -> TrainingLogger: |
|
|
""" |
|
|
创建训练日志记录器的工厂函数 |
|
|
|
|
|
Args: |
|
|
config: 日志配置 |
|
|
experiment_name: 实验名称 |
|
|
log_dir: 日志目录 |
|
|
|
|
|
Returns: |
|
|
训练日志记录器实例 |
|
|
""" |
|
|
return TrainingLogger(config, experiment_name, log_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
test_config = { |
|
|
'training_info': { |
|
|
'experiment_name': 'test_experiment' |
|
|
}, |
|
|
'logging': { |
|
|
'level': 'INFO', |
|
|
'tensorboard': { |
|
|
'enabled': True |
|
|
} |
|
|
}, |
|
|
'experiment_tracking': { |
|
|
'enabled': False |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
with create_logger(test_config) as logger: |
|
|
logger.log_config(test_config, 'test_config') |
|
|
|
|
|
|
|
|
for epoch in range(5): |
|
|
metrics = { |
|
|
'loss': 1.0 - epoch * 0.1, |
|
|
'mae': 0.5 - epoch * 0.05, |
|
|
'r2': epoch * 0.15, |
|
|
'lr': 0.001 * (0.9 ** epoch) |
|
|
} |
|
|
logger.log_metrics(metrics, step=epoch, prefix='train') |
|
|
|
|
|
|
|
|
logger.plot_training_curves() |
|
|
|
|
|
|
|
|
summary = { |
|
|
'best_loss': 0.5, |
|
|
'best_mae': 0.25, |
|
|
'best_r2': 0.6, |
|
|
'total_epochs': 5 |
|
|
} |
|
|
logger.log_experiment_summary(summary) |
|
|
|
|
|
print("日志记录器测试完成!") |