llm / extract_batch5.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
Extract batch5 from the last 20k samples of allenai/Dolci-Instruct-SFT-Tool-Use dataset.
Select samples with tool calling (function_call and observation).
"""
import json
from datasets import load_dataset
from tqdm import tqdm
def has_tool_calling(conversations):
"""Check if conversations contain function_call and observation."""
roles = [conv.get('from') or conv.get('role') for conv in conversations]
return 'function_call' in roles and 'observation' in roles
def convert_to_llamafactory_format(sample):
"""Convert dataset sample to LlamaFactory format."""
conversations = sample.get('conversations', [])
# Convert role names if needed
converted_conversations = []
for conv in conversations:
role = conv.get('from') or conv.get('role')
value = conv.get('value') or conv.get('content')
converted_conversations.append({
'from': role,
'value': value
})
result = {
'conversations': converted_conversations
}
# Add system prompt if exists
if 'system' in sample and sample['system']:
result['system'] = sample['system']
# Add tools if exists
if 'tools' in sample and sample['tools']:
result['tools'] = sample['tools']
return result
def main():
print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")
total_samples = len(dataset)
print(f"Total samples in dataset: {total_samples}")
# Get the last 20k samples
start_idx = max(0, total_samples - 20000)
last_20k = dataset.select(range(start_idx, total_samples))
print(f"Processing last 20k samples (from index {start_idx} to {total_samples})")
# Filter samples with tool calling
tool_calling_samples = []
for sample in tqdm(last_20k, desc="Filtering samples with tool calling"):
conversations = sample.get('conversations', [])
if has_tool_calling(conversations):
converted = convert_to_llamafactory_format(sample)
tool_calling_samples.append(converted)
print(f"\nFound {len(tool_calling_samples)} samples with tool calling")
# Select up to 10k samples
if len(tool_calling_samples) > 10000:
selected_samples = tool_calling_samples[:10000]
print(f"Selected first 10,000 samples for batch5")
else:
selected_samples = tool_calling_samples
print(f"Using all {len(selected_samples)} samples for batch5")
# Save to file
output_file = "data/dolci_10k_with_tool_call_batch5.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(selected_samples, f, ensure_ascii=False, indent=2)
print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")
# Print a sample for verification
if selected_samples:
print("\nSample entry:")
print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:500] + "...")
if __name__ == "__main__":
main()