#!/usr/bin/env python # -*- coding:utf-8 _*- import importlib import json from transformers import Trainer, TrainingArguments from transformers.trainer_utils import EvalPrediction from transformers.trainer_callback import TrainerCallback import os import torch from models.TimeLanguageModel import TLM, TLMConfig from dataset.dataset import DataCollator from typing import Dict, List, Any, NamedTuple, Optional, Tuple, Union from datasets import load_metric import numpy as np from utils.metrics import open_question_metrics,closed_question_metrics,compute_rul import warnings from tqdm import tqdm import pickle from torch import nn import pandas as pd import matplotlib.pyplot as plt warnings.filterwarnings("ignore") from accelerate import Accelerator accelerator = Accelerator() import torch.distributed as dist from datetime import datetime from contextlib import nullcontext def distributed_tqdm(iterable, desc=None): if not dist.is_initialized() or dist.get_rank() == 0: return tqdm(iterable, desc=desc) else: return iterable class OutputWrapper: def __init__(self, original_output): self.original_output = original_output def __getattr__(self, name): # 如果属性不存在于自身,则尝试从原始对象中获取 return getattr(self.original_output, name) class EvalLoopOutput(NamedTuple): predictions: Union[np.ndarray, Tuple[np.ndarray]] label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] metrics: Optional[Dict[str, float]] num_samples: Optional[int] pred_extra: Optional[Dict[str, Any]] = None class Exp_Instruct(Trainer): def __init__(self, args, train_dataset, tlm_config=None, eval_dataset=None): # Build the model self.tlmconfig = tlm_config model = self._build_model(args) use_bf16 = bool(getattr(args, "bf16", False)) and torch.cuda.is_available() and torch.cuda.is_bf16_supported() use_fp16 = bool(args.fp16) and not use_bf16 # Define training arguments training_args = TrainingArguments( output_dir=args.output_dir, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, dataloader_num_workers = args.dataloader_num_workers, lr_scheduler_type="cosine", warmup_ratio=0.1, per_device_eval_batch_size=args.per_device_eval_batch_size, learning_rate=args.learning_rate, weight_decay=args.weight_decay, logging_dir=args.output_dir, logging_steps=args.logging_steps, save_steps=args.save_steps, eval_strategy='no', eval_steps=args.eval_steps, save_total_limit=args.save_total_limit, ddp_find_unused_parameters=False, fp16=use_fp16, bf16=use_bf16, num_train_epochs=args.num_train_epochs, report_to=args.report_to, # Example: Integrate TensorBoard prediction_loss_only=False, max_grad_norm=float(getattr(args, "max_grad_norm", 0.1)), remove_unused_columns=False, disable_tqdm=False, dataloader_drop_last=True) super().__init__( model=model, args=training_args, train_dataset=train_dataset, data_collator=DataCollator(tokenizer=train_dataset.tokenizer), eval_dataset=eval_dataset, # compute_metrics=self._compute_metrics if eval_dataset else None, ) self.compute_metrics = self.custom_compute_metrics if eval_dataset else None self.special_id = train_dataset.processor.all_special_ids self.processor = train_dataset.processor self.padding_idx = self.processor.pad_token_id # 常用标点符号列表 common_punctuations = [".", ",", ":", ";", "!", "?", "(", ")", "[", "]", "{", "}", "-", "_", "\"", "'"] punctuation_ids = self.processor.convert_tokens_to_ids(common_punctuations) # 将标点符号 ID 合并到特殊标记 ID 列表中 self.special_id.extend(punctuation_ids) self.tlmargs = args # 定义stage权重 self.stage_weights = { 1: 1.0, # 开放式问题 - 基础权重 2: 1.0, # 封闭式问题 - 稍高权重 3: 1.0, # 封闭式问题 - 中等权重 4: 1.0 # 开放式问题 - 稍低权重 } # 初始化损失函数,不使用ignore_index(我们将手动处理) self.base_loss_fn = nn.CrossEntropyLoss(reduction='none', ignore_index=self.padding_idx) # self.args.remove_unused_columns = True # 添加这一行 def load_model(self, checkpoint_path): self.model = TLM.from_pretrained(checkpoint_path, config=self.tlmconfig, ts_config=self.tlmargs).cuda() def _build_model(self, args): """Load the model dynamically based on the configuration.""" # self.tlmconfig = TLMConfig(llm_model_path = args.llm_model_path) model = TLM(self.tlmconfig, ts_config=args).cuda() # monitor = GradientAndActivationMonitor(model,track_outputs=False,verbose=True) return model def concat_np_array(self, array_list,num_samples): """ 对传入的列表进行 Concat 操作。 Args: array_list (List[List[int]]): 每个子列表为需要 Padding 的序列。 num_samples (int): 样本数量。 Returns: np.ndarray: Padding 后的二维数组。 """ # 获取最大长度 max_length = max(arr.shape[-1] for arr in array_list) # 初始化 Padding 后的数组,填充为 padding_idx padded_array = np.full((num_samples, max_length), self.padding_idx, dtype=np.int32) # 填充每个序列 for i, arr in enumerate(array_list): padded_array[:arr.shape[0], :arr.shape[1]] = arr concat_array = np.stack(padded_array, axis=0) return concat_array def debug_generate(self, input_ids, query_ids,ts_values, stage, attention_mask): # 生成阶段 import time start_time = time.time() with torch.no_grad(): output = self.model.generate( input_ids=input_ids, query_ids=query_ids, ts_values=ts_values, stage=stage, past_key_values=None, max_new_tokens=128, do_sample=False, eos_token_id=self.processor.eos_token_id, pad_token_id=self.processor.pad_token_id, attention_mask=attention_mask, use_cache=True, # 新增优化参数 num_beams=1, # 贪婪搜索,最快 temperature=1.0, # 避免额外计算 top_p=None, # 关闭nucleus sampling top_k=None, # 关闭top-k sampling repetition_penalty=1.0, # 关闭重复惩罚 length_penalty=1.0, # 关闭长度惩罚 no_repeat_ngram_size=0, # 关闭n-gram重复检查 output_scores=False, # 不输出分数 output_attentions=False, # 不输出attention output_hidden_states=False, # 不输出隐藏状态 return_dict_in_generate=False, # 简化返回格式 ) return output def generate( self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval", ): all_predictions = [] all_labels = [] all_losses = [] all_index = [] model = self._wrap_model(self.model, training=False) model.eval() sample_num = len(dataloader.dataset) # forms = [] stages = [] with torch.no_grad(): for step, inputs in enumerate(distributed_tqdm(dataloader, desc=description)): # if step==50: # break input_ids = inputs['input_ids'] ts_values = inputs['ts_values'] stage = inputs['stage'] index = inputs['index'] query_ids = inputs['query_ids'] attention_mask =inputs['attention_mask'] generated_ids = self.debug_generate(input_ids, query_ids,ts_values, stage, attention_mask) prediction = generated_ids.cpu().numpy() all_predictions.extend(prediction) all_labels.extend(inputs["labels"].cpu().numpy()) # forms.extend(inputs['form']) stages.extend(inputs['stage'].tolist()) all_index.extend(inputs['index'].tolist()) filtered_preds, filtered_labels = [], [] str_predictions = self.processor.batch_decode(all_predictions,skip_special_tokens=True) str_labels = self.processor.batch_decode(all_labels,skip_special_tokens=True) #取出assistant\n后的内容 str_predictions = [pred.split('assistant\n')[-1] for pred in str_predictions] output_data = { "predictions": str_predictions, "labels": str_labels, "stages": stages, "index": all_index, "num_samples": sample_num } if accelerator.is_main_process: with open('output_result_all.json', 'w', encoding='utf-8') as f: json.dump(output_data, f, indent=4, ensure_ascii=False) pred_extra = {'stages': stages} avg_loss = np.mean(all_losses) if all_losses else None return EvalLoopOutput(predictions=str_predictions, label_ids=str_labels, metrics=avg_loss, num_samples=sample_num,pred_extra=pred_extra) #写一个过滤str_predictions和str_labels的函数 def evaluate( self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval", ): eval_dataset = eval_dataset or self.eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self.generate( eval_dataloader, 'eval' ) metrics = self.custom_compute_metrics(output) if accelerator.is_main_process: # 打印到控制台 print(metrics) # 生成时间戳 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') filename = f'metrics_eval_{timestamp}.txt' # 同时写入文件 with open(filename, 'w', encoding='utf-8') as f: print(metrics, file=f) def custom_compute_metrics(self,eval_pred: EvalLoopOutput) -> Dict[str, Any]: """ 针对 stages 为 1 或 2 的样本,计算 BLEU 和 ROUGE 指标。 Args: eval_pred (EvalPrediction): 包含 predictions 和 labels,以及附加信息 pred_extra。 Returns: Dict[str, Any]: BLEU 和 ROUGE 指标结果字典。 """ # 解析预测和标签 labels = eval_pred.label_ids stages = eval_pred.pred_extra['stages'] # 解析附加信息 # 筛选 stages 为 1 stage1_indices = [i for i, stage in enumerate(stages) if stage in [1]] if len(stage1_indices) >=1: # 提取对应的预测和标签 stage1_labels = [labels[i] for i in stage1_indices] stage1_metrics = open_question_metrics([eval_pred.predictions[i] for i in stage1_indices], stage1_labels,self.special_id) #筛选出stage为2的样本 stage2_indices = [i for i, stage in enumerate(stages) if stage in [2]] if len(stage2_indices) >=1: # 提取对应的预测和标签 stage2_labels = [labels[i] for i in stage2_indices] stage2_predictions = [eval_pred.predictions[i] for i in stage2_indices] stage2_metrics = closed_question_metrics( stage2_predictions, stage2_labels,self.special_id) #筛选出stage为3的样本 stage3_indices = [i for i, stage in enumerate(stages) if stage in [3]] if len(stage3_indices)>=1 : # 提取对应的预测和标签 stage3_labels = [labels[i] for i in stage3_indices] stage3_predictions = [eval_pred.predictions[i] for i in stage3_indices] stage3_metrics = closed_question_metrics( stage3_predictions, stage3_labels,self.special_id) #筛选出stage为4的样本 stage4_indices = [i for i, stage in enumerate(stages) if stage in [4]] if len(stage4_indices) >=1: # 提取对应的预测和标签 stage4_labels = [labels[i] for i in stage4_indices] stage4_metrics = open_question_metrics([eval_pred.predictions[i] for i in stage4_indices], stage4_labels,self.special_id) #合并存在的指标 metrics = {} if stage1_indices: metrics.update({f"stage1_{k}": v for k, v in stage1_metrics.items()}) if stage2_indices: metrics.update({f"stage2_{k}": v for k, v in stage2_metrics.items()}) if stage3_indices: metrics.update({f"stage3_{k}": v for k, v in stage3_metrics.items()}) if stage4_indices: metrics.update({f"stage4_{k}": v for k, v in stage4_metrics.items()}) return metrics def compute_stage_weighted_loss(self, logits, labels, stages, attention_mask=None): """ 修正版本 - 不需要shift,因为Dataset已经处理了 """ batch_size, seq_len, vocab_size = logits.shape # 🔧 不需要shift,直接使用 flat_logits = logits.view(-1, vocab_size) # [batch_size * seq_len, vocab_size] flat_labels = labels.view(-1) # [batch_size * seq_len] # 计算基础损失(padding会被自动ignore) token_losses = self.base_loss_fn(flat_logits, flat_labels) # [batch_size * seq_len] token_losses = token_losses.view(batch_size, seq_len) # [batch_size, seq_len] # 创建有效token掩码 valid_mask = (labels != self.padding_idx).float() # [batch_size, seq_len] # 应用stage权重 stage_weights = torch.tensor([self.stage_weights.get(stage.item(), 1.0) for stage in stages], device=logits.device, dtype=torch.float32) # 计算每个样本的加权损失 sample_losses = [] for i in range(batch_size): valid_tokens = valid_mask[i].sum() # 有效token数量 if valid_tokens > 0: # 🔧 只对有效token计算平均损失 sample_loss = (token_losses[i] * valid_mask[i]).sum() / valid_tokens * stage_weights[i] else: sample_loss = torch.tensor(0.0, device=logits.device) sample_losses.append(sample_loss) return torch.stack(sample_losses).mean() def compute_loss(self, model, inputs, return_outputs=False, **kwargs): """ 内存优化版本的损失计算 """ if self.args.bf16: autocast_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) elif self.args.fp16: autocast_context = torch.autocast(device_type="cuda", dtype=torch.float16) else: autocast_context = nullcontext() with autocast_context: # 前向传播 outputs = model( input_ids=inputs.get('input_ids'), query_ids=inputs.get('query_ids'), ts_values=inputs.get('ts_values'), stage=inputs.get('stage'), attention_mask=inputs.get('attention_mask'), labels=inputs.get('labels') ) # 获取logits logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0] # 计算损失 loss = self.compute_stage_weighted_loss( logits=logits, labels=inputs.get('labels'), stages=inputs.get('stage'), attention_mask=inputs.get('attention_mask') ) if not torch.isfinite(loss): labels = inputs.get('labels') valid_tokens = int((labels != self.padding_idx).sum().item()) if labels is not None else -1 stage_values = inputs.get('stage').detach().cpu().tolist() if inputs.get('stage') is not None else [] raise RuntimeError( f"Non-finite SFT loss detected: loss={loss.detach().item()}, " f"valid_label_tokens={valid_tokens}, stages={stage_values}" ) if return_outputs: # 清理不必要的输出以节省内存 if hasattr(outputs, 'past_key_values'): outputs.past_key_values = None if hasattr(outputs, 'hidden_states'): outputs.hidden_states = None if hasattr(outputs, 'attentions'): outputs.attentions = None wrapped_outputs = OutputWrapper(outputs) wrapped_outputs.loss = loss return loss, wrapped_outputs return loss def get_stage_loss_statistics(self, dataloader, num_samples=100): """ 分析不同stage的损失分布,用于调整权重 Args: dataloader: 数据加载器 num_samples: 分析的样本数量 Returns: Dict: 包含各stage损失统计信息的字典 """ self.model.eval() stage_losses = {1: [], 2: [], 3: [], 4: []} with torch.no_grad(): for i, inputs in enumerate(dataloader): if i >= num_samples: break # 移动到正确的设备 for key in inputs: if isinstance(inputs[key], torch.Tensor): inputs[key] = inputs[key].to(self.model.device) # 前向传播 outputs = self.model(**inputs) logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0] # 计算每个样本的损失 labels = inputs['labels'] stages = inputs['stage'] attention_mask = inputs.get('attention_mask') batch_size, seq_len, vocab_size = logits.shape flat_logits = logits.view(-1, vocab_size) flat_labels = labels.view(-1) token_losses = self.base_loss_fn(flat_logits, flat_labels) token_losses = token_losses.view(batch_size, seq_len) if attention_mask is not None: valid_mask = attention_mask.bool() else: valid_mask = (labels != self.padding_idx) masked_losses = token_losses * valid_mask.float() valid_token_counts = valid_mask.sum(dim=1).float() valid_token_counts = torch.clamp(valid_token_counts, min=1.0) sample_losses = masked_losses.sum(dim=1) / valid_token_counts # 按stage收集损失 for j, stage in enumerate(stages): stage_val = stage.item() if stage_val in stage_losses: stage_losses[stage_val].append(sample_losses[j].item()) # 计算统计信息 statistics = {} for stage, losses in stage_losses.items(): if losses: statistics[f'stage_{stage}'] = { 'mean': np.mean(losses), 'std': np.std(losses), 'count': len(losses), 'min': np.min(losses), 'max': np.max(losses) } return statistics