llm / extract_last_10k_correct.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
从 allenai/Dolci-Instruct-SFT-Tool-Use 数据集的最后 2w 数据中提取 1w 数据
使用正确的转换逻辑
"""
import json
from datasets import load_dataset
from tqdm import tqdm
import re
import ast
def parse_python_function_call(func_call_str):
"""
Parse Python function call syntax to JSON format
Example: 'weather.forecast_weather_api(q="Paris", days=5)'
-> {"name": "weather.forecast_weather_api", "arguments": {"q": "Paris", "days": 5}}
"""
try:
# Parse single function call line
match = re.match(r'^([a-zA-Z_][\w.]*)\((.*)\)$', func_call_str.strip())
if not match:
return None
func_name = match.group(1)
args_str = match.group(2)
arguments = {}
if args_str.strip():
try:
args_str_formatted = "{" + args_str + "}"
args_dict = ast.literal_eval(args_str_formatted)
arguments = args_dict
except:
for arg in args_str.split(','):
arg = arg.strip()
if '=' in arg:
key, val = arg.split('=', 1)
key = key.strip()
val = val.strip()
try:
arguments[key] = ast.literal_eval(val)
except:
arguments[key] = val.strip('"\'')
return {"name": func_name, "arguments": arguments}
except:
return None
def convert_function_calls_to_json(function_calls_str):
"""
Convert Python function call format to JSON format
Returns None if conversion fails
"""
if not function_calls_str or not function_calls_str.strip():
return None
try:
lines = [line.strip() for line in function_calls_str.strip().split('\n') if line.strip()]
parsed_calls = []
for line in lines:
parsed = parse_python_function_call(line)
if parsed:
parsed_calls.append(parsed)
if not parsed_calls:
return None
if len(parsed_calls) == 1:
return json.dumps(parsed_calls[0])
else:
return json.dumps(parsed_calls)
except:
return None
def convert_to_llamafactory_format(example):
"""Convert messages to ShareGPT format"""
conversations = []
system_prompt = None
tools_str = None
messages = example.get("messages", [])
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
function_calls = msg.get("function_calls", "")
functions = msg.get("functions", "")
# Extract system prompt and tools
if role == "system":
if content:
system_prompt = content
if functions:
tools_str = functions
continue
# Map roles to ShareGPT format
if role == "user":
if content:
conversations.append({"from": "human", "value": content})
elif role == "assistant":
# If empty content but has function_calls, it's a function call
if (not content or content == "") and function_calls:
json_function_calls = convert_function_calls_to_json(function_calls)
if json_function_calls:
conversations.append({"from": "function_call", "value": json_function_calls})
elif content:
conversations.append({"from": "gpt", "value": content})
elif role in ["tool", "function", "environment"]:
if content:
conversations.append({"from": "observation", "value": content})
result = {"conversations": conversations}
if system_prompt:
result["system"] = system_prompt
if tools_str:
result["tools"] = tools_str
return result
def validate_position_rules(conversations):
"""
Validate position rules:
- human and observation in odd positions (1, 3, 5...)
- gpt and function_call in even positions (2, 4, 6...)
"""
for idx, conv in enumerate(conversations):
position = idx + 1
role = conv['from']
if position % 2 == 1: # odd
if role not in ['human', 'observation']:
return False
else: # even
if role not in ['gpt', 'function_call']:
return False
return True
def main():
print("Loading Dolci-Instruct-SFT-Tool-Use dataset...")
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")
total_samples = len(dataset)
print(f"Total samples: {total_samples}")
# Get last 20k
start_idx = max(0, total_samples - 20000)
last_20k = dataset.select(range(start_idx, total_samples))
print(f"Selected last 20k samples (from {start_idx} to {total_samples})")
# Extract last 10k from the 20k
last_10k = last_20k.select(range(10000, 20000))
print(f"Processing last 10k from the 20k (indices 10000-20000 of the 20k batch)")
# Convert
print("Converting to LLaMA-Factory ShareGPT format...")
converted_data = []
skipped_invalid = 0
skipped_error = 0
for idx, example in enumerate(tqdm(last_10k)):
try:
converted = convert_to_llamafactory_format(example)
# Validate
if not validate_position_rules(converted['conversations']):
skipped_invalid += 1
continue
if converted['conversations']:
converted_data.append(converted)
except Exception as e:
skipped_error += 1
print(f"\nResults:")
print(f"- Successfully converted: {len(converted_data)}")
print(f"- Skipped (invalid position): {skipped_invalid}")
print(f"- Skipped (errors): {skipped_error}")
# Statistics
has_function_call = sum(1 for s in converted_data
if any(c['from'] == 'function_call' for c in s['conversations']))
has_tools = sum(1 for s in converted_data if 'tools' in s and s['tools'])
print(f"\nStatistics:")
print(f"- With function_call: {has_function_call}")
print(f"- With tools field: {has_tools}")
print(f"- Without function_call: {len(converted_data) - has_function_call}")
# Save
output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(converted_data, f, ensure_ascii=False, indent=2)
print(f"Saved {len(converted_data)} samples!")
# Show example
if converted_data:
print("\n" + "="*80)
print("Sample with function_call:")
print("="*80)
for sample in converted_data:
if any(c['from'] == 'function_call' for c in sample['conversations']):
print(json.dumps(sample, ensure_ascii=False, indent=2)[:1200])
break
if __name__ == "__main__":
main()