""" Lifecycle Snapshot Retriever -- compute bimonthly topic lifecycle snapshots. Computes Gartner-style hype cycle classification for research topics using all available paper data up to each snapshot month (every 2 months). Results are pushed to Elfsong/hf_paper_lifecycle. Usage: uv run python src/lifecycle_retrieve.py # latest snapshot uv run python src/lifecycle_retrieve.py --snapshot 2025-06 # specific snapshot uv run python src/lifecycle_retrieve.py --all # all missing snapshots uv run python src/lifecycle_retrieve.py --no-push # dry run """ import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" os.environ["DATASETS_VERBOSITY"] = "error" from tqdm import tqdm # noqa: E402 from functools import partialmethod # noqa: E402 tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) import argparse # noqa: E402 import json # noqa: E402 import logging # noqa: E402 import sys # noqa: E402 import time # noqa: E402 from collections import Counter, defaultdict # noqa: E402 from datetime import datetime, timezone # noqa: E402 from pathlib import Path # noqa: E402 import numpy as np # noqa: E402 from scipy.stats import linregress # noqa: E402 from dotenv import load_dotenv # noqa: E402 ROOT = Path(__file__).resolve().parent.parent load_dotenv(ROOT / ".env") for _name in ("datasets", "huggingface_hub", "huggingface_hub.utils", "fsspec", "datasets.utils", "datasets.arrow_writer"): logging.getLogger(_name).setLevel(logging.ERROR) # --------------------------------------------------------------------------- # ANSI helpers # --------------------------------------------------------------------------- _RESET = "\033[0m" _BOLD = "\033[1m" _DIM = "\033[2m" _GREEN = "\033[32m" _YELLOW = "\033[33m" _CYAN = "\033[36m" _GRAY = "\033[90m" # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- HF_DATASET_REPO = "Elfsong/hf_paper_summary" HF_LIFECYCLE_REPO = "Elfsong/hf_paper_lifecycle" # Bimonthly snapshot months (even months) SNAPSHOT_MONTHS = {2, 4, 6, 8, 10, 12} # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _get_env(key: str) -> str: val = os.getenv(key, "") if val: return val env_path = ROOT / ".env" if env_path.exists(): for line in env_path.read_text().splitlines(): if line.startswith(f"{key}="): return line.split("=", 1)[1].strip() return "" def _snapshot_to_split(snapshot_str: str) -> str: return "snapshot_" + snapshot_str.replace("-", "_") def _parse_paper_row(paper: dict) -> dict: for key in ("detailed_analysis", "detailed_analysis_zh"): v = paper.get(key, "{}") if isinstance(v, str): paper[key] = json.loads(v) if v else {} for key in ("topics", "topics_zh", "keywords", "keywords_zh"): v = paper.get(key, "[]") if isinstance(v, str): paper[key] = json.loads(v) if v else [] if not isinstance(paper.get("authors"), list): try: paper["authors"] = list(paper["authors"]) except Exception: paper["authors"] = [] return paper def _list_repo_files(repo: str) -> list[str]: from huggingface_hub import HfApi token = _get_env("HF_TOKEN") if not token: return [] try: api = HfApi(token=token) return list(api.list_repo_files(repo, repo_type="dataset")) except Exception: return [] def _load_all_papers(files: list[str]) -> list[dict]: """Download all parquet files and return papers with _date and _month.""" import pandas as pd from huggingface_hub import hf_hub_download token = _get_env("HF_TOKEN") parquet_files = [f for f in files if f.endswith(".parquet")] seen_ids: set[str] = set() papers: list[dict] = [] for i, pf in enumerate(parquet_files): fname = pf.split("/")[-1] date_part = fname.split("-00")[0] date_str = date_part.replace("date_", "").replace("_", "-") try: local_path = hf_hub_download( HF_DATASET_REPO, pf, repo_type="dataset", token=token, ) df = pd.read_parquet(local_path) for _, row in df.iterrows(): paper = row.to_dict() pid = paper.get("paper_id", "") if pid and pid not in seen_ids: seen_ids.add(pid) paper["_date"] = date_str paper["_month"] = date_str[:7] papers.append(_parse_paper_row(paper)) except Exception: continue if sys.stdout.isatty() and (i + 1) % 20 == 0: sys.stdout.write(f"\r {_DIM}Loading papers... {i+1}/{len(parquet_files)} files, {len(papers)} papers{_RESET}") sys.stdout.flush() if sys.stdout.isatty(): sys.stdout.write("\r\033[K") sys.stdout.flush() return papers # --------------------------------------------------------------------------- # Lifecycle computation # --------------------------------------------------------------------------- def _get_paper_topics(paper: dict, lang: str) -> list[str]: if lang == "zh": return paper.get("topics_zh", []) or paper.get("topics", []) return paper.get("topics", []) def compute_lifecycle(papers: list[str], lang: str = "en") -> tuple[dict, list[str], dict, dict]: """Compute lifecycle metrics for all topics from papers. Returns (lifecycle_dict, sorted_months, topics_by_month, total_by_month). """ topics_by_month: dict[str, Counter] = defaultdict(Counter) all_topics: Counter = Counter() for p in papers: month = p.get("_month", "") if not month: continue topics = _get_paper_topics(p, lang) topics_by_month[month].update(topics) all_topics.update(topics) sorted_months = sorted(topics_by_month.keys()) if len(sorted_months) < 2: return {}, sorted_months, {}, {} total_by_month = {m: sum(topics_by_month[m].values()) for m in sorted_months} n_months = len(sorted_months) min_papers = max(3, n_months) candidates = [t for t, c in all_topics.items() if c >= min_papers] lifecycle: dict = {} for topic in candidates: proportions = np.array([ topics_by_month[m].get(topic, 0) / total_by_month[m] if total_by_month[m] > 0 else 0 for m in sorted_months ]) counts = np.array([topics_by_month[m].get(topic, 0) for m in sorted_months]) nonzero = np.where(proportions > 0)[0] if len(nonzero) < 2: continue first_idx = int(nonzero[0]) peak_idx = int(np.argmax(proportions)) peak_val = float(proportions[peak_idx]) current_avg = float(np.mean(proportions[-min(3, n_months):])) window = min(6, n_months) recent = proportions[-window:] slope = float(linregress(np.arange(len(recent)), recent).slope) if len(recent) >= 3 else 0.0 decline_ratio = current_avg / peak_val if peak_val > 0 else 0 months_since_peak = n_months - 1 - peak_idx months_active = n_months - first_idx recent_window = min(8, len(counts)) recent_fraction = float(counts[-recent_window:].sum() / max(counts.sum(), 1)) # Phase classification (same thresholds as reference analysis script) dr, sl, ma, msp = decline_ratio, slope, months_active, months_since_peak tc = int(counts.sum()) rf = recent_fraction if ma <= 8 or (rf > 0.60 and tc < 200): phase = "Innovation Trigger" elif (dr > 0.70 and msp <= 6) or (sl > 0.001 and dr > 0.65): phase = "Peak of Inflated Expectations" elif dr < 0.65: phase = "Slope of Enlightenment" if sl > 0.0003 else "Trough of Disillusionment" elif sl < -0.001 and dr < 0.75: phase = "Trough of Disillusionment" elif dr < 0.85 and sl > 0.0005 and msp > 4: phase = "Slope of Enlightenment" else: phase = "Plateau of Productivity" lifecycle[topic] = { "topic": topic, "phase": phase, "total_count": tc, "peak_val": peak_val, "peak_month": sorted_months[peak_idx], "current_avg": current_avg, "slope": slope, "decline_ratio": decline_ratio, "months_since_peak": months_since_peak, "months_active": months_active, } # Convert Counters to plain dicts for serialisation tbm = {m: dict(topics_by_month[m]) for m in sorted_months} tbm_total = dict(total_by_month) return lifecycle, sorted_months, tbm, tbm_total # --------------------------------------------------------------------------- # Push to HuggingFace # --------------------------------------------------------------------------- def push_lifecycle_to_hf(lifecycle_en: dict, lifecycle_zh: dict, sorted_months: list[str], n_papers: int, snapshot_month: str, topics_by_month_en: dict | None = None, total_by_month_en: dict | None = None, topics_by_month_zh: dict | None = None, total_by_month_zh: dict | None = None): from datasets import Dataset token = _get_env("HF_TOKEN") if not token: raise RuntimeError("HF_TOKEN not set") row = { "lifecycle_data": json.dumps(lifecycle_en, ensure_ascii=False), "lifecycle_data_zh": json.dumps(lifecycle_zh, ensure_ascii=False), "sorted_months": json.dumps(sorted_months, ensure_ascii=False), "n_papers": n_papers, "n_months": len(sorted_months), "topics_by_month": json.dumps(topics_by_month_en or {}, ensure_ascii=False), "total_by_month": json.dumps(total_by_month_en or {}, ensure_ascii=False), "topics_by_month_zh": json.dumps(topics_by_month_zh or {}, ensure_ascii=False), "total_by_month_zh": json.dumps(total_by_month_zh or {}, ensure_ascii=False), } ds = Dataset.from_list([row]) split_name = _snapshot_to_split(snapshot_month) ds.push_to_hub(HF_LIFECYCLE_REPO, split=split_name, token=token) # --------------------------------------------------------------------------- # Run one snapshot # --------------------------------------------------------------------------- def run_snapshot(snapshot_month: str, all_papers: list[dict], existing_splits: set[str], no_push: bool = False, force: bool = False): split_name = _snapshot_to_split(snapshot_month) if split_name in existing_splits and not force: print(f" {_GRAY}⊘ {snapshot_month} — already on HF, skipping{_RESET}") return papers = [p for p in all_papers if p.get("_month", "") <= snapshot_month] if not papers: print(f" {_YELLOW}⊘ {snapshot_month} — no papers, skipping{_RESET}") return print(f" {_CYAN}⟳ {snapshot_month}{_RESET} — {len(papers)} papers...", end="", flush=True) lc_en, months_en, tbm_en, tbt_en = compute_lifecycle(papers, lang="en") lc_zh, _, tbm_zh, tbt_zh = compute_lifecycle(papers, lang="zh") print(f" {len(lc_en)} topics (en), {len(lc_zh)} topics (zh)", end="", flush=True) if no_push: print(f" {_GRAY}[--no-push]{_RESET}") else: try: push_lifecycle_to_hf( lc_en, lc_zh, months_en, len(papers), snapshot_month, topics_by_month_en=tbm_en, total_by_month_en=tbt_en, topics_by_month_zh=tbm_zh, total_by_month_zh=tbt_zh, ) print(f" {_GREEN}✓ pushed{_RESET}") except Exception as e: print(f" {_YELLOW}✗ push failed: {e}{_RESET}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Compute bimonthly topic lifecycle snapshots and push to HuggingFace" ) parser.add_argument("--snapshot", type=str, default=None, help="Snapshot month (YYYY-MM, even month). Default: latest bimonthly.") parser.add_argument("--all", action="store_true", help="Compute all missing bimonthly snapshots") parser.add_argument("--no-push", action="store_true", help="Skip pushing results to HuggingFace") parser.add_argument("--force", action="store_true", help="Re-compute and overwrite existing snapshots") args = parser.parse_args() print(f"\n {_BOLD}📊 Lifecycle Snapshot Retriever{_RESET}\n") # Step 1: List dataset files print(f" {_DIM}Listing dataset files...{_RESET}", end="", flush=True) all_files = _list_repo_files(HF_DATASET_REPO) if not all_files: print(f"\n {_YELLOW}Error: could not list files — check HF_TOKEN{_RESET}") return print(f" {len(all_files)} files") # Step 2: Load all papers print(f" {_DIM}Loading all papers...{_RESET}", end="", flush=True) t0 = time.time() all_papers = _load_all_papers(all_files) elapsed = time.time() - t0 print(f" {len(all_papers)} papers in {elapsed:.1f}s") if not all_papers: print(f" {_YELLOW}No papers found{_RESET}") return # Step 3: Determine data range all_months = sorted(set(p["_month"] for p in all_papers if p.get("_month"))) print(f" {_DIM}Data range: {all_months[0]} → {all_months[-1]} ({len(all_months)} months){_RESET}") # List existing lifecycle splits lifecycle_files = _list_repo_files(HF_LIFECYCLE_REPO) existing_splits: set[str] = set() for f in lifecycle_files: name = f.split("/")[-1].replace(".parquet", "").replace(".arrow", "") for part in name.split("-"): if part.startswith("snapshot_"): existing_splits.add(part) # Step 4: Determine snapshots to compute if args.all: snapshots = [m for m in all_months if int(m[5:7]) in SNAPSHOT_MONTHS] elif args.snapshot: snapshots = [args.snapshot] else: now = datetime.now(timezone.utc) last_completed = now.month - 1 if now.month > 1 else 12 snap_year = now.year if now.month > 1 else now.year - 1 snap_month = last_completed if last_completed % 2 == 0 else last_completed - 1 if snap_month <= 0: snap_month = 12 snap_year -= 1 snapshots = [f"{snap_year}-{snap_month:02d}"] print(f" {_DIM}Snapshots to process: {len(snapshots)}{_RESET}\n") for snapshot in snapshots: run_snapshot(snapshot, all_papers, existing_splits, no_push=args.no_push, force=args.force) print(f"\n {_GREEN}{_BOLD}✓{_RESET} Done\n") if __name__ == "__main__": main()