SparseAE / summarize_tokens.py
nancyH's picture
Upload folder using huggingface_hub
b46126b verified
import re
import glob
import numpy as np
import pandas as pd
LOG_FILE = "extract_log.txt"
# 1. Parse N hits from log file
token_counts = {}
pattern = re.compile(r"Saved token (\d+) \(N=(\d+)\)")
with open(LOG_FILE, "r") as f:
for line in f:
m = pattern.search(line)
if m:
t = int(m.group(1))
n = int(m.group(2))
token_counts[t] = n
print(f"Found {len(token_counts)} tokens with counts from log.")
# 2. For each token, load PWM + phyloP, compute entropy + avg phyloP
rows = []
def pwm_entropy(pwm, eps=1e-8):
"""
pwm: (L, 4) array of mean one-hot probs
returns: mean Shannon entropy across positions, in bits
"""
p = pwm / (pwm.sum(axis=1, keepdims=True) + eps) # normalize safety
H = -np.sum(p * np.log2(p + eps), axis=1) # (L,)
return H.mean()
for pwm_path in glob.glob("token*_pwm.npy"):
# token ID from filename
m = re.search(r"token(\d+)_pwm\.npy", pwm_path)
if not m:
continue
t = int(m.group(1))
pwm = np.load(pwm_path) # (L, 4)
phy = np.load(f"token{t}_phy.npy") # (L,)
H = pwm_entropy(pwm)
avg_phy = float(phy.mean())
N_hits = token_counts.get(t, None)
rows.append({
"token_id": t,
"N_hits": N_hits,
"pwm_entropy_bits": H,
"avg_phyloP": avg_phy
})
df = pd.DataFrame(rows)
df = df.sort_values(["pwm_entropy_bits", "avg_phyloP"], ascending=[True, False])
print(df.head(20))
df.to_csv("token_summary.tsv", sep="\t", index=False)
print("\nSaved summary to token_summary.tsv")