llm / filter_by_length.py
dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
#!/usr/bin/env python3
"""
过滤出序列长度小于 65536 的样本
"""
import json
from transformers import AutoTokenizer
from tqdm import tqdm
def calculate_length(sample, tokenizer):
"""计算样本的 token 长度"""
total_length = 0
# 添加 system
if 'system' in sample and sample['system']:
total_length += len(tokenizer.encode(sample['system']))
# 添加 tools
if 'tools' in sample and sample['tools']:
total_length += len(tokenizer.encode(sample['tools']))
# 添加 conversations
for conv in sample['conversations']:
content = conv.get('value', '')
total_length += len(tokenizer.encode(content))
return total_length
def main():
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B-Base", trust_remote_code=True)
print("Loading data...")
input_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json"
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"Total samples: {len(data)}")
print("Calculating lengths and filtering...")
filtered_data = []
length_stats = []
max_length = 65536
for sample in tqdm(data):
length = calculate_length(sample, tokenizer)
length_stats.append(length)
if length <= max_length:
filtered_data.append(sample)
print(f"\nResults:")
print(f"- Original samples: {len(data)}")
print(f"- Filtered samples (≤{max_length}): {len(filtered_data)}")
print(f"- Removed: {len(data) - len(filtered_data)}")
# 统计长度分布
length_stats.sort()
print(f"\nLength statistics:")
print(f"- Min: {min(length_stats)}")
print(f"- Max: {max(length_stats)}")
print(f"- Median: {length_stats[len(length_stats)//2]}")
print(f"- 90th percentile: {length_stats[int(len(length_stats)*0.9)]}")
print(f"- 95th percentile: {length_stats[int(len(length_stats)*0.95)]}")
print(f"- 99th percentile: {length_stats[int(len(length_stats)*0.99)]}")
# 统计有多少超过不同阈值
print(f"\nSamples exceeding thresholds:")
for threshold in [32768, 49152, 65536, 80000]:
count = sum(1 for l in length_stats if l > threshold)
print(f"- > {threshold}: {count} ({count/len(length_stats)*100:.2f}%)")
# 保存过滤后的数据
output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k_filtered.json"
print(f"\nSaving filtered data to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(filtered_data, f, ensure_ascii=False, indent=2)
print(f"Done! Saved {len(filtered_data)} samples")
# 统计 filtered 数据
has_func = sum(1 for s in filtered_data if any(c['from'] == 'function_call' for c in s['conversations']))
print(f"\nFiltered data statistics:")
print(f"- With function_call: {has_func}")
print(f"- Without function_call: {len(filtered_data) - has_func}")
if __name__ == "__main__":
main()