llm / fix_batch5.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
Fix batch5 by correctly converting environment role to observation.
"""
import json
from datasets import load_dataset
from tqdm import tqdm
def convert_to_llamafactory_format(sample):
"""
Convert from Dolci format to LlamaFactory format.
Dolci format (messages):
- role: system/user/assistant/environment
- content: text content
- function_calls: function call string (in assistant messages)
- functions: available functions JSON string (in system message)
LlamaFactory format (conversations):
- from: human/gpt/function_call/observation/system
- value: text or JSON
"""
messages = sample.get('messages', [])
conversations = []
tools = None
system_prompt = None
for i, msg in enumerate(messages):
role = msg.get('role', '')
content = msg.get('content', '')
function_calls = msg.get('function_calls')
functions = msg.get('functions')
# Extract tools from first system message
if role == 'system':
if functions and not tools:
tools = functions
if content:
system_prompt = content
continue
# Convert roles
if role == 'user':
conversations.append({
'from': 'human',
'value': content
})
elif role == 'assistant':
# Check if this message contains function calls
if function_calls:
# This is a function call
conversations.append({
'from': 'function_call',
'value': function_calls
})
elif content:
# This is a regular assistant response
conversations.append({
'from': 'gpt',
'value': content
})
elif role == 'environment':
# This is the tool result / observation
conversations.append({
'from': 'observation',
'value': content
})
result = {'conversations': conversations}
if system_prompt:
result['system'] = system_prompt
if tools:
result['tools'] = tools
return result
def get_sample_hash(sample):
"""Create a hash for a sample to identify duplicates."""
messages = sample.get('messages', [])
for msg in messages:
if msg.get('role') == 'user':
return hash(msg.get('content', ''))
return None
def has_tool_calling(messages):
"""Check if messages contain function_call."""
for msg in messages:
if msg.get('function_calls'):
return True
return False
def main():
print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...")
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train")
total_samples = len(dataset)
print(f"Total samples in dataset: {total_samples}")
# Load existing batch1-4 to avoid duplicates
print("\nLoading existing batches to avoid duplicates...")
existing_hashes = set()
for batch_num in range(1, 5):
batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json"
try:
with open(batch_file, 'r', encoding='utf-8') as f:
batch_data = json.load(f)
for sample in batch_data:
conversations = sample.get('conversations', [])
for conv in conversations:
if conv.get('from') == 'human':
sample_hash = hash(conv.get('value', ''))
existing_hashes.add(sample_hash)
break
print(f" Loaded batch{batch_num}: {len(batch_data)} samples")
except FileNotFoundError:
print(f" Warning: {batch_file} not found, skipping...")
print(f"Total existing samples to avoid: {len(existing_hashes)}")
# Get last 20k samples
start_idx = max(0, total_samples - 20000)
last_20k = dataset.select(range(start_idx, total_samples))
print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})")
# Filter samples with tool calling and proper format
tool_calling_samples = []
for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")):
messages = sample.get('messages', [])
if has_tool_calling(messages):
sample_hash = get_sample_hash(sample)
# Skip if already in batch1-4
if sample_hash not in existing_hashes:
converted = convert_to_llamafactory_format(sample)
# Verify conversion has proper structure with observation
conversations = converted.get('conversations', [])
roles = [c['from'] for c in conversations]
# Check if has both function_call and observation
if 'function_call' in roles and 'observation' in roles:
tool_calling_samples.append(converted)
print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples with proper format")
# Select up to 10k samples
if len(tool_calling_samples) > 10000:
selected_samples = tool_calling_samples[:10000]
print(f"Selected first 10,000 samples for batch5")
else:
selected_samples = tool_calling_samples
print(f"Using all {len(selected_samples)} samples for batch5")
if not selected_samples:
print("\n❌ No new tool calling samples found!")
return
# Save to file
output_file = "data/dolci_10k_with_tool_call_batch5.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(selected_samples, f, ensure_ascii=False, indent=2)
print(f"✓ Successfully created batch5 with {len(selected_samples)} samples")
# Verify format
print("\n=== Verifying format ===")
role_patterns = {}
for sample in selected_samples[:100]:
roles = [c['from'] for c in sample['conversations']]
pattern = ' -> '.join(roles)
role_patterns[pattern] = role_patterns.get(pattern, 0) + 1
print("Top patterns in first 100 samples:")
for pattern, count in sorted(role_patterns.items(), key=lambda x: -x[1])[:5]:
print(f" [{count:3d}] {pattern}")
# Print a sample for verification
if selected_samples:
print("\nSample entry:")
print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...")
if __name__ == "__main__":
main()