llm / check_length_simple.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
使用字符长度估算样本长度(粗略估计)
字符数 * 1.5 ≈ token 数(保守估计)
"""
import json
from tqdm import tqdm
def calculate_char_length(sample):
"""计算样本的字符长度"""
total_length = 0
# 添加 system
if 'system' in sample and sample['system']:
total_length += len(sample['system'])
# 添加 tools
if 'tools' in sample and sample['tools']:
total_length += len(sample['tools'])
# 添加 conversations
for conv in sample['conversations']:
content = conv.get('value', '')
total_length += len(content)
return total_length
def main():
print("Loading data...")
input_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json"
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"Total samples: {len(data)}")
print("Calculating character lengths...")
length_stats = []
for sample in tqdm(data):
char_length = calculate_char_length(sample)
length_stats.append(char_length)
length_stats.sort()
print(f"\nCharacter length statistics:")
print(f"- Min: {min(length_stats):,}")
print(f"- Max: {max(length_stats):,}")
print(f"- Median: {length_stats[len(length_stats)//2]:,}")
print(f"- 90th percentile: {length_stats[int(len(length_stats)*0.9)]:,}")
print(f"- 95th percentile: {length_stats[int(len(length_stats)*0.95)]:,}")
print(f"- 99th percentile: {length_stats[int(len(length_stats)*0.99)]:,}")
# 估算 token 长度(字符数 * 1.5)
print(f"\nEstimated token length (chars * 1.5):")
max_token = 65536
# 估算需要的字符长度阈值
char_threshold = int(max_token / 1.5)
print(f"- Character threshold for {max_token} tokens: {char_threshold:,}")
exceeding = sum(1 for l in length_stats if l * 1.5 > max_token)
print(f"- Samples estimated to exceed {max_token} tokens: {exceeding} ({exceeding/len(length_stats)*100:.2f}%)")
# 过滤
print(f"\nFiltering samples with estimated tokens ≤ {max_token}...")
filtered_data = []
for sample in tqdm(data):
char_length = calculate_char_length(sample)
estimated_tokens = char_length * 1.5
if estimated_tokens <= max_token:
filtered_data.append(sample)
print(f"\nResults:")
print(f"- Original samples: {len(data)}")
print(f"- Filtered samples: {len(filtered_data)}")
print(f"- Removed: {len(data) - len(filtered_data)}")
# 保存
output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k_filtered.json"
print(f"\nSaving to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(filtered_data, f, ensure_ascii=False, indent=2)
# 统计
has_func = sum(1 for s in filtered_data if any(c['from'] == 'function_call' for c in s['conversations']))
has_tools = sum(1 for s in filtered_data if 'tools' in s and s['tools'])
print(f"\nFiltered data statistics:")
print(f"- Total: {len(filtered_data)}")
print(f"- With function_call: {has_func}")
print(f"- With tools: {has_tools}")
print(f"- Without function_call: {len(filtered_data) - has_func}")
import os
file_size = os.path.getsize(output_file) / (1024 * 1024)
print(f"- File size: {file_size:.2f} MB")
if __name__ == "__main__":
main()