| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| import signal |
| import sys |
| import time |
| from concurrent.futures import ThreadPoolExecutor |
| from datetime import timedelta |
| from typing import TYPE_CHECKING, Any, Optional |
|
|
| import torch |
| import transformers |
| from peft import PeftModel |
| from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback |
| from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length |
| from transformers.utils import ( |
| SAFE_WEIGHTS_NAME, |
| WEIGHTS_NAME, |
| is_safetensors_available, |
| ) |
| from typing_extensions import override |
|
|
| from ..extras import logging |
| from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME |
| from ..extras.misc import get_peak_memory, is_env_enabled, use_ray |
|
|
|
|
| if is_safetensors_available(): |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers import TrainerControl, TrainerState, TrainingArguments |
| from trl import AutoModelForCausalLMWithValueHead |
|
|
| from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def fix_valuehead_checkpoint( |
| model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool |
| ) -> None: |
| r"""Fix the valuehead checkpoint files. |
| |
| The model is already unwrapped. |
| |
| There are three cases: |
| 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} |
| 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} |
| 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} |
| |
| We assume `stage3_gather_16bit_weights_on_model_save=true`. |
| """ |
| if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): |
| return |
|
|
| if safe_serialization: |
| path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) |
| with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: |
| state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} |
| else: |
| path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) |
| state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True) |
|
|
| os.remove(path_to_checkpoint) |
| decoder_state_dict, v_head_state_dict = {}, {} |
| for name, param in state_dict.items(): |
| if name.startswith("v_head."): |
| v_head_state_dict[name] = param |
| else: |
| decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param |
|
|
| model.pretrained_model.save_pretrained( |
| output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization |
| ) |
|
|
| if safe_serialization: |
| save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) |
| else: |
| torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) |
|
|
| logger.info_rank0(f"Value head model saved at: {output_dir}") |
|
|
|
|
| class FixValueHeadModelCallback(TrainerCallback): |
| r"""A callback for fixing the checkpoint for valuehead models.""" |
|
|
| @override |
| def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if args.should_save: |
| output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") |
| fix_valuehead_checkpoint( |
| model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors |
| ) |
|
|
|
|
| class SaveProcessorCallback(TrainerCallback): |
| r"""A callback for saving the processor.""" |
|
|
| def __init__(self, processor: "ProcessorMixin") -> None: |
| self.processor = processor |
|
|
| @override |
| def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if args.should_save: |
| output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") |
| self.processor.save_pretrained(output_dir) |
|
|
| @override |
| def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if args.should_save: |
| self.processor.save_pretrained(args.output_dir) |
|
|
|
|
| class PissaConvertCallback(TrainerCallback): |
| r"""A callback for converting the PiSSA adapter to a normal one.""" |
|
|
| @override |
| def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if args.should_save: |
| model = kwargs.pop("model") |
| pissa_init_dir = os.path.join(args.output_dir, "pissa_init") |
| logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.") |
| if isinstance(model, PeftModel): |
| init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") |
| setattr(model.peft_config["default"], "init_lora_weights", True) |
| model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) |
| setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) |
|
|
| @override |
| def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if args.should_save: |
| model = kwargs.pop("model") |
| pissa_init_dir = os.path.join(args.output_dir, "pissa_init") |
| pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") |
| pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") |
| logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.") |
| |
| |
| |
| |
| if isinstance(model, PeftModel): |
| init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") |
| setattr(model.peft_config["default"], "init_lora_weights", True) |
| model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) |
| setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) |
| model.save_pretrained( |
| pissa_convert_dir, |
| safe_serialization=args.save_safetensors, |
| path_initial_model_for_weight_conversion=pissa_init_dir, |
| ) |
| model.load_adapter(pissa_backup_dir, "default", is_trainable=True) |
| model.set_adapter("default") |
| setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) |
|
|
|
|
| class LogCallback(TrainerCallback): |
| r"""A callback for logging training and evaluation status.""" |
|
|
| def __init__(self) -> None: |
| |
| self.start_time = 0 |
| self.cur_steps = 0 |
| self.max_steps = 0 |
| self.elapsed_time = "" |
| self.remaining_time = "" |
| self.thread_pool: Optional[ThreadPoolExecutor] = None |
| |
| self.aborted = False |
| self.do_train = False |
| |
| self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED") |
| if self.webui_mode and not use_ray(): |
| signal.signal(signal.SIGABRT, self._set_abort) |
| self.logger_handler = logging.LoggerHandler(os.getenv("LLAMABOARD_WORKDIR")) |
| logging.add_handler(self.logger_handler) |
| transformers.logging.add_handler(self.logger_handler) |
|
|
| def _set_abort(self, signum, frame) -> None: |
| self.aborted = True |
|
|
| def _reset(self, max_steps: int = 0) -> None: |
| self.start_time = time.time() |
| self.cur_steps = 0 |
| self.max_steps = max_steps |
| self.elapsed_time = "" |
| self.remaining_time = "" |
|
|
| def _timing(self, cur_steps: int) -> None: |
| cur_time = time.time() |
| elapsed_time = cur_time - self.start_time |
| avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 |
| remaining_time = (self.max_steps - cur_steps) * avg_time_per_step |
| self.cur_steps = cur_steps |
| self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) |
| self.remaining_time = str(timedelta(seconds=int(remaining_time))) |
|
|
| def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None: |
| with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: |
| f.write(json.dumps(logs) + "\n") |
|
|
| def _create_thread_pool(self, output_dir: str) -> None: |
| os.makedirs(output_dir, exist_ok=True) |
| self.thread_pool = ThreadPoolExecutor(max_workers=1) |
|
|
| def _close_thread_pool(self) -> None: |
| if self.thread_pool is not None: |
| self.thread_pool.shutdown(wait=True) |
| self.thread_pool = None |
|
|
| @override |
| def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if ( |
| args.should_save |
| and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) |
| and args.overwrite_output_dir |
| ): |
| logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") |
| os.remove(os.path.join(args.output_dir, TRAINER_LOG)) |
|
|
| @override |
| def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if args.should_save: |
| self.do_train = True |
| self._reset(max_steps=state.max_steps) |
| self._create_thread_pool(output_dir=args.output_dir) |
|
|
| @override |
| def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| self._close_thread_pool() |
|
|
| @override |
| def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if self.aborted: |
| control.should_epoch_stop = True |
| control.should_training_stop = True |
|
|
| @override |
| def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if self.aborted: |
| control.should_epoch_stop = True |
| control.should_training_stop = True |
|
|
| @override |
| def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if not self.do_train: |
| self._close_thread_pool() |
|
|
| @override |
| def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if not self.do_train: |
| self._close_thread_pool() |
|
|
| @override |
| def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if not args.should_save: |
| return |
|
|
| self._timing(cur_steps=state.global_step) |
| logs = dict( |
| current_steps=self.cur_steps, |
| total_steps=self.max_steps, |
| loss=state.log_history[-1].get("loss"), |
| eval_loss=state.log_history[-1].get("eval_loss"), |
| predict_loss=state.log_history[-1].get("predict_loss"), |
| reward=state.log_history[-1].get("reward"), |
| accuracy=state.log_history[-1].get("rewards/accuracies"), |
| lr=state.log_history[-1].get("learning_rate"), |
| epoch=state.log_history[-1].get("epoch"), |
| percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, |
| elapsed_time=self.elapsed_time, |
| remaining_time=self.remaining_time, |
| ) |
| if state.num_input_tokens_seen: |
| logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2) |
| logs["total_tokens"] = state.num_input_tokens_seen |
|
|
| if is_env_enabled("RECORD_VRAM"): |
| vram_allocated, vram_reserved = get_peak_memory() |
| logs["vram_allocated"] = round(vram_allocated / (1024**3), 2) |
| logs["vram_reserved"] = round(vram_reserved / (1024**3), 2) |
|
|
| logs = {k: v for k, v in logs.items() if v is not None} |
| if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")): |
| log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}" |
| for extra_key in ("reward", "accuracy", "throughput"): |
| if logs.get(extra_key): |
| log_str += f", '{extra_key}': {logs[extra_key]:.2f}" |
|
|
| logger.info_rank0("{" + log_str + "}") |
|
|
| if self.thread_pool is not None: |
| self.thread_pool.submit(self._write_log, args.output_dir, logs) |
|
|
| @override |
| def on_prediction_step( |
| self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs |
| ): |
| if self.do_train: |
| return |
|
|
| if self.aborted: |
| sys.exit(0) |
|
|
| if not args.should_save: |
| return |
|
|
| eval_dataloader = kwargs.pop("eval_dataloader", None) |
| if has_length(eval_dataloader): |
| if self.max_steps == 0: |
| self._reset(max_steps=len(eval_dataloader)) |
| self._create_thread_pool(output_dir=args.output_dir) |
|
|
| self._timing(cur_steps=self.cur_steps + 1) |
| if self.cur_steps % 5 == 0 and self.thread_pool is not None: |
| logs = dict( |
| current_steps=self.cur_steps, |
| total_steps=self.max_steps, |
| percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, |
| elapsed_time=self.elapsed_time, |
| remaining_time=self.remaining_time, |
| ) |
| self.thread_pool.submit(self._write_log, args.output_dir, logs) |
|
|
|
|
| class ReporterCallback(TrainerCallback): |
| r"""A callback for reporting training status to external logger.""" |
|
|
| def __init__( |
| self, |
| model_args: "ModelArguments", |
| data_args: "DataArguments", |
| finetuning_args: "FinetuningArguments", |
| generating_args: "GeneratingArguments", |
| ) -> None: |
| self.model_args = model_args |
| self.data_args = data_args |
| self.finetuning_args = finetuning_args |
| self.generating_args = generating_args |
| os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory") |
|
|
| @override |
| def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| if not state.is_world_process_zero: |
| return |
|
|
| if "wandb" in args.report_to: |
| import wandb |
|
|
| wandb.config.update( |
| { |
| "model_args": self.model_args.to_dict(), |
| "data_args": self.data_args.to_dict(), |
| "finetuning_args": self.finetuning_args.to_dict(), |
| "generating_args": self.generating_args.to_dict(), |
| } |
| ) |
|
|
| if self.finetuning_args.use_swanlab: |
| import swanlab |
|
|
| swanlab.config.update( |
| { |
| "model_args": self.model_args.to_dict(), |
| "data_args": self.data_args.to_dict(), |
| "finetuning_args": self.finetuning_args.to_dict(), |
| "generating_args": self.generating_args.to_dict(), |
| } |
| ) |
|
|
| class LabelPredictionMonitorCallback(TrainerCallback): |
| """训练过程中的标签和预测监控回调""" |
| |
| def __init__(self, |
| output_dir: str, |
| log_interval: int = 10, |
| save_detailed_logs: bool = True): |
| """ |
| 初始化监控回调 |
| |
| Args: |
| output_dir: 输出目录 |
| log_interval: 日志记录间隔(每N步记录一次) |
| save_detailed_logs: 是否保存详细日志 |
| """ |
| self.output_dir = output_dir |
| self.log_interval = log_interval |
| self.save_detailed_logs = save_detailed_logs |
| |
| |
| self.log_dir = os.path.join(output_dir, "monitoring_logs") |
| os.makedirs(self.log_dir, exist_ok=True) |
| |
| |
| from datetime import datetime |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| self.log_files = { |
| "labels": os.path.join(self.log_dir, f"label_analysis_{timestamp}.log"), |
| "predictions": os.path.join(self.log_dir, f"prediction_monitor_{timestamp}.log"), |
| "alignment": os.path.join(self.log_dir, f"alignment_analysis_{timestamp}.log"), |
| "summary": os.path.join(self.log_dir, f"training_summary_{timestamp}.json") |
| } |
| |
| |
| self.loggers = self._setup_loggers() |
| |
| |
| self.training_history = [] |
| self.prediction_history = [] |
| self.step_count = 0 |
| self.previous_tokens = {} |
| self.previous_predictions = {} |
| self.previous_prediction_texts = {} |
| self.tokenizer = None |
| |
| |
| self.loggers["labels"].info("🔧 标签预测监控回调初始化完成") |
| self.loggers["labels"].info(f"📁 输出目录: {output_dir}") |
| self.loggers["labels"].info(f"📝 日志文件: {self.log_files['labels']}") |
| self.loggers["labels"].info(f"📊 记录间隔: {log_interval}步") |
| |
| def _setup_loggers(self): |
| """设置日志记录器""" |
| import logging |
| loggers = {} |
| |
| for log_type, log_file in self.log_files.items(): |
| if log_type == "summary": |
| continue |
| |
| logger = logging.getLogger(f"monitor_{log_type}") |
| logger.setLevel(logging.INFO) |
| |
| |
| logger.handlers.clear() |
| |
| |
| file_handler = logging.FileHandler(log_file, encoding='utf-8') |
| file_handler.setLevel(logging.INFO) |
| |
| |
| console_handler = logging.StreamHandler() |
| console_handler.setLevel(logging.INFO) |
| |
| |
| formatter = logging.Formatter( |
| '%(asctime)s | %(levelname)s | %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| file_handler.setFormatter(formatter) |
| console_handler.setFormatter(formatter) |
| |
| logger.addHandler(file_handler) |
| logger.addHandler(console_handler) |
| |
| loggers[log_type] = logger |
| |
| return loggers |
| |
| def set_tokenizer(self, tokenizer): |
| """设置tokenizer用于Token解码""" |
| self.tokenizer = tokenizer |
| self.loggers["labels"].info(f"🔤 Tokenizer已设置: {type(tokenizer).__name__}") |
| |
| def _decode_tokens(self, token_ids): |
| """解码Token ID为文本""" |
| if self.tokenizer is None: |
| return [f"<token_{tid}>" for tid in token_ids] |
| |
| try: |
| |
| decoded_text = self.tokenizer.decode(token_ids, skip_special_tokens=False) |
| return decoded_text |
| except Exception as e: |
| return [f"<decode_error_{tid}>" for tid in token_ids] |
| |
| def analyze_model_predictions(self, model, inputs, labels): |
| """分析模型的预测输出""" |
| import torch |
| import numpy as np |
| |
| |
| if self.step_count % self.log_interval == 0: |
| self.loggers["labels"].info(f"\n🔮 模型预测分析 - 步骤 {self.step_count}") |
| self.loggers["labels"].info(f"{'='*80}") |
| |
| try: |
| |
| if isinstance(inputs, dict): |
| input_ids = inputs.get('input_ids') |
| else: |
| input_ids = inputs |
| |
| |
| if hasattr(input_ids, 'input_ids'): |
| actual_input_ids = input_ids.input_ids |
| else: |
| actual_input_ids = input_ids |
| |
| if actual_input_ids is not None and labels is not None: |
| |
| model.eval() |
| |
| with torch.no_grad(): |
| |
| outputs = model(actual_input_ids) |
| logits = outputs.logits |
| |
| |
| predicted_token_ids = torch.argmax(logits, dim=-1) |
| |
| |
| batch_size = len(actual_input_ids) |
| |
| for i in range(min(batch_size, 1)): |
| try: |
| |
| if isinstance(actual_input_ids, torch.Tensor): |
| sample_input = actual_input_ids[i].cpu().numpy() |
| else: |
| sample_input = actual_input_ids[i] |
| |
| if isinstance(labels, torch.Tensor): |
| sample_labels = labels[i].cpu().numpy() |
| else: |
| sample_labels = labels[i] |
| |
| if isinstance(predicted_token_ids, torch.Tensor): |
| sample_predictions = predicted_token_ids[i].cpu().numpy() |
| else: |
| sample_predictions = predicted_token_ids[i] |
| |
| |
| trainable_mask = sample_labels != -100 |
| trainable_positions = np.where(trainable_mask)[0] |
| |
| if len(trainable_positions) > 0: |
| |
| predicted_tokens = sample_predictions[trainable_mask] |
| target_tokens = sample_labels[trainable_mask] |
| |
| |
| if self.step_count % self.log_interval == 0: |
| self.loggers["labels"].info(f"\n🎯 样本 {i+1} 预测分析:") |
| self.loggers["labels"].info(f" 📏 可训练位置数: {len(trainable_positions)}") |
| |
| |
| self.loggers["labels"].info(f" 🔮 模型预测Token ID: {predicted_tokens.tolist()}") |
| self.loggers["labels"].info(f" 🎯 目标Token ID: {target_tokens.tolist()}") |
| |
| |
| predicted_text = self._decode_tokens(predicted_tokens.tolist()) |
| target_text = self._decode_tokens(target_tokens.tolist()) |
| |
| self.loggers["labels"].info(f" 🔮 模型预测文本: {predicted_text}") |
| self.loggers["labels"].info(f" 🎯 目标文本: {target_text}") |
| |
| |
| self._analyze_prediction_text_changes(predicted_text, i) |
| |
| |
| correct_predictions = np.sum(predicted_tokens == target_tokens) |
| accuracy = correct_predictions / len(target_tokens) * 100 |
| self.loggers["labels"].info(f" 📊 预测准确率: {accuracy:.2f}% ({correct_predictions}/{len(target_tokens)})") |
| |
| |
| self.loggers["labels"].info(f" 🔍 调试信息:") |
| self.loggers["labels"].info(f" 预测Token长度: {len(predicted_tokens)}") |
| self.loggers["labels"].info(f" 目标Token长度: {len(target_tokens)}") |
| self.loggers["labels"].info(f" 预测Token类型: {type(predicted_tokens)}") |
| self.loggers["labels"].info(f" 目标Token类型: {type(target_tokens)}") |
| |
| |
| if correct_predictions == 0: |
| self.loggers["labels"].info(f" ⚠️ 警告: 预测准确率为0%,可能的原因:") |
| self.loggers["labels"].info(f" 1. 模型在训练初期,还未学会正确预测") |
| self.loggers["labels"].info(f" 2. 预测位置可能不正确") |
| self.loggers["labels"].info(f" 3. Token对齐可能有问题") |
| |
| |
| self.loggers["labels"].info(f" 🔍 前5个Token详细对比:") |
| for j in range(min(5, len(predicted_tokens))): |
| pred_token = predicted_tokens[j] |
| target_token = target_tokens[j] |
| pred_text = self._decode_tokens([pred_token]) |
| target_text_single = self._decode_tokens([target_token]) |
| self.loggers["labels"].info(f" 位置{j}: 预测:{pred_token}({pred_text}) vs 目标:{target_token}({target_text_single})") |
| |
| |
| self.loggers["labels"].info(f" 🔍 前10个位置详细对比:") |
| for j in range(min(10, len(predicted_tokens))): |
| pred_token = predicted_tokens[j] |
| target_token = target_tokens[j] |
| pred_text = self._decode_tokens([pred_token]) |
| target_text_single = self._decode_tokens([target_token]) |
| match_symbol = "✅" if pred_token == target_token else "❌" |
| self.loggers["labels"].info(f" 位置{j}: {match_symbol} 预测:{pred_token}({pred_text}) vs 目标:{target_token}({target_text_single})") |
| |
| |
| ignore_mask = sample_labels == -100 |
| ignore_positions = np.where(ignore_mask)[0] |
| ignore_tokens = sample_input[ignore_mask] |
| |
| if len(ignore_positions) > 0: |
| self.loggers["labels"].info(f"\n🚫 忽略的Token分析 (-100部分):") |
| self.loggers["labels"].info(f" 📏 忽略位置数: {len(ignore_positions)}") |
| self.loggers["labels"].info(f" 📍 忽略位置: {ignore_positions.tolist()}") |
| self.loggers["labels"].info(f" 🔤 忽略Token ID: {ignore_tokens.tolist()}") |
| |
| |
| ignore_text = self._decode_tokens(ignore_tokens.tolist()) |
| self.loggers["labels"].info(f" 🔤 忽略Token文本: {ignore_text}") |
| |
| |
| self.loggers["labels"].info(f"\n💬 多轮对话结构分析:") |
| self.loggers["labels"].info(f" 📊 总长度: {len(sample_input)}") |
| self.loggers["labels"].info(f" 🎯 训练部分: {len(trainable_positions)} ({len(trainable_positions)/len(sample_input)*100:.1f}%)") |
| self.loggers["labels"].info(f" 🚫 忽略部分: {len(ignore_positions)} ({len(ignore_positions)/len(sample_input)*100:.1f}%)") |
| |
| |
| self._analyze_conversation_segments(sample_input, sample_labels, i) |
| |
| |
| if hasattr(self, 'previous_predictions') and i in self.previous_predictions: |
| prev_predictions = self.previous_predictions[i] |
| if len(prev_predictions) == len(predicted_tokens): |
| changes = np.sum(prev_predictions != predicted_tokens) |
| |
| if changes > 0: |
| change_positions = np.where(prev_predictions != predicted_tokens)[0] |
| self.loggers["labels"].info(f"\n🔄 步骤 {self.step_count} 预测变化:") |
| self.loggers["labels"].info(f" 📊 变化数量: {changes}/{len(predicted_tokens)}") |
| self.loggers["labels"].info(f" 📍 变化位置: {change_positions.tolist()}") |
| |
| |
| for pos in change_positions: |
| prev_token = prev_predictions[pos] |
| curr_token = predicted_tokens[pos] |
| target_token = target_tokens[pos] |
| prev_text = self._decode_tokens([prev_token]) |
| curr_text = self._decode_tokens([curr_token]) |
| target_text = self._decode_tokens([target_token]) |
| self.loggers["labels"].info(f" 位置{pos}: {prev_token}({prev_text}) -> {curr_token}({curr_text}) [目标: {target_token}({target_text})]") |
| |
| |
| if not hasattr(self, 'previous_predictions'): |
| self.previous_predictions = {} |
| self.previous_predictions[i] = predicted_tokens.copy() |
| |
| except Exception as e: |
| self.loggers["labels"].error(f"❌ 分析样本 {i} 预测失败: {e}") |
| |
| |
| model.train() |
| |
| except Exception as e: |
| self.loggers["labels"].error(f"❌ 预测分析失败: {e}") |
| |
| def _analyze_conversation_segments(self, sample_input, sample_labels, sample_idx): |
| """分析多轮对话的分段结构""" |
| import numpy as np |
| |
| try: |
| |
| trainable_mask = sample_labels != -100 |
| ignore_mask = sample_labels == -100 |
| |
| |
| segments = [] |
| current_segment = [] |
| current_type = None |
| |
| for i, (is_trainable, token_id) in enumerate(zip(trainable_mask, sample_input)): |
| segment_type = "trainable" if is_trainable else "ignore" |
| |
| if current_type != segment_type: |
| if current_segment: |
| segments.append({ |
| 'type': current_type, |
| 'start': current_segment[0]['pos'], |
| 'end': current_segment[-1]['pos'], |
| 'length': len(current_segment), |
| 'tokens': [item['token'] for item in current_segment] |
| }) |
| current_segment = [] |
| current_type = segment_type |
| |
| current_segment.append({ |
| 'pos': i, |
| 'token': token_id |
| }) |
| |
| |
| if current_segment: |
| segments.append({ |
| 'type': current_type, |
| 'start': current_segment[0]['pos'], |
| 'end': current_segment[-1]['pos'], |
| 'length': len(current_segment), |
| 'tokens': [item['token'] for item in current_segment] |
| }) |
| |
| |
| self.loggers["labels"].info(f" 📝 对话分段详情:") |
| for seg_idx, segment in enumerate(segments): |
| segment_text = self._decode_tokens(segment['tokens']) |
| segment_type_emoji = "🎯" if segment['type'] == 'trainable' else "🚫" |
| self.loggers["labels"].info(f" 分段{seg_idx+1}: {segment_type_emoji} {segment['type']} 位置{segment['start']}-{segment['end']} 长度{segment['length']}") |
| self.loggers["labels"].info(f" 文本: {segment_text}") |
| |
| |
| trainable_segments = [s for s in segments if s['type'] == 'trainable'] |
| ignore_segments = [s for s in segments if s['type'] == 'ignore'] |
| |
| self.loggers["labels"].info(f" 📊 对话模式分析:") |
| self.loggers["labels"].info(f" 训练分段数: {len(trainable_segments)}") |
| self.loggers["labels"].info(f" 忽略分段数: {len(ignore_segments)}") |
| |
| if len(trainable_segments) > 1: |
| self.loggers["labels"].info(f" 💬 多轮对话检测: 发现{len(trainable_segments)}个训练分段") |
| for i, seg in enumerate(trainable_segments): |
| self.loggers["labels"].info(f" 轮次{i+1}: 位置{seg['start']}-{seg['end']} 长度{seg['length']}") |
| |
| except Exception as e: |
| self.loggers["labels"].error(f"❌ 对话分段分析失败: {e}") |
| |
| def _analyze_prediction_text_changes(self, current_text, sample_idx): |
| """分析预测文本的变化""" |
| try: |
| if hasattr(self, 'previous_prediction_texts') and sample_idx in self.previous_prediction_texts: |
| previous_text = self.previous_prediction_texts[sample_idx] |
| |
| if previous_text != current_text: |
| self.loggers["labels"].info(f"\n📝 步骤 {self.step_count} 预测文本变化:") |
| self.loggers["labels"].info(f" 🔄 文本发生变化!") |
| |
| |
| similarity = self._calculate_text_similarity(previous_text, current_text) |
| self.loggers["labels"].info(f" 📊 文本相似度: {similarity:.2f}%") |
| |
| |
| self._show_text_differences(previous_text, current_text) |
| else: |
| self.loggers["labels"].info(f"\n📝 步骤 {self.step_count} 预测文本变化:") |
| self.loggers["labels"].info(f" ✅ 文本未发生变化") |
| |
| |
| if not hasattr(self, 'previous_prediction_texts'): |
| self.previous_prediction_texts = {} |
| self.previous_prediction_texts[sample_idx] = current_text |
| |
| except Exception as e: |
| self.loggers["labels"].error(f"❌ 预测文本变化分析失败: {e}") |
| |
| def _calculate_text_similarity(self, text1, text2): |
| """计算两个文本的相似度""" |
| try: |
| |
| if len(text1) == 0 and len(text2) == 0: |
| return 100.0 |
| if len(text1) == 0 or len(text2) == 0: |
| return 0.0 |
| |
| |
| from difflib import SequenceMatcher |
| similarity = SequenceMatcher(None, text1, text2).ratio() |
| return similarity * 100 |
| except Exception: |
| return 0.0 |
| |
| def _show_text_differences(self, old_text, new_text): |
| """显示文本差异""" |
| try: |
| from difflib import unified_diff |
| |
| self.loggers["labels"].info(f" 📋 文本变化详情:") |
| self.loggers["labels"].info(f" 之前: {old_text}") |
| self.loggers["labels"].info(f" 现在: {new_text}") |
| |
| |
| diff_lines = list(unified_diff( |
| old_text.splitlines(keepends=True), |
| new_text.splitlines(keepends=True), |
| fromfile='之前', |
| tofile='现在', |
| lineterm='' |
| )) |
| |
| if diff_lines: |
| self.loggers["labels"].info(f" 🔍 差异分析:") |
| for line in diff_lines[:10]: |
| if line.startswith('+'): |
| self.loggers["labels"].info(f" ➕ 新增: {line[1:].strip()}") |
| elif line.startswith('-'): |
| self.loggers["labels"].info(f" ➖ 删除: {line[1:].strip()}") |
| elif line.startswith('@@'): |
| self.loggers["labels"].info(f" 📍 {line.strip()}") |
| |
| |
| self._analyze_change_types(old_text, new_text) |
| |
| except Exception as e: |
| self.loggers["labels"].error(f"❌ 文本差异分析失败: {e}") |
| |
| def _analyze_change_types(self, old_text, new_text): |
| """分析变化类型""" |
| try: |
| changes = { |
| 'added_chars': 0, |
| 'removed_chars': 0, |
| 'modified_chars': 0 |
| } |
| |
| |
| if len(new_text) > len(old_text): |
| changes['added_chars'] = len(new_text) - len(old_text) |
| elif len(new_text) < len(old_text): |
| changes['removed_chars'] = len(old_text) - len(new_text) |
| |
| |
| min_len = min(len(old_text), len(new_text)) |
| for i in range(min_len): |
| if old_text[i] != new_text[i]: |
| changes['modified_chars'] += 1 |
| |
| self.loggers["labels"].info(f" 📈 变化统计:") |
| self.loggers["labels"].info(f" 新增字符: {changes['added_chars']}") |
| self.loggers["labels"].info(f" 删除字符: {changes['removed_chars']}") |
| self.loggers["labels"].info(f" 修改字符: {changes['modified_chars']}") |
| |
| except Exception as e: |
| self.loggers["labels"].error(f"❌ 变化类型分析失败: {e}") |
| |
| @override |
| def on_step_end(self, args, state, control, **kwargs): |
| """在每个训练步骤结束时调用""" |
| self.step_count += 1 |
| |
| if self.step_count % self.log_interval == 0: |
| self.loggers["labels"].info(f"\n{'='*80}") |
| self.loggers["labels"].info(f"🔄 训练步骤 {state.global_step} 监控") |
| self.loggers["labels"].info(f"{'='*80}") |
| |
| |
| if state.log_history: |
| latest_log = state.log_history[-1] |
| self.loggers["labels"].info(f"📈 当前Loss: {latest_log.get('loss', 'N/A')}") |
| |
| |
| from datetime import datetime |
| step_info = { |
| "step": state.global_step, |
| "timestamp": datetime.now().isoformat(), |
| "loss": state.log_history[-1].get('loss') if state.log_history else None, |
| "training_time": getattr(state, 'training_time', None) |
| } |
| |
| self.training_history.append(step_info) |
| |
| @override |
| def on_log(self, args, state, control, **kwargs): |
| """在日志记录时调用,可以获取到训练数据""" |
| if self.step_count % self.log_interval == 0: |
| |
| if hasattr(kwargs, 'logs') and kwargs['logs']: |
| logs = kwargs['logs'] |
| self.loggers["labels"].info(f"📊 训练日志: {logs}") |
| |
| def analyze_training_tokens(self, model, inputs, labels): |
| """分析训练过程中的token变化""" |
| import numpy as np |
| from datetime import datetime |
| |
| |
| if self.step_count % self.log_interval == 0: |
| self.loggers["labels"].info(f"\n🔍 训练Token分析 - 步骤 {self.step_count}") |
| self.loggers["labels"].info(f"{'='*80}") |
| |
| |
| if isinstance(inputs, dict): |
| input_ids = inputs.get('input_ids') |
| attention_mask = inputs.get('attention_mask') |
| else: |
| input_ids = inputs |
| attention_mask = None |
| |
| |
| if hasattr(input_ids, 'input_ids'): |
| |
| actual_input_ids = input_ids.input_ids |
| else: |
| actual_input_ids = input_ids |
| |
| |
| if self.step_count % self.log_interval == 0: |
| self.loggers["labels"].info(f"🔍 详细调试信息:") |
| self.loggers["labels"].info(f" inputs类型: {type(inputs)}") |
| if isinstance(inputs, dict): |
| self.loggers["labels"].info(f" inputs键: {list(inputs.keys())}") |
| for key, value in inputs.items(): |
| self.loggers["labels"].info(f" {key}: 类型={type(value)}, 形状={getattr(value, 'shape', 'N/A')}") |
| |
| self.loggers["labels"].info(f" input_ids类型: {type(input_ids)}") |
| self.loggers["labels"].info(f" actual_input_ids类型: {type(actual_input_ids)}") |
| self.loggers["labels"].info(f" labels类型: {type(labels)}") |
| |
| if actual_input_ids is not None: |
| self.loggers["labels"].info(f" actual_input_ids详细信息:") |
| self.loggers["labels"].info(f" 类型: {type(actual_input_ids)}") |
| self.loggers["labels"].info(f" 形状: {getattr(actual_input_ids, 'shape', 'N/A')}") |
| self.loggers["labels"].info(f" 设备: {getattr(actual_input_ids, 'device', 'N/A')}") |
| self.loggers["labels"].info(f" 数据类型: {getattr(actual_input_ids, 'dtype', 'N/A')}") |
| |
| if labels is not None: |
| self.loggers["labels"].info(f" labels详细信息:") |
| self.loggers["labels"].info(f" 类型: {type(labels)}") |
| self.loggers["labels"].info(f" 形状: {getattr(labels, 'shape', 'N/A')}") |
| self.loggers["labels"].info(f" 设备: {getattr(labels, 'device', 'N/A')}") |
| self.loggers["labels"].info(f" 数据类型: {getattr(labels, 'dtype', 'N/A')}") |
| |
| |
| self.loggers["labels"].info(f"🔍 尝试访问方式:") |
| try: |
| if actual_input_ids is not None: |
| self.loggers["labels"].info(f" actual_input_ids[0] 类型: {type(actual_input_ids[0])}") |
| if hasattr(actual_input_ids[0], 'shape'): |
| self.loggers["labels"].info(f" actual_input_ids[0] 形状: {actual_input_ids[0].shape}") |
| except Exception as e: |
| self.loggers["labels"].error(f" ❌ actual_input_ids[0] 访问失败: {e}") |
| |
| try: |
| if labels is not None: |
| self.loggers["labels"].info(f" labels[0] 类型: {type(labels[0])}") |
| if hasattr(labels[0], 'shape'): |
| self.loggers["labels"].info(f" labels[0] 形状: {labels[0].shape}") |
| except Exception as e: |
| self.loggers["labels"].error(f" ❌ labels[0] 访问失败: {e}") |
| |
| if actual_input_ids is not None and labels is not None: |
| |
| batch_size = len(actual_input_ids) |
| |
| for i in range(min(batch_size, 1)): |
| try: |
| |
| if isinstance(actual_input_ids, torch.Tensor): |
| sample_input = actual_input_ids[i].cpu().numpy() |
| else: |
| sample_input = actual_input_ids[i] |
| |
| if isinstance(labels, torch.Tensor): |
| sample_labels = labels[i].cpu().numpy() |
| else: |
| sample_labels = labels[i] |
| except Exception as e: |
| self.loggers["labels"].error(f"❌ 访问样本 {i} 失败: {e}") |
| continue |
| |
| |
| trainable_mask = sample_labels != -100 |
| trainable_positions = np.where(trainable_mask)[0] |
| trainable_tokens = sample_labels[trainable_mask] |
| |
| |
| if self.step_count % self.log_interval == 0: |
| self.loggers["labels"].info(f"\n📝 样本 {i+1} 训练Token详情:") |
| self.loggers["labels"].info(f" 📏 总长度: {len(sample_input)}") |
| self.loggers["labels"].info(f" 🎯 可训练长度: {len(trainable_positions)}") |
| self.loggers["labels"].info(f" 📍 可训练位置: {trainable_positions.tolist()}") |
| self.loggers["labels"].info(f" 🔤 要训练的Token ID: {trainable_tokens.tolist()}") |
| |
| |
| trainable_text = self._decode_tokens(trainable_tokens.tolist()) |
| self.loggers["labels"].info(f" 🔤 要训练的Token文本: {trainable_text}") |
| |
| |
| trainable_input_tokens = sample_input[trainable_mask] |
| self.loggers["labels"].info(f" 📥 对应的Input Token ID: {trainable_input_tokens.tolist()}") |
| |
| |
| input_text = self._decode_tokens(trainable_input_tokens.tolist()) |
| self.loggers["labels"].info(f" 📥 对应的Input Token文本: {input_text}") |
| |
| |
| if len(sample_input) > 200: |
| input_preview = sample_input[:100].tolist() + ["..."] + sample_input[-100:].tolist() |
| labels_preview = sample_labels[:100].tolist() + ["..."] + sample_labels[-100:].tolist() |
| else: |
| input_preview = sample_input.tolist() |
| labels_preview = sample_labels.tolist() |
| |
| self.loggers["labels"].info(f" 📋 完整Input IDs (预览): {input_preview}") |
| self.loggers["labels"].info(f" 🏷️ 完整Labels (预览): {labels_preview}") |
| |
| |
| if hasattr(self, 'previous_tokens') and i in self.previous_tokens: |
| prev_tokens = self.previous_tokens[i] |
| if len(prev_tokens) == len(trainable_tokens): |
| changes = np.sum(prev_tokens != trainable_tokens) |
| |
| if changes > 0: |
| change_positions = np.where(prev_tokens != trainable_tokens)[0] |
| self.loggers["labels"].info(f"\n🔄 步骤 {self.step_count} Token变化:") |
| self.loggers["labels"].info(f" 📊 变化数量: {changes}/{len(trainable_tokens)}") |
| self.loggers["labels"].info(f" 📍 变化位置: {change_positions.tolist()}") |
| |
| |
| for pos in change_positions: |
| prev_token = prev_tokens[pos] |
| curr_token = trainable_tokens[pos] |
| prev_text = self._decode_tokens([prev_token]) |
| curr_text = self._decode_tokens([curr_token]) |
| self.loggers["labels"].info(f" 位置{pos}: {prev_token}({prev_text}) -> {curr_token}({curr_text})") |
| |
| |
| if not hasattr(self, 'previous_tokens'): |
| self.previous_tokens = {} |
| self.previous_tokens[i] = trainable_tokens.copy() |
| |
| @override |
| def on_evaluate(self, args, state, control, **kwargs): |
| """在评估时调用""" |
| self.loggers["predictions"].info(f"\n{'='*80}") |
| self.loggers["predictions"].info(f"📊 评估阶段监控 - 步骤 {state.global_step}") |
| self.loggers["predictions"].info(f"{'='*80}") |
| |
| |
| if hasattr(kwargs, 'predict_results') and kwargs['predict_results'] is not None: |
| self._analyze_predictions(kwargs['predict_results'], state.global_step) |
| |
| @override |
| def on_predict(self, args, state, control, **kwargs): |
| """在预测时调用""" |
| self.loggers["predictions"].info(f"\n{'='*80}") |
| self.loggers["predictions"].info(f"🔮 预测阶段监控 - 步骤 {state.global_step}") |
| self.loggers["predictions"].info(f"{'='*80}") |
| |
| |
| predict_results = kwargs.get('predict_results') |
| if predict_results is not None: |
| self._analyze_predictions(predict_results, state.global_step) |
| |
| def _analyze_predictions(self, predict_results, step: int): |
| """分析预测结果""" |
| import numpy as np |
| from datetime import datetime |
| |
| self.loggers["predictions"].info(f"📊 预测结果分析 - 步骤 {step}") |
| |
| |
| predictions = predict_results.predictions |
| labels = predict_results.label_ids |
| |
| if predictions is None or labels is None: |
| self.loggers["predictions"].warning("⚠️ 预测结果或标签为空") |
| return |
| |
| |
| if isinstance(predictions, torch.Tensor): |
| predictions = predictions.cpu().numpy() |
| if isinstance(labels, torch.Tensor): |
| labels = labels.cpu().numpy() |
| |
| |
| batch_size = len(predictions) |
| self.loggers["predictions"].info(f"📦 批次大小: {batch_size}") |
| |
| for i in range(min(batch_size, 3)): |
| self.loggers["predictions"].info(f"\n🔍 样本 {i+1} 分析:") |
| |
| pred_sample = predictions[i] |
| label_sample = labels[i] |
| |
| |
| pred_sample = self._remove_padding(pred_sample) |
| label_sample = self._remove_padding(label_sample) |
| |
| |
| self.loggers["labels"].info(f"\n📝 样本 {i+1} 标签分析:") |
| self.loggers["labels"].info(f" 🎯 标签长度: {len(label_sample)}") |
| self.loggers["labels"].info(f" 🔮 预测长度: {len(pred_sample)}") |
| |
| |
| alignment_analysis = self._analyze_alignment(pred_sample, label_sample) |
| self.loggers["alignment"].info(f"\n🎯 样本 {i+1} 对齐分析:") |
| self.loggers["alignment"].info(f" 📏 长度差异: {alignment_analysis['length_difference']}") |
| self.loggers["alignment"].info(f" 🎯 精确匹配: {alignment_analysis['exact_match_percentage']:.1f}%") |
| self.loggers["alignment"].info(f" ✅ 有效匹配: {alignment_analysis['valid_match_percentage']:.1f}%") |
| |
| |
| analysis = { |
| "step": step, |
| "sample_idx": i, |
| "timestamp": datetime.now().isoformat(), |
| "predictions": pred_sample.tolist(), |
| "labels": label_sample.tolist(), |
| "alignment_analysis": alignment_analysis |
| } |
| |
| self.prediction_history.append(analysis) |
| |
| def _remove_padding(self, tokens, pad_token_id: int = -100): |
| """移除padding tokens""" |
| import numpy as np |
| |
| non_pad_mask = tokens != pad_token_id |
| if np.any(non_pad_mask): |
| |
| first_non_pad = np.argmax(non_pad_mask) |
| last_non_pad = len(tokens) - 1 - np.argmax(non_pad_mask[::-1]) |
| return tokens[first_non_pad:last_non_pad+1] |
| else: |
| return tokens |
| |
| def _analyze_alignment(self, predictions, labels): |
| """分析预测和标签的对齐情况""" |
| |
| min_len = min(len(predictions), len(labels)) |
| max_len = max(len(predictions), len(labels)) |
| |
| |
| exact_matches = 0 |
| valid_matches = 0 |
| |
| for i in range(min_len): |
| if predictions[i] == labels[i]: |
| exact_matches += 1 |
| if predictions[i] != -100 and labels[i] != -100: |
| valid_matches += 1 |
| |
| |
| exact_match_percentage = (exact_matches / min_len * 100) if min_len > 0 else 0 |
| valid_match_percentage = (exact_matches / valid_matches * 100) if valid_matches > 0 else 0 |
| |
| return { |
| "min_length": min_len, |
| "max_length": max_len, |
| "exact_matches": exact_matches, |
| "valid_matches": valid_matches, |
| "exact_match_percentage": exact_match_percentage, |
| "valid_match_percentage": valid_match_percentage, |
| "length_difference": abs(len(predictions) - len(labels)) |
| } |
| |
| @override |
| def on_train_end(self, args, state, control, **kwargs): |
| """训练结束时调用""" |
| from datetime import datetime |
| |
| self.loggers["labels"].info(f"\n{'='*80}") |
| self.loggers["labels"].info(f"🏁 训练结束监控") |
| self.loggers["labels"].info(f"{'='*80}") |
| |
| |
| summary_data = { |
| "training_info": { |
| "total_steps": len(self.training_history), |
| "total_predictions": len(self.prediction_history), |
| "completion_time": datetime.now().isoformat(), |
| "output_dir": self.output_dir |
| }, |
| "training_history": self.training_history, |
| "prediction_history": self.prediction_history, |
| "log_files": self.log_files |
| } |
| |
| with open(self.log_files["summary"], "w", encoding="utf-8") as f: |
| json.dump(summary_data, f, ensure_ascii=False, indent=2) |
| |
| self.loggers["labels"].info(f"📊 训练摘要已保存: {self.log_files['summary']}") |
| self.loggers["labels"].info(f"📝 标签分析日志: {self.log_files['labels']}") |
| self.loggers["labels"].info(f"🔮 预测监控日志: {self.log_files['predictions']}") |
| self.loggers["labels"].info(f"🎯 对齐分析日志: {self.log_files['alignment']}") |