from __future__ import annotations
import os
from pathlib import Path
from typing import List
import pandas as pd
import networkx as nx
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
from pyvis.network import Network
import streamlit.components.v1 as components
HF_REPO_ID = os.environ.get("HF_REPO_ID", "")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
st.set_page_config(page_title="CitationHub", page_icon="π", layout="wide")
ALLOWED_INTENTS = [
"background","uses","similarities","motivation",
"differences","future_work","extends",
]
INTENT_COLORS = {
"background":"#94a3b8","uses":"#22c55e","similarities":"#3b82f6",
"motivation":"#f59e0b","differences":"#ef4444",
"future_work":"#8b5cf6","extends":"#06b6d4",
}
NODE_COLORS = {
"seed_paper":"#111827","citing_paper":"#dbeafe","citation_event":"#fde68a",
"journal":"#ede9fe","author":"#fee2e2","affiliation":"#fae8ff",
"city":"#cffafe","country":"#ffedd5","field":"#e0e7ff","intent":"#dcfce7",
}
NODE_TYPE_COLORS = {
"seed_paper":"#111827","citing_paper":"#3b82f6","citation_event":"#f59e0b",
"journal":"#8b5cf6","author":"#ef4444","affiliation":"#ec4899",
"city":"#06b6d4","country":"#f97316","field":"#6366f1","intent":"#22c55e",
}
DEFAULT_DATA_DIR = Path(os.environ.get(
"CITATIONHUB_DATA_DIR",
r"C:\Users\user\OneDrive\λ°ν νλ©΄\Citehub_huggingface\data",
))
def fmt_num(x):
try: return f"{int(x):,}"
except: return "-"
def _hf_download(filename: str) -> str:
from huggingface_hub import hf_hub_download
return hf_hub_download(
repo_id=HF_REPO_ID, repo_type="dataset",
filename=f"data/{filename}", token=HF_TOKEN or None,
)
def _read(filename: str, data_dir: Path | None = None) -> pd.DataFrame:
if HF_REPO_ID:
return pd.read_parquet(_hf_download(filename))
return pd.read_parquet(data_dir / filename)
def plotly_network_fig(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
title: str = "",
height: int = 750,
seed_node_ids: list | None = None,
) -> go.Figure:
"""SVG κΈ°λ° Plotly λ€νΈμν¬ κ·Έλν β νλν΄λ μ λͺ
."""
G = nx.Graph()
node_meta: dict = {}
for _, row in nodes_df.iterrows():
nid = str(row["node_id"])
G.add_node(nid)
node_meta[nid] = row
for _, row in edges_df.iterrows():
s, t = str(row["source"]), str(row["target"])
if s in node_meta and t in node_meta:
G.add_edge(s, t, edge_type=row.get("edge_type", ""))
if len(G.nodes) == 0:
return go.Figure()
k = max(1.5, 3.0 / (len(G.nodes) ** 0.4))
pos = nx.spring_layout(G, seed=42, k=k, iterations=60)
# ββ edges βββββββββββββββββββββββββββββββββ
ex, ey = [], []
for src, tgt in G.edges():
x0, y0 = pos.get(src, (0, 0))
x1, y1 = pos.get(tgt, (0, 0))
ex += [x0, x1, None]
ey += [y0, y1, None]
traces: list[go.BaseTraceType] = [
go.Scatter(
x=ex, y=ey, mode="lines",
line=dict(width=0.8, color="#cbd5e1"),
hoverinfo="none", showlegend=False,
)
]
# ββ nodes grouped by type βββββββββββββββββ
for ntype, color in NODE_TYPE_COLORS.items():
subset = nodes_df[nodes_df["node_type"] == ntype]
if subset.empty:
continue
xs, ys, hovers, texts = [], [], [], []
for _, row in subset.iterrows():
nid = str(row["node_id"])
if nid not in pos:
continue
x, y = pos[nid]
xs.append(x); ys.append(y)
label = str(row.get("label", ""))[:50]
texts.append(label if ntype == "seed_paper" else "")
hovers.append(
f"{label}
"
f"Type: {ntype}
"
f"DOI: {row.get('doi','') or '-'}
"
f"Pub: {row.get('publication_name','') or '-'}
"
f"Group: {row.get('group','') or '-'}"
)
is_seed = ntype == "seed_paper"
traces.append(go.Scatter(
x=xs, y=ys,
mode="markers+text" if is_seed else "markers",
text=texts, textposition="top center",
hovertext=hovers, hoverinfo="text",
name=ntype,
marker=dict(
size=20 if is_seed else 10,
color=color,
line=dict(width=1.5 if is_seed else 0.5, color="white"),
symbol="circle",
),
))
fig = go.Figure(data=traces)
fig.update_layout(
title=dict(text=title, font=dict(size=14)),
showlegend=True,
legend=dict(title="Node type", itemsizing="constant"),
hovermode="closest",
height=height,
margin=dict(l=0, r=0, t=40 if title else 10, b=0),
paper_bgcolor="white",
plot_bgcolor="#f8fafc",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
)
return fig
def plotly_ontology_fig(height: int = 750) -> go.Figure:
"""CitationHub μ¨ν¨λ‘μ§ κ΅¬μ‘° β Plotly SVG."""
node_defs = [
("seed", "Top5PctCitedPaper", "seed_paper"),
("event", "CitationEvent", "citation_event"),
("citing", "CitingPaper", "citing_paper"),
("intent", "Intent", "intent"),
("journal", "Journal", "journal"),
("author", "Author", "author"),
("affiliation", "Affiliation", "affiliation"),
("city", "City", "city"),
("country", "Country", "country"),
("field", "Field", "field"),
]
edge_defs = [
("event","citing","hasCitingPaper"), ("event","seed","hasCitedPaper"),
("event","intent","hasPrimaryIntent"), ("seed","journal","publishedInJournal"),
("seed","author","hasAuthor"), ("seed","affiliation","hasAffiliation"),
("seed","city","locatedInCity"), ("seed","country","locatedInCountry"),
("seed","field","belongsToField"),
]
G = nx.DiGraph()
for nid, _, _ in node_defs:
G.add_node(nid)
for s, t, _ in edge_defs:
G.add_edge(s, t)
pos = nx.spring_layout(G, seed=7, k=2.5, iterations=80)
# edges + edge labels
ex, ey = [], []
ann = []
for s, t, lbl in edge_defs:
x0, y0 = pos[s]; x1, y1 = pos[t]
ex += [x0, x1, None]; ey += [y0, y1, None]
mx, my = (x0+x1)/2, (y0+y1)/2
ann.append(dict(x=mx, y=my, text=f"{lbl}",
showarrow=False, font=dict(size=9, color="#64748b"),
bgcolor="rgba(255,255,255,0.7)"))
traces: list[go.BaseTraceType] = [
go.Scatter(x=ex, y=ey, mode="lines",
line=dict(width=1.2, color="#94a3b8"),
hoverinfo="none", showlegend=False)
]
for nid, label, ntype in node_defs:
x, y = pos[nid]
color = NODE_TYPE_COLORS.get(ntype, "#94a3b8")
traces.append(go.Scatter(
x=[x], y=[y], mode="markers+text",
text=[label], textposition="top center",
hoverinfo="text", hovertext=f"{label}
Type: {ntype}",
name=label, showlegend=False,
marker=dict(size=22, color=color,
line=dict(width=1.5, color="white")),
textfont=dict(size=11),
))
fig = go.Figure(data=traces)
fig.update_layout(
showlegend=False, hovermode="closest", height=height,
annotations=ann,
margin=dict(l=0, r=0, t=10, b=0),
paper_bgcolor="white", plot_bgcolor="#f8fafc",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
)
return fig
def inject_fullscreen(html: str) -> str:
extra = """