#!/usr/bin/env python3 """ Extract batch5 from the last 20k samples of allenai/Dolci-Instruct-SFT-Tool-Use dataset. Select samples with tool calling (function_call and observation). """ import json from datasets import load_dataset from tqdm import tqdm def has_tool_calling(conversations): """Check if conversations contain function_call and observation.""" roles = [conv.get('from') or conv.get('role') for conv in conversations] return 'function_call' in roles and 'observation' in roles def convert_to_llamafactory_format(sample): """Convert dataset sample to LlamaFactory format.""" conversations = sample.get('conversations', []) # Convert role names if needed converted_conversations = [] for conv in conversations: role = conv.get('from') or conv.get('role') value = conv.get('value') or conv.get('content') converted_conversations.append({ 'from': role, 'value': value }) result = { 'conversations': converted_conversations } # Add system prompt if exists if 'system' in sample and sample['system']: result['system'] = sample['system'] # Add tools if exists if 'tools' in sample and sample['tools']: result['tools'] = sample['tools'] return result def main(): print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") total_samples = len(dataset) print(f"Total samples in dataset: {total_samples}") # Get the last 20k samples start_idx = max(0, total_samples - 20000) last_20k = dataset.select(range(start_idx, total_samples)) print(f"Processing last 20k samples (from index {start_idx} to {total_samples})") # Filter samples with tool calling tool_calling_samples = [] for sample in tqdm(last_20k, desc="Filtering samples with tool calling"): conversations = sample.get('conversations', []) if has_tool_calling(conversations): converted = convert_to_llamafactory_format(sample) tool_calling_samples.append(converted) print(f"\nFound {len(tool_calling_samples)} samples with tool calling") # Select up to 10k samples if len(tool_calling_samples) > 10000: selected_samples = tool_calling_samples[:10000] print(f"Selected first 10,000 samples for batch5") else: selected_samples = tool_calling_samples print(f"Using all {len(selected_samples)} samples for batch5") # Save to file output_file = "data/dolci_10k_with_tool_call_batch5.json" print(f"\nSaving to {output_file}...") with open(output_file, 'w', encoding='utf-8') as f: json.dump(selected_samples, f, ensure_ascii=False, indent=2) print(f"✓ Successfully created batch5 with {len(selected_samples)} samples") # Print a sample for verification if selected_samples: print("\nSample entry:") print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:500] + "...") if __name__ == "__main__": main()