File size: 6,619 Bytes
db704cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
"""
Fix batch5 by correctly converting environment role to observation.
"""

import json
from datasets import load_dataset
from tqdm import tqdm

def convert_to_llamafactory_format(sample):
    """
    Convert from Dolci format to LlamaFactory format.

    Dolci format (messages):
    - role: system/user/assistant/environment
    - content: text content
    - function_calls: function call string (in assistant messages)
    - functions: available functions JSON string (in system message)

    LlamaFactory format (conversations):
    - from: human/gpt/function_call/observation/system
    - value: text or JSON
    """
    messages = sample.get('messages', [])
    conversations = []
    tools = None
    system_prompt = None

    for i, msg in enumerate(messages):
        role = msg.get('role', '')
        content = msg.get('content', '')
        function_calls = msg.get('function_calls')
        functions = msg.get('functions')

        # Extract tools from first system message
        if role == 'system':
            if functions and not tools:
                tools = functions
            if content:
                system_prompt = content
            continue

        # Convert roles
        if role == 'user':
            conversations.append({
                'from': 'human',
                'value': content
            })
        elif role == 'assistant':
            # Check if this message contains function calls
            if function_calls:
                # This is a function call
                conversations.append({
                    'from': 'function_call',
                    'value': function_calls
                })
            elif content:
                # This is a regular assistant response
                conversations.append({
                    'from': 'gpt',
                    'value': content
                })
        elif role == 'environment':
            # This is the tool result / observation
            conversations.append({
                'from': 'observation',
                'value': content
            })

    result = {'conversations': conversations}

    if system_prompt:
        result['system'] = system_prompt
    if tools:
        result['tools'] = tools

    return result

def get_sample_hash(sample):
    """Create a hash for a sample to identify duplicates."""
    messages = sample.get('messages', [])
    for msg in messages:
        if msg.get('role') == 'user':
            return hash(msg.get('content', ''))
    return None

def has_tool_calling(messages):
    """Check if messages contain function_call."""
    for msg in messages:
        if msg.get('function_calls'):
            return True
    return False

def main():
    print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
    dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")

    total_samples = len(dataset)
    print(f"Total samples in dataset: {total_samples}")

    # Load existing batch1-4 to avoid duplicates
    print("\nLoading existing batches to avoid duplicates...")
    existing_hashes = set()
    for batch_num in range(1, 5):
        batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json"
        try:
            with open(batch_file, 'r', encoding='utf-8') as f:
                batch_data = json.load(f)
                for sample in batch_data:
                    conversations = sample.get('conversations', [])
                    for conv in conversations:
                        if conv.get('from') == 'human':
                            sample_hash = hash(conv.get('value', ''))
                            existing_hashes.add(sample_hash)
                            break
                print(f"  Loaded batch{batch_num}: {len(batch_data)} samples")
        except FileNotFoundError:
            print(f"  Warning: {batch_file} not found, skipping...")

    print(f"Total existing samples to avoid: {len(existing_hashes)}")

    # Get last 20k samples
    start_idx = max(0, total_samples - 20000)
    last_20k = dataset.select(range(start_idx, total_samples))
    print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})")

    # Filter samples with tool calling and proper format
    tool_calling_samples = []
    for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")):
        messages = sample.get('messages', [])
        if has_tool_calling(messages):
            sample_hash = get_sample_hash(sample)

            # Skip if already in batch1-4
            if sample_hash not in existing_hashes:
                converted = convert_to_llamafactory_format(sample)

                # Verify conversion has proper structure with observation
                conversations = converted.get('conversations', [])
                roles = [c['from'] for c in conversations]

                # Check if has both function_call and observation
                if 'function_call' in roles and 'observation' in roles:
                    tool_calling_samples.append(converted)

    print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples with proper format")

    # Select up to 10k samples
    if len(tool_calling_samples) > 10000:
        selected_samples = tool_calling_samples[:10000]
        print(f"Selected first 10,000 samples for batch5")
    else:
        selected_samples = tool_calling_samples
        print(f"Using all {len(selected_samples)} samples for batch5")

    if not selected_samples:
        print("\n❌ No new tool calling samples found!")
        return

    # Save to file
    output_file = "data/dolci_10k_with_tool_call_batch5.json"
    print(f"\nSaving to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(selected_samples, f, ensure_ascii=False, indent=2)

    print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")

    # Verify format
    print("\n=== Verifying format ===")
    role_patterns = {}
    for sample in selected_samples[:100]:
        roles = [c['from'] for c in sample['conversations']]
        pattern = ' -> '.join(roles)
        role_patterns[pattern] = role_patterns.get(pattern, 0) + 1

    print("Top patterns in first 100 samples:")
    for pattern, count in sorted(role_patterns.items(), key=lambda x: -x[1])[:5]:
        print(f"  [{count:3d}] {pattern}")

    # Print a sample for verification
    if selected_samples:
        print("\nSample entry:")
        print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...")

if __name__ == "__main__":
    main()