File size: 6,881 Bytes
db704cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python3
"""
Extract batch5 from Dolci dataset with correct field names.
"""

import json
from datasets import load_dataset
from tqdm import tqdm

def has_tool_calling(messages):
    """Check if messages contain function_call."""
    for msg in messages:
        if msg.get('function_calls'):
            return True
    return False

def convert_to_llamafactory_format(sample):
    """
    Convert from Dolci format to LlamaFactory format.

    Dolci format (messages):
    - role: system/user/assistant
    - content: text content
    - function_calls: function call JSON string
    - functions: available functions JSON string

    LlamaFactory format (conversations):
    - from: human/gpt/function_call/observation/system
    - value: text or JSON
    """
    messages = sample.get('messages', [])
    conversations = []
    tools = None
    system_prompt = None

    for i, msg in enumerate(messages):
        role = msg.get('role', '')
        content = msg.get('content', '')
        function_calls = msg.get('function_calls')
        functions = msg.get('functions')

        # Extract tools from first system message
        if role == 'system' and functions and not tools:
            tools = functions
            system_prompt = content
            continue

        # Convert roles
        if role == 'user':
            conversations.append({
                'from': 'human',
                'value': content
            })
        elif role == 'assistant':
            # Check if this message contains function calls
            if function_calls:
                # Add function_call message
                conversations.append({
                    'from': 'function_call',
                    'value': function_calls
                })
                # The next message should be the observation (tool result)
                # or we add the assistant response if no observation follows
            else:
                # Regular assistant message
                if content:
                    conversations.append({
                        'from': 'gpt',
                        'value': content
                    })
        elif role == 'function':
            # This is a tool/observation result
            conversations.append({
                'from': 'observation',
                'value': content
            })

    result = {'conversations': conversations}

    if system_prompt:
        result['system'] = system_prompt
    if tools:
        result['tools'] = tools

    return result

def get_sample_hash(sample):
    """Create a hash for a sample to identify duplicates."""
    messages = sample.get('messages', [])
    for msg in messages:
        if msg.get('role') == 'user':
            return hash(msg.get('content', ''))
    return None

def main():
    print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
    dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")

    total_samples = len(dataset)
    print(f"Total samples in dataset: {total_samples}")

    # Load existing batch1-4 to avoid duplicates
    print("\nLoading existing batches to avoid duplicates...")
    existing_hashes = set()
    for batch_num in range(1, 5):
        batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json"
        try:
            with open(batch_file, 'r', encoding='utf-8') as f:
                batch_data = json.load(f)
                for sample in batch_data:
                    conversations = sample.get('conversations', [])
                    for conv in conversations:
                        if conv.get('from') == 'human':
                            sample_hash = hash(conv.get('value', ''))
                            existing_hashes.add(sample_hash)
                            break
                print(f"  Loaded batch{batch_num}: {len(batch_data)} samples")
        except FileNotFoundError:
            print(f"  Warning: {batch_file} not found, skipping...")

    print(f"Total existing samples to avoid: {len(existing_hashes)}")

    # Get last 20k samples
    start_idx = max(0, total_samples - 20000)
    last_20k = dataset.select(range(start_idx, total_samples))
    print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})")

    # Filter samples with tool calling
    tool_calling_samples = []
    for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")):
        messages = sample.get('messages', [])
        if has_tool_calling(messages):
            sample_hash = get_sample_hash(sample)

            # Skip if already in batch1-4
            if sample_hash not in existing_hashes:
                converted = convert_to_llamafactory_format(sample)
                # Verify conversion has proper structure
                if converted.get('conversations'):
                    tool_calling_samples.append(converted)

    print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples")

    # Select up to 10k samples
    if len(tool_calling_samples) > 10000:
        selected_samples = tool_calling_samples[:10000]
        print(f"Selected first 10,000 samples for batch5")
    else:
        selected_samples = tool_calling_samples
        print(f"Using all {len(selected_samples)} samples for batch5")

    if not selected_samples:
        print("\n❌ No new tool calling samples found in last 20k!")
        print("Trying from entire dataset...")

        # Try entire dataset
        tool_calling_samples = []
        for idx, sample in enumerate(tqdm(dataset, desc="Scanning entire dataset")):
            messages = sample.get('messages', [])
            if has_tool_calling(messages):
                sample_hash = get_sample_hash(sample)

                if sample_hash not in existing_hashes:
                    converted = convert_to_llamafactory_format(sample)
                    if converted.get('conversations'):
                        tool_calling_samples.append(converted)

                        if len(tool_calling_samples) >= 10000:
                            break

        selected_samples = tool_calling_samples[:10000]
        print(f"Found {len(selected_samples)} new samples from entire dataset")

    if not selected_samples:
        print("\n❌ No new tool calling samples available!")
        return

    # Save to file
    output_file = "data/dolci_10k_with_tool_call_batch5.json"
    print(f"\nSaving to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(selected_samples, f, ensure_ascii=False, indent=2)

    print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")

    # Print a sample for verification
    if selected_samples:
        print("\nSample entry:")
        print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...")

if __name__ == "__main__":
    main()