| | |
| | |
| |
|
| | import os, json, math |
| | import numpy as np |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | DATA_PATH = "/home/data/STUDY.parquet" |
| | TOKENIZER_PATH = "/home/rm3.4.1_9e-6" |
| |
|
| | TEXT_COL = "text" |
| | PROMPT_COL = "prompt" |
| | RMTEXT_COL = "rm_text" |
| |
|
| | OUT_DIR = "./figs" |
| | LIMIT = 0 |
| | ADD_SPECIAL_TOKENS = False |
| | TRUNCATION = False |
| | MAX_LENGTH = None |
| | BATCH_SIZE = 1024 |
| | |
| |
|
| |
|
| | def read_table(path: str) -> pd.DataFrame: |
| | ext = os.path.splitext(path)[1].lower() |
| | if ext in [".parquet", ".pq"]: |
| | return pd.read_parquet(path) |
| | if ext == ".csv": |
| | return pd.read_csv(path) |
| | if ext in [".jsonl", ".json"]: |
| | rows = [] |
| | with open(path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | s = line.strip() |
| | if s: |
| | rows.append(json.loads(s)) |
| | return pd.DataFrame(rows) |
| | raise ValueError(f"Unsupported file type: {ext}") |
| |
|
| |
|
| | def to_str(x): |
| | if x is None: return "" |
| | if isinstance(x, float) and math.isnan(x): return "" |
| | return str(x) |
| |
|
| |
|
| | def batch_token_lengths(texts, tokenizer, add_special_tokens=False, |
| | truncation=False, max_length=None, batch_size=1024): |
| | n = len(texts) |
| | lens = np.zeros(n, dtype=np.int32) |
| | for i in range(0, n, batch_size): |
| | batch = [to_str(t) for t in texts[i:i+batch_size]] |
| | enc = tokenizer( |
| | batch, |
| | add_special_tokens=add_special_tokens, |
| | truncation=truncation, |
| | max_length=max_length, |
| | ) |
| | ids = enc["input_ids"] |
| | if isinstance(ids, list): |
| | lens[i:i+batch_size] = [len(x) for x in ids] |
| | else: |
| | lens[i:i+batch_size] = ids.shape[1] |
| | return lens |
| |
|
| |
|
| | def summarize(name, arr): |
| | arr = np.asarray(arr, dtype=np.int64) |
| | if arr.size == 0: |
| | print(f"[{name}] empty") |
| | return |
| | print( |
| | f"[{name}] count={arr.size} min={arr.min()} max={arr.max()} " |
| | f"mean={arr.mean():.2f} median={np.median(arr):.2f} std={arr.std():.2f}" |
| | ) |
| |
|
| |
|
| | def save_hist(data, title, out_path, bins=60): |
| | plt.figure() |
| | plt.hist(data, bins=bins) |
| | plt.title(title) |
| | plt.xlabel("Token count") |
| | plt.ylabel("Frequency") |
| | plt.tight_layout() |
| | plt.savefig(out_path, dpi=200) |
| | plt.close() |
| | print(f"[saved] {out_path}") |
| |
|
| |
|
| | def main(): |
| | os.makedirs(OUT_DIR, exist_ok=True) |
| |
|
| | print(f"[info] loading data: {DATA_PATH}") |
| | df = read_table(DATA_PATH) |
| |
|
| | |
| | drop_cols = [c for c in df.columns if str(c).strip() in {"__index_level_0__", "index", "[__index_level_0__]"}] |
| | if drop_cols: |
| | df = df.drop(columns=drop_cols) |
| |
|
| | for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]: |
| | if col not in df.columns: |
| | raise KeyError(f"Column '{col}' not found! Available: {list(df.columns)[:30]} ...") |
| |
|
| | if LIMIT and LIMIT > 0: |
| | df = df.head(LIMIT).copy() |
| | print(f"[info] subsampled to first {len(df)} rows") |
| |
|
| | print(f"[info] loading tokenizer: {TOKENIZER_PATH}") |
| | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True) |
| |
|
| | print("[info] tokenizing ...") |
| | text_lens = batch_token_lengths(df[TEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE) |
| | prompt_lens = batch_token_lengths(df[PROMPT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE) |
| | rmtext_lens = batch_token_lengths(df[RMTEXT_COL].tolist(), tokenizer, ADD_SPECIAL_TOKENS, TRUNCATION, MAX_LENGTH, BATCH_SIZE) |
| |
|
| | |
| | summarize("text", text_lens) |
| | summarize("prompt", prompt_lens) |
| | summarize("rm_text", rmtext_lens) |
| |
|
| | |
| | save_hist(text_lens, "Text token count", os.path.join(OUT_DIR, "hist_text.png")) |
| | save_hist(prompt_lens, "Prompt token count", os.path.join(OUT_DIR, "hist_prompt.png")) |
| | save_hist(rmtext_lens, "RM_Text token count", os.path.join(OUT_DIR, "hist_rm_text.png")) |
| |
|
| | |
| | mask = np.ones(len(df), dtype=bool) |
| | for col in [TEXT_COL, PROMPT_COL, RMTEXT_COL]: |
| | mask &= df[col].map(lambda x: isinstance(x, str) and len(x) > 0).values |
| |
|
| | x1, y1 = prompt_lens[mask], text_lens[mask] |
| | x2, y2 = rmtext_lens[mask], text_lens[mask] |
| |
|
| | plt.figure() |
| | plt.scatter(x1, y1, s=10, alpha=0.4, label="prompt vs text") |
| | plt.scatter(x2, y2, s=10, alpha=0.4, label="rm_text vs text") |
| | |
| | mn = int(min(x1.min() if len(x1) else 0, x2.min() if len(x2) else 0, y1.min() if len(y1) else 0, y2.min() if len(y2) else 0)) |
| | mx = int(max(x1.max() if len(x1) else 0, x2.max() if len(x2) else 0, y1.max() if len(y1) else 0, y2.max() if len(y2) else 0)) |
| | plt.plot([mn, mx], [mn, mx]) |
| | plt.title("Token count comparison") |
| | plt.xlabel("X tokens (prompt / rm_text)") |
| | plt.ylabel("Text tokens (Y)") |
| | plt.legend() |
| | plt.tight_layout() |
| | scatter_path = os.path.join(OUT_DIR, "scatter_compare.png") |
| | plt.savefig(scatter_path, dpi=200) |
| | plt.close() |
| | print(f"[saved] {scatter_path}") |
| |
|
| | |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|