#!/usr/bin/env python3 """ 使用字符长度估算样本长度(粗略估计) 字符数 * 1.5 ≈ token 数(保守估计) """ import json from tqdm import tqdm def calculate_char_length(sample): """计算样本的字符长度""" total_length = 0 # 添加 system if 'system' in sample and sample['system']: total_length += len(sample['system']) # 添加 tools if 'tools' in sample and sample['tools']: total_length += len(sample['tools']) # 添加 conversations for conv in sample['conversations']: content = conv.get('value', '') total_length += len(content) return total_length def main(): print("Loading data...") input_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json" with open(input_file, 'r', encoding='utf-8') as f: data = json.load(f) print(f"Total samples: {len(data)}") print("Calculating character lengths...") length_stats = [] for sample in tqdm(data): char_length = calculate_char_length(sample) length_stats.append(char_length) length_stats.sort() print(f"\nCharacter length statistics:") print(f"- Min: {min(length_stats):,}") print(f"- Max: {max(length_stats):,}") print(f"- Median: {length_stats[len(length_stats)//2]:,}") print(f"- 90th percentile: {length_stats[int(len(length_stats)*0.9)]:,}") print(f"- 95th percentile: {length_stats[int(len(length_stats)*0.95)]:,}") print(f"- 99th percentile: {length_stats[int(len(length_stats)*0.99)]:,}") # 估算 token 长度(字符数 * 1.5) print(f"\nEstimated token length (chars * 1.5):") max_token = 65536 # 估算需要的字符长度阈值 char_threshold = int(max_token / 1.5) print(f"- Character threshold for {max_token} tokens: {char_threshold:,}") exceeding = sum(1 for l in length_stats if l * 1.5 > max_token) print(f"- Samples estimated to exceed {max_token} tokens: {exceeding} ({exceeding/len(length_stats)*100:.2f}%)") # 过滤 print(f"\nFiltering samples with estimated tokens ≤ {max_token}...") filtered_data = [] for sample in tqdm(data): char_length = calculate_char_length(sample) estimated_tokens = char_length * 1.5 if estimated_tokens <= max_token: filtered_data.append(sample) print(f"\nResults:") print(f"- Original samples: {len(data)}") print(f"- Filtered samples: {len(filtered_data)}") print(f"- Removed: {len(data) - len(filtered_data)}") # 保存 output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k_filtered.json" print(f"\nSaving to {output_file}...") with open(output_file, 'w', encoding='utf-8') as f: json.dump(filtered_data, f, ensure_ascii=False, indent=2) # 统计 has_func = sum(1 for s in filtered_data if any(c['from'] == 'function_call' for c in s['conversations'])) has_tools = sum(1 for s in filtered_data if 'tools' in s and s['tools']) print(f"\nFiltered data statistics:") print(f"- Total: {len(filtered_data)}") print(f"- With function_call: {has_func}") print(f"- With tools: {has_tools}") print(f"- Without function_call: {len(filtered_data) - has_func}") import os file_size = os.path.getsize(output_file) / (1024 * 1024) print(f"- File size: {file_size:.2f} MB") if __name__ == "__main__": main()