File size: 8,424 Bytes
ee50027 91b56e9 ee50027 05df72c 91b56e9 ee50027 91b56e9 ee50027 91b56e9 ee50027 91b56e9 ee50027 91b56e9 0a624b3 05df72c ee50027 91b56e9 ee50027 91b56e9 ee50027 91b56e9 e1de3b9 05df72c 91b56e9 ee50027 91b56e9 ee50027 91b56e9 e100b63 91b56e9 e100b63 91b56e9 ee50027 91b56e9 e100b63 ee50027 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """app.py β Gradio UI entry point (<200 lines, Β§11)."""
import os, json, tempfile, time
import pandas as pd, numpy as np
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from agent import run_pipeline
# ββ CSV preview on upload ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _preview(file):
if not file: return "Upload a Scopus CSV to begin."
df = pd.read_csv(file.name)
df.columns = df.columns.str.lower()
has_t = "title" in df.columns
has_a = "abstract" in df.columns
n = len(df)
blanks_t = int(df["title"].isna().sum()) if has_t else n
blanks_a = int(df["abstract"].isna().sum()) if has_a else n
ok = "β
" if has_t and has_a and blanks_t < n and blanks_a < n else "β"
return (f"## {ok} CSV loaded β {n} entries\n\n"
f"| Column | Present | Blank rows |\n|---|---|---|\n"
f"| title | {'β
' if has_t else 'β'} | {blanks_t} |\n"
f"| abstract | {'β
' if has_a else 'β'} | {blanks_a} |\n\n"
f"**Usable papers:** {n - max(blanks_t,blanks_a)} / {n}")
# ββ Pipeline runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _run(file, gk, mk, gek, n_trials, progress=gr.Progress(track_tqdm=True)):
if not file: raise gr.Error("Upload a CSV first.")
gk = gk.strip() or os.getenv("GROQ_API_KEY","")
mk = mk.strip() or os.getenv("MISTRAL_API_KEY","")
gek = gek.strip() or os.getenv("GEMINI_API_KEY","")
if not all([gk,mk,gek]): raise gr.Error("All 3 API keys required.")
progress(0.05, desc="π₯ Loading CSVβ¦")
progress(0.1, desc="π¬ Embedding with SPECTER-2 (this takes a few minutes)β¦")
r = run_pipeline(file.name, gk, mk, gek, int(n_trials))
if r.get("error"): raise gr.Error(r["error"])
progress(0.95, desc="π Building outputsβ¦")
td, interps = r["topic_data"], r.get("interpretations",{})
disc, met = td["discipline"], td["metrics"]
ar = r.get("agreement_rates",{})
# ββ Summary metrics (styled like reference) ββ
def _s(ok): return "β
PASS" if ok else "β FAIL"
summary = (f"## Pipeline Complete β {disc['n_clusters']} clusters discovered\n\n"
f"| Criterion | Value | Status |\n|---|---|---|\n"
f"| Max cluster mass | {round(disc['max_mass_pct']*100,1)}% | {_s(disc['max_mass_ok'])} |\n"
f"| Min cluster size | {disc['min_size']} | {_s(disc['min_size_ok'])} |\n"
f"| Persistence (mean) | {round(met['persistence'],4)} | β |\n"
f"| DBCV | {round(met['dbcv'],4)} | β |\n"
f"| Stability ({3} seeds) | {round(met['stability'],4)} | β |\n\n"
f"**Trials:** {td['n_trials_run']} (best #{td['best_trial']}) Β· "
f"**Agreement:** Triple {ar.get('triple',0)}% Β· Two+ {ar.get('two_or_more',0)}%")
# ββ UMAP scatter ββ
u2d = np.array(td["umap_2d"])
sdf = pd.DataFrame({"UMAP-1":u2d[:,0],"UMAP-2":u2d[:,1],
"Cluster":[str(l) for l in td["labels"]],
"Doc":[d[:60] for d in td["documents"]]})
fig = px.scatter(sdf, x="UMAP-1", y="UMAP-2", color="Cluster",
hover_data=["Doc"], opacity=0.75,
title=f"2-D UMAP visualisation of SPECTER-2 embeddings")
fig.update_layout(template="plotly_dark", height=500,
paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
font=dict(size=11))
# ββ Trial log ββ
tl = pd.DataFrame(td["trial_log"])
tl_cols = [c for c in ["trial","discipline_pass","n_clusters","persistence",
"dbcv","max_mass_pct","min_size","n_noise"] if c in tl.columns]
tl_show = tl[tl_cols] if not tl.empty else pd.DataFrame()
# ββ Pareto front ββ
pfig = go.Figure()
if not tl.empty:
for passed, color, name in [(True,"#3dba7a","PASS"),(False,"#e04d4d","FAIL")]:
sub = tl[tl["discipline_pass"]==passed]
if not sub.empty:
pfig.add_trace(go.Scatter(x=sub["max_mass_pct"],y=sub["persistence"],
mode="markers",marker=dict(size=8,color=color),name=name,
text=sub["trial"],hovertemplate="Trial %{text}<br>Mass: %{x:.0%}<br>Pers: %{y:.3f}"))
pfig.add_vline(x=0.25, line_dash="dash", line_color="#5a6480",
annotation_text="25% rule")
pfig.update_layout(template="plotly_dark", height=400,
paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
title="Pareto front β Persistence vs Max cluster mass",
xaxis_title="Max cluster mass (lower is better)",
yaxis_title="Persistence (higher is better)", font=dict(size=11))
# ββ Cluster table ββ
rows = []
for cid in sorted(interps.keys()):
v = interps[cid]
rows.append({"Cluster":cid,"Label":v["label"],"Agreement":v["agreement"],
"Strong":v["strong"],"Weak":v["weak"],
"Persistence":round(v.get("persistence",0),4),
"Keyphrases":", ".join(v.get("keyphrases",[]))})
cdf = pd.DataFrame(rows)
# ββ 4 separate sheets ββ
sheets = r.get("sheets",{})
s1 = pd.DataFrame(sheets.get(1,[])); s2 = pd.DataFrame(sheets.get(2,[]))
s3 = pd.DataFrame(sheets.get(3,[])); s4 = pd.DataFrame(sheets.get(4,[]))
sp = r.get("sheet_paths",{})
mdf = pd.DataFrame(r.get("mismatch_table",[]))
progress(1.0, desc="β
Done!")
dl_files = [f for f in
[sp.get(1), sp.get(2), sp.get(3), sp.get(4), r.get("json_path")]
if f is not None]
return (summary, fig, pfig, tl_show, cdf, s1, s2, s3, s4,
dl_files if dl_files else None, mdf)
# ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
css = ".gradio-container{background:#0d1117!important;color:#c9d1d9!important}" \
"footer{display:none!important}"
with gr.Blocks(theme=gr.themes.Base(primary_hue="blue",neutral_hue="slate"),
css=css, title="SPECTER-2 Topic Analyzer") as demo:
gr.Markdown("# π SPECTER-2 Topic Analyzer")
with gr.Row():
with gr.Column(scale=1):
file_in = gr.File(label="Upload Scopus CSV", file_types=[".csv"])
preview_out = gr.Markdown("Upload a CSV to see stats.")
groq_in = gr.Textbox(label="Groq API Key", type="password",
placeholder="or set GROQ_API_KEY env var")
mistral_in = gr.Textbox(label="Mistral API Key", type="password",
placeholder="or set MISTRAL_API_KEY env var")
gemini_in = gr.Textbox(label="Gemini API Key", type="password",
placeholder="or set GEMINI_API_KEY env var")
trials_in = gr.Slider(10,100,50,step=5,label="Optuna Trials")
run_btn = gr.Button("βΆ Run Full Pipeline", variant="primary", size="lg")
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("Summary"): summary_out = gr.Markdown()
with gr.Tab("2-D UMAP"): scatter_out = gr.Plot()
with gr.Tab("Pareto Front"): pareto_out = gr.Plot()
with gr.Tab("Trial Log"): trial_out = gr.Dataframe()
with gr.Tab("Clusters"): cluster_out = gr.Dataframe()
with gr.Tab("Sheet 1 β Groq"): s1_out = gr.Dataframe()
with gr.Tab("Sheet 2 β Mistral"): s2_out = gr.Dataframe()
with gr.Tab("Sheet 3 β Gemini"): s3_out = gr.Dataframe()
with gr.Tab("Sheet 4 β Consolidated"): s4_out = gr.Dataframe()
with gr.Tab("RQ Mismatch"): mismatch_out = gr.Dataframe()
with gr.Tab("Downloads"):
dl_out = gr.File(label="All sheet CSVs + topics.json",
file_count="multiple")
file_in.change(_preview, inputs=[file_in], outputs=[preview_out])
run_btn.click(_run,
inputs=[file_in, groq_in, mistral_in, gemini_in, trials_in],
outputs=[summary_out, scatter_out, pareto_out, trial_out, cluster_out,
s1_out, s2_out, s3_out, s4_out, dl_out, mismatch_out])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|