File size: 2,992 Bytes
1c980b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import tiktoken
from tqdm import tqdm
from multiprocessing import Pool
import pandas as pd

# 全局编码器初始化(每个子进程独立初始化)
def init_process():
    global encoder
    encoder = tiktoken.get_encoding("cl100k_base")

def calculate_tokens(obj):
    """计算单个对象的token数量(子进程内部调用)"""
    global encoder
    total_text = []
    
    try:
        messages = obj.get("body", {}).get("messages", [])
        for msg in messages:
            # 系统提示
            if msg.get("role") == "system":
                content = msg.get("content", "")
                if content:  # 跳过空内容
                    total_text.append(content)
            
            # 用户消息
            elif msg.get("role") == "user":
                content = msg.get("content", [])
                if isinstance(content, list):
                    for item in content:
                        if isinstance(item, dict) and item.get("type") == "text":
                            text = item.get("text", "")
                            if text:
                                total_text.append(text)
                elif isinstance(content, dict) and content.get("type") == "text":
                    text = content.get("text", "")
                    if text:
                        total_text.append(text)
        
        # 合并文本并计算Token
        return len(encoder.encode("\n".join(total_text)))
    
    except Exception as e:
        print(f"处理错误: {e} | 数据: {obj.get('custom_id')}")
        return 0

def process_line(line):
    """处理单行数据"""
    try:
        data = json.loads(line)
        return {
            "custom_id": data.get("custom_id"),
            "tokens": calculate_tokens(data)
        }
    except json.JSONDecodeError:
        print(f"无效JSON: {line[:100]}...")  # 打印前100字符辅助定位
        return None
    except Exception as e:
        print(f"全局错误: {e}")
        return None

if __name__ == "__main__":
    # 读取数据
    with open("/mnt/data/users/zys/proj/vlm_reasoning/request/vqa_batch_requests.jsonl", "r") as f:
        lines = f.readlines()

    # 并行处理
    with Pool(processes=8, initializer=init_process) as pool:
        results = []
        with tqdm(total=len(lines), desc="处理进度") as pbar:
            for result in pool.imap(process_line, lines):
                if result is not None:  # 过滤失败记录
                    results.append(result)
                pbar.update()
    
    # 保存结果
    df = pd.DataFrame(results)
    df.to_csv("token_results.csv", index=False)
    
    # 统计输出
    total_tokens = df["tokens"].sum()
    avg_tokens = df["tokens"].mean()
    print(f"统计报告:\n"
          f"- 总Token数: {total_tokens:,}\n"
          f"- 平均每条: {avg_tokens:.1f}\n"
          f"- 最大单条: {df['tokens'].max()}\n"
          f"- 有效数据: {len(df)}/{len(lines)}")