from __future__ import annotations
import base64
import os
from pathlib import Path
from typing import List
import pandas as pd
import networkx as nx
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
from pyvis.network import Network
import streamlit.components.v1 as components
HF_REPO_ID = os.environ.get("HF_REPO_ID", "")
def csv_download_link(data: bytes, filename: str, label: str) -> None:
b64 = base64.b64encode(data).decode()
st.markdown(
f''
f'{label}',
unsafe_allow_html=True,
)
HF_TOKEN = os.environ.get("HF_TOKEN", "")
st.set_page_config(page_title="CitationHub", page_icon="📚", layout="wide")
ALLOWED_INTENTS = [
"background","uses","similarities","motivation",
"differences","future_work","extends",
]
INTENT_COLORS = {
"background":"#94a3b8","uses":"#22c55e","similarities":"#3b82f6",
"motivation":"#f59e0b","differences":"#ef4444",
"future_work":"#8b5cf6","extends":"#06b6d4",
}
NODE_COLORS = {
"seed_paper":"#111827","citing_paper":"#dbeafe","citation_event":"#fde68a",
"journal":"#ede9fe","author":"#fee2e2","affiliation":"#fae8ff",
"city":"#cffafe","country":"#ffedd5","field":"#e0e7ff","intent":"#dcfce7",
}
NODE_TYPE_COLORS = {
"seed_paper":"#111827","citing_paper":"#3b82f6","citation_event":"#f59e0b",
"journal":"#8b5cf6","author":"#ef4444","affiliation":"#ec4899",
"city":"#06b6d4","country":"#f97316","field":"#6366f1","intent":"#22c55e",
}
DEFAULT_DATA_DIR = Path(os.environ.get(
"CITATIONHUB_DATA_DIR",
"/tmp/citationhub_data",
))
def fmt_num(x):
try: return f"{int(x):,}"
except: return "-"
def _hf_download(filename: str) -> str:
from huggingface_hub import hf_hub_download
return hf_hub_download(
repo_id=HF_REPO_ID, repo_type="dataset",
filename=f"data/{filename}", token=HF_TOKEN or None,
)
def _read(filename: str, data_dir: Path | None = None, columns: list | None = None) -> pd.DataFrame:
path = _hf_download(filename) if HF_REPO_ID else str(data_dir / filename)
return pd.read_parquet(path, columns=columns, engine="pyarrow")
def _safe_cols(path: str, wanted: list) -> list:
import pyarrow.parquet as pq
avail = set(pq.read_schema(path).names)
return [c for c in wanted if c in avail]
def plotly_network_fig(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
title: str = "",
height: int = 750,
seed_node_ids: list | None = None,
) -> go.Figure:
G = nx.Graph()
node_meta: dict = {}
for _, row in nodes_df.iterrows():
nid = str(row["node_id"])
G.add_node(nid)
node_meta[nid] = row
for _, row in edges_df.iterrows():
s, t = str(row["source"]), str(row["target"])
if s in node_meta and t in node_meta:
G.add_edge(s, t, edge_type=row.get("edge_type", ""))
if len(G.nodes) == 0:
return go.Figure()
k = max(1.5, 3.0 / (len(G.nodes) ** 0.4))
pos = nx.spring_layout(G, seed=42, k=k, iterations=60)
ex, ey = [], []
for src, tgt in G.edges():
x0, y0 = pos.get(src, (0, 0))
x1, y1 = pos.get(tgt, (0, 0))
ex += [x0, x1, None]
ey += [y0, y1, None]
traces: list[go.BaseTraceType] = [
go.Scatter(
x=ex, y=ey, mode="lines",
line=dict(width=0.8, color="#cbd5e1"),
hoverinfo="none", showlegend=False,
)
]
for ntype, color in NODE_TYPE_COLORS.items():
subset = nodes_df[nodes_df["node_type"] == ntype]
if subset.empty:
continue
xs, ys, hovers, texts = [], [], [], []
for _, row in subset.iterrows():
nid = str(row["node_id"])
if nid not in pos:
continue
x, y = pos[nid]
xs.append(x); ys.append(y)
label = str(row.get("label", ""))[:50]
texts.append(label if ntype == "seed_paper" else "")
hovers.append(
f"{label}
"
f"Type: {ntype}
"
f"DOI: {row.get('doi','') or '-'}
"
f"Pub: {row.get('publication_name','') or '-'}
"
f"Group: {row.get('group','') or '-'}"
)
is_seed = ntype == "seed_paper"
traces.append(go.Scatter(
x=xs, y=ys,
mode="markers+text" if is_seed else "markers",
text=texts, textposition="top center",
hovertext=hovers, hoverinfo="text",
name=ntype,
marker=dict(
size=20 if is_seed else 10,
color=color,
line=dict(width=1.5 if is_seed else 0.5, color="white"),
symbol="circle",
),
))
fig = go.Figure(data=traces)
fig.update_layout(
title=dict(text=title, font=dict(size=14)),
showlegend=True,
legend=dict(title="Node type", itemsizing="constant"),
hovermode="closest",
height=height,
margin=dict(l=0, r=0, t=40 if title else 10, b=0),
paper_bgcolor="white",
plot_bgcolor="#f8fafc",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
)
return fig
def plotly_ontology_fig(height: int = 820) -> go.Figure:
NODE_PROPS = {
"seed_paper": "doi · title · journal\nauthor · affiliation\ncountry · field · citedby_count",
"citation_event": "event_id · citing_year\nprimary_intent · context\nis_influential",
"citing_paper": "doi · title\nyear · venue · oa_pdf",
"intent": "background · uses\nsimilarities · motivation\ndifferences · future_work · extends",
"journal": "journal_name",
"author": "author_name · author_id",
"affiliation": "affiliation_name",
"city": "city_name",
"country": "country_name",
"field": "field_name",
}
node_defs = [
("seed", "Top5PctCitedPaper", "seed_paper"),
("event", "CitationEvent", "citation_event"),
("citing", "CitingPaper", "citing_paper"),
("intent", "Intent", "intent"),
("journal", "Journal", "journal"),
("author", "Author", "author"),
("affiliation", "Affiliation", "affiliation"),
("city", "City", "city"),
("country", "Country", "country"),
("field", "Field", "field"),
]
edge_defs = [
("event","citing","hasCitingPaper"), ("event","seed","hasCitedPaper"),
("event","intent","hasPrimaryIntent"), ("seed","journal","publishedInJournal"),
("seed","author","hasAuthor"), ("seed","affiliation","hasAffiliation"),
("seed","city","locatedInCity"), ("seed","country","locatedInCountry"),
("seed","field","belongsToField"),
]
G = nx.DiGraph()
for nid, _, _ in node_defs:
G.add_node(nid)
for s, t, _ in edge_defs:
G.add_edge(s, t)
pos = nx.spring_layout(G, seed=7, k=2.5, iterations=80)
ex, ey = [], []
ann = []
for s, t, lbl in edge_defs:
x0, y0 = pos[s]; x1, y1 = pos[t]
ex += [x0, x1, None]; ey += [y0, y1, None]
mx, my = (x0+x1)/2, (y0+y1)/2
ann.append(dict(
x=mx, y=my, text=f"{lbl}",
showarrow=False, font=dict(size=9, color="#64748b"),
bgcolor="rgba(255,255,255,0.75)",
))
traces: list[go.BaseTraceType] = [
go.Scatter(x=ex, y=ey, mode="lines",
line=dict(width=1.2, color="#94a3b8"),
hoverinfo="none", showlegend=False)
]
for nid, label, ntype in node_defs:
x, y = pos[nid]
color = NODE_TYPE_COLORS.get(ntype, "#94a3b8")
props = NODE_PROPS.get(ntype, "")
traces.append(go.Scatter(
x=[x], y=[y], mode="markers+text",
text=[f"{label}"], textposition="top center",
hoverinfo="text",
hovertext=(f"{label}
Type: {ntype}
"
+ props.replace("\n", "
")),
name=label, showlegend=False,
marker=dict(size=24, color=color,
line=dict(width=1.5, color="white")),
textfont=dict(size=11, color="#1e293b"),
))
if props:
prop_html = props.replace("\n", "
")
ann.append(dict(
x=x, y=y,
text=f"{prop_html}",
showarrow=False,
xanchor="center",
yanchor="top",
yshift=-22,
font=dict(size=8, color="#64748b"),
bgcolor="rgba(248,250,252,0.85)",
borderpad=2,
))
fig = go.Figure(data=traces)
fig.update_layout(
showlegend=False, hovermode="closest", height=height,
annotations=ann,
margin=dict(l=10, r=10, t=20, b=10),
paper_bgcolor="white", plot_bgcolor="#f8fafc",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
)
return fig
def inject_fullscreen(html: str) -> str:
extra = """