File size: 6,619 Bytes
db704cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
#!/usr/bin/env python3
"""
Fix batch5 by correctly converting environment role to observation.
"""
import json
from datasets import load_dataset
from tqdm import tqdm
def convert_to_llamafactory_format(sample):
"""
Convert from Dolci format to LlamaFactory format.
Dolci format (messages):
- role: system/user/assistant/environment
- content: text content
- function_calls: function call string (in assistant messages)
- functions: available functions JSON string (in system message)
LlamaFactory format (conversations):
- from: human/gpt/function_call/observation/system
- value: text or JSON
"""
messages = sample.get('messages', [])
conversations = []
tools = None
system_prompt = None
for i, msg in enumerate(messages):
role = msg.get('role', '')
content = msg.get('content', '')
function_calls = msg.get('function_calls')
functions = msg.get('functions')
# Extract tools from first system message
if role == 'system':
if functions and not tools:
tools = functions
if content:
system_prompt = content
continue
# Convert roles
if role == 'user':
conversations.append({
'from': 'human',
'value': content
})
elif role == 'assistant':
# Check if this message contains function calls
if function_calls:
# This is a function call
conversations.append({
'from': 'function_call',
'value': function_calls
})
elif content:
# This is a regular assistant response
conversations.append({
'from': 'gpt',
'value': content
})
elif role == 'environment':
# This is the tool result / observation
conversations.append({
'from': 'observation',
'value': content
})
result = {'conversations': conversations}
if system_prompt:
result['system'] = system_prompt
if tools:
result['tools'] = tools
return result
def get_sample_hash(sample):
"""Create a hash for a sample to identify duplicates."""
messages = sample.get('messages', [])
for msg in messages:
if msg.get('role') == 'user':
return hash(msg.get('content', ''))
return None
def has_tool_calling(messages):
"""Check if messages contain function_call."""
for msg in messages:
if msg.get('function_calls'):
return True
return False
def main():
print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")
total_samples = len(dataset)
print(f"Total samples in dataset: {total_samples}")
# Load existing batch1-4 to avoid duplicates
print("\nLoading existing batches to avoid duplicates...")
existing_hashes = set()
for batch_num in range(1, 5):
batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json"
try:
with open(batch_file, 'r', encoding='utf-8') as f:
batch_data = json.load(f)
for sample in batch_data:
conversations = sample.get('conversations', [])
for conv in conversations:
if conv.get('from') == 'human':
sample_hash = hash(conv.get('value', ''))
existing_hashes.add(sample_hash)
break
print(f" Loaded batch{batch_num}: {len(batch_data)} samples")
except FileNotFoundError:
print(f" Warning: {batch_file} not found, skipping...")
print(f"Total existing samples to avoid: {len(existing_hashes)}")
# Get last 20k samples
start_idx = max(0, total_samples - 20000)
last_20k = dataset.select(range(start_idx, total_samples))
print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})")
# Filter samples with tool calling and proper format
tool_calling_samples = []
for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")):
messages = sample.get('messages', [])
if has_tool_calling(messages):
sample_hash = get_sample_hash(sample)
# Skip if already in batch1-4
if sample_hash not in existing_hashes:
converted = convert_to_llamafactory_format(sample)
# Verify conversion has proper structure with observation
conversations = converted.get('conversations', [])
roles = [c['from'] for c in conversations]
# Check if has both function_call and observation
if 'function_call' in roles and 'observation' in roles:
tool_calling_samples.append(converted)
print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples with proper format")
# Select up to 10k samples
if len(tool_calling_samples) > 10000:
selected_samples = tool_calling_samples[:10000]
print(f"Selected first 10,000 samples for batch5")
else:
selected_samples = tool_calling_samples
print(f"Using all {len(selected_samples)} samples for batch5")
if not selected_samples:
print("\n❌ No new tool calling samples found!")
return
# Save to file
output_file = "data/dolci_10k_with_tool_call_batch5.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(selected_samples, f, ensure_ascii=False, indent=2)
print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")
# Verify format
print("\n=== Verifying format ===")
role_patterns = {}
for sample in selected_samples[:100]:
roles = [c['from'] for c in sample['conversations']]
pattern = ' -> '.join(roles)
role_patterns[pattern] = role_patterns.get(pattern, 0) + 1
print("Top patterns in first 100 samples:")
for pattern, count in sorted(role_patterns.items(), key=lambda x: -x[1])[:5]:
print(f" [{count:3d}] {pattern}")
# Print a sample for verification
if selected_samples:
print("\nSample entry:")
print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...")
if __name__ == "__main__":
main()
|