# Copyright 2025 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import signal import sys import time from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from typing import TYPE_CHECKING, Any, Optional import torch import transformers from peft import PeftModel from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from transformers.utils import ( SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available, ) from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.misc import get_peak_memory, is_env_enabled, use_ray if is_safetensors_available(): from safetensors import safe_open from safetensors.torch import save_file if TYPE_CHECKING: from transformers import TrainerControl, TrainerState, TrainingArguments from trl import AutoModelForCausalLMWithValueHead from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments logger = logging.get_logger(__name__) def fix_valuehead_checkpoint( model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool ) -> None: r"""Fix the valuehead checkpoint files. The model is already unwrapped. There are three cases: 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} We assume `stage3_gather_16bit_weights_on_model_save=true`. """ if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): return if safe_serialization: path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} else: path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True) os.remove(path_to_checkpoint) decoder_state_dict, v_head_state_dict = {}, {} for name, param in state_dict.items(): if name.startswith("v_head."): v_head_state_dict[name] = param else: decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization ) if safe_serialization: save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) else: torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) logger.info_rank0(f"Value head model saved at: {output_dir}") class FixValueHeadModelCallback(TrainerCallback): r"""A callback for fixing the checkpoint for valuehead models.""" @override def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") fix_valuehead_checkpoint( model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors ) class SaveProcessorCallback(TrainerCallback): r"""A callback for saving the processor.""" def __init__(self, processor: "ProcessorMixin") -> None: self.processor = processor @override def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") self.processor.save_pretrained(output_dir) @override def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: self.processor.save_pretrained(args.output_dir) class PissaConvertCallback(TrainerCallback): r"""A callback for converting the PiSSA adapter to a normal one.""" @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: model = kwargs.pop("model") pissa_init_dir = os.path.join(args.output_dir, "pissa_init") logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.") if isinstance(model, PeftModel): init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") setattr(model.peft_config["default"], "init_lora_weights", True) model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) @override def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: model = kwargs.pop("model") pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.") # 1. save a pissa backup with init_lora_weights: True # 2. save a converted lora with init_lora_weights: pissa # 3. load the pissa backup with init_lora_weights: True # 4. delete the initial adapter and change init_lora_weights to pissa if isinstance(model, PeftModel): init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") setattr(model.peft_config["default"], "init_lora_weights", True) model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) model.save_pretrained( pissa_convert_dir, safe_serialization=args.save_safetensors, path_initial_model_for_weight_conversion=pissa_init_dir, ) model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.set_adapter("default") setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) class LogCallback(TrainerCallback): r"""A callback for logging training and evaluation status.""" def __init__(self) -> None: # Progress self.start_time = 0 self.cur_steps = 0 self.max_steps = 0 self.elapsed_time = "" self.remaining_time = "" self.thread_pool: Optional[ThreadPoolExecutor] = None # Status self.aborted = False self.do_train = False # Web UI self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED") if self.webui_mode and not use_ray(): signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = logging.LoggerHandler(os.getenv("LLAMABOARD_WORKDIR")) logging.add_handler(self.logger_handler) transformers.logging.add_handler(self.logger_handler) def _set_abort(self, signum, frame) -> None: self.aborted = True def _reset(self, max_steps: int = 0) -> None: self.start_time = time.time() self.cur_steps = 0 self.max_steps = max_steps self.elapsed_time = "" self.remaining_time = "" def _timing(self, cur_steps: int) -> None: cur_time = time.time() elapsed_time = cur_time - self.start_time avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 remaining_time = (self.max_steps - cur_steps) * avg_time_per_step self.cur_steps = cur_steps self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time))) def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None: with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: f.write(json.dumps(logs) + "\n") def _create_thread_pool(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) self.thread_pool = ThreadPoolExecutor(max_workers=1) def _close_thread_pool(self) -> None: if self.thread_pool is not None: self.thread_pool.shutdown(wait=True) self.thread_pool = None @override def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if ( args.should_save and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and args.overwrite_output_dir ): logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") os.remove(os.path.join(args.output_dir, TRAINER_LOG)) @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: self.do_train = True self._reset(max_steps=state.max_steps) self._create_thread_pool(output_dir=args.output_dir) @override def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): self._close_thread_pool() @override def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if self.aborted: control.should_epoch_stop = True control.should_training_stop = True @override def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if self.aborted: control.should_epoch_stop = True control.should_training_stop = True @override def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not self.do_train: self._close_thread_pool() @override def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not self.do_train: self._close_thread_pool() @override def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not args.should_save: return self._timing(cur_steps=state.global_step) logs = dict( current_steps=self.cur_steps, total_steps=self.max_steps, loss=state.log_history[-1].get("loss"), eval_loss=state.log_history[-1].get("eval_loss"), predict_loss=state.log_history[-1].get("predict_loss"), reward=state.log_history[-1].get("reward"), accuracy=state.log_history[-1].get("rewards/accuracies"), lr=state.log_history[-1].get("learning_rate"), epoch=state.log_history[-1].get("epoch"), percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, elapsed_time=self.elapsed_time, remaining_time=self.remaining_time, ) if state.num_input_tokens_seen: logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2) logs["total_tokens"] = state.num_input_tokens_seen if is_env_enabled("RECORD_VRAM"): vram_allocated, vram_reserved = get_peak_memory() logs["vram_allocated"] = round(vram_allocated / (1024**3), 2) logs["vram_reserved"] = round(vram_reserved / (1024**3), 2) logs = {k: v for k, v in logs.items() if v is not None} if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")): log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}" for extra_key in ("reward", "accuracy", "throughput"): if logs.get(extra_key): log_str += f", '{extra_key}': {logs[extra_key]:.2f}" logger.info_rank0("{" + log_str + "}") if self.thread_pool is not None: self.thread_pool.submit(self._write_log, args.output_dir, logs) @override def on_prediction_step( self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs ): if self.do_train: return if self.aborted: sys.exit(0) if not args.should_save: return eval_dataloader = kwargs.pop("eval_dataloader", None) if has_length(eval_dataloader): if self.max_steps == 0: self._reset(max_steps=len(eval_dataloader)) self._create_thread_pool(output_dir=args.output_dir) self._timing(cur_steps=self.cur_steps + 1) if self.cur_steps % 5 == 0 and self.thread_pool is not None: logs = dict( current_steps=self.cur_steps, total_steps=self.max_steps, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, elapsed_time=self.elapsed_time, remaining_time=self.remaining_time, ) self.thread_pool.submit(self._write_log, args.output_dir, logs) class ReporterCallback(TrainerCallback): r"""A callback for reporting training status to external logger.""" def __init__( self, model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: self.model_args = model_args self.data_args = data_args self.finetuning_args = finetuning_args self.generating_args = generating_args os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory") @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not state.is_world_process_zero: return if "wandb" in args.report_to: import wandb wandb.config.update( { "model_args": self.model_args.to_dict(), "data_args": self.data_args.to_dict(), "finetuning_args": self.finetuning_args.to_dict(), "generating_args": self.generating_args.to_dict(), } ) if self.finetuning_args.use_swanlab: import swanlab # type: ignore swanlab.config.update( { "model_args": self.model_args.to_dict(), "data_args": self.data_args.to_dict(), "finetuning_args": self.finetuning_args.to_dict(), "generating_args": self.generating_args.to_dict(), } ) class LabelPredictionMonitorCallback(TrainerCallback): """训练过程中的标签和预测监控回调""" def __init__(self, output_dir: str, log_interval: int = 10, save_detailed_logs: bool = True): """ 初始化监控回调 Args: output_dir: 输出目录 log_interval: 日志记录间隔(每N步记录一次) save_detailed_logs: 是否保存详细日志 """ self.output_dir = output_dir self.log_interval = log_interval self.save_detailed_logs = save_detailed_logs # 创建日志目录 self.log_dir = os.path.join(output_dir, "monitoring_logs") os.makedirs(self.log_dir, exist_ok=True) # 设置日志文件 from datetime import datetime timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.log_files = { "labels": os.path.join(self.log_dir, f"label_analysis_{timestamp}.log"), "predictions": os.path.join(self.log_dir, f"prediction_monitor_{timestamp}.log"), "alignment": os.path.join(self.log_dir, f"alignment_analysis_{timestamp}.log"), "summary": os.path.join(self.log_dir, f"training_summary_{timestamp}.json") } # 设置日志记录器 self.loggers = self._setup_loggers() # 存储历史数据 self.training_history = [] self.prediction_history = [] self.step_count = 0 self.previous_tokens = {} self.previous_predictions = {} # 存储之前的预测结果 self.previous_prediction_texts = {} # 存储之前的预测文本 self.tokenizer = None # 稍后从训练器获取 # 记录初始化 self.loggers["labels"].info("🔧 标签预测监控回调初始化完成") self.loggers["labels"].info(f"📁 输出目录: {output_dir}") self.loggers["labels"].info(f"📝 日志文件: {self.log_files['labels']}") self.loggers["labels"].info(f"📊 记录间隔: {log_interval}步") def _setup_loggers(self): """设置日志记录器""" import logging loggers = {} for log_type, log_file in self.log_files.items(): if log_type == "summary": continue # summary是JSON文件,不需要logger logger = logging.getLogger(f"monitor_{log_type}") logger.setLevel(logging.INFO) # 清除现有处理器 logger.handlers.clear() # 文件处理器 file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(logging.INFO) # 控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) # 格式化器 formatter = logging.Formatter( '%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) loggers[log_type] = logger return loggers def set_tokenizer(self, tokenizer): """设置tokenizer用于Token解码""" self.tokenizer = tokenizer self.loggers["labels"].info(f"🔤 Tokenizer已设置: {type(tokenizer).__name__}") def _decode_tokens(self, token_ids): """解码Token ID为文本""" if self.tokenizer is None: return [f"" for tid in token_ids] try: # 解码Token decoded_text = self.tokenizer.decode(token_ids, skip_special_tokens=False) return decoded_text except Exception as e: return [f"" for tid in token_ids] def analyze_model_predictions(self, model, inputs, labels): """分析模型的预测输出""" import torch import numpy as np # 只在特定步骤记录详细信息 if self.step_count % self.log_interval == 0: self.loggers["labels"].info(f"\n🔮 模型预测分析 - 步骤 {self.step_count}") self.loggers["labels"].info(f"{'='*80}") try: # 获取input_ids if isinstance(inputs, dict): input_ids = inputs.get('input_ids') else: input_ids = inputs # 处理BatchEncoding对象 if hasattr(input_ids, 'input_ids'): actual_input_ids = input_ids.input_ids else: actual_input_ids = input_ids if actual_input_ids is not None and labels is not None: # 设置模型为评估模式以获取预测 model.eval() with torch.no_grad(): # 获取模型输出 outputs = model(actual_input_ids) logits = outputs.logits # 获取预测的Token ID predicted_token_ids = torch.argmax(logits, dim=-1) # 分析每个样本 batch_size = len(actual_input_ids) for i in range(min(batch_size, 1)): # 只分析第一个样本 try: # 获取样本数据 if isinstance(actual_input_ids, torch.Tensor): sample_input = actual_input_ids[i].cpu().numpy() else: sample_input = actual_input_ids[i] if isinstance(labels, torch.Tensor): sample_labels = labels[i].cpu().numpy() else: sample_labels = labels[i] if isinstance(predicted_token_ids, torch.Tensor): sample_predictions = predicted_token_ids[i].cpu().numpy() else: sample_predictions = predicted_token_ids[i] # 找到要训练的部分 trainable_mask = sample_labels != -100 trainable_positions = np.where(trainable_mask)[0] if len(trainable_positions) > 0: # 获取预测的Token(在可训练位置) predicted_tokens = sample_predictions[trainable_mask] target_tokens = sample_labels[trainable_mask] # 只在特定步骤记录详细信息 if self.step_count % self.log_interval == 0: self.loggers["labels"].info(f"\n🎯 样本 {i+1} 预测分析:") self.loggers["labels"].info(f" 📏 可训练位置数: {len(trainable_positions)}") # 显示预测的Token ID self.loggers["labels"].info(f" 🔮 模型预测Token ID: {predicted_tokens.tolist()}") self.loggers["labels"].info(f" 🎯 目标Token ID: {target_tokens.tolist()}") # 解码预测的Token predicted_text = self._decode_tokens(predicted_tokens.tolist()) target_text = self._decode_tokens(target_tokens.tolist()) self.loggers["labels"].info(f" 🔮 模型预测文本: {predicted_text}") self.loggers["labels"].info(f" 🎯 目标文本: {target_text}") # 分析预测文本的变化 self._analyze_prediction_text_changes(predicted_text, i) # 计算预测准确率 correct_predictions = np.sum(predicted_tokens == target_tokens) accuracy = correct_predictions / len(target_tokens) * 100 self.loggers["labels"].info(f" 📊 预测准确率: {accuracy:.2f}% ({correct_predictions}/{len(target_tokens)})") # 调试信息:检查Token长度和类型 self.loggers["labels"].info(f" 🔍 调试信息:") self.loggers["labels"].info(f" 预测Token长度: {len(predicted_tokens)}") self.loggers["labels"].info(f" 目标Token长度: {len(target_tokens)}") self.loggers["labels"].info(f" 预测Token类型: {type(predicted_tokens)}") self.loggers["labels"].info(f" 目标Token类型: {type(target_tokens)}") # 检查是否有任何匹配 if correct_predictions == 0: self.loggers["labels"].info(f" ⚠️ 警告: 预测准确率为0%,可能的原因:") self.loggers["labels"].info(f" 1. 模型在训练初期,还未学会正确预测") self.loggers["labels"].info(f" 2. 预测位置可能不正确") self.loggers["labels"].info(f" 3. Token对齐可能有问题") # 显示前5个Token的详细对比 self.loggers["labels"].info(f" 🔍 前5个Token详细对比:") for j in range(min(5, len(predicted_tokens))): pred_token = predicted_tokens[j] target_token = target_tokens[j] pred_text = self._decode_tokens([pred_token]) target_text_single = self._decode_tokens([target_token]) self.loggers["labels"].info(f" 位置{j}: 预测:{pred_token}({pred_text}) vs 目标:{target_token}({target_text_single})") # 显示前10个位置的详细对比 self.loggers["labels"].info(f" 🔍 前10个位置详细对比:") for j in range(min(10, len(predicted_tokens))): pred_token = predicted_tokens[j] target_token = target_tokens[j] pred_text = self._decode_tokens([pred_token]) target_text_single = self._decode_tokens([target_token]) match_symbol = "✅" if pred_token == target_token else "❌" self.loggers["labels"].info(f" 位置{j}: {match_symbol} 预测:{pred_token}({pred_text}) vs 目标:{target_token}({target_text_single})") # 分析-100部分(忽略的Token) ignore_mask = sample_labels == -100 ignore_positions = np.where(ignore_mask)[0] ignore_tokens = sample_input[ignore_mask] if len(ignore_positions) > 0: self.loggers["labels"].info(f"\n🚫 忽略的Token分析 (-100部分):") self.loggers["labels"].info(f" 📏 忽略位置数: {len(ignore_positions)}") self.loggers["labels"].info(f" 📍 忽略位置: {ignore_positions.tolist()}") self.loggers["labels"].info(f" 🔤 忽略Token ID: {ignore_tokens.tolist()}") # 解码忽略的Token ignore_text = self._decode_tokens(ignore_tokens.tolist()) self.loggers["labels"].info(f" 🔤 忽略Token文本: {ignore_text}") # 分析多轮对话结构 self.loggers["labels"].info(f"\n💬 多轮对话结构分析:") self.loggers["labels"].info(f" 📊 总长度: {len(sample_input)}") self.loggers["labels"].info(f" 🎯 训练部分: {len(trainable_positions)} ({len(trainable_positions)/len(sample_input)*100:.1f}%)") self.loggers["labels"].info(f" 🚫 忽略部分: {len(ignore_positions)} ({len(ignore_positions)/len(sample_input)*100:.1f}%)") # 分析对话分段 self._analyze_conversation_segments(sample_input, sample_labels, i) # 分析预测变化 if hasattr(self, 'previous_predictions') and i in self.previous_predictions: prev_predictions = self.previous_predictions[i] if len(prev_predictions) == len(predicted_tokens): changes = np.sum(prev_predictions != predicted_tokens) if changes > 0: change_positions = np.where(prev_predictions != predicted_tokens)[0] self.loggers["labels"].info(f"\n🔄 步骤 {self.step_count} 预测变化:") self.loggers["labels"].info(f" 📊 变化数量: {changes}/{len(predicted_tokens)}") self.loggers["labels"].info(f" 📍 变化位置: {change_positions.tolist()}") # 显示具体的变化 for pos in change_positions: prev_token = prev_predictions[pos] curr_token = predicted_tokens[pos] target_token = target_tokens[pos] prev_text = self._decode_tokens([prev_token]) curr_text = self._decode_tokens([curr_token]) target_text = self._decode_tokens([target_token]) self.loggers["labels"].info(f" 位置{pos}: {prev_token}({prev_text}) -> {curr_token}({curr_text}) [目标: {target_token}({target_text})]") # 保存当前预测用于下次比较 if not hasattr(self, 'previous_predictions'): self.previous_predictions = {} self.previous_predictions[i] = predicted_tokens.copy() except Exception as e: self.loggers["labels"].error(f"❌ 分析样本 {i} 预测失败: {e}") # 恢复训练模式 model.train() except Exception as e: self.loggers["labels"].error(f"❌ 预测分析失败: {e}") def _analyze_conversation_segments(self, sample_input, sample_labels, sample_idx): """分析多轮对话的分段结构""" import numpy as np try: # 找到训练和忽略的分段 trainable_mask = sample_labels != -100 ignore_mask = sample_labels == -100 # 找到分段的边界 segments = [] current_segment = [] current_type = None for i, (is_trainable, token_id) in enumerate(zip(trainable_mask, sample_input)): segment_type = "trainable" if is_trainable else "ignore" if current_type != segment_type: if current_segment: segments.append({ 'type': current_type, 'start': current_segment[0]['pos'], 'end': current_segment[-1]['pos'], 'length': len(current_segment), 'tokens': [item['token'] for item in current_segment] }) current_segment = [] current_type = segment_type current_segment.append({ 'pos': i, 'token': token_id }) # 添加最后一个分段 if current_segment: segments.append({ 'type': current_type, 'start': current_segment[0]['pos'], 'end': current_segment[-1]['pos'], 'length': len(current_segment), 'tokens': [item['token'] for item in current_segment] }) # 记录分段信息 self.loggers["labels"].info(f" 📝 对话分段详情:") for seg_idx, segment in enumerate(segments): segment_text = self._decode_tokens(segment['tokens']) segment_type_emoji = "🎯" if segment['type'] == 'trainable' else "🚫" self.loggers["labels"].info(f" 分段{seg_idx+1}: {segment_type_emoji} {segment['type']} 位置{segment['start']}-{segment['end']} 长度{segment['length']}") self.loggers["labels"].info(f" 文本: {segment_text}") # 分析对话模式 trainable_segments = [s for s in segments if s['type'] == 'trainable'] ignore_segments = [s for s in segments if s['type'] == 'ignore'] self.loggers["labels"].info(f" 📊 对话模式分析:") self.loggers["labels"].info(f" 训练分段数: {len(trainable_segments)}") self.loggers["labels"].info(f" 忽略分段数: {len(ignore_segments)}") if len(trainable_segments) > 1: self.loggers["labels"].info(f" 💬 多轮对话检测: 发现{len(trainable_segments)}个训练分段") for i, seg in enumerate(trainable_segments): self.loggers["labels"].info(f" 轮次{i+1}: 位置{seg['start']}-{seg['end']} 长度{seg['length']}") except Exception as e: self.loggers["labels"].error(f"❌ 对话分段分析失败: {e}") def _analyze_prediction_text_changes(self, current_text, sample_idx): """分析预测文本的变化""" try: if hasattr(self, 'previous_prediction_texts') and sample_idx in self.previous_prediction_texts: previous_text = self.previous_prediction_texts[sample_idx] if previous_text != current_text: self.loggers["labels"].info(f"\n📝 步骤 {self.step_count} 预测文本变化:") self.loggers["labels"].info(f" 🔄 文本发生变化!") # 计算文本相似度 similarity = self._calculate_text_similarity(previous_text, current_text) self.loggers["labels"].info(f" 📊 文本相似度: {similarity:.2f}%") # 显示变化的部分 self._show_text_differences(previous_text, current_text) else: self.loggers["labels"].info(f"\n📝 步骤 {self.step_count} 预测文本变化:") self.loggers["labels"].info(f" ✅ 文本未发生变化") # 保存当前预测文本 if not hasattr(self, 'previous_prediction_texts'): self.previous_prediction_texts = {} self.previous_prediction_texts[sample_idx] = current_text except Exception as e: self.loggers["labels"].error(f"❌ 预测文本变化分析失败: {e}") def _calculate_text_similarity(self, text1, text2): """计算两个文本的相似度""" try: # 简单的字符级相似度计算 if len(text1) == 0 and len(text2) == 0: return 100.0 if len(text1) == 0 or len(text2) == 0: return 0.0 # 使用编辑距离计算相似度 from difflib import SequenceMatcher similarity = SequenceMatcher(None, text1, text2).ratio() return similarity * 100 except Exception: return 0.0 def _show_text_differences(self, old_text, new_text): """显示文本差异""" try: from difflib import unified_diff self.loggers["labels"].info(f" 📋 文本变化详情:") self.loggers["labels"].info(f" 之前: {old_text}") self.loggers["labels"].info(f" 现在: {new_text}") # 使用unified_diff显示差异 diff_lines = list(unified_diff( old_text.splitlines(keepends=True), new_text.splitlines(keepends=True), fromfile='之前', tofile='现在', lineterm='' )) if diff_lines: self.loggers["labels"].info(f" 🔍 差异分析:") for line in diff_lines[:10]: # 只显示前10行差异 if line.startswith('+'): self.loggers["labels"].info(f" ➕ 新增: {line[1:].strip()}") elif line.startswith('-'): self.loggers["labels"].info(f" ➖ 删除: {line[1:].strip()}") elif line.startswith('@@'): self.loggers["labels"].info(f" 📍 {line.strip()}") # 分析变化类型 self._analyze_change_types(old_text, new_text) except Exception as e: self.loggers["labels"].error(f"❌ 文本差异分析失败: {e}") def _analyze_change_types(self, old_text, new_text): """分析变化类型""" try: changes = { 'added_chars': 0, 'removed_chars': 0, 'modified_chars': 0 } # 简单的变化分析 if len(new_text) > len(old_text): changes['added_chars'] = len(new_text) - len(old_text) elif len(new_text) < len(old_text): changes['removed_chars'] = len(old_text) - len(new_text) # 计算修改的字符数 min_len = min(len(old_text), len(new_text)) for i in range(min_len): if old_text[i] != new_text[i]: changes['modified_chars'] += 1 self.loggers["labels"].info(f" 📈 变化统计:") self.loggers["labels"].info(f" 新增字符: {changes['added_chars']}") self.loggers["labels"].info(f" 删除字符: {changes['removed_chars']}") self.loggers["labels"].info(f" 修改字符: {changes['modified_chars']}") except Exception as e: self.loggers["labels"].error(f"❌ 变化类型分析失败: {e}") @override def on_step_end(self, args, state, control, **kwargs): """在每个训练步骤结束时调用""" self.step_count += 1 if self.step_count % self.log_interval == 0: self.loggers["labels"].info(f"\n{'='*80}") self.loggers["labels"].info(f"🔄 训练步骤 {state.global_step} 监控") self.loggers["labels"].info(f"{'='*80}") # 记录训练状态 if state.log_history: latest_log = state.log_history[-1] self.loggers["labels"].info(f"📈 当前Loss: {latest_log.get('loss', 'N/A')}") # 记录步骤信息 from datetime import datetime step_info = { "step": state.global_step, "timestamp": datetime.now().isoformat(), "loss": state.log_history[-1].get('loss') if state.log_history else None, "training_time": getattr(state, 'training_time', None) } self.training_history.append(step_info) @override def on_log(self, args, state, control, **kwargs): """在日志记录时调用,可以获取到训练数据""" if self.step_count % self.log_interval == 0: # 尝试获取当前批次的数据 if hasattr(kwargs, 'logs') and kwargs['logs']: logs = kwargs['logs'] self.loggers["labels"].info(f"📊 训练日志: {logs}") def analyze_training_tokens(self, model, inputs, labels): """分析训练过程中的token变化""" import numpy as np from datetime import datetime # 只在特定步骤记录详细信息 if self.step_count % self.log_interval == 0: self.loggers["labels"].info(f"\n🔍 训练Token分析 - 步骤 {self.step_count}") self.loggers["labels"].info(f"{'='*80}") # 获取input_ids和labels if isinstance(inputs, dict): input_ids = inputs.get('input_ids') attention_mask = inputs.get('attention_mask') else: input_ids = inputs attention_mask = None # 处理BatchEncoding对象 if hasattr(input_ids, 'input_ids'): # 如果是BatchEncoding,获取实际的input_ids tensor actual_input_ids = input_ids.input_ids else: actual_input_ids = input_ids # 详细调试信息 if self.step_count % self.log_interval == 0: self.loggers["labels"].info(f"🔍 详细调试信息:") self.loggers["labels"].info(f" inputs类型: {type(inputs)}") if isinstance(inputs, dict): self.loggers["labels"].info(f" inputs键: {list(inputs.keys())}") for key, value in inputs.items(): self.loggers["labels"].info(f" {key}: 类型={type(value)}, 形状={getattr(value, 'shape', 'N/A')}") self.loggers["labels"].info(f" input_ids类型: {type(input_ids)}") self.loggers["labels"].info(f" actual_input_ids类型: {type(actual_input_ids)}") self.loggers["labels"].info(f" labels类型: {type(labels)}") if actual_input_ids is not None: self.loggers["labels"].info(f" actual_input_ids详细信息:") self.loggers["labels"].info(f" 类型: {type(actual_input_ids)}") self.loggers["labels"].info(f" 形状: {getattr(actual_input_ids, 'shape', 'N/A')}") self.loggers["labels"].info(f" 设备: {getattr(actual_input_ids, 'device', 'N/A')}") self.loggers["labels"].info(f" 数据类型: {getattr(actual_input_ids, 'dtype', 'N/A')}") if labels is not None: self.loggers["labels"].info(f" labels详细信息:") self.loggers["labels"].info(f" 类型: {type(labels)}") self.loggers["labels"].info(f" 形状: {getattr(labels, 'shape', 'N/A')}") self.loggers["labels"].info(f" 设备: {getattr(labels, 'device', 'N/A')}") self.loggers["labels"].info(f" 数据类型: {getattr(labels, 'dtype', 'N/A')}") # 尝试不同的访问方式 self.loggers["labels"].info(f"🔍 尝试访问方式:") try: if actual_input_ids is not None: self.loggers["labels"].info(f" actual_input_ids[0] 类型: {type(actual_input_ids[0])}") if hasattr(actual_input_ids[0], 'shape'): self.loggers["labels"].info(f" actual_input_ids[0] 形状: {actual_input_ids[0].shape}") except Exception as e: self.loggers["labels"].error(f" ❌ actual_input_ids[0] 访问失败: {e}") try: if labels is not None: self.loggers["labels"].info(f" labels[0] 类型: {type(labels[0])}") if hasattr(labels[0], 'shape'): self.loggers["labels"].info(f" labels[0] 形状: {labels[0].shape}") except Exception as e: self.loggers["labels"].error(f" ❌ labels[0] 访问失败: {e}") if actual_input_ids is not None and labels is not None: # 分析每个样本 batch_size = len(actual_input_ids) for i in range(min(batch_size, 1)): # 只分析第一个样本 try: # 使用正确的tensor索引方式 if isinstance(actual_input_ids, torch.Tensor): sample_input = actual_input_ids[i].cpu().numpy() else: sample_input = actual_input_ids[i] if isinstance(labels, torch.Tensor): sample_labels = labels[i].cpu().numpy() else: sample_labels = labels[i] except Exception as e: self.loggers["labels"].error(f"❌ 访问样本 {i} 失败: {e}") continue # 找到要训练的部分(labels != -100的部分) trainable_mask = sample_labels != -100 trainable_positions = np.where(trainable_mask)[0] trainable_tokens = sample_labels[trainable_mask] # 只在特定步骤记录详细信息 if self.step_count % self.log_interval == 0: self.loggers["labels"].info(f"\n📝 样本 {i+1} 训练Token详情:") self.loggers["labels"].info(f" 📏 总长度: {len(sample_input)}") self.loggers["labels"].info(f" 🎯 可训练长度: {len(trainable_positions)}") self.loggers["labels"].info(f" 📍 可训练位置: {trainable_positions.tolist()}") self.loggers["labels"].info(f" 🔤 要训练的Token ID: {trainable_tokens.tolist()}") # 解码Token为中文 trainable_text = self._decode_tokens(trainable_tokens.tolist()) self.loggers["labels"].info(f" 🔤 要训练的Token文本: {trainable_text}") # 显示对应的input tokens trainable_input_tokens = sample_input[trainable_mask] self.loggers["labels"].info(f" 📥 对应的Input Token ID: {trainable_input_tokens.tolist()}") # 解码Input Token为中文 input_text = self._decode_tokens(trainable_input_tokens.tolist()) self.loggers["labels"].info(f" 📥 对应的Input Token文本: {input_text}") # 显示完整的input_ids和labels(只显示前100个和后100个,避免日志过长) if len(sample_input) > 200: input_preview = sample_input[:100].tolist() + ["..."] + sample_input[-100:].tolist() labels_preview = sample_labels[:100].tolist() + ["..."] + sample_labels[-100:].tolist() else: input_preview = sample_input.tolist() labels_preview = sample_labels.tolist() self.loggers["labels"].info(f" 📋 完整Input IDs (预览): {input_preview}") self.loggers["labels"].info(f" 🏷️ 完整Labels (预览): {labels_preview}") # 分析token变化(每次都检查) if hasattr(self, 'previous_tokens') and i in self.previous_tokens: prev_tokens = self.previous_tokens[i] if len(prev_tokens) == len(trainable_tokens): changes = np.sum(prev_tokens != trainable_tokens) if changes > 0: change_positions = np.where(prev_tokens != trainable_tokens)[0] self.loggers["labels"].info(f"\n🔄 步骤 {self.step_count} Token变化:") self.loggers["labels"].info(f" 📊 变化数量: {changes}/{len(trainable_tokens)}") self.loggers["labels"].info(f" 📍 变化位置: {change_positions.tolist()}") # 显示具体的变化(包含解码文本) for pos in change_positions: prev_token = prev_tokens[pos] curr_token = trainable_tokens[pos] prev_text = self._decode_tokens([prev_token]) curr_text = self._decode_tokens([curr_token]) self.loggers["labels"].info(f" 位置{pos}: {prev_token}({prev_text}) -> {curr_token}({curr_text})") # 保存当前tokens用于下次比较 if not hasattr(self, 'previous_tokens'): self.previous_tokens = {} self.previous_tokens[i] = trainable_tokens.copy() @override def on_evaluate(self, args, state, control, **kwargs): """在评估时调用""" self.loggers["predictions"].info(f"\n{'='*80}") self.loggers["predictions"].info(f"📊 评估阶段监控 - 步骤 {state.global_step}") self.loggers["predictions"].info(f"{'='*80}") # 如果有预测结果,进行分析 if hasattr(kwargs, 'predict_results') and kwargs['predict_results'] is not None: self._analyze_predictions(kwargs['predict_results'], state.global_step) @override def on_predict(self, args, state, control, **kwargs): """在预测时调用""" self.loggers["predictions"].info(f"\n{'='*80}") self.loggers["predictions"].info(f"🔮 预测阶段监控 - 步骤 {state.global_step}") self.loggers["predictions"].info(f"{'='*80}") # 获取预测结果 predict_results = kwargs.get('predict_results') if predict_results is not None: self._analyze_predictions(predict_results, state.global_step) def _analyze_predictions(self, predict_results, step: int): """分析预测结果""" import numpy as np from datetime import datetime self.loggers["predictions"].info(f"📊 预测结果分析 - 步骤 {step}") # 获取预测和标签 predictions = predict_results.predictions labels = predict_results.label_ids if predictions is None or labels is None: self.loggers["predictions"].warning("⚠️ 预测结果或标签为空") return # 转换为numpy数组(如果是tensor) if isinstance(predictions, torch.Tensor): predictions = predictions.cpu().numpy() if isinstance(labels, torch.Tensor): labels = labels.cpu().numpy() # 分析每个样本 batch_size = len(predictions) self.loggers["predictions"].info(f"📦 批次大小: {batch_size}") for i in range(min(batch_size, 3)): # 只分析前3个样本 self.loggers["predictions"].info(f"\n🔍 样本 {i+1} 分析:") pred_sample = predictions[i] label_sample = labels[i] # 移除padding pred_sample = self._remove_padding(pred_sample) label_sample = self._remove_padding(label_sample) # 记录标签信息 self.loggers["labels"].info(f"\n📝 样本 {i+1} 标签分析:") self.loggers["labels"].info(f" 🎯 标签长度: {len(label_sample)}") self.loggers["labels"].info(f" 🔮 预测长度: {len(pred_sample)}") # 对齐分析 alignment_analysis = self._analyze_alignment(pred_sample, label_sample) self.loggers["alignment"].info(f"\n🎯 样本 {i+1} 对齐分析:") self.loggers["alignment"].info(f" 📏 长度差异: {alignment_analysis['length_difference']}") self.loggers["alignment"].info(f" 🎯 精确匹配: {alignment_analysis['exact_match_percentage']:.1f}%") self.loggers["alignment"].info(f" ✅ 有效匹配: {alignment_analysis['valid_match_percentage']:.1f}%") # 存储分析结果 analysis = { "step": step, "sample_idx": i, "timestamp": datetime.now().isoformat(), "predictions": pred_sample.tolist(), "labels": label_sample.tolist(), "alignment_analysis": alignment_analysis } self.prediction_history.append(analysis) def _remove_padding(self, tokens, pad_token_id: int = -100): """移除padding tokens""" import numpy as np # 找到非padding的位置 non_pad_mask = tokens != pad_token_id if np.any(non_pad_mask): # 找到第一个和最后一个非padding位置 first_non_pad = np.argmax(non_pad_mask) last_non_pad = len(tokens) - 1 - np.argmax(non_pad_mask[::-1]) return tokens[first_non_pad:last_non_pad+1] else: return tokens def _analyze_alignment(self, predictions, labels): """分析预测和标签的对齐情况""" # 基本统计 min_len = min(len(predictions), len(labels)) max_len = max(len(predictions), len(labels)) # 计算匹配 exact_matches = 0 valid_matches = 0 for i in range(min_len): if predictions[i] == labels[i]: exact_matches += 1 if predictions[i] != -100 and labels[i] != -100: valid_matches += 1 # 计算匹配率 exact_match_percentage = (exact_matches / min_len * 100) if min_len > 0 else 0 valid_match_percentage = (exact_matches / valid_matches * 100) if valid_matches > 0 else 0 return { "min_length": min_len, "max_length": max_len, "exact_matches": exact_matches, "valid_matches": valid_matches, "exact_match_percentage": exact_match_percentage, "valid_match_percentage": valid_match_percentage, "length_difference": abs(len(predictions) - len(labels)) } @override def on_train_end(self, args, state, control, **kwargs): """训练结束时调用""" from datetime import datetime self.loggers["labels"].info(f"\n{'='*80}") self.loggers["labels"].info(f"🏁 训练结束监控") self.loggers["labels"].info(f"{'='*80}") # 保存最终分析摘要 summary_data = { "training_info": { "total_steps": len(self.training_history), "total_predictions": len(self.prediction_history), "completion_time": datetime.now().isoformat(), "output_dir": self.output_dir }, "training_history": self.training_history, "prediction_history": self.prediction_history, "log_files": self.log_files } with open(self.log_files["summary"], "w", encoding="utf-8") as f: json.dump(summary_data, f, ensure_ascii=False, indent=2) self.loggers["labels"].info(f"📊 训练摘要已保存: {self.log_files['summary']}") self.loggers["labels"].info(f"📝 标签分析日志: {self.log_files['labels']}") self.loggers["labels"].info(f"🔮 预测监控日志: {self.log_files['predictions']}") self.loggers["labels"].info(f"🎯 对齐分析日志: {self.log_files['alignment']}")