rm_code / study_token.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
# analyze_tokens.py
# -*- coding: utf-8 -*-
import os, json, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
# ===================== 配置区(改这里) =====================
DATA_PATH = "/home/data/STUDY.parquet" # 支持 .parquet / .csv / .jsonl
TOKENIZER_PATH = "/home/rm3.4.1_9e-6" # 如: "meta-llama/Meta-Llama-3-8B"
TEXT_COL = "text"
PROMPT_COL = "prompt"
RMTEXT_COL = "rm_text"
OUT_DIR = "./figs" # 图片/汇总输出目录
LIMIT = 0 # >0 时只取前 N 行
ADD_SPECIAL_TOKENS = False # 统计时是否加入 special tokens
TRUNCATION = False # 是否在统计时截断
MAX_LENGTH = None # 截断长度(仅 TRUNCATION=True 有效)
BATCH_SIZE = 1024 # tokenizer 批大小
# ===========================================================
def read_table(path: str) -> pd.DataFrame:
ext = os.path.splitext(path)[1].lower()
if ext in [".parquet", ".pq"]:
return pd.read_parquet(path)
if ext == ".csv":
return pd.read_csv(path)
if ext in [".jsonl", ".json"]:
rows = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
s = line.strip()
if s:
rows.append(json.loads(s))
return pd.DataFrame(rows)
raise ValueError(f"Unsupported file type: {ext}")
def to_str(x):
if x is None: return ""
if isinstance(x, float) and math.isnan(x): return ""
return str(x)
def batch_token_lengths(texts, tokenizer, add_special_tokens=False,
truncation=False, max_length=None, batch_size=1024):
n = len(texts)
lens = np.zeros(n, dtype=np.int32)
for i in range(0, n, batch_size):
batch = [to_str(t) for t in texts[i:i+batch_size]]
enc = tokenizer(
batch,
add_special_tokens=add_special_tokens,
truncation=truncation,
max_length=max_length,
)
ids = enc["input_ids"]
if isinstance(ids, list):
lens[i:i+batch_size] = [len(x) for x in ids]
else:
lens[i:i+batch_size] = ids.shape[1]
return lens
def summarize(name, arr):
arr = np.asarray(arr, dtype=np.int64)
if arr.size == 0:
print(f"[{name}] empty")
return
print(
f"[{name}] count={arr.size} min={arr.min()} max={arr.max()} "
f"mean={arr.mean():.2f} median={np.median(arr):.2f} std={arr.std():.2f}"
)
def save_hist(data, title, out_path, bins=60):
plt.figure()
plt.hist(data, bins=bins)
plt.title(title)
plt.xlabel("Token count")
plt.ylabel("Frequency")
plt.tight_layout()
plt.savefig(out_path, dpi=200)
plt.close()
print(f"[saved] {out_path}")
def main():
os.makedirs(OUT_DIR, exist_ok=True)
print(f"[info] loading data: {DATA_PATH}")
df = read_table(DATA_PATH)
# 清掉 Pandas 索引列等非业务列
drop_cols = [c for c in df.columns if str(c).strip() in {"__index_level_0__", "index", "[__index_level_0__]"}]
if drop_cols:
df = df.drop(columns=drop_cols)
for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]:
if col not in df.columns:
raise KeyError(f"Column '{col}' not found! Available: {list(df.columns)[:30]} ...")
if LIMIT and LIMIT > 0:
df = df.head(LIMIT).copy()
print(f"[info] subsampled to first {len(df)} rows")
print(f"[info] loading tokenizer: {TOKENIZER_PATH}")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)
print("[info] tokenizing ...")
text_lens = batch_token_lengths(df[TEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
prompt_lens = batch_token_lengths(df[PROMPT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
rmtext_lens = batch_token_lengths(df[RMTEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE)
# 概要统计
summarize("text", text_lens)
summarize("prompt", prompt_lens)
summarize("rm_text", rmtext_lens)
# 保存直方图(PNG)
save_hist(text_lens, "Text token count", os.path.join(OUT_DIR, "hist_text.png"))
save_hist(prompt_lens, "Prompt token count", os.path.join(OUT_DIR, "hist_prompt.png"))
save_hist(rmtext_lens, "RM_Text token count", os.path.join(OUT_DIR, "hist_rm_text.png"))
# 对比散点:同图展示 prompt vs text、rm_text vs text,并加 y=x 参考线
mask = np.ones(len(df), dtype=bool)
for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]:
mask &= df[col].map(lambda x: isinstance(x, str) and len(x) > 0).values
x1, y1 = prompt_lens[mask], text_lens[mask]
x2, y2 = rmtext_lens[mask], text_lens[mask]
plt.figure()
plt.scatter(x1, y1, s=10, alpha=0.4, label="prompt vs text")
plt.scatter(x2, y2, s=10, alpha=0.4, label="rm_text vs text")
# y = x 参考线
mn = int(min(x1.min() if len(x1) else 0, x2.min() if len(x2) else 0, y1.min() if len(y1) else 0, y2.min() if len(y2) else 0))
mx = int(max(x1.max() if len(x1) else 0, x2.max() if len(x2) else 0, y1.max() if len(y1) else 0, y2.max() if len(y2) else 0))
plt.plot([mn, mx], [mn, mx])
plt.title("Token count comparison")
plt.xlabel("X tokens (prompt / rm_text)")
plt.ylabel("Text tokens (Y)")
plt.legend()
plt.tight_layout()
scatter_path = os.path.join(OUT_DIR, "scatter_compare.png")
plt.savefig(scatter_path, dpi=200)
plt.close()
print(f"[saved] {scatter_path}")
# 汇总表
if __name__ == "__main__":
main()