Paper_Espresso / src /lifecycle_retrieve.py
elfsong
feat: add --force flag to allow re-computing and overwriting existing snapshots
5e842ff
"""
Lifecycle Snapshot Retriever -- compute bimonthly topic lifecycle snapshots.
Computes Gartner-style hype cycle classification for research topics using
all available paper data up to each snapshot month (every 2 months).
Results are pushed to Elfsong/hf_paper_lifecycle.
Usage:
uv run python src/lifecycle_retrieve.py # latest snapshot
uv run python src/lifecycle_retrieve.py --snapshot 2025-06 # specific snapshot
uv run python src/lifecycle_retrieve.py --all # all missing snapshots
uv run python src/lifecycle_retrieve.py --no-push # dry run
"""
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["DATASETS_VERBOSITY"] = "error"
from tqdm import tqdm # noqa: E402
from functools import partialmethod # noqa: E402
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
import argparse # noqa: E402
import json # noqa: E402
import logging # noqa: E402
import sys # noqa: E402
import time # noqa: E402
from collections import Counter, defaultdict # noqa: E402
from datetime import datetime, timezone # noqa: E402
from pathlib import Path # noqa: E402
import numpy as np # noqa: E402
from scipy.stats import linregress # noqa: E402
from dotenv import load_dotenv # noqa: E402
ROOT = Path(__file__).resolve().parent.parent
load_dotenv(ROOT / ".env")
for _name in ("datasets", "huggingface_hub", "huggingface_hub.utils",
"fsspec", "datasets.utils", "datasets.arrow_writer"):
logging.getLogger(_name).setLevel(logging.ERROR)
# ---------------------------------------------------------------------------
# ANSI helpers
# ---------------------------------------------------------------------------
_RESET = "\033[0m"
_BOLD = "\033[1m"
_DIM = "\033[2m"
_GREEN = "\033[32m"
_YELLOW = "\033[33m"
_CYAN = "\033[36m"
_GRAY = "\033[90m"
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
HF_DATASET_REPO = "Elfsong/hf_paper_summary"
HF_LIFECYCLE_REPO = "Elfsong/hf_paper_lifecycle"
# Bimonthly snapshot months (even months)
SNAPSHOT_MONTHS = {2, 4, 6, 8, 10, 12}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_env(key: str) -> str:
val = os.getenv(key, "")
if val:
return val
env_path = ROOT / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
if line.startswith(f"{key}="):
return line.split("=", 1)[1].strip()
return ""
def _snapshot_to_split(snapshot_str: str) -> str:
return "snapshot_" + snapshot_str.replace("-", "_")
def _parse_paper_row(paper: dict) -> dict:
for key in ("detailed_analysis", "detailed_analysis_zh"):
v = paper.get(key, "{}")
if isinstance(v, str):
paper[key] = json.loads(v) if v else {}
for key in ("topics", "topics_zh", "keywords", "keywords_zh"):
v = paper.get(key, "[]")
if isinstance(v, str):
paper[key] = json.loads(v) if v else []
if not isinstance(paper.get("authors"), list):
try:
paper["authors"] = list(paper["authors"])
except Exception:
paper["authors"] = []
return paper
def _list_repo_files(repo: str) -> list[str]:
from huggingface_hub import HfApi
token = _get_env("HF_TOKEN")
if not token:
return []
try:
api = HfApi(token=token)
return list(api.list_repo_files(repo, repo_type="dataset"))
except Exception:
return []
def _load_all_papers(files: list[str]) -> list[dict]:
"""Download all parquet files and return papers with _date and _month."""
import pandas as pd
from huggingface_hub import hf_hub_download
token = _get_env("HF_TOKEN")
parquet_files = [f for f in files if f.endswith(".parquet")]
seen_ids: set[str] = set()
papers: list[dict] = []
for i, pf in enumerate(parquet_files):
fname = pf.split("/")[-1]
date_part = fname.split("-00")[0]
date_str = date_part.replace("date_", "").replace("_", "-")
try:
local_path = hf_hub_download(
HF_DATASET_REPO, pf, repo_type="dataset", token=token,
)
df = pd.read_parquet(local_path)
for _, row in df.iterrows():
paper = row.to_dict()
pid = paper.get("paper_id", "")
if pid and pid not in seen_ids:
seen_ids.add(pid)
paper["_date"] = date_str
paper["_month"] = date_str[:7]
papers.append(_parse_paper_row(paper))
except Exception:
continue
if sys.stdout.isatty() and (i + 1) % 20 == 0:
sys.stdout.write(f"\r {_DIM}Loading papers... {i+1}/{len(parquet_files)} files, {len(papers)} papers{_RESET}")
sys.stdout.flush()
if sys.stdout.isatty():
sys.stdout.write("\r\033[K")
sys.stdout.flush()
return papers
# ---------------------------------------------------------------------------
# Lifecycle computation
# ---------------------------------------------------------------------------
def _get_paper_topics(paper: dict, lang: str) -> list[str]:
if lang == "zh":
return paper.get("topics_zh", []) or paper.get("topics", [])
return paper.get("topics", [])
def compute_lifecycle(papers: list[str], lang: str = "en") -> tuple[dict, list[str], dict, dict]:
"""Compute lifecycle metrics for all topics from papers.
Returns (lifecycle_dict, sorted_months, topics_by_month, total_by_month).
"""
topics_by_month: dict[str, Counter] = defaultdict(Counter)
all_topics: Counter = Counter()
for p in papers:
month = p.get("_month", "")
if not month:
continue
topics = _get_paper_topics(p, lang)
topics_by_month[month].update(topics)
all_topics.update(topics)
sorted_months = sorted(topics_by_month.keys())
if len(sorted_months) < 2:
return {}, sorted_months, {}, {}
total_by_month = {m: sum(topics_by_month[m].values()) for m in sorted_months}
n_months = len(sorted_months)
min_papers = max(3, n_months)
candidates = [t for t, c in all_topics.items() if c >= min_papers]
lifecycle: dict = {}
for topic in candidates:
proportions = np.array([
topics_by_month[m].get(topic, 0) / total_by_month[m]
if total_by_month[m] > 0 else 0
for m in sorted_months
])
counts = np.array([topics_by_month[m].get(topic, 0) for m in sorted_months])
nonzero = np.where(proportions > 0)[0]
if len(nonzero) < 2:
continue
first_idx = int(nonzero[0])
peak_idx = int(np.argmax(proportions))
peak_val = float(proportions[peak_idx])
current_avg = float(np.mean(proportions[-min(3, n_months):]))
window = min(6, n_months)
recent = proportions[-window:]
slope = float(linregress(np.arange(len(recent)), recent).slope) if len(recent) >= 3 else 0.0
decline_ratio = current_avg / peak_val if peak_val > 0 else 0
months_since_peak = n_months - 1 - peak_idx
months_active = n_months - first_idx
recent_window = min(8, len(counts))
recent_fraction = float(counts[-recent_window:].sum() / max(counts.sum(), 1))
# Phase classification (same thresholds as reference analysis script)
dr, sl, ma, msp = decline_ratio, slope, months_active, months_since_peak
tc = int(counts.sum())
rf = recent_fraction
if ma <= 8 or (rf > 0.60 and tc < 200):
phase = "Innovation Trigger"
elif (dr > 0.70 and msp <= 6) or (sl > 0.001 and dr > 0.65):
phase = "Peak of Inflated Expectations"
elif dr < 0.65:
phase = "Slope of Enlightenment" if sl > 0.0003 else "Trough of Disillusionment"
elif sl < -0.001 and dr < 0.75:
phase = "Trough of Disillusionment"
elif dr < 0.85 and sl > 0.0005 and msp > 4:
phase = "Slope of Enlightenment"
else:
phase = "Plateau of Productivity"
lifecycle[topic] = {
"topic": topic, "phase": phase,
"total_count": tc, "peak_val": peak_val,
"peak_month": sorted_months[peak_idx],
"current_avg": current_avg, "slope": slope,
"decline_ratio": decline_ratio,
"months_since_peak": months_since_peak,
"months_active": months_active,
}
# Convert Counters to plain dicts for serialisation
tbm = {m: dict(topics_by_month[m]) for m in sorted_months}
tbm_total = dict(total_by_month)
return lifecycle, sorted_months, tbm, tbm_total
# ---------------------------------------------------------------------------
# Push to HuggingFace
# ---------------------------------------------------------------------------
def push_lifecycle_to_hf(lifecycle_en: dict, lifecycle_zh: dict,
sorted_months: list[str], n_papers: int,
snapshot_month: str,
topics_by_month_en: dict | None = None,
total_by_month_en: dict | None = None,
topics_by_month_zh: dict | None = None,
total_by_month_zh: dict | None = None):
from datasets import Dataset
token = _get_env("HF_TOKEN")
if not token:
raise RuntimeError("HF_TOKEN not set")
row = {
"lifecycle_data": json.dumps(lifecycle_en, ensure_ascii=False),
"lifecycle_data_zh": json.dumps(lifecycle_zh, ensure_ascii=False),
"sorted_months": json.dumps(sorted_months, ensure_ascii=False),
"n_papers": n_papers,
"n_months": len(sorted_months),
"topics_by_month": json.dumps(topics_by_month_en or {}, ensure_ascii=False),
"total_by_month": json.dumps(total_by_month_en or {}, ensure_ascii=False),
"topics_by_month_zh": json.dumps(topics_by_month_zh or {}, ensure_ascii=False),
"total_by_month_zh": json.dumps(total_by_month_zh or {}, ensure_ascii=False),
}
ds = Dataset.from_list([row])
split_name = _snapshot_to_split(snapshot_month)
ds.push_to_hub(HF_LIFECYCLE_REPO, split=split_name, token=token)
# ---------------------------------------------------------------------------
# Run one snapshot
# ---------------------------------------------------------------------------
def run_snapshot(snapshot_month: str, all_papers: list[dict],
existing_splits: set[str], no_push: bool = False,
force: bool = False):
split_name = _snapshot_to_split(snapshot_month)
if split_name in existing_splits and not force:
print(f" {_GRAY}{snapshot_month} — already on HF, skipping{_RESET}")
return
papers = [p for p in all_papers if p.get("_month", "") <= snapshot_month]
if not papers:
print(f" {_YELLOW}{snapshot_month} — no papers, skipping{_RESET}")
return
print(f" {_CYAN}{snapshot_month}{_RESET}{len(papers)} papers...", end="", flush=True)
lc_en, months_en, tbm_en, tbt_en = compute_lifecycle(papers, lang="en")
lc_zh, _, tbm_zh, tbt_zh = compute_lifecycle(papers, lang="zh")
print(f" {len(lc_en)} topics (en), {len(lc_zh)} topics (zh)", end="", flush=True)
if no_push:
print(f" {_GRAY}[--no-push]{_RESET}")
else:
try:
push_lifecycle_to_hf(
lc_en, lc_zh, months_en, len(papers), snapshot_month,
topics_by_month_en=tbm_en, total_by_month_en=tbt_en,
topics_by_month_zh=tbm_zh, total_by_month_zh=tbt_zh,
)
print(f" {_GREEN}✓ pushed{_RESET}")
except Exception as e:
print(f" {_YELLOW}✗ push failed: {e}{_RESET}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Compute bimonthly topic lifecycle snapshots and push to HuggingFace"
)
parser.add_argument("--snapshot", type=str, default=None,
help="Snapshot month (YYYY-MM, even month). Default: latest bimonthly.")
parser.add_argument("--all", action="store_true",
help="Compute all missing bimonthly snapshots")
parser.add_argument("--no-push", action="store_true",
help="Skip pushing results to HuggingFace")
parser.add_argument("--force", action="store_true",
help="Re-compute and overwrite existing snapshots")
args = parser.parse_args()
print(f"\n {_BOLD}📊 Lifecycle Snapshot Retriever{_RESET}\n")
# Step 1: List dataset files
print(f" {_DIM}Listing dataset files...{_RESET}", end="", flush=True)
all_files = _list_repo_files(HF_DATASET_REPO)
if not all_files:
print(f"\n {_YELLOW}Error: could not list files — check HF_TOKEN{_RESET}")
return
print(f" {len(all_files)} files")
# Step 2: Load all papers
print(f" {_DIM}Loading all papers...{_RESET}", end="", flush=True)
t0 = time.time()
all_papers = _load_all_papers(all_files)
elapsed = time.time() - t0
print(f" {len(all_papers)} papers in {elapsed:.1f}s")
if not all_papers:
print(f" {_YELLOW}No papers found{_RESET}")
return
# Step 3: Determine data range
all_months = sorted(set(p["_month"] for p in all_papers if p.get("_month")))
print(f" {_DIM}Data range: {all_months[0]}{all_months[-1]} ({len(all_months)} months){_RESET}")
# List existing lifecycle splits
lifecycle_files = _list_repo_files(HF_LIFECYCLE_REPO)
existing_splits: set[str] = set()
for f in lifecycle_files:
name = f.split("/")[-1].replace(".parquet", "").replace(".arrow", "")
for part in name.split("-"):
if part.startswith("snapshot_"):
existing_splits.add(part)
# Step 4: Determine snapshots to compute
if args.all:
snapshots = [m for m in all_months if int(m[5:7]) in SNAPSHOT_MONTHS]
elif args.snapshot:
snapshots = [args.snapshot]
else:
now = datetime.now(timezone.utc)
last_completed = now.month - 1 if now.month > 1 else 12
snap_year = now.year if now.month > 1 else now.year - 1
snap_month = last_completed if last_completed % 2 == 0 else last_completed - 1
if snap_month <= 0:
snap_month = 12
snap_year -= 1
snapshots = [f"{snap_year}-{snap_month:02d}"]
print(f" {_DIM}Snapshots to process: {len(snapshots)}{_RESET}\n")
for snapshot in snapshots:
run_snapshot(snapshot, all_papers, existing_splits,
no_push=args.no_push, force=args.force)
print(f"\n {_GREEN}{_BOLD}{_RESET} Done\n")
if __name__ == "__main__":
main()