|
|
|
|
|
"""调试 dolci_200k_fixed 数据被 drop 的原因""" |
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, '/shared_workspace_mfs/ximing/LLaMA-Factory/src') |
|
|
|
|
|
from llamafactory.data import get_dataset |
|
|
from llamafactory.hparams import get_train_args |
|
|
|
|
|
|
|
|
args_dict = { |
|
|
"model_name_or_path": "/shared_workspace_mfs/ximing/Qwen3-8B-Base", |
|
|
"trust_remote_code": True, |
|
|
"stage": "sft", |
|
|
"do_train": True, |
|
|
"finetuning_type": "full", |
|
|
"dataset": "dolci_200k_fixed", |
|
|
"template": "qwen3_nothink", |
|
|
"cutoff_len": 65536, |
|
|
"rope_scaling": "yarn", |
|
|
"max_samples": 200000, |
|
|
"overwrite_cache": True, |
|
|
"preprocessing_num_workers": 16, |
|
|
"output_dir": "/tmp/test_output", |
|
|
} |
|
|
|
|
|
print("="*80) |
|
|
print("调试 dolci_200k_fixed 数据加载") |
|
|
print("="*80) |
|
|
|
|
|
try: |
|
|
|
|
|
print("\n正在加载参数...") |
|
|
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args_dict) |
|
|
|
|
|
print(f"\n✓ 参数加载成功") |
|
|
print(f" 数据集: {data_args.dataset}") |
|
|
print(f" 模板: {data_args.template}") |
|
|
print(f" cutoff_len: {data_args.cutoff_len}") |
|
|
print(f" max_samples: {data_args.max_samples}") |
|
|
|
|
|
|
|
|
print(f"\n正在加载数据集...") |
|
|
dataset_module = get_dataset(model_args, data_args) |
|
|
|
|
|
|
|
|
print(f"\n✓ 数据集加载成功!") |
|
|
if 'train_dataset' in dataset_module: |
|
|
train_dataset = dataset_module['train_dataset'] |
|
|
print(f" 训练样本数: {len(train_dataset)}") |
|
|
|
|
|
|
|
|
print(f"\n正在分析样本长度分布...") |
|
|
lengths = [] |
|
|
for i in range(min(1000, len(train_dataset))): |
|
|
sample = train_dataset[i] |
|
|
if 'input_ids' in sample: |
|
|
lengths.append(len(sample['input_ids'])) |
|
|
|
|
|
if lengths: |
|
|
lengths.sort() |
|
|
print(f" 前 1000 个样本长度统计:") |
|
|
print(f" 最小长度: {min(lengths)}") |
|
|
print(f" 最大长度: {max(lengths)}") |
|
|
print(f" 平均长度: {sum(lengths)/len(lengths):.2f}") |
|
|
print(f" 中位数: {lengths[len(lengths)//2]}") |
|
|
print(f" 超过 cutoff_len (65536) 的样本数: {sum(1 for l in lengths if l > 65536)}") |
|
|
|
|
|
|
|
|
if len(train_dataset) > 0: |
|
|
print(f"\n第一个样本的详细信息:") |
|
|
sample = train_dataset[0] |
|
|
print(f" Keys: {list(sample.keys())}") |
|
|
if 'input_ids' in sample: |
|
|
print(f" input_ids 长度: {len(sample['input_ids'])}") |
|
|
if 'labels' in sample: |
|
|
print(f" labels 长度: {len(sample['labels'])}") |
|
|
else: |
|
|
print(" 没有找到 train_dataset") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n✗ 错误:") |
|
|
print(f"{type(e).__name__}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|