#!/usr/bin/env python3 """ Extract 50k samples without tool calls from Dolci-Instruct-SFT dataset """ import os import json import pyarrow.parquet as pq from pathlib import Path # Source and destination paths SOURCE_DIR = "/shared_workspace_mfs/ximing/datasets/Dolci-Instruct-SFT/data" DEST_DIR = "/shared_workspace_mfs/ximing/LLaMA-Factory/data" OUTPUT_FILE = "dolci_50k_no_tool_call.json" TARGET_COUNT = 50000 def has_tool_call(conversations): """Check if the conversation contains any tool calls""" if not conversations: return False for msg in conversations: role = msg.get("role", "") content = msg.get("content", "") # Check for tool/function role if role in ["tool", "function"]: return True # Check for tool_calls in assistant messages if "tool_calls" in msg and msg["tool_calls"]: return True # Check for function_call in assistant messages (old format) if "function_call" in msg and msg["function_call"]: return True return False def main(): # Get all parquet files sorted parquet_files = sorted(Path(SOURCE_DIR).glob("train-*.parquet")) print(f"Found {len(parquet_files)} parquet files") no_tool_call_samples = [] total_processed = 0 for parquet_file in parquet_files: print(f"\nProcessing: {parquet_file.name}") # Read parquet file table = pq.read_table(parquet_file) data = table.to_pydict() num_rows = len(data[list(data.keys())[0]]) print(f" Total rows: {num_rows}") # Process each row for idx in range(num_rows): total_processed += 1 # Extract row data row_data = {key: data[key][idx] for key in data.keys()} conversations = row_data.get("conversations", []) # Check if this sample has tool calls if not has_tool_call(conversations): no_tool_call_samples.append(row_data) # Check if we've reached the target if len(no_tool_call_samples) >= TARGET_COUNT: print(f"\n✓ Reached target of {TARGET_COUNT} samples!") break # Progress indicator if (idx + 1) % 1000 == 0: print(f" Progress: {idx + 1}/{num_rows}, no tool call: {len(no_tool_call_samples)}") print(f" No tool call samples so far: {len(no_tool_call_samples)}") # Stop if we've reached the target if len(no_tool_call_samples) >= TARGET_COUNT: break # Save to JSON file output_path = os.path.join(DEST_DIR, OUTPUT_FILE) print(f"\nSaving {len(no_tool_call_samples)} samples to: {output_path}") with open(output_path, 'w', encoding='utf-8') as f: json.dump(no_tool_call_samples, f, ensure_ascii=False, indent=2) print(f"\n{'='*60}") print(f"Summary:") print(f" Total processed: {total_processed}") print(f" No tool call samples: {len(no_tool_call_samples)}") print(f" Output file: {output_path}") print(f" File size: {os.path.getsize(output_path) / (1024*1024):.2f} MB") print(f"{'='*60}") if __name__ == "__main__": main()