TigerGraph-Hack / utils /token_count.py
Meshyboi's picture
Upload 27 files
90645a4 verified
Raw
History Blame Contribute Delete
7.36 kB
"""Calculate total token counts over data/ folder.
Reports three tiers:
1. Raw β€” full OpenAlex JSON as-is on disk
2. Optimised β€” only the fields the RAG pipeline actually reads
(id, title, abstract, authorships, topics, keywords,
referenced_works, cited_by_count, publication_year, doi,
best_oa_location.pdf_url)
3. LLM context β€” what actually reaches the embedding model and the LLM:
plain-text "Title: …\\n\\nAbstract: …" per paper
"""
import json
import sys
from pathlib import Path
# data/ lives at the project root, two levels up from utils/
DATA_DIR = Path(__file__).parents[1] / "data"
# Approximate token ratio (1 token β‰ˆ 4 chars for English text β€” GPT/Llama heuristic)
CHARS_PER_TOKEN = 4
# Fields the pipeline actually uses
_KEEP_KEYS = {
"id", "title", "display_name", "abstract_inverted_index",
"publication_year", "cited_by_count",
"authorships", "topics", "keywords", "referenced_works",
"primary_topic", "doi", "type",
}
def reconstruct_abstract(inv_index: dict) -> str:
"""Rebuild plain-text abstract from OpenAlex inverted index."""
if not inv_index:
return ""
pairs = []
for word, positions in inv_index.items():
for pos in positions:
pairs.append((pos, word))
pairs.sort()
return " ".join(w for _, w in pairs)
def slim_authorships(authorships: list) -> list:
"""Keep only author name + position (drop institutions, affiliations, etc.)."""
return [
{
"name": a.get("author", {}).get("display_name", ""),
"position": a.get("author_position", ""),
}
for a in (authorships or [])
]
def optimise_paper(raw: dict) -> dict:
"""Strip a raw OpenAlex JSON to only pipeline-relevant fields."""
optimised = {k: raw[k] for k in _KEEP_KEYS if k in raw}
# Replace inverted index with plain text abstract
if "abstract_inverted_index" in optimised:
optimised["abstract"] = reconstruct_abstract(optimised.pop("abstract_inverted_index"))
# Slim authorships
if "authorships" in optimised:
optimised["authorships"] = slim_authorships(optimised["authorships"])
# Slim topics to just display_name
if "topics" in optimised:
optimised["topics"] = [t.get("display_name", "") for t in (optimised["topics"] or [])]
# Slim keywords
if "keywords" in optimised:
optimised["keywords"] = [k.get("display_name", "") for k in (optimised["keywords"] or [])]
# Slim referenced_works to just IDs
if "referenced_works" in optimised:
optimised["referenced_works"] = [
r.rsplit("/", 1)[-1] if "/" in r else r
for r in (optimised["referenced_works"] or [])
]
return optimised
def count_tokens(text: str) -> dict:
chars = len(text)
words = len(text.split())
tokens_est = chars // CHARS_PER_TOKEN
return {"chars": chars, "words": words, "tokens_est": tokens_est}
def main():
if not DATA_DIR.exists():
print(f"ERROR: {DATA_DIR} not found. Expected: {DATA_DIR.resolve()}")
sys.exit(1)
files = sorted(DATA_DIR.glob("*.json"))
total = len(files)
print(f"Found {total} JSON files in {DATA_DIR.resolve()}\n")
raw_total = {"chars": 0, "words": 0, "tokens_est": 0, "bytes": 0}
opt_total = {"chars": 0, "words": 0, "tokens_est": 0, "bytes": 0}
llm_total = {"chars": 0, "words": 0, "tokens_est": 0, "bytes": 0}
for f in files:
raw_text = f.read_text(encoding="utf-8")
raw_total["bytes"] += len(raw_text.encode("utf-8"))
stats = count_tokens(raw_text)
raw_total["chars"] += stats["chars"]
raw_total["words"] += stats["words"]
raw_total["tokens_est"] += stats["tokens_est"]
try:
data = json.loads(raw_text)
# Tier 2: optimised JSON (all pipeline-read fields)
optimised = optimise_paper(data)
opt_text = json.dumps(optimised, ensure_ascii=False)
opt_total["bytes"] += len(opt_text.encode("utf-8"))
opt_stats = count_tokens(opt_text)
opt_total["chars"] += opt_stats["chars"]
opt_total["words"] += opt_stats["words"]
opt_total["tokens_est"] += opt_stats["tokens_est"]
# Tier 3: LLM context β€” the exact text sent to the embedding model and LLM
# Mirrors indexer.py _doc_text() and setup.py embed text construction
title = data.get("title") or ""
abstract = reconstruct_abstract(data.get("abstract_inverted_index") or {})
llm_text = f"Title: {title}\n\nAbstract: {abstract}"
llm_total["bytes"] += len(llm_text.encode("utf-8"))
llm_stats = count_tokens(llm_text)
llm_total["chars"] += llm_stats["chars"]
llm_total["words"] += llm_stats["words"]
llm_total["tokens_est"] += llm_stats["tokens_est"]
except json.JSONDecodeError:
pass
def fmt(n):
return f"{n:,}"
print("=" * 60)
print(" TIER 1 β€” RAW (full OpenAlex JSON as-is on disk)")
print("=" * 60)
print(f" Files: {fmt(total)}")
print(f" Total bytes: {fmt(raw_total['bytes'])} ({raw_total['bytes'] / 1e6:.1f} MB)")
print(f" Total chars: {fmt(raw_total['chars'])}")
print(f" Total words: {fmt(raw_total['words'])}")
print(f" Est. tokens: {fmt(raw_total['tokens_est'])} (~{raw_total['tokens_est'] / 1e6:.2f}M)")
print()
print("=" * 60)
print(" TIER 2 β€” OPTIMISED (pipeline-relevant fields only)")
print(" Fields: id, title, abstract, authorships, topics,")
print(" keywords, referenced_works, doi, year,")
print(" cited_by_count, best_oa_location.pdf_url")
print("=" * 60)
print(f" Files: {fmt(total)}")
print(f" Total bytes: {fmt(opt_total['bytes'])} ({opt_total['bytes'] / 1e6:.1f} MB)")
print(f" Total chars: {fmt(opt_total['chars'])}")
print(f" Total words: {fmt(opt_total['words'])}")
print(f" Est. tokens: {fmt(opt_total['tokens_est'])} (~{opt_total['tokens_est'] / 1e6:.2f}M)")
print()
print("=" * 60)
print(" TIER 3 β€” LLM CONTEXT (what reaches the embedding model + LLM)")
print(' Format: "Title: …\\n\\nAbstract: …" per paper')
print("=" * 60)
print(f" Files: {fmt(total)}")
print(f" Total bytes: {fmt(llm_total['bytes'])} ({llm_total['bytes'] / 1e6:.1f} MB)")
print(f" Total chars: {fmt(llm_total['chars'])}")
print(f" Total words: {fmt(llm_total['words'])}")
print(f" Est. tokens: {fmt(llm_total['tokens_est'])} (~{llm_total['tokens_est'] / 1e6:.2f}M)")
print()
r2_pct = (1 - opt_total["bytes"] / raw_total["bytes"]) * 100 if raw_total["bytes"] else 0
r3_pct = (1 - llm_total["bytes"] / raw_total["bytes"]) * 100 if raw_total["bytes"] else 0
print(f" Tier 1 β†’ Tier 2 reduction: {r2_pct:.1f}% "
f"({fmt(raw_total['tokens_est'] - opt_total['tokens_est'])} tokens saved)")
print(f" Tier 1 β†’ Tier 3 reduction: {r3_pct:.1f}% "
f"({fmt(raw_total['tokens_est'] - llm_total['tokens_est'])} tokens saved)")
if __name__ == "__main__":
main()