llm / analyze_and_extract_batch5.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
Analyze Dolci dataset and extract batch5.
Find tool calling samples that haven't been used in batch1-4.
"""
import json
from datasets import load_dataset
from tqdm import tqdm
def has_tool_calling(conversations):
"""Check if conversations contain function_call and observation."""
roles = [conv.get('from') or conv.get('role') for conv in conversations]
return 'function_call' in roles and 'observation' in roles
def convert_to_llamafactory_format(sample):
"""Convert dataset sample to LlamaFactory format."""
conversations = sample.get('conversations', [])
# Convert role names if needed
converted_conversations = []
for conv in conversations:
role = conv.get('from') or conv.get('role')
value = conv.get('value') or conv.get('content')
converted_conversations.append({
'from': role,
'value': value
})
result = {
'conversations': converted_conversations
}
# Add system prompt if exists
if 'system' in sample and sample['system']:
result['system'] = sample['system']
# Add tools if exists
if 'tools' in sample and sample['tools']:
result['tools'] = sample['tools']
return result
def get_sample_hash(sample):
"""Create a hash for a sample to identify duplicates."""
conversations = sample.get('conversations', [])
if conversations:
first_msg = conversations[0].get('value', '')
return hash(first_msg)
return None
def main():
print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")
total_samples = len(dataset)
print(f"Total samples in dataset: {total_samples}")
# Load existing batch1-4 to avoid duplicates
print("\nLoading existing batches to avoid duplicates...")
existing_hashes = set()
for batch_num in range(1, 5):
batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json"
try:
with open(batch_file, 'r', encoding='utf-8') as f:
batch_data = json.load(f)
for sample in batch_data:
sample_hash = get_sample_hash(sample)
if sample_hash:
existing_hashes.add(sample_hash)
print(f" Loaded batch{batch_num}: {len(batch_data)} samples")
except FileNotFoundError:
print(f" Warning: {batch_file} not found, skipping...")
print(f"Total existing samples to avoid: {len(existing_hashes)}")
# Analyze the dataset distribution
print("\nAnalyzing dataset for tool calling samples...")
tool_calling_samples = []
for idx, sample in enumerate(tqdm(dataset, desc="Scanning entire dataset")):
conversations = sample.get('conversations', [])
if has_tool_calling(conversations):
sample_hash = get_sample_hash(sample)
# Skip if already in batch1-4
if sample_hash not in existing_hashes:
converted = convert_to_llamafactory_format(sample)
tool_calling_samples.append({
'index': idx,
'data': converted
})
print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples (excluding batch1-4)")
# Select up to 10k samples
if len(tool_calling_samples) > 10000:
selected_samples = [s['data'] for s in tool_calling_samples[:10000]]
print(f"Selected first 10,000 samples for batch5")
print(f"Index range: {tool_calling_samples[0]['index']} to {tool_calling_samples[9999]['index']}")
else:
selected_samples = [s['data'] for s in tool_calling_samples]
print(f"Using all {len(selected_samples)} new samples for batch5")
if tool_calling_samples:
print(f"Index range: {tool_calling_samples[0]['index']} to {tool_calling_samples[-1]['index']}")
if not selected_samples:
print("\n❌ No new tool calling samples found!")
return
# Save to file
output_file = "data/dolci_10k_with_tool_call_batch5.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(selected_samples, f, ensure_ascii=False, indent=2)
print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")
# Print a sample for verification
if selected_samples:
print("\nSample entry:")
print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:800] + "...")
if __name__ == "__main__":
main()