| | |
| | """ |
| | Extract 50k samples without tool calls from Dolci-Instruct-SFT dataset |
| | """ |
| | import os |
| | import json |
| | import pyarrow.parquet as pq |
| | from pathlib import Path |
| |
|
| | |
| | SOURCE_DIR = "/shared_workspace_mfs/ximing/datasets/Dolci-Instruct-SFT/data" |
| | DEST_DIR = "/shared_workspace_mfs/ximing/LLaMA-Factory/data" |
| | OUTPUT_FILE = "dolci_50k_no_tool_call.json" |
| | TARGET_COUNT = 50000 |
| |
|
| | def has_tool_call(conversations): |
| | """Check if the conversation contains any tool calls""" |
| | if not conversations: |
| | return False |
| |
|
| | for msg in conversations: |
| | role = msg.get("role", "") |
| | content = msg.get("content", "") |
| |
|
| | |
| | if role in ["tool", "function"]: |
| | return True |
| |
|
| | |
| | if "tool_calls" in msg and msg["tool_calls"]: |
| | return True |
| |
|
| | |
| | if "function_call" in msg and msg["function_call"]: |
| | return True |
| |
|
| | return False |
| |
|
| | def main(): |
| | |
| | parquet_files = sorted(Path(SOURCE_DIR).glob("train-*.parquet")) |
| | print(f"Found {len(parquet_files)} parquet files") |
| |
|
| | no_tool_call_samples = [] |
| | total_processed = 0 |
| |
|
| | for parquet_file in parquet_files: |
| | print(f"\nProcessing: {parquet_file.name}") |
| |
|
| | |
| | table = pq.read_table(parquet_file) |
| | data = table.to_pydict() |
| | num_rows = len(data[list(data.keys())[0]]) |
| | print(f" Total rows: {num_rows}") |
| |
|
| | |
| | for idx in range(num_rows): |
| | total_processed += 1 |
| |
|
| | |
| | row_data = {key: data[key][idx] for key in data.keys()} |
| | conversations = row_data.get("conversations", []) |
| |
|
| | |
| | if not has_tool_call(conversations): |
| | no_tool_call_samples.append(row_data) |
| |
|
| | |
| | if len(no_tool_call_samples) >= TARGET_COUNT: |
| | print(f"\n✓ Reached target of {TARGET_COUNT} samples!") |
| | break |
| |
|
| | |
| | if (idx + 1) % 1000 == 0: |
| | print(f" Progress: {idx + 1}/{num_rows}, no tool call: {len(no_tool_call_samples)}") |
| |
|
| | print(f" No tool call samples so far: {len(no_tool_call_samples)}") |
| |
|
| | |
| | if len(no_tool_call_samples) >= TARGET_COUNT: |
| | break |
| |
|
| | |
| | output_path = os.path.join(DEST_DIR, OUTPUT_FILE) |
| | print(f"\nSaving {len(no_tool_call_samples)} samples to: {output_path}") |
| |
|
| | with open(output_path, 'w', encoding='utf-8') as f: |
| | json.dump(no_tool_call_samples, f, ensure_ascii=False, indent=2) |
| |
|
| | print(f"\n{'='*60}") |
| | print(f"Summary:") |
| | print(f" Total processed: {total_processed}") |
| | print(f" No tool call samples: {len(no_tool_call_samples)}") |
| | print(f" Output file: {output_path}") |
| | print(f" File size: {os.path.getsize(output_path) / (1024*1024):.2f} MB") |
| | print(f"{'='*60}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|