"""app.py — Gradio UI entry point (<200 lines, §11)."""
import os, json, tempfile, time
import pandas as pd, numpy as np
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from agent import run_pipeline
# ── CSV preview on upload ────────────────────────────────────────────────────
def _preview(file):
if not file: return "Upload a Scopus CSV to begin."
df = pd.read_csv(file.name)
df.columns = df.columns.str.lower()
has_t = "title" in df.columns
has_a = "abstract" in df.columns
n = len(df)
blanks_t = int(df["title"].isna().sum()) if has_t else n
blanks_a = int(df["abstract"].isna().sum()) if has_a else n
ok = "✅" if has_t and has_a and blanks_t < n and blanks_a < n else "❌"
return (f"## {ok} CSV loaded — {n} entries\n\n"
f"| Column | Present | Blank rows |\n|---|---|---|\n"
f"| title | {'✅' if has_t else '❌'} | {blanks_t} |\n"
f"| abstract | {'✅' if has_a else '❌'} | {blanks_a} |\n\n"
f"**Usable papers:** {n - max(blanks_t,blanks_a)} / {n}")
# ── Pipeline runner ──────────────────────────────────────────────────────────
def _run(file, gk, mk, gek, n_trials, progress=gr.Progress(track_tqdm=True)):
if not file: raise gr.Error("Upload a CSV first.")
gk = gk.strip() or os.getenv("GROQ_API_KEY","")
mk = mk.strip() or os.getenv("MISTRAL_API_KEY","")
gek = gek.strip() or os.getenv("GEMINI_API_KEY","")
if not all([gk,mk,gek]): raise gr.Error("All 3 API keys required.")
progress(0.05, desc="📥 Loading CSV…")
progress(0.1, desc="🔬 Embedding with SPECTER-2 (this takes a few minutes)…")
r = run_pipeline(file.name, gk, mk, gek, int(n_trials))
if r.get("error"): raise gr.Error(r["error"])
progress(0.95, desc="📊 Building outputs…")
td, interps = r["topic_data"], r.get("interpretations",{})
disc, met = td["discipline"], td["metrics"]
ar = r.get("agreement_rates",{})
# ── Summary metrics (styled like reference) ──
def _s(ok): return "✅ PASS" if ok else "❌ FAIL"
summary = (f"## Pipeline Complete — {disc['n_clusters']} clusters discovered\n\n"
f"| Criterion | Value | Status |\n|---|---|---|\n"
f"| Max cluster mass | {round(disc['max_mass_pct']*100,1)}% | {_s(disc['max_mass_ok'])} |\n"
f"| Min cluster size | {disc['min_size']} | {_s(disc['min_size_ok'])} |\n"
f"| Persistence (mean) | {round(met['persistence'],4)} | — |\n"
f"| DBCV | {round(met['dbcv'],4)} | — |\n"
f"| Stability ({3} seeds) | {round(met['stability'],4)} | — |\n\n"
f"**Trials:** {td['n_trials_run']} (best #{td['best_trial']}) · "
f"**Agreement:** Triple {ar.get('triple',0)}% · Two+ {ar.get('two_or_more',0)}%")
# ── UMAP scatter ──
u2d = np.array(td["umap_2d"])
sdf = pd.DataFrame({"UMAP-1":u2d[:,0],"UMAP-2":u2d[:,1],
"Cluster":[str(l) for l in td["labels"]],
"Doc":[d[:60] for d in td["documents"]]})
fig = px.scatter(sdf, x="UMAP-1", y="UMAP-2", color="Cluster",
hover_data=["Doc"], opacity=0.75,
title=f"2-D UMAP visualisation of SPECTER-2 embeddings")
fig.update_layout(template="plotly_dark", height=500,
paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
font=dict(size=11))
# ── Trial log ──
tl = pd.DataFrame(td["trial_log"])
tl_cols = [c for c in ["trial","discipline_pass","n_clusters","persistence",
"dbcv","max_mass_pct","min_size","n_noise"] if c in tl.columns]
tl_show = tl[tl_cols] if not tl.empty else pd.DataFrame()
# ── Pareto front ──
pfig = go.Figure()
if not tl.empty:
for passed, color, name in [(True,"#3dba7a","PASS"),(False,"#e04d4d","FAIL")]:
sub = tl[tl["discipline_pass"]==passed]
if not sub.empty:
pfig.add_trace(go.Scatter(x=sub["max_mass_pct"],y=sub["persistence"],
mode="markers",marker=dict(size=8,color=color),name=name,
text=sub["trial"],hovertemplate="Trial %{text}
Mass: %{x:.0%}
Pers: %{y:.3f}"))
pfig.add_vline(x=0.25, line_dash="dash", line_color="#5a6480",
annotation_text="25% rule")
pfig.update_layout(template="plotly_dark", height=400,
paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
title="Pareto front — Persistence vs Max cluster mass",
xaxis_title="Max cluster mass (lower is better)",
yaxis_title="Persistence (higher is better)", font=dict(size=11))
# ── Cluster table ──
rows = []
for cid in sorted(interps.keys()):
v = interps[cid]
rows.append({"Cluster":cid,"Label":v["label"],"Agreement":v["agreement"],
"Strong":v["strong"],"Weak":v["weak"],
"Persistence":round(v.get("persistence",0),4),
"Keyphrases":", ".join(v.get("keyphrases",[]))})
cdf = pd.DataFrame(rows)
# ── 4 separate sheets ──
sheets = r.get("sheets",{})
s1 = pd.DataFrame(sheets.get(1,[])); s2 = pd.DataFrame(sheets.get(2,[]))
s3 = pd.DataFrame(sheets.get(3,[])); s4 = pd.DataFrame(sheets.get(4,[]))
sp = r.get("sheet_paths",{})
mdf = pd.DataFrame(r.get("mismatch_table",[]))
progress(1.0, desc="✅ Done!")
dl_files = [f for f in
[sp.get(1), sp.get(2), sp.get(3), sp.get(4), r.get("json_path")]
if f is not None]
return (summary, fig, pfig, tl_show, cdf, s1, s2, s3, s4,
dl_files if dl_files else None, mdf)
# ── UI ───────────────────────────────────────────────────────────────────────
css = ".gradio-container{background:#0d1117!important;color:#c9d1d9!important}" \
"footer{display:none!important}"
with gr.Blocks(theme=gr.themes.Base(primary_hue="blue",neutral_hue="slate"),
css=css, title="SPECTER-2 Topic Analyzer") as demo:
gr.Markdown("# 📐 SPECTER-2 Topic Analyzer")
with gr.Row():
with gr.Column(scale=1):
file_in = gr.File(label="Upload Scopus CSV", file_types=[".csv"])
preview_out = gr.Markdown("Upload a CSV to see stats.")
groq_in = gr.Textbox(label="Groq API Key", type="password",
placeholder="or set GROQ_API_KEY env var")
mistral_in = gr.Textbox(label="Mistral API Key", type="password",
placeholder="or set MISTRAL_API_KEY env var")
gemini_in = gr.Textbox(label="Gemini API Key", type="password",
placeholder="or set GEMINI_API_KEY env var")
trials_in = gr.Slider(10,100,50,step=5,label="Optuna Trials")
run_btn = gr.Button("▶ Run Full Pipeline", variant="primary", size="lg")
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Summary"): summary_out = gr.Markdown()
with gr.Tab("2-D UMAP"): scatter_out = gr.Plot()
with gr.Tab("Pareto Front"): pareto_out = gr.Plot()
with gr.Tab("Trial Log"): trial_out = gr.Dataframe()
with gr.Tab("Clusters"): cluster_out = gr.Dataframe()
with gr.Tab("Sheet 1 — Groq"): s1_out = gr.Dataframe()
with gr.Tab("Sheet 2 — Mistral"): s2_out = gr.Dataframe()
with gr.Tab("Sheet 3 — Gemini"): s3_out = gr.Dataframe()
with gr.Tab("Sheet 4 — Consolidated"): s4_out = gr.Dataframe()
with gr.Tab("RQ Mismatch"): mismatch_out = gr.Dataframe()
with gr.Tab("Downloads"):
dl_out = gr.File(label="All sheet CSVs + topics.json",
file_count="multiple")
file_in.change(_preview, inputs=[file_in], outputs=[preview_out])
run_btn.click(_run,
inputs=[file_in, groq_in, mistral_in, gemini_in, trials_in],
outputs=[summary_out, scatter_out, pareto_out, trial_out, cluster_out,
s1_out, s2_out, s3_out, s4_out, dl_out, mismatch_out])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)