| |
| """ |
| 训练过程中的标签和预测监控回调 |
| 直接集成到LLaMA-Factory的训练流程中 |
| """ |
|
|
| import os |
| import json |
| import logging |
| import numpy as np |
| import torch |
| from datetime import datetime |
| from typing import Dict, List, Any, Optional |
| from transformers import TrainerCallback, TrainerState, TrainerControl |
| from transformers.trainer_utils import PredictionOutput |
|
|
| class LabelPredictionMonitor(TrainerCallback): |
| """训练过程中的标签和预测监控回调""" |
| |
| def __init__(self, |
| output_dir: str, |
| log_interval: int = 10, |
| save_detailed_logs: bool = True): |
| """ |
| 初始化监控回调 |
| |
| Args: |
| output_dir: 输出目录 |
| log_interval: 日志记录间隔(每N步记录一次) |
| save_detailed_logs: 是否保存详细日志 |
| """ |
| self.output_dir = output_dir |
| self.log_interval = log_interval |
| self.save_detailed_logs = save_detailed_logs |
| |
| |
| self.log_dir = os.path.join(output_dir, "monitoring_logs") |
| os.makedirs(self.log_dir, exist_ok=True) |
| |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| self.log_files = { |
| "labels": os.path.join(self.log_dir, f"label_analysis_{timestamp}.log"), |
| "predictions": os.path.join(self.log_dir, f"prediction_monitor_{timestamp}.log"), |
| "alignment": os.path.join(self.log_dir, f"alignment_analysis_{timestamp}.log"), |
| "summary": os.path.join(self.log_dir, f"training_summary_{timestamp}.json") |
| } |
| |
| |
| self.loggers = self._setup_loggers() |
| |
| |
| self.training_history = [] |
| self.prediction_history = [] |
| self.step_count = 0 |
| |
| |
| self.loggers["labels"].info("🔧 标签预测监控回调初始化完成") |
| self.loggers["labels"].info(f"📁 输出目录: {output_dir}") |
| self.loggers["labels"].info(f"📝 日志文件: {self.log_files['labels']}") |
| self.loggers["labels"].info(f"📊 记录间隔: {log_interval}步") |
| |
| def _setup_loggers(self) -> Dict[str, logging.Logger]: |
| """设置日志记录器""" |
| loggers = {} |
| |
| for log_type, log_file in self.log_files.items(): |
| if log_type == "summary": |
| continue |
| |
| logger = logging.getLogger(f"monitor_{log_type}") |
| logger.setLevel(logging.INFO) |
| |
| |
| logger.handlers.clear() |
| |
| |
| file_handler = logging.FileHandler(log_file, encoding='utf-8') |
| file_handler.setLevel(logging.INFO) |
| |
| |
| console_handler = logging.StreamHandler() |
| console_handler.setLevel(logging.INFO) |
| |
| |
| formatter = logging.Formatter( |
| '%(asctime)s | %(levelname)s | %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| file_handler.setFormatter(formatter) |
| console_handler.setFormatter(formatter) |
| |
| logger.addHandler(file_handler) |
| logger.addHandler(console_handler) |
| |
| loggers[log_type] = logger |
| |
| return loggers |
| |
| def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| """在每个训练步骤结束时调用""" |
| self.step_count += 1 |
| |
| if self.step_count % self.log_interval == 0: |
| self.loggers["labels"].info(f"\n{'='*80}") |
| self.loggers["labels"].info(f"🔄 训练步骤 {state.global_step} 监控") |
| self.loggers["labels"].info(f"{'='*80}") |
| |
| |
| if state.log_history: |
| latest_log = state.log_history[-1] |
| self.loggers["labels"].info(f"📈 当前Loss: {latest_log.get('loss', 'N/A')}") |
| self.loggers["labels"].info(f"📊 学习率: {latest_log.get('learning_rate', 'N/A')}") |
| self.loggers["labels"].info(f"⏱️ 训练时间: {state.training_time:.2f}秒") |
| |
| |
| step_info = { |
| "step": state.global_step, |
| "timestamp": datetime.now().isoformat(), |
| "loss": state.log_history[-1].get('loss') if state.log_history else None, |
| "learning_rate": state.log_history[-1].get('learning_rate') if state.log_history else None, |
| "training_time": state.training_time |
| } |
| |
| self.training_history.append(step_info) |
| |
| def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| """在评估时调用""" |
| self.loggers["predictions"].info(f"\n{'='*80}") |
| self.loggers["predictions"].info(f"📊 评估阶段监控 - 步骤 {state.global_step}") |
| self.loggers["predictions"].info(f"{'='*80}") |
| |
| |
| if hasattr(kwargs, 'predict_results') and kwargs['predict_results'] is not None: |
| self._analyze_predictions(kwargs['predict_results'], state.global_step) |
| |
| def on_predict(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| """在预测时调用""" |
| self.loggers["predictions"].info(f"\n{'='*80}") |
| self.loggers["predictions"].info(f"🔮 预测阶段监控 - 步骤 {state.global_step}") |
| self.loggers["predictions"].info(f"{'='*80}") |
| |
| |
| predict_results = kwargs.get('predict_results') |
| if predict_results is not None: |
| self._analyze_predictions(predict_results, state.global_step) |
| |
| def _analyze_predictions(self, predict_results: PredictionOutput, step: int): |
| """分析预测结果""" |
| self.loggers["predictions"].info(f"📊 预测结果分析 - 步骤 {step}") |
| |
| |
| predictions = predict_results.predictions |
| labels = predict_results.label_ids |
| |
| if predictions is None or labels is None: |
| self.loggers["predictions"].warning("⚠️ 预测结果或标签为空") |
| return |
| |
| |
| if isinstance(predictions, torch.Tensor): |
| predictions = predictions.cpu().numpy() |
| if isinstance(labels, torch.Tensor): |
| labels = labels.cpu().numpy() |
| |
| |
| batch_size = len(predictions) |
| self.loggers["predictions"].info(f"📦 批次大小: {batch_size}") |
| |
| for i in range(min(batch_size, 3)): |
| self.loggers["predictions"].info(f"\n🔍 样本 {i+1} 分析:") |
| |
| pred_sample = predictions[i] |
| label_sample = labels[i] |
| |
| |
| pred_sample = self._remove_padding(pred_sample) |
| label_sample = self._remove_padding(label_sample) |
| |
| |
| self.loggers["labels"].info(f"\n📝 样本 {i+1} 标签分析:") |
| self.loggers["labels"].info(f" 🎯 标签长度: {len(label_sample)}") |
| self.loggers["labels"].info(f" 🔮 预测长度: {len(pred_sample)}") |
| |
| |
| alignment_analysis = self._analyze_alignment(pred_sample, label_sample) |
| self.loggers["alignment"].info(f"\n🎯 样本 {i+1} 对齐分析:") |
| self.loggers["alignment"].info(f" 📏 长度差异: {alignment_analysis['length_difference']}") |
| self.loggers["alignment"].info(f" 🎯 精确匹配: {alignment_analysis['exact_match_percentage']:.1f}%") |
| self.loggers["alignment"].info(f" ✅ 有效匹配: {alignment_analysis['valid_match_percentage']:.1f}%") |
| |
| |
| analysis = { |
| "step": step, |
| "sample_idx": i, |
| "timestamp": datetime.now().isoformat(), |
| "predictions": pred_sample.tolist(), |
| "labels": label_sample.tolist(), |
| "alignment_analysis": alignment_analysis |
| } |
| |
| self.prediction_history.append(analysis) |
| |
| def _remove_padding(self, tokens: np.ndarray, pad_token_id: int = -100) -> np.ndarray: |
| """移除padding tokens""" |
| |
| non_pad_mask = tokens != pad_token_id |
| if np.any(non_pad_mask): |
| |
| first_non_pad = np.argmax(non_pad_mask) |
| last_non_pad = len(tokens) - 1 - np.argmax(non_pad_mask[::-1]) |
| return tokens[first_non_pad:last_non_pad+1] |
| else: |
| return tokens |
| |
| def _analyze_alignment(self, predictions: np.ndarray, labels: np.ndarray) -> Dict[str, Any]: |
| """分析预测和标签的对齐情况""" |
| |
| min_len = min(len(predictions), len(labels)) |
| max_len = max(len(predictions), len(labels)) |
| |
| |
| exact_matches = 0 |
| valid_matches = 0 |
| |
| for i in range(min_len): |
| if predictions[i] == labels[i]: |
| exact_matches += 1 |
| if predictions[i] != -100 and labels[i] != -100: |
| valid_matches += 1 |
| |
| |
| exact_match_percentage = (exact_matches / min_len * 100) if min_len > 0 else 0 |
| valid_match_percentage = (exact_matches / valid_matches * 100) if valid_matches > 0 else 0 |
| |
| return { |
| "min_length": min_len, |
| "max_length": max_len, |
| "exact_matches": exact_matches, |
| "valid_matches": valid_matches, |
| "exact_match_percentage": exact_match_percentage, |
| "valid_match_percentage": valid_match_percentage, |
| "length_difference": abs(len(predictions) - len(labels)) |
| } |
| |
| def on_train_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
| """训练结束时调用""" |
| self.loggers["labels"].info(f"\n{'='*80}") |
| self.loggers["labels"].info(f"🏁 训练结束监控") |
| self.loggers["labels"].info(f"{'='*80}") |
| |
| |
| summary_data = { |
| "training_info": { |
| "total_steps": len(self.training_history), |
| "total_predictions": len(self.prediction_history), |
| "completion_time": datetime.now().isoformat(), |
| "output_dir": self.output_dir |
| }, |
| "training_history": self.training_history, |
| "prediction_history": self.prediction_history, |
| "log_files": self.log_files |
| } |
| |
| with open(self.log_files["summary"], "w", encoding="utf-8") as f: |
| json.dump(summary_data, f, ensure_ascii=False, indent=2) |
| |
| self.loggers["labels"].info(f"📊 训练摘要已保存: {self.log_files['summary']}") |
| self.loggers["labels"].info(f"📝 标签分析日志: {self.log_files['labels']}") |
| self.loggers["labels"].info(f"🔮 预测监控日志: {self.log_files['predictions']}") |
| self.loggers["labels"].info(f"🎯 对齐分析日志: {self.log_files['alignment']}") |
|
|
| def create_label_prediction_monitor(output_dir: str, **kwargs) -> LabelPredictionMonitor: |
| """创建标签预测监控回调的工厂函数""" |
| return LabelPredictionMonitor(output_dir=output_dir, **kwargs) |
|
|