llm / extract_no_tool_call.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
Extract 50k samples without tool calls from Dolci-Instruct-SFT dataset
"""
import os
import json
import pyarrow.parquet as pq
from pathlib import Path
# Source and destination paths
SOURCE_DIR = "/shared_workspace_mfs/ximing/datasets/Dolci-Instruct-SFT/data"
DEST_DIR = "/shared_workspace_mfs/ximing/LLaMA-Factory/data"
OUTPUT_FILE = "dolci_50k_no_tool_call.json"
TARGET_COUNT = 50000
def has_tool_call(conversations):
"""Check if the conversation contains any tool calls"""
if not conversations:
return False
for msg in conversations:
role = msg.get("role", "")
content = msg.get("content", "")
# Check for tool/function role
if role in ["tool", "function"]:
return True
# Check for tool_calls in assistant messages
if "tool_calls" in msg and msg["tool_calls"]:
return True
# Check for function_call in assistant messages (old format)
if "function_call" in msg and msg["function_call"]:
return True
return False
def main():
# Get all parquet files sorted
parquet_files = sorted(Path(SOURCE_DIR).glob("train-*.parquet"))
print(f"Found {len(parquet_files)} parquet files")
no_tool_call_samples = []
total_processed = 0
for parquet_file in parquet_files:
print(f"\nProcessing: {parquet_file.name}")
# Read parquet file
table = pq.read_table(parquet_file)
data = table.to_pydict()
num_rows = len(data[list(data.keys())[0]])
print(f" Total rows: {num_rows}")
# Process each row
for idx in range(num_rows):
total_processed += 1
# Extract row data
row_data = {key: data[key][idx] for key in data.keys()}
conversations = row_data.get("conversations", [])
# Check if this sample has tool calls
if not has_tool_call(conversations):
no_tool_call_samples.append(row_data)
# Check if we've reached the target
if len(no_tool_call_samples) >= TARGET_COUNT:
print(f"\n✓ Reached target of {TARGET_COUNT} samples!")
break
# Progress indicator
if (idx + 1) % 1000 == 0:
print(f" Progress: {idx + 1}/{num_rows}, no tool call: {len(no_tool_call_samples)}")
print(f" No tool call samples so far: {len(no_tool_call_samples)}")
# Stop if we've reached the target
if len(no_tool_call_samples) >= TARGET_COUNT:
break
# Save to JSON file
output_path = os.path.join(DEST_DIR, OUTPUT_FILE)
print(f"\nSaving {len(no_tool_call_samples)} samples to: {output_path}")
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(no_tool_call_samples, f, ensure_ascii=False, indent=2)
print(f"\n{'='*60}")
print(f"Summary:")
print(f" Total processed: {total_processed}")
print(f" No tool call samples: {len(no_tool_call_samples)}")
print(f" Output file: {output_path}")
print(f" File size: {os.path.getsize(output_path) / (1024*1024):.2f} MB")
print(f"{'='*60}")
if __name__ == "__main__":
main()