File size: 5,665 Bytes
d8a76be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | # analyze_tokens.py
# -*- coding: utf-8 -*-
import os, json, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
# ===================== 配置区(改这里) =====================
DATA_PATH = "/home/data/STUDY.parquet" # 支持 .parquet / .csv / .jsonl
TOKENIZER_PATH = "/home/rm3.4.1_9e-6" # 如: "meta-llama/Meta-Llama-3-8B"
TEXT_COL = "text"
PROMPT_COL = "prompt"
RMTEXT_COL = "rm_text"
OUT_DIR = "./figs" # 图片/汇总输出目录
LIMIT = 0 # >0 时只取前 N 行
ADD_SPECIAL_TOKENS = False # 统计时是否加入 special tokens
TRUNCATION = False # 是否在统计时截断
MAX_LENGTH = None # 截断长度(仅 TRUNCATION=True 有效)
BATCH_SIZE = 1024 # tokenizer 批大小
# ===========================================================
def read_table(path: str) -> pd.DataFrame:
ext = os.path.splitext(path)[1].lower()
if ext in [".parquet", ".pq"]:
return pd.read_parquet(path)
if ext == ".csv":
return pd.read_csv(path)
if ext in [".jsonl", ".json"]:
rows = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
s = line.strip()
if s:
rows.append(json.loads(s))
return pd.DataFrame(rows)
raise ValueError(f"Unsupported file type: {ext}")
def to_str(x):
if x is None: return ""
if isinstance(x, float) and math.isnan(x): return ""
return str(x)
def batch_token_lengths(texts, tokenizer, add_special_tokens=False,
truncation=False, max_length=None, batch_size=1024):
n = len(texts)
lens = np.zeros(n, dtype=np.int32)
for i in range(0, n, batch_size):
batch = [to_str(t) for t in texts[i:i+batch_size]]
enc = tokenizer(
batch,
add_special_tokens=add_special_tokens,
truncation=truncation,
max_length=max_length,
)
ids = enc["input_ids"]
if isinstance(ids, list):
lens[i:i+batch_size] = [len(x) for x in ids]
else:
lens[i:i+batch_size] = ids.shape[1]
return lens
def summarize(name, arr):
arr = np.asarray(arr, dtype=np.int64)
if arr.size == 0:
print(f"[{name}] empty")
return
print(
f"[{name}] count={arr.size} min={arr.min()} max={arr.max()} "
f"mean={arr.mean():.2f} median={np.median(arr):.2f} std={arr.std():.2f}"
)
def save_hist(data, title, out_path, bins=60):
plt.figure()
plt.hist(data, bins=bins)
plt.title(title)
plt.xlabel("Token count")
plt.ylabel("Frequency")
plt.tight_layout()
plt.savefig(out_path, dpi=200)
plt.close()
print(f"[saved] {out_path}")
def main():
os.makedirs(OUT_DIR, exist_ok=True)
print(f"[info] loading data: {DATA_PATH}")
df = read_table(DATA_PATH)
# 清掉 Pandas 索引列等非业务列
drop_cols = [c for c in df.columns if str(c).strip() in {"__index_level_0__", "index", "[__index_level_0__]"}]
if drop_cols:
df = df.drop(columns=drop_cols)
for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]:
if col not in df.columns:
raise KeyError(f"Column '{col}' not found! Available: {list(df.columns)[:30]} ...")
if LIMIT and LIMIT > 0:
df = df.head(LIMIT).copy()
print(f"[info] subsampled to first {len(df)} rows")
print(f"[info] loading tokenizer: {TOKENIZER_PATH}")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)
print("[info] tokenizing ...")
text_lens = batch_token_lengths(df[TEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
prompt_lens = batch_token_lengths(df[PROMPT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
rmtext_lens = batch_token_lengths(df[RMTEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
# 概要统计
summarize("text", text_lens)
summarize("prompt", prompt_lens)
summarize("rm_text", rmtext_lens)
# 保存直方图(PNG)
save_hist(text_lens, "Text token count", os.path.join(OUT_DIR, "hist_text.png"))
save_hist(prompt_lens, "Prompt token count", os.path.join(OUT_DIR, "hist_prompt.png"))
save_hist(rmtext_lens, "RM_Text token count", os.path.join(OUT_DIR, "hist_rm_text.png"))
# 对比散点:同图展示 prompt vs text、rm_text vs text,并加 y=x 参考线
mask = np.ones(len(df), dtype=bool)
for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]:
mask &= df[col].map(lambda x: isinstance(x, str) and len(x) > 0).values
x1, y1 = prompt_lens[mask], text_lens[mask]
x2, y2 = rmtext_lens[mask], text_lens[mask]
plt.figure()
plt.scatter(x1, y1, s=10, alpha=0.4, label="prompt vs text")
plt.scatter(x2, y2, s=10, alpha=0.4, label="rm_text vs text")
# y = x 参考线
mn = int(min(x1.min() if len(x1) else 0, x2.min() if len(x2) else 0, y1.min() if len(y1) else 0, y2.min() if len(y2) else 0))
mx = int(max(x1.max() if len(x1) else 0, x2.max() if len(x2) else 0, y1.max() if len(y1) else 0, y2.max() if len(y2) else 0))
plt.plot([mn, mx], [mn, mx])
plt.title("Token count comparison")
plt.xlabel("X tokens (prompt / rm_text)")
plt.ylabel("Text tokens (Y)")
plt.legend()
plt.tight_layout()
scatter_path = os.path.join(OUT_DIR, "scatter_compare.png")
plt.savefig(scatter_path, dpi=200)
plt.close()
print(f"[saved] {scatter_path}")
# 汇总表
if __name__ == "__main__":
main()
|