File size: 2,952 Bytes
db704cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
#!/usr/bin/env python3
"""调试 dolci_200k_fixed 数据被 drop 的原因"""
import sys
sys.path.insert(0, '/shared_workspace_mfs/ximing/LLaMA-Factory/src')
from llamafactory.data import get_dataset
from llamafactory.hparams import get_train_args
# 使用和训练时相同的配置
args_dict = {
"model_name_or_path": "/shared_workspace_mfs/ximing/Qwen3-8B-Base",
"trust_remote_code": True,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
"dataset": "dolci_200k_fixed",
"template": "qwen3_nothink",
"cutoff_len": 65536,
"rope_scaling": "yarn",
"max_samples": 200000,
"overwrite_cache": True,
"preprocessing_num_workers": 16,
"output_dir": "/tmp/test_output",
}
print("="*80)
print("调试 dolci_200k_fixed 数据加载")
print("="*80)
try:
# 获取训练参数
print("\n正在加载参数...")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args_dict)
print(f"\n✓ 参数加载成功")
print(f" 数据集: {data_args.dataset}")
print(f" 模板: {data_args.template}")
print(f" cutoff_len: {data_args.cutoff_len}")
print(f" max_samples: {data_args.max_samples}")
# 加载数据集
print(f"\n正在加载数据集...")
dataset_module = get_dataset(model_args, data_args)
# 打印数据集信息
print(f"\n✓ 数据集加载成功!")
if 'train_dataset' in dataset_module:
train_dataset = dataset_module['train_dataset']
print(f" 训练样本数: {len(train_dataset)}")
# 检查样本长度分布
print(f"\n正在分析样本长度分布...")
lengths = []
for i in range(min(1000, len(train_dataset))):
sample = train_dataset[i]
if 'input_ids' in sample:
lengths.append(len(sample['input_ids']))
if lengths:
lengths.sort()
print(f" 前 1000 个样本长度统计:")
print(f" 最小长度: {min(lengths)}")
print(f" 最大长度: {max(lengths)}")
print(f" 平均长度: {sum(lengths)/len(lengths):.2f}")
print(f" 中位数: {lengths[len(lengths)//2]}")
print(f" 超过 cutoff_len (65536) 的样本数: {sum(1 for l in lengths if l > 65536)}")
# 打印第一个样本的详细信息
if len(train_dataset) > 0:
print(f"\n第一个样本的详细信息:")
sample = train_dataset[0]
print(f" Keys: {list(sample.keys())}")
if 'input_ids' in sample:
print(f" input_ids 长度: {len(sample['input_ids'])}")
if 'labels' in sample:
print(f" labels 长度: {len(sample['labels'])}")
else:
print(" 没有找到 train_dataset")
except Exception as e:
print(f"\n✗ 错误:")
print(f"{type(e).__name__}: {e}")
import traceback
traceback.print_exc()
|