| """app.py β Gradio UI entry point (<200 lines, Β§11).""" |
| import os, json, tempfile, time |
| import pandas as pd, numpy as np |
| import gradio as gr |
| import plotly.express as px |
| import plotly.graph_objects as go |
| from agent import run_pipeline |
|
|
| |
| def _preview(file): |
| if not file: return "Upload a Scopus CSV to begin." |
| df = pd.read_csv(file.name) |
| df.columns = df.columns.str.lower() |
| has_t = "title" in df.columns |
| has_a = "abstract" in df.columns |
| n = len(df) |
| blanks_t = int(df["title"].isna().sum()) if has_t else n |
| blanks_a = int(df["abstract"].isna().sum()) if has_a else n |
| ok = "β
" if has_t and has_a and blanks_t < n and blanks_a < n else "β" |
| return (f"## {ok} CSV loaded β {n} entries\n\n" |
| f"| Column | Present | Blank rows |\n|---|---|---|\n" |
| f"| title | {'β
' if has_t else 'β'} | {blanks_t} |\n" |
| f"| abstract | {'β
' if has_a else 'β'} | {blanks_a} |\n\n" |
| f"**Usable papers:** {n - max(blanks_t,blanks_a)} / {n}") |
|
|
| |
| def _run(file, gk, mk, gek, n_trials, progress=gr.Progress(track_tqdm=True)): |
| if not file: raise gr.Error("Upload a CSV first.") |
| gk = gk.strip() or os.getenv("GROQ_API_KEY","") |
| mk = mk.strip() or os.getenv("MISTRAL_API_KEY","") |
| gek = gek.strip() or os.getenv("GEMINI_API_KEY","") |
| if not all([gk,mk,gek]): raise gr.Error("All 3 API keys required.") |
| progress(0.05, desc="π₯ Loading CSVβ¦") |
| progress(0.1, desc="π¬ Embedding with SPECTER-2 (this takes a few minutes)β¦") |
| r = run_pipeline(file.name, gk, mk, gek, int(n_trials)) |
| if r.get("error"): raise gr.Error(r["error"]) |
| progress(0.95, desc="π Building outputsβ¦") |
| td, interps = r["topic_data"], r.get("interpretations",{}) |
| disc, met = td["discipline"], td["metrics"] |
| ar = r.get("agreement_rates",{}) |
| |
| def _s(ok): return "β
PASS" if ok else "β FAIL" |
| summary = (f"## Pipeline Complete β {disc['n_clusters']} clusters discovered\n\n" |
| f"| Criterion | Value | Status |\n|---|---|---|\n" |
| f"| Max cluster mass | {round(disc['max_mass_pct']*100,1)}% | {_s(disc['max_mass_ok'])} |\n" |
| f"| Min cluster size | {disc['min_size']} | {_s(disc['min_size_ok'])} |\n" |
| f"| Persistence (mean) | {round(met['persistence'],4)} | β |\n" |
| f"| DBCV | {round(met['dbcv'],4)} | β |\n" |
| f"| Stability ({3} seeds) | {round(met['stability'],4)} | β |\n\n" |
| f"**Trials:** {td['n_trials_run']} (best #{td['best_trial']}) Β· " |
| f"**Agreement:** Triple {ar.get('triple',0)}% Β· Two+ {ar.get('two_or_more',0)}%") |
| |
| u2d = np.array(td["umap_2d"]) |
| sdf = pd.DataFrame({"UMAP-1":u2d[:,0],"UMAP-2":u2d[:,1], |
| "Cluster":[str(l) for l in td["labels"]], |
| "Doc":[d[:60] for d in td["documents"]]}) |
| fig = px.scatter(sdf, x="UMAP-1", y="UMAP-2", color="Cluster", |
| hover_data=["Doc"], opacity=0.75, |
| title=f"2-D UMAP visualisation of SPECTER-2 embeddings") |
| fig.update_layout(template="plotly_dark", height=500, |
| paper_bgcolor="#0d1117", plot_bgcolor="#161b22", |
| font=dict(size=11)) |
| |
| tl = pd.DataFrame(td["trial_log"]) |
| tl_cols = [c for c in ["trial","discipline_pass","n_clusters","persistence", |
| "dbcv","max_mass_pct","min_size","n_noise"] if c in tl.columns] |
| tl_show = tl[tl_cols] if not tl.empty else pd.DataFrame() |
| |
| pfig = go.Figure() |
| if not tl.empty: |
| for passed, color, name in [(True,"#3dba7a","PASS"),(False,"#e04d4d","FAIL")]: |
| sub = tl[tl["discipline_pass"]==passed] |
| if not sub.empty: |
| pfig.add_trace(go.Scatter(x=sub["max_mass_pct"],y=sub["persistence"], |
| mode="markers",marker=dict(size=8,color=color),name=name, |
| text=sub["trial"],hovertemplate="Trial %{text}<br>Mass: %{x:.0%}<br>Pers: %{y:.3f}")) |
| pfig.add_vline(x=0.25, line_dash="dash", line_color="#5a6480", |
| annotation_text="25% rule") |
| pfig.update_layout(template="plotly_dark", height=400, |
| paper_bgcolor="#0d1117", plot_bgcolor="#161b22", |
| title="Pareto front β Persistence vs Max cluster mass", |
| xaxis_title="Max cluster mass (lower is better)", |
| yaxis_title="Persistence (higher is better)", font=dict(size=11)) |
| |
| rows = [] |
| for cid in sorted(interps.keys()): |
| v = interps[cid] |
| rows.append({"Cluster":cid,"Label":v["label"],"Agreement":v["agreement"], |
| "Strong":v["strong"],"Weak":v["weak"], |
| "Persistence":round(v.get("persistence",0),4), |
| "Keyphrases":", ".join(v.get("keyphrases",[]))}) |
| cdf = pd.DataFrame(rows) |
| |
| sheets = r.get("sheets",{}) |
| s1 = pd.DataFrame(sheets.get(1,[])); s2 = pd.DataFrame(sheets.get(2,[])) |
| s3 = pd.DataFrame(sheets.get(3,[])); s4 = pd.DataFrame(sheets.get(4,[])) |
| sp = r.get("sheet_paths",{}) |
| mdf = pd.DataFrame(r.get("mismatch_table",[])) |
| progress(1.0, desc="β
Done!") |
| dl_files = [f for f in |
| [sp.get(1), sp.get(2), sp.get(3), sp.get(4), r.get("json_path")] |
| if f is not None] |
| return (summary, fig, pfig, tl_show, cdf, s1, s2, s3, s4, |
| dl_files if dl_files else None, mdf) |
|
|
| |
| css = ".gradio-container{background:#0d1117!important;color:#c9d1d9!important}" \ |
| "footer{display:none!important}" |
| with gr.Blocks(theme=gr.themes.Base(primary_hue="blue",neutral_hue="slate"), |
| css=css, title="SPECTER-2 Topic Analyzer") as demo: |
| gr.Markdown("# π SPECTER-2 Topic Analyzer") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| file_in = gr.File(label="Upload Scopus CSV", file_types=[".csv"]) |
| preview_out = gr.Markdown("Upload a CSV to see stats.") |
| groq_in = gr.Textbox(label="Groq API Key", type="password", |
| placeholder="or set GROQ_API_KEY env var") |
| mistral_in = gr.Textbox(label="Mistral API Key", type="password", |
| placeholder="or set MISTRAL_API_KEY env var") |
| gemini_in = gr.Textbox(label="Gemini API Key", type="password", |
| placeholder="or set GEMINI_API_KEY env var") |
| trials_in = gr.Slider(10,100,50,step=5,label="Optuna Trials") |
| run_btn = gr.Button("βΆ Run Full Pipeline", variant="primary", size="lg") |
| with gr.Column(scale=3): |
| with gr.Tabs(): |
| with gr.Tab("Summary"): summary_out = gr.Markdown() |
| with gr.Tab("2-D UMAP"): scatter_out = gr.Plot() |
| with gr.Tab("Pareto Front"): pareto_out = gr.Plot() |
| with gr.Tab("Trial Log"): trial_out = gr.Dataframe() |
| with gr.Tab("Clusters"): cluster_out = gr.Dataframe() |
| with gr.Tab("Sheet 1 β Groq"): s1_out = gr.Dataframe() |
| with gr.Tab("Sheet 2 β Mistral"): s2_out = gr.Dataframe() |
| with gr.Tab("Sheet 3 β Gemini"): s3_out = gr.Dataframe() |
| with gr.Tab("Sheet 4 β Consolidated"): s4_out = gr.Dataframe() |
| with gr.Tab("RQ Mismatch"): mismatch_out = gr.Dataframe() |
| with gr.Tab("Downloads"): |
| dl_out = gr.File(label="All sheet CSVs + topics.json", |
| file_count="multiple") |
| file_in.change(_preview, inputs=[file_in], outputs=[preview_out]) |
| run_btn.click(_run, |
| inputs=[file_in, groq_in, mistral_in, gemini_in, trials_in], |
| outputs=[summary_out, scatter_out, pareto_out, trial_out, cluster_out, |
| s1_out, s2_out, s3_out, s4_out, dl_out, mismatch_out]) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|