llm / extract_exactly_10k.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
从 allenai/Dolci-Instruct-SFT-Tool-Use 数据集的最后部分提取恰好 10000 个符合规则的样本
"""
import json
from datasets import load_dataset
from tqdm import tqdm
def convert_to_llamafactory_format(example):
"""
将 Dolci 格式转换为 LLaMA-Factory ShareGPT 格式
"""
conversations = []
# 获取对话内容
messages = example.get('messages', [])
for msg in messages:
role = msg.get('role', '')
content = msg.get('content', '')
tool_calls = msg.get('tool_calls', [])
# 转换角色映射
if role == 'user':
conversations.append({
"from": "human",
"value": content
})
elif role == 'assistant':
# 如果有 tool_calls,添加 function_call
if tool_calls:
# 提取 function call 信息
function_calls = []
for tc in tool_calls:
func_name = tc.get('function', {}).get('name', '')
func_args = tc.get('function', {}).get('arguments', '')
function_calls.append({
"name": func_name,
"arguments": json.loads(func_args) if isinstance(func_args, str) else func_args
})
# 如果只有一个函数调用,使用对象格式;多个则使用数组
if len(function_calls) == 1:
conversations.append({
"from": "function_call",
"value": json.dumps(function_calls[0])
})
else:
conversations.append({
"from": "function_call",
"value": json.dumps(function_calls)
})
# 如果有内容(assistant 的回复),添加 gpt
if content and content.strip():
conversations.append({
"from": "gpt",
"value": content
})
elif role == 'tool':
# tool 结果转换为 observation
conversations.append({
"from": "observation",
"value": content
})
# 构建结果
result = {
"conversations": conversations
}
# 添加 system prompt(如果有)
if 'system' in example:
result['system'] = example['system']
# 添加 tools(如果有)
if 'tools' in example and example['tools']:
result['tools'] = json.dumps(example['tools'])
return result
def validate_position_rules(conversations):
"""
验证位置规则:
- human 和 observation 应该在奇数位置(1, 3, 5, ...)
- gpt 和 function_call 应该在偶数位置(2, 4, 6, ...)
"""
for idx, conv in enumerate(conversations):
position = idx + 1 # 1-based index
role = conv['from']
if position % 2 == 1: # 奇数位置
if role not in ['human', 'observation']:
return False, f"Position {position} should be 'human' or 'observation', but got '{role}'"
else: # 偶数位置
if role not in ['gpt', 'function_call']:
return False, f"Position {position} should be 'gpt' or 'function_call', but got '{role}'"
return True, "Valid"
def main():
print("Loading Dolci-Instruct-SFT-Tool-Use dataset...")
# 加载数据集
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")
total_samples = len(dataset)
print(f"Total samples in dataset: {total_samples}")
# 从末尾开始,收集足够的样本
target_count = 10000
batch_size = 15000 # 一次处理更多样本,确保能凑够 10000 个
start_idx = max(0, total_samples - batch_size)
candidates = dataset.select(range(start_idx, total_samples))
print(f"Processing samples from index {start_idx} to {total_samples} ({len(candidates)} samples)")
# 转换格式
print("Converting to LLaMA-Factory ShareGPT format and filtering...")
converted_data = []
skipped = 0
invalid_position = 0
for idx, example in enumerate(tqdm(candidates)):
if len(converted_data) >= target_count:
break # 已经收集够了
try:
converted = convert_to_llamafactory_format(example)
# 验证位置规则
is_valid, message = validate_position_rules(converted['conversations'])
if not is_valid:
invalid_position += 1
continue # 跳过不符合规则的样本
if converted['conversations']: # 确保有对话内容
converted_data.append(converted)
except Exception as e:
print(f"\nError processing sample {idx}: {e}")
skipped += 1
print(f"\nConversion completed:")
print(f"- Total processed: {idx + 1}")
print(f"- Successfully converted: {len(converted_data)}")
print(f"- Skipped due to errors: {skipped}")
print(f"- Invalid position: {invalid_position}")
# 如果还不够,继续从更前面提取
if len(converted_data) < target_count:
print(f"\nNeed more samples! Current: {len(converted_data)}, Target: {target_count}")
print("Processing more samples...")
# 继续往前提取
while len(converted_data) < target_count and start_idx > 0:
batch_start = max(0, start_idx - batch_size)
batch = dataset.select(range(batch_start, start_idx))
print(f"Processing batch: {batch_start} to {start_idx}")
for idx, example in enumerate(tqdm(batch)):
if len(converted_data) >= target_count:
break
try:
converted = convert_to_llamafactory_format(example)
is_valid, _ = validate_position_rules(converted['conversations'])
if is_valid and converted['conversations']:
converted_data.append(converted)
except:
pass
start_idx = batch_start
# 确保恰好 10000 个
converted_data = converted_data[:target_count]
print(f"\nFinal count: {len(converted_data)} samples")
# 统计
has_tool_call = sum(1 for sample in converted_data
if any(conv['from'] == 'function_call' for conv in sample['conversations']))
has_tools_field = sum(1 for sample in converted_data if 'tools' in sample and sample['tools'])
print(f"- Samples with function_call: {has_tool_call}")
print(f"- Samples with tools field: {has_tools_field}")
print(f"- Samples without tool calls: {len(converted_data) - has_tool_call}")
# 保存结果
output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(converted_data, f, ensure_ascii=False, indent=2)
print(f"Successfully saved {len(converted_data)} samples!")
# 打印一个有 tool call 的示例
print("\n" + "="*80)
print("Sample example with tool call:")
print("="*80)
for sample in converted_data:
if any(conv['from'] == 'function_call' for conv in sample['conversations']):
print(json.dumps(sample, ensure_ascii=False, indent=2)[:1000])
break
if __name__ == "__main__":
main()