#!/usr/bin/env python3 """调试 dolci_200k_fixed 数据被 drop 的原因""" import sys sys.path.insert(0, '/shared_workspace_mfs/ximing/LLaMA-Factory/src') from llamafactory.data import get_dataset from llamafactory.hparams import get_train_args # 使用和训练时相同的配置 args_dict = { "model_name_or_path": "/shared_workspace_mfs/ximing/Qwen3-8B-Base", "trust_remote_code": True, "stage": "sft", "do_train": True, "finetuning_type": "full", "dataset": "dolci_200k_fixed", "template": "qwen3_nothink", "cutoff_len": 65536, "rope_scaling": "yarn", "max_samples": 200000, "overwrite_cache": True, "preprocessing_num_workers": 16, "output_dir": "/tmp/test_output", } print("="*80) print("调试 dolci_200k_fixed 数据加载") print("="*80) try: # 获取训练参数 print("\n正在加载参数...") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args_dict) print(f"\n✓ 参数加载成功") print(f" 数据集: {data_args.dataset}") print(f" 模板: {data_args.template}") print(f" cutoff_len: {data_args.cutoff_len}") print(f" max_samples: {data_args.max_samples}") # 加载数据集 print(f"\n正在加载数据集...") dataset_module = get_dataset(model_args, data_args) # 打印数据集信息 print(f"\n✓ 数据集加载成功!") if 'train_dataset' in dataset_module: train_dataset = dataset_module['train_dataset'] print(f" 训练样本数: {len(train_dataset)}") # 检查样本长度分布 print(f"\n正在分析样本长度分布...") lengths = [] for i in range(min(1000, len(train_dataset))): sample = train_dataset[i] if 'input_ids' in sample: lengths.append(len(sample['input_ids'])) if lengths: lengths.sort() print(f" 前 1000 个样本长度统计:") print(f" 最小长度: {min(lengths)}") print(f" 最大长度: {max(lengths)}") print(f" 平均长度: {sum(lengths)/len(lengths):.2f}") print(f" 中位数: {lengths[len(lengths)//2]}") print(f" 超过 cutoff_len (65536) 的样本数: {sum(1 for l in lengths if l > 65536)}") # 打印第一个样本的详细信息 if len(train_dataset) > 0: print(f"\n第一个样本的详细信息:") sample = train_dataset[0] print(f" Keys: {list(sample.keys())}") if 'input_ids' in sample: print(f" input_ids 长度: {len(sample['input_ids'])}") if 'labels' in sample: print(f" labels 长度: {len(sample['labels'])}") else: print(" 没有找到 train_dataset") except Exception as e: print(f"\n✗ 错误:") print(f"{type(e).__name__}: {e}") import traceback traceback.print_exc()