| |
| """ |
| 从 allenai/Dolci-Instruct-SFT-Tool-Use 数据集的最后 2w 数据中提取 1w 数据 |
| 使用正确的转换逻辑 |
| """ |
|
|
| import json |
| from datasets import load_dataset |
| from tqdm import tqdm |
| import re |
| import ast |
|
|
| def parse_python_function_call(func_call_str): |
| """ |
| Parse Python function call syntax to JSON format |
| Example: 'weather.forecast_weather_api(q="Paris", days=5)' |
| -> {"name": "weather.forecast_weather_api", "arguments": {"q": "Paris", "days": 5}} |
| """ |
| try: |
| |
| match = re.match(r'^([a-zA-Z_][\w.]*)\((.*)\)$', func_call_str.strip()) |
| if not match: |
| return None |
|
|
| func_name = match.group(1) |
| args_str = match.group(2) |
|
|
| arguments = {} |
| if args_str.strip(): |
| try: |
| args_str_formatted = "{" + args_str + "}" |
| args_dict = ast.literal_eval(args_str_formatted) |
| arguments = args_dict |
| except: |
| for arg in args_str.split(','): |
| arg = arg.strip() |
| if '=' in arg: |
| key, val = arg.split('=', 1) |
| key = key.strip() |
| val = val.strip() |
| try: |
| arguments[key] = ast.literal_eval(val) |
| except: |
| arguments[key] = val.strip('"\'') |
|
|
| return {"name": func_name, "arguments": arguments} |
| except: |
| return None |
|
|
| def convert_function_calls_to_json(function_calls_str): |
| """ |
| Convert Python function call format to JSON format |
| Returns None if conversion fails |
| """ |
| if not function_calls_str or not function_calls_str.strip(): |
| return None |
|
|
| try: |
| lines = [line.strip() for line in function_calls_str.strip().split('\n') if line.strip()] |
|
|
| parsed_calls = [] |
| for line in lines: |
| parsed = parse_python_function_call(line) |
| if parsed: |
| parsed_calls.append(parsed) |
|
|
| if not parsed_calls: |
| return None |
|
|
| if len(parsed_calls) == 1: |
| return json.dumps(parsed_calls[0]) |
| else: |
| return json.dumps(parsed_calls) |
| except: |
| return None |
|
|
| def convert_to_llamafactory_format(example): |
| """Convert messages to ShareGPT format""" |
| conversations = [] |
| system_prompt = None |
| tools_str = None |
|
|
| messages = example.get("messages", []) |
|
|
| for msg in messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| function_calls = msg.get("function_calls", "") |
| functions = msg.get("functions", "") |
|
|
| |
| if role == "system": |
| if content: |
| system_prompt = content |
| if functions: |
| tools_str = functions |
| continue |
|
|
| |
| if role == "user": |
| if content: |
| conversations.append({"from": "human", "value": content}) |
| elif role == "assistant": |
| |
| if (not content or content == "") and function_calls: |
| json_function_calls = convert_function_calls_to_json(function_calls) |
| if json_function_calls: |
| conversations.append({"from": "function_call", "value": json_function_calls}) |
| elif content: |
| conversations.append({"from": "gpt", "value": content}) |
| elif role in ["tool", "function", "environment"]: |
| if content: |
| conversations.append({"from": "observation", "value": content}) |
|
|
| result = {"conversations": conversations} |
| if system_prompt: |
| result["system"] = system_prompt |
| if tools_str: |
| result["tools"] = tools_str |
|
|
| return result |
|
|
| def validate_position_rules(conversations): |
| """ |
| Validate position rules: |
| - human and observation in odd positions (1, 3, 5...) |
| - gpt and function_call in even positions (2, 4, 6...) |
| """ |
| for idx, conv in enumerate(conversations): |
| position = idx + 1 |
| role = conv['from'] |
|
|
| if position % 2 == 1: |
| if role not in ['human', 'observation']: |
| return False |
| else: |
| if role not in ['gpt', 'function_call']: |
| return False |
|
|
| return True |
|
|
| def main(): |
| print("Loading Dolci-Instruct-SFT-Tool-Use dataset...") |
| dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
| total_samples = len(dataset) |
| print(f"Total samples: {total_samples}") |
|
|
| |
| start_idx = max(0, total_samples - 20000) |
| last_20k = dataset.select(range(start_idx, total_samples)) |
| print(f"Selected last 20k samples (from {start_idx} to {total_samples})") |
|
|
| |
| last_10k = last_20k.select(range(10000, 20000)) |
| print(f"Processing last 10k from the 20k (indices 10000-20000 of the 20k batch)") |
|
|
| |
| print("Converting to LLaMA-Factory ShareGPT format...") |
| converted_data = [] |
| skipped_invalid = 0 |
| skipped_error = 0 |
|
|
| for idx, example in enumerate(tqdm(last_10k)): |
| try: |
| converted = convert_to_llamafactory_format(example) |
|
|
| |
| if not validate_position_rules(converted['conversations']): |
| skipped_invalid += 1 |
| continue |
|
|
| if converted['conversations']: |
| converted_data.append(converted) |
| except Exception as e: |
| skipped_error += 1 |
|
|
| print(f"\nResults:") |
| print(f"- Successfully converted: {len(converted_data)}") |
| print(f"- Skipped (invalid position): {skipped_invalid}") |
| print(f"- Skipped (errors): {skipped_error}") |
|
|
| |
| has_function_call = sum(1 for s in converted_data |
| if any(c['from'] == 'function_call' for c in s['conversations'])) |
| has_tools = sum(1 for s in converted_data if 'tools' in s and s['tools']) |
|
|
| print(f"\nStatistics:") |
| print(f"- With function_call: {has_function_call}") |
| print(f"- With tools field: {has_tools}") |
| print(f"- Without function_call: {len(converted_data) - has_function_call}") |
|
|
| |
| output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json" |
| print(f"\nSaving to {output_file}...") |
|
|
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(converted_data, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Saved {len(converted_data)} samples!") |
|
|
| |
| if converted_data: |
| print("\n" + "="*80) |
| print("Sample with function_call:") |
| print("="*80) |
| for sample in converted_data: |
| if any(c['from'] == 'function_call' for c in sample['conversations']): |
| print(json.dumps(sample, ensure_ascii=False, indent=2)[:1200]) |
| break |
|
|
| if __name__ == "__main__": |
| main() |
|
|