|
|
|
|
|
""" |
|
|
Extract batch5 from Dolci dataset with correct field names. |
|
|
""" |
|
|
|
|
|
import json |
|
|
from datasets import load_dataset |
|
|
from tqdm import tqdm |
|
|
|
|
|
def has_tool_calling(messages): |
|
|
"""Check if messages contain function_call.""" |
|
|
for msg in messages: |
|
|
if msg.get('function_calls'): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def convert_to_llamafactory_format(sample): |
|
|
""" |
|
|
Convert from Dolci format to LlamaFactory format. |
|
|
|
|
|
Dolci format (messages): |
|
|
- role: system/user/assistant |
|
|
- content: text content |
|
|
- function_calls: function call JSON string |
|
|
- functions: available functions JSON string |
|
|
|
|
|
LlamaFactory format (conversations): |
|
|
- from: human/gpt/function_call/observation/system |
|
|
- value: text or JSON |
|
|
""" |
|
|
messages = sample.get('messages', []) |
|
|
conversations = [] |
|
|
tools = None |
|
|
system_prompt = None |
|
|
|
|
|
for i, msg in enumerate(messages): |
|
|
role = msg.get('role', '') |
|
|
content = msg.get('content', '') |
|
|
function_calls = msg.get('function_calls') |
|
|
functions = msg.get('functions') |
|
|
|
|
|
|
|
|
if role == 'system' and functions and not tools: |
|
|
tools = functions |
|
|
system_prompt = content |
|
|
continue |
|
|
|
|
|
|
|
|
if role == 'user': |
|
|
conversations.append({ |
|
|
'from': 'human', |
|
|
'value': content |
|
|
}) |
|
|
elif role == 'assistant': |
|
|
|
|
|
if function_calls: |
|
|
|
|
|
conversations.append({ |
|
|
'from': 'function_call', |
|
|
'value': function_calls |
|
|
}) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if content: |
|
|
conversations.append({ |
|
|
'from': 'gpt', |
|
|
'value': content |
|
|
}) |
|
|
elif role == 'function': |
|
|
|
|
|
conversations.append({ |
|
|
'from': 'observation', |
|
|
'value': content |
|
|
}) |
|
|
|
|
|
result = {'conversations': conversations} |
|
|
|
|
|
if system_prompt: |
|
|
result['system'] = system_prompt |
|
|
if tools: |
|
|
result['tools'] = tools |
|
|
|
|
|
return result |
|
|
|
|
|
def get_sample_hash(sample): |
|
|
"""Create a hash for a sample to identify duplicates.""" |
|
|
messages = sample.get('messages', []) |
|
|
for msg in messages: |
|
|
if msg.get('role') == 'user': |
|
|
return hash(msg.get('content', '')) |
|
|
return None |
|
|
|
|
|
def main(): |
|
|
print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") |
|
|
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
|
|
|
total_samples = len(dataset) |
|
|
print(f"Total samples in dataset: {total_samples}") |
|
|
|
|
|
|
|
|
print("\nLoading existing batches to avoid duplicates...") |
|
|
existing_hashes = set() |
|
|
for batch_num in range(1, 5): |
|
|
batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json" |
|
|
try: |
|
|
with open(batch_file, 'r', encoding='utf-8') as f: |
|
|
batch_data = json.load(f) |
|
|
for sample in batch_data: |
|
|
conversations = sample.get('conversations', []) |
|
|
for conv in conversations: |
|
|
if conv.get('from') == 'human': |
|
|
sample_hash = hash(conv.get('value', '')) |
|
|
existing_hashes.add(sample_hash) |
|
|
break |
|
|
print(f" Loaded batch{batch_num}: {len(batch_data)} samples") |
|
|
except FileNotFoundError: |
|
|
print(f" Warning: {batch_file} not found, skipping...") |
|
|
|
|
|
print(f"Total existing samples to avoid: {len(existing_hashes)}") |
|
|
|
|
|
|
|
|
start_idx = max(0, total_samples - 20000) |
|
|
last_20k = dataset.select(range(start_idx, total_samples)) |
|
|
print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})") |
|
|
|
|
|
|
|
|
tool_calling_samples = [] |
|
|
for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")): |
|
|
messages = sample.get('messages', []) |
|
|
if has_tool_calling(messages): |
|
|
sample_hash = get_sample_hash(sample) |
|
|
|
|
|
|
|
|
if sample_hash not in existing_hashes: |
|
|
converted = convert_to_llamafactory_format(sample) |
|
|
|
|
|
if converted.get('conversations'): |
|
|
tool_calling_samples.append(converted) |
|
|
|
|
|
print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples") |
|
|
|
|
|
|
|
|
if len(tool_calling_samples) > 10000: |
|
|
selected_samples = tool_calling_samples[:10000] |
|
|
print(f"Selected first 10,000 samples for batch5") |
|
|
else: |
|
|
selected_samples = tool_calling_samples |
|
|
print(f"Using all {len(selected_samples)} samples for batch5") |
|
|
|
|
|
if not selected_samples: |
|
|
print("\n❌ No new tool calling samples found in last 20k!") |
|
|
print("Trying from entire dataset...") |
|
|
|
|
|
|
|
|
tool_calling_samples = [] |
|
|
for idx, sample in enumerate(tqdm(dataset, desc="Scanning entire dataset")): |
|
|
messages = sample.get('messages', []) |
|
|
if has_tool_calling(messages): |
|
|
sample_hash = get_sample_hash(sample) |
|
|
|
|
|
if sample_hash not in existing_hashes: |
|
|
converted = convert_to_llamafactory_format(sample) |
|
|
if converted.get('conversations'): |
|
|
tool_calling_samples.append(converted) |
|
|
|
|
|
if len(tool_calling_samples) >= 10000: |
|
|
break |
|
|
|
|
|
selected_samples = tool_calling_samples[:10000] |
|
|
print(f"Found {len(selected_samples)} new samples from entire dataset") |
|
|
|
|
|
if not selected_samples: |
|
|
print("\n❌ No new tool calling samples available!") |
|
|
return |
|
|
|
|
|
|
|
|
output_file = "data/dolci_10k_with_tool_call_batch5.json" |
|
|
print(f"\nSaving to {output_file}...") |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(selected_samples, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"✓ Successfully created batch5 with {len(selected_samples)} samples") |
|
|
|
|
|
|
|
|
if selected_samples: |
|
|
print("\nSample entry:") |
|
|
print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|