import json import glob import logging from datetime import datetime, timedelta, timezone from pathlib import Path import streamlit as st logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("paper_reader") # --------------------------------------------------------------------------- # Page config # --------------------------------------------------------------------------- st.set_page_config( page_title="Paper Espresso", page_icon="☕️", layout="wide", initial_sidebar_state="collapsed", ) # --------------------------------------------------------------------------- # Custom CSS – HuggingFace-inspired design # --------------------------------------------------------------------------- st.markdown( """ """, unsafe_allow_html=True, ) # --------------------------------------------------------------------------- # Data helpers # --------------------------------------------------------------------------- DATA_DIR = Path(__file__).resolve().parent.parent / "data" HF_DATASET_REPO = "Elfsong/hf_paper_summary" HF_TRENDING_REPO = "Elfsong/hf_paper_daily_trending" HF_MONTHLY_TRENDING_REPO = "Elfsong/hf_paper_monthly_trending" HF_LIFECYCLE_REPO = "Elfsong/hf_paper_lifecycle" def _get_hf_token() -> str | None: import os token = os.getenv("HF_TOKEN", "") if token: return token env_path = Path(__file__).resolve().parent.parent / ".env" if env_path.exists(): for line in env_path.read_text().splitlines(): if line.startswith("HF_TOKEN="): return line.split("=", 1)[1].strip() return None def _date_to_split(date_str: str) -> str: """Convert '2026-03-11' to 'date_2026_03_11' for valid split name.""" return "date_" + date_str.replace("-", "_") def _split_to_date(split_name: str) -> str: """Convert 'date_2026_03_11' back to '2026-03-11'.""" return split_name.replace("date_", "", 1).replace("_", "-") def _month_to_split(month_str: str) -> str: """Convert '2026-03' to 'month_2026_03'.""" return "month_" + month_str.replace("-", "_") def _last_day_of_month(year: int, month: int): """Return the last date of the given month.""" if month == 12: return datetime(year + 1, 1, 1, tzinfo=timezone.utc).date() - timedelta(days=1) return datetime(year, month + 1, 1, tzinfo=timezone.utc).date() - timedelta(days=1) def _list_repo_files(repo: str) -> list[str]: """List all files in a HF dataset repo (uncached, usable from any thread).""" from huggingface_hub import HfApi log.info("[_list_repo_files] listing files for %s", repo) token = _get_hf_token() api = HfApi(token=token) try: result = list(api.list_repo_files(repo, repo_type="dataset")) log.info("[_list_repo_files] %s → %d files", repo, len(result)) return result except Exception as e: log.error("[_list_repo_files] %s failed: %s", repo, e) return [] @st.cache_data(ttl=300, show_spinner=False) def _list_repo_files_cached(repo: str) -> list[str]: """List all files in a HF dataset repo (Streamlit cached).""" return _list_repo_files(repo) def _extract_splits(files: list[str], prefix: str = "date_") -> list[str]: """Extract sorted split names from a list of repo file paths.""" splits = set() for f in files: name = f.split("/")[-1] for part in name.replace(".parquet", "").replace(".arrow", "").split("-"): if part.startswith(prefix): splits.add(part) break return sorted(splits, reverse=True) @st.cache_data(ttl=300, show_spinner=False) def _list_dataset_splits() -> list[str]: """List available date splits from the HF dataset repo without loading data.""" return _extract_splits(_list_repo_files_cached(HF_DATASET_REPO)) def _download_split_rows(repo: str, split_name: str) -> list[dict]: """Download only the parquet files for ONE split, return rows as list[dict]. Uses hf_hub_download (per-file) instead of load_dataset (all-splits).""" import pandas as pd from huggingface_hub import hf_hub_download log.info("[_download_split_rows] repo=%s split=%s", repo, split_name) token = _get_hf_token() files = _list_repo_files_cached(repo) split_files = [f for f in files if split_name in f and f.endswith(".parquet")] log.debug("[_download_split_rows] matched %d parquet files: %s", len(split_files), split_files) if not split_files: return [] dfs = [] for f in split_files: try: log.info("[_download_split_rows] downloading %s", f) local_path = hf_hub_download( repo, f, repo_type="dataset", token=token ) log.info("[_download_split_rows] reading parquet %s", local_path) dfs.append(pd.read_parquet(local_path)) except Exception as e: log.error("[_download_split_rows] failed on %s: %s", f, e) continue if not dfs: return [] result = pd.concat(dfs, ignore_index=True).to_dict("records") log.info("[_download_split_rows] returning %d rows", len(result)) return result def _parse_paper_row(paper: dict) -> dict: """Decode JSON string fields in a paper row.""" for key in ("detailed_analysis", "detailed_analysis_zh"): v = paper.get(key, "{}") if isinstance(v, str): paper[key] = json.loads(v) if v else {} for key in ("topics", "topics_zh", "keywords", "keywords_zh"): v = paper.get(key, "[]") if isinstance(v, str): paper[key] = json.loads(v) if v else [] # pandas may convert list columns to numpy arrays if not isinstance(paper.get("authors"), list): try: paper["authors"] = list(paper["authors"]) except Exception: paper["authors"] = [] return paper @st.cache_data(ttl=300, show_spinner=False) def pull_from_hf_dataset(target_date: str | None = None) -> dict[str, list[dict]]: """Load a single date split from HF dataset. Returns {date_str: papers_list}.""" log.info("[pull_from_hf_dataset] target_date=%s", target_date) splits = _list_dataset_splits() if not splits: return {} if target_date: target_split = _date_to_split(target_date) if target_split not in splits: return {} split_to_load = target_split else: split_to_load = splits[0] date_str = _split_to_date(split_to_load) rows = _download_split_rows(HF_DATASET_REPO, split_to_load) if not rows: return {} papers = [_parse_paper_row(r) for r in rows] return {date_str: papers} @st.cache_data(ttl=300, show_spinner=False) def list_available_dates() -> list[str]: """Return available dates (YYYY-MM-DD) from HF dataset and local files, sorted descending.""" log.info("[list_available_dates] START") dates = set() # From HF dataset splits for split in _list_dataset_splits(): dates.add(_split_to_date(split)) # From local JSON files for date_str in find_json_files(): dates.add(date_str) result = sorted(dates, reverse=True) log.info("[list_available_dates] found %d dates", len(result)) return result def find_json_files() -> dict[str, Path]: """Return {date_str: path} for all summarized JSON files.""" files: dict[str, Path] = {} for fp in glob.glob(str(DATA_DIR / "hf_papers_*_summarized.json")): p = Path(fp) for part in p.stem.split("_"): if len(part) == 10 and part[4] == "-" and part[7] == "-": files[part] = p break return dict(sorted(files.items(), reverse=True)) def load_papers(source) -> list[dict]: if isinstance(source, (str, Path)): with open(source, "r", encoding="utf-8") as f: return json.load(f) return json.loads(source.read()) HF_THUMB = "https://cdn-thumbnails.huggingface.co/social-thumbnails/papers/{pid}.png" @st.cache_data(ttl=600, show_spinner=False) def load_papers_for_dates(dates: tuple[str, ...]) -> list[dict]: """Load and deduplicate papers across multiple dates (for monthly).""" all_papers: list[dict] = [] seen_ids: set[str] = set() for date_str in dates: day_papers: list[dict] = [] hf_data = pull_from_hf_dataset(target_date=date_str) if hf_data and date_str in hf_data: day_papers = hf_data[date_str] if not day_papers: json_files = find_json_files() if date_str in json_files: day_papers = load_papers(json_files[date_str]) for p in day_papers: pid = p.get("paper_id", "") if pid and pid not in seen_ids: seen_ids.add(pid) all_papers.append(p) return all_papers # --------------------------------------------------------------------------- # Trending summary # --------------------------------------------------------------------------- def _deserialize_trending_row(row: dict) -> dict: """Deserialize JSON string fields in a trending row.""" for key in ("top_topics", "top_topics_zh", "keywords", "keywords_zh"): v = row.get(key, "[]") if isinstance(v, str): row[key] = json.loads(v) if v else [] for key in ("topic_mapping", "topic_mapping_zh"): v = row.get(key) if isinstance(v, str): row[key] = json.loads(v) if v else {} return row @st.cache_data(ttl=300, show_spinner=False) def pull_trending_from_hf(target_date: str | None = None) -> dict | None: """Load trending summary from HF dataset. Returns dict or None.""" log.info("[pull_trending_from_hf] target_date=%s", target_date) files = _list_repo_files_cached(HF_TRENDING_REPO) splits = _extract_splits(files) if not splits: return None if target_date: target_split = _date_to_split(target_date) if target_split not in splits: return None split_to_load = target_split else: split_to_load = splits[0] rows = _download_split_rows(HF_TRENDING_REPO, split_to_load) if not rows: return None return _deserialize_trending_row(rows[0]) def get_cached_trending(date_str: str) -> tuple[dict | None, str]: """Try to load trending from HF cache only (no generation). Returns (trending_dict, date_range_str).""" log.info("[get_cached_trending] date_str=%s", date_str) trending = pull_trending_from_hf(target_date=date_str) if trending: return trending, trending.get("date_range", "") return None, "" # --------------------------------------------------------------------------- # Monthly trending (read-only from HF, generated by monthly_retrieve.py) # --------------------------------------------------------------------------- @st.cache_data(ttl=300, show_spinner=False) def pull_monthly_trending_from_hf(month_str: str) -> dict | None: """Load monthly trending summary from HF dataset.""" log.info("[pull_monthly_trending] month_str=%s", month_str) files = _list_repo_files_cached(HF_MONTHLY_TRENDING_REPO) splits = _extract_splits(files, prefix="month_") if not splits: return None target_split = _month_to_split(month_str) if target_split not in splits: return None rows = _download_split_rows(HF_MONTHLY_TRENDING_REPO, target_split) if not rows: return None return _deserialize_trending_row(rows[0]) # --------------------------------------------------------------------------- # Topic lifecycle (read-only from HF, generated by lifecycle_retrieve.py) # --------------------------------------------------------------------------- _PHASES_ORDER = [ "Innovation Trigger", "Peak of Inflated Expectations", "Trough of Disillusionment", "Slope of Enlightenment", "Plateau of Productivity", ] @st.cache_data(ttl=300, show_spinner=False) def pull_lifecycle_from_hf(snapshot_str: str) -> dict | None: """Load a pre-computed lifecycle snapshot from HF.""" log.info("[pull_lifecycle] snapshot_str=%s", snapshot_str) files = _list_repo_files_cached(HF_LIFECYCLE_REPO) splits = _extract_splits(files, prefix="snapshot_") target_split = "snapshot_" + snapshot_str.replace("-", "_") if target_split not in splits: return None rows = _download_split_rows(HF_LIFECYCLE_REPO, target_split) if not rows: return None row = rows[0] return { "lifecycle_data": json.loads(row.get("lifecycle_data", "{}")), "lifecycle_data_zh": json.loads(row.get("lifecycle_data_zh", "{}")), "sorted_months": json.loads(row.get("sorted_months", "[]")), "n_papers": row.get("n_papers", 0), "n_months": row.get("n_months", 0), "topics_by_month": json.loads(row.get("topics_by_month", "{}")), "total_by_month": json.loads(row.get("total_by_month", "{}")), "topics_by_month_zh": json.loads(row.get("topics_by_month_zh", "{}")), "total_by_month_zh": json.loads(row.get("total_by_month_zh", "{}")), } def _render_hype_cycle(lifecycle_data: dict, lang: bool): """Render a Gartner-style hype cycle figure with matplotlib.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patheffects as pe import matplotlib.font_manager as fm import numpy as np from scipy.interpolate import CubicSpline from collections import defaultdict # Use a CJK-capable font when rendering Chinese if lang: cjk_candidates = [ "Noto Sans CJK JP", "Noto Sans CJK SC", "Noto Sans CJK TC", "Noto Sans CJK", "PingFang SC", "Heiti SC", "Heiti TC", "Microsoft YaHei", "Noto Sans SC", "WenQuanYi Micro Hei", "Source Han Sans SC", "SimHei", "AR PL UMing CN", ] available = {f.name for f in fm.fontManager.ttflist} cjk_font = next((f for f in cjk_candidates if f in available), None) if cjk_font: plt.rcParams["font.family"] = "sans-serif" plt.rcParams["font.sans-serif"] = [cjk_font] plt.rcParams["axes.unicode_minus"] = False phase_groups: dict[str, list] = defaultdict(list) for lc in lifecycle_data.values(): phase_groups[lc["phase"]].append(lc) sort_keys = { "Innovation Trigger": lambda x: -x["current_avg"], "Peak of Inflated Expectations": lambda x: -x["total_count"], "Trough of Disillusionment": lambda x: -x["total_count"], "Slope of Enlightenment": lambda x: -x["total_count"], "Plateau of Productivity": lambda x: -x["total_count"], } for phase, key_fn in sort_keys.items(): if phase in phase_groups: phase_groups[phase].sort(key=key_fn) max_per_phase = { "Innovation Trigger": 3, "Peak of Inflated Expectations": 4, "Trough of Disillusionment": 4, "Slope of Enlightenment": 3, "Plateau of Productivity": 2, } selected = [] for phase in _PHASES_ORDER: selected.extend(phase_groups.get(phase, [])[:max_per_phase[phase]]) if not selected: return None # Hype cycle curve (cubic spline through control points) ctrl_x = np.array([0, 0.5, 1, 1.5, 2, 2.3, 2.8, 3.2, 3.8, 4.5, 5.5, 6.5, 7.5, 8.5, 10]) ctrl_y = np.array([.02, .08, .22, .58, .98, .78, .38, .16, .10, .15, .26, .36, .42, .45, .47]) cs = CubicSpline(ctrl_x, ctrl_y) curve_x = np.linspace(0, 10, 500) curve_y = cs(curve_x) phase_ranges = { "Innovation Trigger": (0.3, 1.3), "Peak of Inflated Expectations": (1.4, 2.6), "Trough of Disillusionment": (2.8, 4.3), "Slope of Enlightenment": (4.8, 7.2), "Plateau of Productivity": (7.5, 9.5), } phase_colors = { "Innovation Trigger": "#16a34a", "Peak of Inflated Expectations": "#dc2626", "Trough of Disillusionment": "#2563eb", "Slope of Enlightenment": "#d97706", "Plateau of Productivity": "#6b7280", } phase_labels = { "Innovation Trigger": "技术\n萌芽期" if lang else "Innovation\nTrigger", "Peak of Inflated Expectations": "期望\n膨胀期" if lang else "Peak of Inflated\nExpectations", "Trough of Disillusionment": "泡沫\n破裂期" if lang else "Trough of\nDisillusionment", "Slope of Enlightenment": "稳步\n爬升期" if lang else "Slope of\nEnlightenment", "Plateau of Productivity": "生产\n成熟期" if lang else "Plateau of\nProductivity", } offset_patterns = { "Innovation Trigger": [(0.10, "bottom"), (-0.08, "top"), (0.15, "bottom")], "Peak of Inflated Expectations": [(0.16, "bottom"), (-0.13, "top"), (0.10, "bottom"), (-0.09, "top")], "Trough of Disillusionment": [(0.10, "bottom"), (-0.07, "top"), (0.14, "bottom"), (-0.10, "top")], "Slope of Enlightenment": [(0.10, "bottom"), (-0.08, "top"), (0.14, "bottom")], "Plateau of Productivity": [(0.10, "bottom"), (-0.08, "top")], } fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(curve_x, curve_y, color="#d1d5db", linewidth=3.5, zorder=1, solid_capstyle="round") ax.fill_between(curve_x, 0, curve_y, alpha=0.03, color="#9ca3af") for bx in [1.35, 2.7, 4.55, 7.35]: ax.axvline(bx, color="#e5e7eb", linewidth=0.6, linestyle="--", zorder=0) for phase in _PHASES_ORDER: x_lo, x_hi = phase_ranges[phase] pts = [lc for lc in selected if lc["phase"] == phase] if not pts: continue x_positions = np.linspace(x_lo, x_hi, len(pts) + 2)[1:-1] color = phase_colors[phase] offsets = offset_patterns[phase] for i, lc in enumerate(pts): xp = x_positions[i] yp = float(cs(xp)) dot_size = max(4, min(10, lc["total_count"] / 30)) ax.plot(xp, yp, "o", color=color, markersize=dot_size, zorder=3, markeredgecolor="white", markeredgewidth=0.6) offset_y, va = offsets[i % len(offsets)] ax.annotate( lc["topic"], xy=(xp, yp), xytext=(xp, yp + offset_y), fontsize=7, color=color, fontweight="bold", ha="center", va=va, arrowprops=dict(arrowstyle="-", color=color, alpha=0.3, lw=0.5), path_effects=[pe.withStroke(linewidth=2.5, foreground="white")], ) for phase in _PHASES_ORDER: x_lo, x_hi = phase_ranges[phase] ax.text((x_lo + x_hi) / 2, -0.10, phase_labels[phase], fontsize=7, ha="center", va="top", color=phase_colors[phase], fontweight="bold", style="italic") ax.set_xlim(-0.3, 10.3) ax.set_ylim(-0.22, 1.20) ax.set_ylabel("关注度" if lang else "Visibility", fontsize=10) ax.annotate( "成熟度" if lang else "Maturity", xy=(10.2, -0.03), xytext=(8.5, -0.03), fontsize=9, arrowprops=dict(arrowstyle="->", color="#6b7280", lw=1.2), color="#6b7280", va="center", ) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) fig.tight_layout() return fig # --------------------------------------------------------------------------- # Summary dialog # --------------------------------------------------------------------------- @st.dialog("📄 Summary", width="large") def show_summary(paper: dict): st.markdown(f"### {paper.get('title', '')}") # Authors authors = paper.get("authors", []) if authors: st.caption(", ".join(authors)) # Resource links links_html = f"""""" st.markdown(links_html, unsafe_allow_html=True) # Use global language toggle lang = st.session_state.get("lang_toggle", False) # Topics & Keywords if lang: topics = paper.get("topics_zh", []) or paper.get("topics", []) kws = paper.get("keywords_zh", []) or paper.get("keywords", []) else: topics = paper.get("topics", []) kws = paper.get("keywords", []) if topics or kws: lines = [] if topics: topic_spans = "".join( f'{t}' for t in topics ) lines.append(f'
{topic_spans}
') if kws: kw_spans = "".join( f'{k}' for k in kws ) lines.append(f'
{kw_spans}
') st.markdown( f'
{"".join(lines)}
', unsafe_allow_html=True, ) # TL;DR if lang: concise = paper.get("concise_summary_zh", "") or paper.get( "concise_summary", "" ) else: concise = paper.get("concise_summary", "") if concise: st.markdown("#### 📝 TL;DR") st.markdown(concise) # Detailed Analysis if lang: analysis = paper.get("detailed_analysis_zh", {}) or paper.get( "detailed_analysis", {} ) else: analysis = paper.get("detailed_analysis", {}) if analysis: st.divider() st.markdown("#### 🔬 Detailed Analysis" if not lang else "#### 🔬 详细分析") st.markdown(analysis.get("summary", "")) st.divider() col_a, col_b = st.columns(2) with col_a: pros = analysis.get("pros", []) pros_html = "".join(f'
{p}
' for p in pros) label = "✓ Strengths" if not lang else "✓ 优势" st.markdown( f'
{label}
{pros_html}
', unsafe_allow_html=True, ) with col_b: cons = analysis.get("cons", []) cons_html = "".join(f'
{c}
' for c in cons) label = "✗ Limitations" if not lang else "✗ 不足" st.markdown( f'
{label}
{cons_html}
', unsafe_allow_html=True, ) # --------------------------------------------------------------------------- # Render paper card # --------------------------------------------------------------------------- def render_card(paper: dict, rank: int, tab_key: str = ""): pid = paper.get("paper_id", "") title = paper.get("title", "Untitled") authors = paper.get("authors", []) thumb_url = HF_THUMB.format(pid=pid) if authors: authors_str = ", ".join(authors) else: authors_str = "Unknown authors" with st.container(border=True): # Thumbnail st.image(thumb_url, width="stretch") # Title as clickable button if st.button(f"**{title}**", key=f"card-{tab_key}-{rank}", use_container_width=True): show_summary(paper) # Authors lang = st.session_state.get("lang_toggle", False) if lang: topics = paper.get("topics_zh", []) or paper.get("topics", []) else: topics = paper.get("topics", []) topic_spans = "".join( f'{t}' for t in topics ) html = f"""
{authors_str}
{topic_spans}
""" st.markdown(html, unsafe_allow_html=True) # --------------------------------------------------------------------------- # Shared rendering helpers # --------------------------------------------------------------------------- def _render_trending_content(trending: dict, trending_date_range: str, lang: bool, placeholder): """Render a trending summary dict into the given placeholder.""" if lang: summary_text = trending.get("trending_summary_zh", "") or trending.get("trending_summary", "") topics = trending.get("top_topics_zh", []) or trending.get("top_topics", []) keywords = trending.get("keywords_zh", []) or trending.get("keywords", []) else: summary_text = trending.get("trending_summary", "") topics = trending.get("top_topics", []) keywords = trending.get("keywords", []) topics_html = " ".join( f'{t}' for t in topics ) keywords_html = " ".join( f'{k}' for k in keywords ) date_range_label = ( f'({trending_date_range})' if trending_date_range else "" ) placeholder.markdown( f"""
{"🔥 趋势" if lang else "🔥 Trending"} {date_range_label}
{summary_text}
{topics_html}
{keywords_html}
""", unsafe_allow_html=True, ) def _render_trending(date_str: str, lang: bool, placeholder): """Load and render trending summary into the given placeholder.""" _trending_cache_key = f"trending_{date_str}" trending = None trending_date_range = "" if _trending_cache_key in st.session_state: trending, trending_date_range = st.session_state[_trending_cache_key] else: with st.spinner("Loading trending summary..."): trending, trending_date_range = get_cached_trending(date_str) if trending: st.session_state[_trending_cache_key] = (trending, trending_date_range) if not trending: return _render_trending_content(trending, trending_date_range, lang, placeholder) def _get_paper_topics(paper: dict, lang: bool) -> list[str]: """Get topic labels for a paper, respecting language preference.""" if lang: return paper.get("topics_zh", []) or paper.get("topics", []) return paper.get("topics", []) def _render_papers_section( papers: list[dict], lang: bool, date_str: str, tab_key: str, clustered_topics: list[str] | None = None, topic_mapping: dict[str, list[str]] | None = None, trending_data: tuple[dict, str] | None = None, ): """Render trending, topic filters, and paper grid for a list of papers.""" if not papers: st.error("No papers retrieved. Please check back later.") return papers.sort(key=lambda p: p.get("upvotes", 0), reverse=True) trending_placeholder = st.empty() if clustered_topics: all_topics = clustered_topics else: all_topics = [] seen_topics: set[str] = set() for p in papers: for t in _get_paper_topics(p, lang): if t not in seen_topics: seen_topics.add(t) all_topics.append(t) selected_topics: list[str] = [] if all_topics: selected_topics = st.pills( "🏷️ Filter by topic" if not lang else "🏷️ 按主题筛选", options=all_topics, selection_mode="multi", default=None, key=f"topic_filter_{tab_key}", ) if selected_topics: if topic_mapping: match_set: set[str] = set() for sel in selected_topics: match_set.update(topic_mapping.get(sel, [sel])) else: match_set = set(selected_topics) display_papers = [ p for p in papers if any(t in match_set for t in _get_paper_topics(p, lang)) ] else: display_papers = papers NUM_COLS = 3 for row_start in range(0, len(display_papers), NUM_COLS): cols = st.columns(NUM_COLS, gap="medium") for col_idx, col in enumerate(cols): paper_idx = row_start + col_idx if paper_idx >= len(display_papers): break with col: render_card(display_papers[paper_idx], rank=paper_idx + 1, tab_key=tab_key) if trending_data: _render_trending_content(trending_data[0], trending_data[1], lang, trending_placeholder) else: _render_trending(date_str, lang, trending_placeholder) # --------------------------------------------------------------------------- # Main content # --------------------------------------------------------------------------- yesterday_str = (datetime.now(timezone.utc) - timedelta(days=1)).strftime("%Y-%m-%d") # --- Header --- today = datetime.now(timezone.utc).date() hdr = st.columns([1, 2, 1.2, 0.8, 4.5], vertical_alignment="center") with hdr[0]: st.markdown("**☕️ Paper Espresso**") with hdr[2]: active_tab = st.selectbox( "Mode", ["Daily", "Monthly", "Lifecycle"], label_visibility="collapsed", key="mode_select", ) with hdr[3]: is_chinese = st.toggle("中文", key="lang_toggle") lang = is_chinese # ---- Daily ---- if active_tab == "Daily": with hdr[1]: available_dates = sorted( [_split_to_date(s) for s in _list_dataset_splits()], reverse=True, ) selected_date = st.date_input( "Select date", value=( datetime.strptime(available_dates[0], "%Y-%m-%d").date() if available_dates else (today - timedelta(days=1)).date() ), format="YYYY-MM-DD", label_visibility="collapsed", key="daily_date", ) selected_date_str = selected_date.strftime("%Y-%m-%d") papers: list[dict] = [] _papers_cache_key = f"papers_daily_{selected_date_str}" if _papers_cache_key not in st.session_state: with st.spinner("Loading papers..."): hf_data = pull_from_hf_dataset(target_date=selected_date_str) if hf_data: papers = hf_data[selected_date_str] if not papers: json_files = find_json_files() if selected_date_str in json_files: papers = load_papers(json_files[selected_date_str]) st.session_state[_papers_cache_key] = papers else: papers = st.session_state[_papers_cache_key] if papers: st.toast(f"**{len(papers)}** papers found for {selected_date_str}", icon="📰") else: st.toast(f"No papers found for {selected_date_str}", icon="⚠️") _render_papers_section(papers, lang, selected_date_str, "daily") # ---- Monthly tab ---- elif active_tab == "Monthly": _monthly_splits_key = "monthly_available_splits" if _monthly_splits_key not in st.session_state: trending_files = _list_repo_files_cached(HF_MONTHLY_TRENDING_REPO) st.session_state[_monthly_splits_key] = sorted( [s.replace("month_", "").replace("_", "-") for s in _extract_splits(trending_files, prefix="month_")], reverse=True, ) month_options = st.session_state[_monthly_splits_key] if not month_options: st.info("No monthly data available yet. Run `uv run python src/monthly_retrieve.py` to generate.") else: with hdr[1]: selected_month = st.selectbox( "Select month", options=month_options, label_visibility="collapsed", key="monthly_select", ) year, month_num = int(selected_month[:4]), int(selected_month[5:7]) first_day = datetime(year, month_num, 1, tzinfo=timezone.utc).date() last_day = _last_day_of_month(year, month_num) month_dates = tuple( (first_day + timedelta(days=i)).strftime("%Y-%m-%d") for i in range((last_day - first_day).days + 1) ) # --- Load trending from HF (pre-generated by monthly_retrieve.py) --- _mt_cache_key = f"monthly_trending_{selected_month}" monthly_trending = None if _mt_cache_key in st.session_state: monthly_trending = st.session_state[_mt_cache_key] else: monthly_trending = pull_monthly_trending_from_hf(selected_month) if monthly_trending: st.session_state[_mt_cache_key] = monthly_trending # --- Load papers --- _monthly_cache_key = f"papers_monthly_{selected_month}" if _monthly_cache_key not in st.session_state: with st.spinner(f"Loading papers for {selected_month}..."): st.session_state[_monthly_cache_key] = load_papers_for_dates(month_dates) monthly_papers = st.session_state[_monthly_cache_key] if not monthly_papers: st.warning(f"No papers found for {selected_month}") else: # --- Statistics + histogram --- from collections import Counter total_papers = len(monthly_papers) st.metric("Papers", f"{total_papers:,}") date_counts = Counter() for p in monthly_papers: d = p.get("_date", "") or p.get("published_at", "")[:10] if d: date_counts[d] += 1 if date_counts: import pandas as pd import altair as alt all_days = [ (first_day + timedelta(days=i)).strftime("%Y-%m-%d") for i in range((last_day - first_day).days + 1) ] weekdays = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] labels = [ f"{d[-2:]} ({weekdays[datetime.strptime(d, '%Y-%m-%d').weekday()]})" for d in all_days ] df = pd.DataFrame({ "date": all_days, "label": labels, "papers": [date_counts.get(d, 0) for d in all_days], }) weekday_names = [weekdays[datetime.strptime(d, "%Y-%m-%d").weekday()] for d in all_days] df["weekday"] = weekday_names weekday_colors = alt.Scale( domain=weekdays, range=["#2563eb", "#7c3aed", "#0891b2", "#059669", "#d97706", "#e11d48", "#dc2626"], ) chart = alt.Chart(df).mark_bar(cornerRadiusTopLeft=3, cornerRadiusTopRight=3).encode( x=alt.X("label:N", sort=None, axis=alt.Axis(title=None, labelAngle=-45, labelFontSize=9)), y=alt.Y("papers:Q", axis=alt.Axis(title=None, labels=False, ticks=False)), color=alt.Color("weekday:N", scale=weekday_colors, legend=None), tooltip=["date:N", "papers:Q"], ).properties(height=180).configure_bar( discreteBandSize=12, ) st.altair_chart(chart, use_container_width=True) # --- Trending insights --- topics = [] topic_mapping = {} if monthly_trending: if lang: summary_text = monthly_trending.get("trending_summary_zh", "") or monthly_trending.get("trending_summary", "") topics = monthly_trending.get("top_topics_zh", []) or monthly_trending.get("top_topics", []) keywords = monthly_trending.get("keywords_zh", []) or monthly_trending.get("keywords", []) topic_mapping = monthly_trending.get("topic_mapping_zh", {}) or monthly_trending.get("topic_mapping", {}) else: summary_text = monthly_trending.get("trending_summary", "") topics = monthly_trending.get("top_topics", []) keywords = monthly_trending.get("keywords", []) topic_mapping = monthly_trending.get("topic_mapping", {}) st.markdown( f"""
{"🔥 月度趋势" if lang else "🔥 Monthly Insights"}
{summary_text}
""", unsafe_allow_html=True, ) if keywords: kw_html = " ".join( f'{k}' for k in keywords ) st.markdown( f'
{kw_html}
', unsafe_allow_html=True, ) # --- Topic Co-occurrence Heatmap --- all_paper_topics = [_get_paper_topics(p, lang) for p in monthly_papers] all_paper_topics = [ts for ts in all_paper_topics if ts] if all_paper_topics: import pandas as pd import altair as alt from matplotlib.colors import Normalize, LinearSegmentedColormap topic_freq = Counter() for ts in all_paper_topics: topic_freq.update(ts) top_n = 40 top_cooc_topics = [t for t, _ in topic_freq.most_common(top_n)] top_set = set(top_cooc_topics) n = len(top_cooc_topics) topic_idx = {t: i for i, t in enumerate(top_cooc_topics)} cooc_counts = Counter() for ts in all_paper_topics: filtered = sorted(set(t for t in ts if t in top_set)) for i in range(len(filtered)): for j in range(i + 1, len(filtered)): cooc_counts[(filtered[i], filtered[j])] += 1 if cooc_counts: import numpy as np # Build symmetric co-occurrence matrix matrix = np.zeros((n, n), dtype=int) for (t1, t2), count in cooc_counts.items(): i, j = topic_idx[t1], topic_idx[t2] matrix[i, j] = count matrix[j, i] = count # Per-topic paper counts for Jaccard topic_paper_count = Counter() for ts in all_paper_topics: for t in set(ts): if t in topic_idx: topic_paper_count[t] += 1 # Jaccard matrix jaccard = np.zeros((n, n)) for i in range(n): for j in range(n): if i == j: continue intersection = matrix[i, j] union = topic_paper_count[top_cooc_topics[i]] + topic_paper_count[top_cooc_topics[j]] - intersection jaccard[i, j] = intersection / union if union > 0 else 0 # Pre-compute hex colors for the split heatmap cmap_count = LinearSegmentedColormap.from_list("gray_red", ["#d0d0d0", "#e04040"]) cmap_jaccard = LinearSegmentedColormap.from_list("gray_blue", ["#d0d0d0", "#4080e0"]) off_diag = matrix[~np.eye(n, dtype=bool)] vmax_count = int(off_diag.max()) if off_diag.size > 0 else 1 vmax_jaccard = float(jaccard[~np.eye(n, dtype=bool)].max()) or 1.0 norm_count = Normalize(vmin=0, vmax=vmax_count) norm_jaccard = Normalize(vmin=0, vmax=vmax_jaccard) def _rgba_to_hex(rgba): r, g, b = (int(c * 255) for c in rgba[:3]) return f"#{r:02x}{g:02x}{b:02x}" cooc_rows = [] for ri, t_row in enumerate(top_cooc_topics): for ci, t_col in enumerate(top_cooc_topics): if ri == ci: cooc_rows.append({"topic_a": t_col, "topic_b": t_row, "count": 0, "jaccard": 0.0, "metric": "—", "color": "#f6f8fa"}) elif ri > ci: # lower triangle: counts val = int(matrix[ri, ci]) cooc_rows.append({"topic_a": t_col, "topic_b": t_row, "count": val, "jaccard": 0.0, "metric": "count", "color": _rgba_to_hex(cmap_count(norm_count(val)))}) else: # upper triangle: jaccard jval = float(jaccard[ri, ci]) cooc_rows.append({"topic_a": t_col, "topic_b": t_row, "count": 0, "jaccard": round(jval, 4), "metric": "jaccard", "color": _rgba_to_hex(cmap_jaccard(norm_jaccard(jval)))}) cooc_df = pd.DataFrame(cooc_rows) heatmap = alt.Chart(cooc_df).mark_rect(cornerRadius=2).encode( x=alt.X("topic_a:N", sort=top_cooc_topics, title=None, axis=alt.Axis(labelAngle=-45, labelFontSize=9, labelOverlap=False)), y=alt.Y("topic_b:N", sort=top_cooc_topics, title=None, axis=alt.Axis(labelFontSize=9, labelOverlap=False)), color=alt.Color("color:N", scale=None), tooltip=[ alt.Tooltip("topic_a:N", title="Topic X"), alt.Tooltip("topic_b:N", title="Topic Y"), alt.Tooltip("count:Q", title="Co-occurrence"), alt.Tooltip("jaccard:Q", title="Jaccard", format=".3f"), ], ).properties( width=alt.Step(25), height=alt.Step(25), ) with st.expander( "🔗 " + ("主题共现图" if lang else "Topic Co-occurrence"), expanded=False, ): st.altair_chart(heatmap, use_container_width=False) # --- Topic filter --- if not topics: seen: set[str] = set() for p in monthly_papers: for t in _get_paper_topics(p, lang): if t not in seen: seen.add(t) topics.append(t) selected_topics: list[str] = [] if topics: selected_topics = st.pills( "🏷️ Filter by topic" if not lang else "🏷️ 按主题筛选", options=topics, selection_mode="multi", default=None, key="topic_filter_monthly", ) # --- Filter papers --- if selected_topics: if topic_mapping: match_set: set[str] = set() for sel in selected_topics: match_set.update(topic_mapping.get(sel, [sel])) else: match_set = set(selected_topics) display_papers = [ p for p in monthly_papers if any(t in match_set for t in _get_paper_topics(p, lang)) ] else: display_papers = monthly_papers display_papers.sort(key=lambda p: p.get("upvotes", 0), reverse=True) if selected_topics: st.caption(f"Showing {len(display_papers)} of {total_papers} papers") # --- Paper card grid (3 columns) --- NUM_COLS = 3 for row_start in range(0, len(display_papers), NUM_COLS): cols = st.columns(NUM_COLS, gap="medium") for col_idx, col in enumerate(cols): paper_idx = row_start + col_idx if paper_idx >= len(display_papers): break with col: render_card(display_papers[paper_idx], rank=paper_idx + 1, tab_key="monthly") # ---- Lifecycle tab ---- elif active_tab == "Lifecycle": _lc_splits_key = "lifecycle_available_snapshots" if _lc_splits_key not in st.session_state: lc_files = _list_repo_files_cached(HF_LIFECYCLE_REPO) st.session_state[_lc_splits_key] = sorted( [s.replace("snapshot_", "").replace("_", "-") for s in _extract_splits(lc_files, prefix="snapshot_")], reverse=True, ) snapshot_options = st.session_state[_lc_splits_key] if not snapshot_options: st.info("No lifecycle data available yet. Run `uv run python src/lifecycle_retrieve.py --all` to generate.") else: with hdr[1]: selected_snapshot = st.selectbox( "Select snapshot", options=snapshot_options, label_visibility="collapsed", key="lifecycle_select", ) _lc_cache_key = f"lifecycle_{selected_snapshot}" lc_raw = None if _lc_cache_key in st.session_state: lc_raw = st.session_state[_lc_cache_key] else: lc_raw = pull_lifecycle_from_hf(selected_snapshot) if lc_raw: st.session_state[_lc_cache_key] = lc_raw if not lc_raw: st.warning(f"Could not load lifecycle data for {selected_snapshot}") else: lc_data = lc_raw["lifecycle_data_zh"] if lang else lc_raw["lifecycle_data"] sorted_months = lc_raw["sorted_months"] st.metric("Papers", f"{lc_raw['n_papers']:,}") if sorted_months: st.caption( f"{lc_raw['n_months']} months ({sorted_months[0]} → {sorted_months[-1]})" ) if not lc_data: st.warning("Not enough data for lifecycle analysis.") else: fig = _render_hype_cycle(lc_data, lang) if fig: import matplotlib.pyplot as plt st.pyplot(fig, use_container_width=True) plt.close(fig) # --- Topic selector & time-series chart --- tbm = lc_raw.get("topics_by_month_zh" if lang else "topics_by_month") or {} tbt = lc_raw.get("total_by_month_zh" if lang else "total_by_month") or {} if tbm and tbt: all_topic_names = sorted( lc_data.keys(), key=lambda t: -lc_data[t]["total_count"], ) selected_topics = st.multiselect( "📊 " + ("选择主题(最多5个)" if lang else "Select topics (max 5)"), options=all_topic_names, default=all_topic_names[:3], max_selections=5, key="lifecycle_topic_select", ) if selected_topics: import pandas as pd import altair as alt count_rows = [] prop_rows = [] for m in sorted_months: month_topics = tbm.get(m, {}) month_total = tbt.get(m, 0) for t in selected_topics: c = month_topics.get(t, 0) count_rows.append({"Month": m, "Topic": t, "Count": c}) prop_rows.append({ "Month": m, "Topic": t, "Proportion": round(c / month_total, 4) if month_total > 0 else 0, }) df_count = pd.DataFrame(count_rows) df_prop = pd.DataFrame(prop_rows) def _alt_line(df, y_field, y_title): nearest = alt.selection_point( nearest=True, on="pointerover", fields=["Month"], empty=False, ) line = alt.Chart(df).mark_line( interpolate="monotone", strokeWidth=2, ).encode( x=alt.X("Month:N", sort=sorted_months, title=None, axis=alt.Axis(labelAngle=-45, labelFontSize=8)), y=alt.Y(f"{y_field}:Q", title=y_title, axis=alt.Axis(titleFontSize=10)), color=alt.Color("Topic:N", legend=alt.Legend( orient="top", title=None, labelFontSize=9)), ) points = line.mark_point(size=40).encode( opacity=alt.condition(nearest, alt.value(1), alt.value(0)), tooltip=[ alt.Tooltip("Month:N"), alt.Tooltip("Topic:N"), alt.Tooltip(f"{y_field}:Q", title=y_title, format=".4f" if y_field == "Proportion" else "d"), ], ).add_params(nearest) rule = alt.Chart(df).mark_rule(color="gray", strokeDash=[4, 4]).encode( x="Month:N", ).transform_filter(nearest) return (line + points + rule).properties(height=260) col_a, col_b = st.columns(2) with col_a: chart_c = _alt_line(df_count, "Count", "论文数量" if lang else "Paper Count") st.altair_chart(chart_c, use_container_width=True) with col_b: chart_p = _alt_line(df_prop, "Proportion", "占比" if lang else "Proportion") st.altair_chart(chart_p, use_container_width=True) _phase_labels_zh = { "Innovation Trigger": "技术萌芽期", "Peak of Inflated Expectations": "期望膨胀期", "Trough of Disillusionment": "泡沫破裂期", "Slope of Enlightenment": "稳步爬升期", "Plateau of Productivity": "生产成熟期", } phase_icons = { "Innovation Trigger": "🌱", "Peak of Inflated Expectations": "🔥", "Trough of Disillusionment": "📉", "Slope of Enlightenment": "📈", "Plateau of Productivity": "⚙️", } for phase in _PHASES_ORDER: topics_in_phase = sorted( [lc for lc in lc_data.values() if lc["phase"] == phase], key=lambda x: -x["total_count"], ) if not topics_in_phase: continue icon = phase_icons[phase] label = _phase_labels_zh[phase] if lang else phase unit = "个主题" if lang else "topics" with st.expander(f"{icon} {label} ({len(topics_in_phase)} {unit})"): for lc in topics_in_phase: st.markdown( f"**{lc['topic']}** — {lc['total_count']} papers, " f"peak: {lc['peak_month']}, trend: {lc['slope']:+.4f}" )