#!/usr/bin/env python3 """ Fix batch5 by correctly converting environment role to observation. """ import json from datasets import load_dataset from tqdm import tqdm def convert_to_llamafactory_format(sample): """ Convert from Dolci format to LlamaFactory format. Dolci format (messages): - role: system/user/assistant/environment - content: text content - function_calls: function call string (in assistant messages) - functions: available functions JSON string (in system message) LlamaFactory format (conversations): - from: human/gpt/function_call/observation/system - value: text or JSON """ messages = sample.get('messages', []) conversations = [] tools = None system_prompt = None for i, msg in enumerate(messages): role = msg.get('role', '') content = msg.get('content', '') function_calls = msg.get('function_calls') functions = msg.get('functions') # Extract tools from first system message if role == 'system': if functions and not tools: tools = functions if content: system_prompt = content continue # Convert roles if role == 'user': conversations.append({ 'from': 'human', 'value': content }) elif role == 'assistant': # Check if this message contains function calls if function_calls: # This is a function call conversations.append({ 'from': 'function_call', 'value': function_calls }) elif content: # This is a regular assistant response conversations.append({ 'from': 'gpt', 'value': content }) elif role == 'environment': # This is the tool result / observation conversations.append({ 'from': 'observation', 'value': content }) result = {'conversations': conversations} if system_prompt: result['system'] = system_prompt if tools: result['tools'] = tools return result def get_sample_hash(sample): """Create a hash for a sample to identify duplicates.""" messages = sample.get('messages', []) for msg in messages: if msg.get('role') == 'user': return hash(msg.get('content', '')) return None def has_tool_calling(messages): """Check if messages contain function_call.""" for msg in messages: if msg.get('function_calls'): return True return False def main(): print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") total_samples = len(dataset) print(f"Total samples in dataset: {total_samples}") # Load existing batch1-4 to avoid duplicates print("\nLoading existing batches to avoid duplicates...") existing_hashes = set() for batch_num in range(1, 5): batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json" try: with open(batch_file, 'r', encoding='utf-8') as f: batch_data = json.load(f) for sample in batch_data: conversations = sample.get('conversations', []) for conv in conversations: if conv.get('from') == 'human': sample_hash = hash(conv.get('value', '')) existing_hashes.add(sample_hash) break print(f" Loaded batch{batch_num}: {len(batch_data)} samples") except FileNotFoundError: print(f" Warning: {batch_file} not found, skipping...") print(f"Total existing samples to avoid: {len(existing_hashes)}") # Get last 20k samples start_idx = max(0, total_samples - 20000) last_20k = dataset.select(range(start_idx, total_samples)) print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})") # Filter samples with tool calling and proper format tool_calling_samples = [] for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")): messages = sample.get('messages', []) if has_tool_calling(messages): sample_hash = get_sample_hash(sample) # Skip if already in batch1-4 if sample_hash not in existing_hashes: converted = convert_to_llamafactory_format(sample) # Verify conversion has proper structure with observation conversations = converted.get('conversations', []) roles = [c['from'] for c in conversations] # Check if has both function_call and observation if 'function_call' in roles and 'observation' in roles: tool_calling_samples.append(converted) print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples with proper format") # Select up to 10k samples if len(tool_calling_samples) > 10000: selected_samples = tool_calling_samples[:10000] print(f"Selected first 10,000 samples for batch5") else: selected_samples = tool_calling_samples print(f"Using all {len(selected_samples)} samples for batch5") if not selected_samples: print("\nāŒ No new tool calling samples found!") return # Save to file output_file = "data/dolci_10k_with_tool_call_batch5.json" print(f"\nSaving to {output_file}...") with open(output_file, 'w', encoding='utf-8') as f: json.dump(selected_samples, f, ensure_ascii=False, indent=2) print(f"āœ“ Successfully created batch5 with {len(selected_samples)} samples") # Verify format print("\n=== Verifying format ===") role_patterns = {} for sample in selected_samples[:100]: roles = [c['from'] for c in sample['conversations']] pattern = ' -> '.join(roles) role_patterns[pattern] = role_patterns.get(pattern, 0) + 1 print("Top patterns in first 100 samples:") for pattern, count in sorted(role_patterns.items(), key=lambda x: -x[1])[:5]: print(f" [{count:3d}] {pattern}") # Print a sample for verification if selected_samples: print("\nSample entry:") print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...") if __name__ == "__main__": main()