llm / extract_batch5_correct.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
Extract batch5 from Dolci dataset with correct field names.
"""
import json
from datasets import load_dataset
from tqdm import tqdm
def has_tool_calling(messages):
"""Check if messages contain function_call."""
for msg in messages:
if msg.get('function_calls'):
return True
return False
def convert_to_llamafactory_format(sample):
"""
Convert from Dolci format to LlamaFactory format.
Dolci format (messages):
- role: system/user/assistant
- content: text content
- function_calls: function call JSON string
- functions: available functions JSON string
LlamaFactory format (conversations):
- from: human/gpt/function_call/observation/system
- value: text or JSON
"""
messages = sample.get('messages', [])
conversations = []
tools = None
system_prompt = None
for i, msg in enumerate(messages):
role = msg.get('role', '')
content = msg.get('content', '')
function_calls = msg.get('function_calls')
functions = msg.get('functions')
# Extract tools from first system message
if role == 'system' and functions and not tools:
tools = functions
system_prompt = content
continue
# Convert roles
if role == 'user':
conversations.append({
'from': 'human',
'value': content
})
elif role == 'assistant':
# Check if this message contains function calls
if function_calls:
# Add function_call message
conversations.append({
'from': 'function_call',
'value': function_calls
})
# The next message should be the observation (tool result)
# or we add the assistant response if no observation follows
else:
# Regular assistant message
if content:
conversations.append({
'from': 'gpt',
'value': content
})
elif role == 'function':
# This is a tool/observation result
conversations.append({
'from': 'observation',
'value': content
})
result = {'conversations': conversations}
if system_prompt:
result['system'] = system_prompt
if tools:
result['tools'] = tools
return result
def get_sample_hash(sample):
"""Create a hash for a sample to identify duplicates."""
messages = sample.get('messages', [])
for msg in messages:
if msg.get('role') == 'user':
return hash(msg.get('content', ''))
return None
def main():
print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")
total_samples = len(dataset)
print(f"Total samples in dataset: {total_samples}")
# Load existing batch1-4 to avoid duplicates
print("\nLoading existing batches to avoid duplicates...")
existing_hashes = set()
for batch_num in range(1, 5):
batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json"
try:
with open(batch_file, 'r', encoding='utf-8') as f:
batch_data = json.load(f)
for sample in batch_data:
conversations = sample.get('conversations', [])
for conv in conversations:
if conv.get('from') == 'human':
sample_hash = hash(conv.get('value', ''))
existing_hashes.add(sample_hash)
break
print(f" Loaded batch{batch_num}: {len(batch_data)} samples")
except FileNotFoundError:
print(f" Warning: {batch_file} not found, skipping...")
print(f"Total existing samples to avoid: {len(existing_hashes)}")
# Get last 20k samples
start_idx = max(0, total_samples - 20000)
last_20k = dataset.select(range(start_idx, total_samples))
print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})")
# Filter samples with tool calling
tool_calling_samples = []
for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")):
messages = sample.get('messages', [])
if has_tool_calling(messages):
sample_hash = get_sample_hash(sample)
# Skip if already in batch1-4
if sample_hash not in existing_hashes:
converted = convert_to_llamafactory_format(sample)
# Verify conversion has proper structure
if converted.get('conversations'):
tool_calling_samples.append(converted)
print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples")
# Select up to 10k samples
if len(tool_calling_samples) > 10000:
selected_samples = tool_calling_samples[:10000]
print(f"Selected first 10,000 samples for batch5")
else:
selected_samples = tool_calling_samples
print(f"Using all {len(selected_samples)} samples for batch5")
if not selected_samples:
print("\n❌ No new tool calling samples found in last 20k!")
print("Trying from entire dataset...")
# Try entire dataset
tool_calling_samples = []
for idx, sample in enumerate(tqdm(dataset, desc="Scanning entire dataset")):
messages = sample.get('messages', [])
if has_tool_calling(messages):
sample_hash = get_sample_hash(sample)
if sample_hash not in existing_hashes:
converted = convert_to_llamafactory_format(sample)
if converted.get('conversations'):
tool_calling_samples.append(converted)
if len(tool_calling_samples) >= 10000:
break
selected_samples = tool_calling_samples[:10000]
print(f"Found {len(selected_samples)} new samples from entire dataset")
if not selected_samples:
print("\n❌ No new tool calling samples available!")
return
# Save to file
output_file = "data/dolci_10k_with_tool_call_batch5.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(selected_samples, f, ensure_ascii=False, indent=2)
print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")
# Print a sample for verification
if selected_samples:
print("\nSample entry:")
print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...")
if __name__ == "__main__":
main()