|
|
|
|
|
""" |
|
|
Extract batch5 from the last 20k samples of allenai/Dolci-Instruct-SFT-Tool-Use dataset. |
|
|
Select samples with tool calling (function_call and observation). |
|
|
""" |
|
|
|
|
|
import json |
|
|
from datasets import load_dataset |
|
|
from tqdm import tqdm |
|
|
|
|
|
def has_tool_calling(conversations): |
|
|
"""Check if conversations contain function_call and observation.""" |
|
|
roles = [conv.get('from') or conv.get('role') for conv in conversations] |
|
|
return 'function_call' in roles and 'observation' in roles |
|
|
|
|
|
def convert_to_llamafactory_format(sample): |
|
|
"""Convert dataset sample to LlamaFactory format.""" |
|
|
conversations = sample.get('conversations', []) |
|
|
|
|
|
|
|
|
converted_conversations = [] |
|
|
for conv in conversations: |
|
|
role = conv.get('from') or conv.get('role') |
|
|
value = conv.get('value') or conv.get('content') |
|
|
converted_conversations.append({ |
|
|
'from': role, |
|
|
'value': value |
|
|
}) |
|
|
|
|
|
result = { |
|
|
'conversations': converted_conversations |
|
|
} |
|
|
|
|
|
|
|
|
if 'system' in sample and sample['system']: |
|
|
result['system'] = sample['system'] |
|
|
|
|
|
|
|
|
if 'tools' in sample and sample['tools']: |
|
|
result['tools'] = sample['tools'] |
|
|
|
|
|
return result |
|
|
|
|
|
def main(): |
|
|
print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") |
|
|
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
|
|
|
total_samples = len(dataset) |
|
|
print(f"Total samples in dataset: {total_samples}") |
|
|
|
|
|
|
|
|
start_idx = max(0, total_samples - 20000) |
|
|
last_20k = dataset.select(range(start_idx, total_samples)) |
|
|
print(f"Processing last 20k samples (from index {start_idx} to {total_samples})") |
|
|
|
|
|
|
|
|
tool_calling_samples = [] |
|
|
for sample in tqdm(last_20k, desc="Filtering samples with tool calling"): |
|
|
conversations = sample.get('conversations', []) |
|
|
if has_tool_calling(conversations): |
|
|
converted = convert_to_llamafactory_format(sample) |
|
|
tool_calling_samples.append(converted) |
|
|
|
|
|
print(f"\nFound {len(tool_calling_samples)} samples with tool calling") |
|
|
|
|
|
|
|
|
if len(tool_calling_samples) > 10000: |
|
|
selected_samples = tool_calling_samples[:10000] |
|
|
print(f"Selected first 10,000 samples for batch5") |
|
|
else: |
|
|
selected_samples = tool_calling_samples |
|
|
print(f"Using all {len(selected_samples)} samples for batch5") |
|
|
|
|
|
|
|
|
output_file = "data/dolci_10k_with_tool_call_batch5.json" |
|
|
print(f"\nSaving to {output_file}...") |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(selected_samples, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"✓ Successfully created batch5 with {len(selected_samples)} samples") |
|
|
|
|
|
|
|
|
if selected_samples: |
|
|
print("\nSample entry:") |
|
|
print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:500] + "...") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|