File size: 5,665 Bytes
d8a76be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# analyze_tokens.py
# -*- coding: utf-8 -*-

import os, json, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer

# ===================== 配置区(改这里) =====================
DATA_PATH = "/home/data/STUDY.parquet"   # 支持 .parquet / .csv / .jsonl
TOKENIZER_PATH = "/home/rm3.4.1_9e-6"  # 如: "meta-llama/Meta-Llama-3-8B"

TEXT_COL   = "text"
PROMPT_COL = "prompt"
RMTEXT_COL = "rm_text"

OUT_DIR = "./figs"          # 图片/汇总输出目录
LIMIT = 0                   # >0 时只取前 N 行
ADD_SPECIAL_TOKENS = False  # 统计时是否加入 special tokens
TRUNCATION = False          # 是否在统计时截断
MAX_LENGTH = None           # 截断长度(仅 TRUNCATION=True 有效)
BATCH_SIZE = 1024           # tokenizer 批大小
# ===========================================================


def read_table(path: str) -> pd.DataFrame:
    ext = os.path.splitext(path)[1].lower()
    if ext in [".parquet", ".pq"]:
        return pd.read_parquet(path)
    if ext == ".csv":
        return pd.read_csv(path)
    if ext in [".jsonl", ".json"]:
        rows = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if s:
                    rows.append(json.loads(s))
        return pd.DataFrame(rows)
    raise ValueError(f"Unsupported file type: {ext}")


def to_str(x):
    if x is None: return ""
    if isinstance(x, float) and math.isnan(x): return ""
    return str(x)


def batch_token_lengths(texts, tokenizer, add_special_tokens=False,
                        truncation=False, max_length=None, batch_size=1024):
    n = len(texts)
    lens = np.zeros(n, dtype=np.int32)
    for i in range(0, n, batch_size):
        batch = [to_str(t) for t in texts[i:i+batch_size]]
        enc = tokenizer(
            batch,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )
        ids = enc["input_ids"]
        if isinstance(ids, list):
            lens[i:i+batch_size] = [len(x) for x in ids]
        else:
            lens[i:i+batch_size] = ids.shape[1]
    return lens


def summarize(name, arr):
    arr = np.asarray(arr, dtype=np.int64)
    if arr.size == 0:
        print(f"[{name}] empty")
        return
    print(
        f"[{name}] count={arr.size}  min={arr.min()}  max={arr.max()}  "
        f"mean={arr.mean():.2f}  median={np.median(arr):.2f}  std={arr.std():.2f}"
    )


def save_hist(data, title, out_path, bins=60):
    plt.figure()
    plt.hist(data, bins=bins)
    plt.title(title)
    plt.xlabel("Token count")
    plt.ylabel("Frequency")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"[saved] {out_path}")


def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    print(f"[info] loading data: {DATA_PATH}")
    df = read_table(DATA_PATH)

    # 清掉 Pandas 索引列等非业务列
    drop_cols = [c for c in df.columns if str(c).strip() in {"__index_level_0__", "index", "[__index_level_0__]"}]
    if drop_cols:
        df = df.drop(columns=drop_cols)

    for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]:
        if col not in df.columns:
            raise KeyError(f"Column '{col}' not found! Available: {list(df.columns)[:30]} ...")

    if LIMIT and LIMIT > 0:
        df = df.head(LIMIT).copy()
        print(f"[info] subsampled to first {len(df)} rows")

    print(f"[info] loading tokenizer: {TOKENIZER_PATH}")
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)

    print("[info] tokenizing ...")
    text_lens   = batch_token_lengths(df[TEXT_COL].tolist(),   tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
    prompt_lens = batch_token_lengths(df[PROMPT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
    rmtext_lens = batch_token_lengths(df[RMTEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)

    # 概要统计
    summarize("text", text_lens)
    summarize("prompt", prompt_lens)
    summarize("rm_text", rmtext_lens)

    # 保存直方图(PNG)
    save_hist(text_lens,   "Text token count",    os.path.join(OUT_DIR, "hist_text.png"))
    save_hist(prompt_lens, "Prompt token count",  os.path.join(OUT_DIR, "hist_prompt.png"))
    save_hist(rmtext_lens, "RM_Text token count", os.path.join(OUT_DIR, "hist_rm_text.png"))

    # 对比散点:同图展示 prompt vs text、rm_text vs text,并加 y=x 参考线
    mask = np.ones(len(df), dtype=bool)
    for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]:
        mask &= df[col].map(lambda x: isinstance(x, str) and len(x) > 0).values

    x1, y1 = prompt_lens[mask], text_lens[mask]
    x2, y2 = rmtext_lens[mask],  text_lens[mask]

    plt.figure()
    plt.scatter(x1, y1, s=10, alpha=0.4, label="prompt vs text")
    plt.scatter(x2, y2, s=10, alpha=0.4, label="rm_text vs text")
    # y = x 参考线
    mn = int(min(x1.min() if len(x1) else 0, x2.min() if len(x2) else 0, y1.min() if len(y1) else 0, y2.min() if len(y2) else 0))
    mx = int(max(x1.max() if len(x1) else 0, x2.max() if len(x2) else 0, y1.max() if len(y1) else 0, y2.max() if len(y2) else 0))
    plt.plot([mn, mx], [mn, mx])
    plt.title("Token count comparison")
    plt.xlabel("X tokens (prompt / rm_text)")
    plt.ylabel("Text tokens (Y)")
    plt.legend()
    plt.tight_layout()
    scatter_path = os.path.join(OUT_DIR, "scatter_compare.png")
    plt.savefig(scatter_path, dpi=200)
    plt.close()
    print(f"[saved] {scatter_path}")

    # 汇总表


if __name__ == "__main__":
    main()