| |
| """ |
| 从 allenai/Dolci-Instruct-SFT-Tool-Use 数据集的最后部分提取恰好 10000 个符合规则的样本 |
| """ |
|
|
| import json |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| def convert_to_llamafactory_format(example): |
| """ |
| 将 Dolci 格式转换为 LLaMA-Factory ShareGPT 格式 |
| """ |
| conversations = [] |
|
|
| |
| messages = example.get('messages', []) |
|
|
| for msg in messages: |
| role = msg.get('role', '') |
| content = msg.get('content', '') |
| tool_calls = msg.get('tool_calls', []) |
|
|
| |
| if role == 'user': |
| conversations.append({ |
| "from": "human", |
| "value": content |
| }) |
| elif role == 'assistant': |
| |
| if tool_calls: |
| |
| function_calls = [] |
| for tc in tool_calls: |
| func_name = tc.get('function', {}).get('name', '') |
| func_args = tc.get('function', {}).get('arguments', '') |
| function_calls.append({ |
| "name": func_name, |
| "arguments": json.loads(func_args) if isinstance(func_args, str) else func_args |
| }) |
|
|
| |
| if len(function_calls) == 1: |
| conversations.append({ |
| "from": "function_call", |
| "value": json.dumps(function_calls[0]) |
| }) |
| else: |
| conversations.append({ |
| "from": "function_call", |
| "value": json.dumps(function_calls) |
| }) |
|
|
| |
| if content and content.strip(): |
| conversations.append({ |
| "from": "gpt", |
| "value": content |
| }) |
| elif role == 'tool': |
| |
| conversations.append({ |
| "from": "observation", |
| "value": content |
| }) |
|
|
| |
| result = { |
| "conversations": conversations |
| } |
|
|
| |
| if 'system' in example: |
| result['system'] = example['system'] |
|
|
| |
| if 'tools' in example and example['tools']: |
| result['tools'] = json.dumps(example['tools']) |
|
|
| return result |
|
|
| def validate_position_rules(conversations): |
| """ |
| 验证位置规则: |
| - human 和 observation 应该在奇数位置(1, 3, 5, ...) |
| - gpt 和 function_call 应该在偶数位置(2, 4, 6, ...) |
| """ |
| for idx, conv in enumerate(conversations): |
| position = idx + 1 |
| role = conv['from'] |
|
|
| if position % 2 == 1: |
| if role not in ['human', 'observation']: |
| return False, f"Position {position} should be 'human' or 'observation', but got '{role}'" |
| else: |
| if role not in ['gpt', 'function_call']: |
| return False, f"Position {position} should be 'gpt' or 'function_call', but got '{role}'" |
|
|
| return True, "Valid" |
|
|
| def main(): |
| print("Loading Dolci-Instruct-SFT-Tool-Use dataset...") |
|
|
| |
| dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
| total_samples = len(dataset) |
| print(f"Total samples in dataset: {total_samples}") |
|
|
| |
| target_count = 10000 |
| batch_size = 15000 |
|
|
| start_idx = max(0, total_samples - batch_size) |
| candidates = dataset.select(range(start_idx, total_samples)) |
| print(f"Processing samples from index {start_idx} to {total_samples} ({len(candidates)} samples)") |
|
|
| |
| print("Converting to LLaMA-Factory ShareGPT format and filtering...") |
| converted_data = [] |
| skipped = 0 |
| invalid_position = 0 |
|
|
| for idx, example in enumerate(tqdm(candidates)): |
| if len(converted_data) >= target_count: |
| break |
|
|
| try: |
| converted = convert_to_llamafactory_format(example) |
|
|
| |
| is_valid, message = validate_position_rules(converted['conversations']) |
| if not is_valid: |
| invalid_position += 1 |
| continue |
|
|
| if converted['conversations']: |
| converted_data.append(converted) |
| except Exception as e: |
| print(f"\nError processing sample {idx}: {e}") |
| skipped += 1 |
|
|
| print(f"\nConversion completed:") |
| print(f"- Total processed: {idx + 1}") |
| print(f"- Successfully converted: {len(converted_data)}") |
| print(f"- Skipped due to errors: {skipped}") |
| print(f"- Invalid position: {invalid_position}") |
|
|
| |
| if len(converted_data) < target_count: |
| print(f"\nNeed more samples! Current: {len(converted_data)}, Target: {target_count}") |
| print("Processing more samples...") |
|
|
| |
| while len(converted_data) < target_count and start_idx > 0: |
| batch_start = max(0, start_idx - batch_size) |
| batch = dataset.select(range(batch_start, start_idx)) |
| print(f"Processing batch: {batch_start} to {start_idx}") |
|
|
| for idx, example in enumerate(tqdm(batch)): |
| if len(converted_data) >= target_count: |
| break |
|
|
| try: |
| converted = convert_to_llamafactory_format(example) |
| is_valid, _ = validate_position_rules(converted['conversations']) |
|
|
| if is_valid and converted['conversations']: |
| converted_data.append(converted) |
| except: |
| pass |
|
|
| start_idx = batch_start |
|
|
| |
| converted_data = converted_data[:target_count] |
|
|
| print(f"\nFinal count: {len(converted_data)} samples") |
|
|
| |
| has_tool_call = sum(1 for sample in converted_data |
| if any(conv['from'] == 'function_call' for conv in sample['conversations'])) |
| has_tools_field = sum(1 for sample in converted_data if 'tools' in sample and sample['tools']) |
|
|
| print(f"- Samples with function_call: {has_tool_call}") |
| print(f"- Samples with tools field: {has_tools_field}") |
| print(f"- Samples without tool calls: {len(converted_data) - has_tool_call}") |
|
|
| |
| output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json" |
| print(f"\nSaving to {output_file}...") |
|
|
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(converted_data, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Successfully saved {len(converted_data)} samples!") |
|
|
| |
| print("\n" + "="*80) |
| print("Sample example with tool call:") |
| print("="*80) |
| for sample in converted_data: |
| if any(conv['from'] == 'function_call' for conv in sample['conversations']): |
| print(json.dumps(sample, ensure_ascii=False, indent=2)[:1000]) |
| break |
|
|
| if __name__ == "__main__": |
| main() |
|
|