| |
| """ |
| Analyze Dolci dataset and extract batch5. |
| Find tool calling samples that haven't been used in batch1-4. |
| """ |
|
|
| import json |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| def has_tool_calling(conversations): |
| """Check if conversations contain function_call and observation.""" |
| roles = [conv.get('from') or conv.get('role') for conv in conversations] |
| return 'function_call' in roles and 'observation' in roles |
|
|
| def convert_to_llamafactory_format(sample): |
| """Convert dataset sample to LlamaFactory format.""" |
| conversations = sample.get('conversations', []) |
|
|
| |
| converted_conversations = [] |
| for conv in conversations: |
| role = conv.get('from') or conv.get('role') |
| value = conv.get('value') or conv.get('content') |
| converted_conversations.append({ |
| 'from': role, |
| 'value': value |
| }) |
|
|
| result = { |
| 'conversations': converted_conversations |
| } |
|
|
| |
| if 'system' in sample and sample['system']: |
| result['system'] = sample['system'] |
|
|
| |
| if 'tools' in sample and sample['tools']: |
| result['tools'] = sample['tools'] |
|
|
| return result |
|
|
| def get_sample_hash(sample): |
| """Create a hash for a sample to identify duplicates.""" |
| conversations = sample.get('conversations', []) |
| if conversations: |
| first_msg = conversations[0].get('value', '') |
| return hash(first_msg) |
| return None |
|
|
| def main(): |
| print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") |
| dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
| total_samples = len(dataset) |
| print(f"Total samples in dataset: {total_samples}") |
|
|
| |
| print("\nLoading existing batches to avoid duplicates...") |
| existing_hashes = set() |
| for batch_num in range(1, 5): |
| batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json" |
| try: |
| with open(batch_file, 'r', encoding='utf-8') as f: |
| batch_data = json.load(f) |
| for sample in batch_data: |
| sample_hash = get_sample_hash(sample) |
| if sample_hash: |
| existing_hashes.add(sample_hash) |
| print(f" Loaded batch{batch_num}: {len(batch_data)} samples") |
| except FileNotFoundError: |
| print(f" Warning: {batch_file} not found, skipping...") |
|
|
| print(f"Total existing samples to avoid: {len(existing_hashes)}") |
|
|
| |
| print("\nAnalyzing dataset for tool calling samples...") |
| tool_calling_samples = [] |
|
|
| for idx, sample in enumerate(tqdm(dataset, desc="Scanning entire dataset")): |
| conversations = sample.get('conversations', []) |
| if has_tool_calling(conversations): |
| sample_hash = get_sample_hash(sample) |
| |
| if sample_hash not in existing_hashes: |
| converted = convert_to_llamafactory_format(sample) |
| tool_calling_samples.append({ |
| 'index': idx, |
| 'data': converted |
| }) |
|
|
| print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples (excluding batch1-4)") |
|
|
| |
| if len(tool_calling_samples) > 10000: |
| selected_samples = [s['data'] for s in tool_calling_samples[:10000]] |
| print(f"Selected first 10,000 samples for batch5") |
| print(f"Index range: {tool_calling_samples[0]['index']} to {tool_calling_samples[9999]['index']}") |
| else: |
| selected_samples = [s['data'] for s in tool_calling_samples] |
| print(f"Using all {len(selected_samples)} new samples for batch5") |
| if tool_calling_samples: |
| print(f"Index range: {tool_calling_samples[0]['index']} to {tool_calling_samples[-1]['index']}") |
|
|
| if not selected_samples: |
| print("\n❌ No new tool calling samples found!") |
| return |
|
|
| |
| output_file = "data/dolci_10k_with_tool_call_batch5.json" |
| print(f"\nSaving to {output_file}...") |
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(selected_samples, f, ensure_ascii=False, indent=2) |
|
|
| print(f"✓ Successfully created batch5 with {len(selected_samples)} samples") |
|
|
| |
| if selected_samples: |
| print("\nSample entry:") |
| print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:800] + "...") |
|
|
| if __name__ == "__main__": |
| main() |
|
|