File size: 3,447 Bytes
db704cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
"""
使用字符长度估算样本长度(粗略估计)
字符数 * 1.5 ≈ token 数(保守估计)
"""

import json
from tqdm import tqdm

def calculate_char_length(sample):
    """计算样本的字符长度"""
    total_length = 0

    # 添加 system
    if 'system' in sample and sample['system']:
        total_length += len(sample['system'])

    # 添加 tools
    if 'tools' in sample and sample['tools']:
        total_length += len(sample['tools'])

    # 添加 conversations
    for conv in sample['conversations']:
        content = conv.get('value', '')
        total_length += len(content)

    return total_length

def main():
    print("Loading data...")
    input_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k.json"
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    print(f"Total samples: {len(data)}")

    print("Calculating character lengths...")
    length_stats = []

    for sample in tqdm(data):
        char_length = calculate_char_length(sample)
        length_stats.append(char_length)

    length_stats.sort()

    print(f"\nCharacter length statistics:")
    print(f"- Min: {min(length_stats):,}")
    print(f"- Max: {max(length_stats):,}")
    print(f"- Median: {length_stats[len(length_stats)//2]:,}")
    print(f"- 90th percentile: {length_stats[int(len(length_stats)*0.9)]:,}")
    print(f"- 95th percentile: {length_stats[int(len(length_stats)*0.95)]:,}")
    print(f"- 99th percentile: {length_stats[int(len(length_stats)*0.99)]:,}")

    # 估算 token 长度(字符数 * 1.5)
    print(f"\nEstimated token length (chars * 1.5):")
    max_token = 65536

    # 估算需要的字符长度阈值
    char_threshold = int(max_token / 1.5)
    print(f"- Character threshold for {max_token} tokens: {char_threshold:,}")

    exceeding = sum(1 for l in length_stats if l * 1.5 > max_token)
    print(f"- Samples estimated to exceed {max_token} tokens: {exceeding} ({exceeding/len(length_stats)*100:.2f}%)")

    # 过滤
    print(f"\nFiltering samples with estimated tokens ≤ {max_token}...")
    filtered_data = []

    for sample in tqdm(data):
        char_length = calculate_char_length(sample)
        estimated_tokens = char_length * 1.5

        if estimated_tokens <= max_token:
            filtered_data.append(sample)

    print(f"\nResults:")
    print(f"- Original samples: {len(data)}")
    print(f"- Filtered samples: {len(filtered_data)}")
    print(f"- Removed: {len(data) - len(filtered_data)}")

    # 保存
    output_file = "/shared_workspace_mfs/ximing/LLaMA-Factory/data/dolci_last_10k_from_20k_filtered.json"
    print(f"\nSaving to {output_file}...")

    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(filtered_data, f, ensure_ascii=False, indent=2)

    # 统计
    has_func = sum(1 for s in filtered_data if any(c['from'] == 'function_call' for c in s['conversations']))
    has_tools = sum(1 for s in filtered_data if 'tools' in s and s['tools'])

    print(f"\nFiltered data statistics:")
    print(f"- Total: {len(filtered_data)}")
    print(f"- With function_call: {has_func}")
    print(f"- With tools: {has_tools}")
    print(f"- Without function_call: {len(filtered_data) - has_func}")

    import os
    file_size = os.path.getsize(output_file) / (1024 * 1024)
    print(f"- File size: {file_size:.2f} MB")

if __name__ == "__main__":
    main()