# analyze_tokens.py # -*- coding: utf-8 -*- import os, json, math import numpy as np import pandas as pd import matplotlib.pyplot as plt from transformers import AutoTokenizer # ===================== 配置区(改这里) ===================== DATA_PATH = "/home/data/STUDY.parquet" # 支持 .parquet / .csv / .jsonl TOKENIZER_PATH = "/home/rm3.4.1_9e-6" # 如: "meta-llama/Meta-Llama-3-8B" TEXT_COL = "text" PROMPT_COL = "prompt" RMTEXT_COL = "rm_text" OUT_DIR = "./figs" # 图片/汇总输出目录 LIMIT = 0 # >0 时只取前 N 行 ADD_SPECIAL_TOKENS = False # 统计时是否加入 special tokens TRUNCATION = False # 是否在统计时截断 MAX_LENGTH = None # 截断长度(仅 TRUNCATION=True 有效) BATCH_SIZE = 1024 # tokenizer 批大小 # =========================================================== def read_table(path: str) -> pd.DataFrame: ext = os.path.splitext(path)[1].lower() if ext in [".parquet", ".pq"]: return pd.read_parquet(path) if ext == ".csv": return pd.read_csv(path) if ext in [".jsonl", ".json"]: rows = [] with open(path, "r", encoding="utf-8") as f: for line in f: s = line.strip() if s: rows.append(json.loads(s)) return pd.DataFrame(rows) raise ValueError(f"Unsupported file type: {ext}") def to_str(x): if x is None: return "" if isinstance(x, float) and math.isnan(x): return "" return str(x) def batch_token_lengths(texts, tokenizer, add_special_tokens=False, truncation=False, max_length=None, batch_size=1024): n = len(texts) lens = np.zeros(n, dtype=np.int32) for i in range(0, n, batch_size): batch = [to_str(t) for t in texts[i:i+batch_size]] enc = tokenizer( batch, add_special_tokens=add_special_tokens, truncation=truncation, max_length=max_length, ) ids = enc["input_ids"] if isinstance(ids, list): lens[i:i+batch_size] = [len(x) for x in ids] else: lens[i:i+batch_size] = ids.shape[1] return lens def summarize(name, arr): arr = np.asarray(arr, dtype=np.int64) if arr.size == 0: print(f"[{name}] empty") return print( f"[{name}] count={arr.size} min={arr.min()} max={arr.max()} " f"mean={arr.mean():.2f} median={np.median(arr):.2f} std={arr.std():.2f}" ) def save_hist(data, title, out_path, bins=60): plt.figure() plt.hist(data, bins=bins) plt.title(title) plt.xlabel("Token count") plt.ylabel("Frequency") plt.tight_layout() plt.savefig(out_path, dpi=200) plt.close() print(f"[saved] {out_path}") def main(): os.makedirs(OUT_DIR, exist_ok=True) print(f"[info] loading data: {DATA_PATH}") df = read_table(DATA_PATH) # 清掉 Pandas 索引列等非业务列 drop_cols = [c for c in df.columns if str(c).strip() in {"__index_level_0__", "index", "[__index_level_0__]"}] if drop_cols: df = df.drop(columns=drop_cols) for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]: if col not in df.columns: raise KeyError(f"Column '{col}' not found! Available: {list(df.columns)[:30]} ...") if LIMIT and LIMIT > 0: df = df.head(LIMIT).copy() print(f"[info] subsampled to first {len(df)} rows") print(f"[info] loading tokenizer: {TOKENIZER_PATH}") tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True) print("[info] tokenizing ...") text_lens = batch_token_lengths(df[TEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE) prompt_lens = batch_token_lengths(df[PROMPT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE) rmtext_lens = batch_token_lengths(df[RMTEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE) # 概要统计 summarize("text", text_lens) summarize("prompt", prompt_lens) summarize("rm_text", rmtext_lens) # 保存直方图(PNG) save_hist(text_lens, "Text token count", os.path.join(OUT_DIR, "hist_text.png")) save_hist(prompt_lens, "Prompt token count", os.path.join(OUT_DIR, "hist_prompt.png")) save_hist(rmtext_lens, "RM_Text token count", os.path.join(OUT_DIR, "hist_rm_text.png")) # 对比散点:同图展示 prompt vs text、rm_text vs text,并加 y=x 参考线 mask = np.ones(len(df), dtype=bool) for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]: mask &= df[col].map(lambda x: isinstance(x, str) and len(x) > 0).values x1, y1 = prompt_lens[mask], text_lens[mask] x2, y2 = rmtext_lens[mask], text_lens[mask] plt.figure() plt.scatter(x1, y1, s=10, alpha=0.4, label="prompt vs text") plt.scatter(x2, y2, s=10, alpha=0.4, label="rm_text vs text") # y = x 参考线 mn = int(min(x1.min() if len(x1) else 0, x2.min() if len(x2) else 0, y1.min() if len(y1) else 0, y2.min() if len(y2) else 0)) mx = int(max(x1.max() if len(x1) else 0, x2.max() if len(x2) else 0, y1.max() if len(y1) else 0, y2.max() if len(y2) else 0)) plt.plot([mn, mx], [mn, mx]) plt.title("Token count comparison") plt.xlabel("X tokens (prompt / rm_text)") plt.ylabel("Text tokens (Y)") plt.legend() plt.tight_layout() scatter_path = os.path.join(OUT_DIR, "scatter_compare.png") plt.savefig(scatter_path, dpi=200) plt.close() print(f"[saved] {scatter_path}") # 汇总表 if __name__ == "__main__": main()