File size: 3,102 Bytes
db704cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
"""
Extract batch5 from the last 20k samples of allenai/Dolci-Instruct-SFT-Tool-Use dataset.
Select samples with tool calling (function_call and observation).
"""

import json
from datasets import load_dataset
from tqdm import tqdm

def has_tool_calling(conversations):
    """Check if conversations contain function_call and observation."""
    roles = [conv.get('from') or conv.get('role') for conv in conversations]
    return 'function_call' in roles and 'observation' in roles

def convert_to_llamafactory_format(sample):
    """Convert dataset sample to LlamaFactory format."""
    conversations = sample.get('conversations', [])

    # Convert role names if needed
    converted_conversations = []
    for conv in conversations:
        role = conv.get('from') or conv.get('role')
        value = conv.get('value') or conv.get('content')
        converted_conversations.append({
            'from': role,
            'value': value
        })

    result = {
        'conversations': converted_conversations
    }

    # Add system prompt if exists
    if 'system' in sample and sample['system']:
        result['system'] = sample['system']

    # Add tools if exists
    if 'tools' in sample and sample['tools']:
        result['tools'] = sample['tools']

    return result

def main():
    print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
    dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")

    total_samples = len(dataset)
    print(f"Total samples in dataset: {total_samples}")

    # Get the last 20k samples
    start_idx = max(0, total_samples - 20000)
    last_20k = dataset.select(range(start_idx, total_samples))
    print(f"Processing last 20k samples (from index {start_idx} to {total_samples})")

    # Filter samples with tool calling
    tool_calling_samples = []
    for sample in tqdm(last_20k, desc="Filtering samples with tool calling"):
        conversations = sample.get('conversations', [])
        if has_tool_calling(conversations):
            converted = convert_to_llamafactory_format(sample)
            tool_calling_samples.append(converted)

    print(f"\nFound {len(tool_calling_samples)} samples with tool calling")

    # Select up to 10k samples
    if len(tool_calling_samples) > 10000:
        selected_samples = tool_calling_samples[:10000]
        print(f"Selected first 10,000 samples for batch5")
    else:
        selected_samples = tool_calling_samples
        print(f"Using all {len(selected_samples)} samples for batch5")

    # Save to file
    output_file = "data/dolci_10k_with_tool_call_batch5.json"
    print(f"\nSaving to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(selected_samples, f, ensure_ascii=False, indent=2)

    print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")

    # Print a sample for verification
    if selected_samples:
        print("\nSample entry:")
        print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:500] + "...")

if __name__ == "__main__":
    main()