File size: 8,424 Bytes
ee50027
91b56e9
ee50027
 
05df72c
91b56e9
ee50027
 
91b56e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee50027
 
 
91b56e9
 
 
 
 
 
 
 
 
 
 
 
ee50027
 
 
 
 
91b56e9
 
 
 
 
 
ee50027
91b56e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a624b3
05df72c
ee50027
 
91b56e9
ee50027
 
 
91b56e9
ee50027
91b56e9
 
 
 
 
e1de3b9
 
 
 
 
05df72c
91b56e9
 
 
 
 
 
ee50027
 
 
91b56e9
 
 
 
 
 
 
 
 
ee50027
 
91b56e9
 
 
 
 
e100b63
 
 
 
91b56e9
e100b63
 
 
91b56e9
ee50027
 
91b56e9
e100b63
ee50027
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""app.py β€” Gradio UI entry point (<200 lines, Β§11)."""
import os, json, tempfile, time
import pandas as pd, numpy as np
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from agent import run_pipeline

# ── CSV preview on upload ────────────────────────────────────────────────────
def _preview(file):
    if not file: return "Upload a Scopus CSV to begin."
    df = pd.read_csv(file.name)
    df.columns = df.columns.str.lower()
    has_t = "title" in df.columns
    has_a = "abstract" in df.columns
    n = len(df)
    blanks_t = int(df["title"].isna().sum()) if has_t else n
    blanks_a = int(df["abstract"].isna().sum()) if has_a else n
    ok = "βœ…" if has_t and has_a and blanks_t < n and blanks_a < n else "❌"
    return (f"## {ok} CSV loaded β€” {n} entries\n\n"
        f"| Column | Present | Blank rows |\n|---|---|---|\n"
        f"| title | {'βœ…' if has_t else '❌'} | {blanks_t} |\n"
        f"| abstract | {'βœ…' if has_a else '❌'} | {blanks_a} |\n\n"
        f"**Usable papers:** {n - max(blanks_t,blanks_a)} / {n}")

# ── Pipeline runner ──────────────────────────────────────────────────────────
def _run(file, gk, mk, gek, n_trials, progress=gr.Progress(track_tqdm=True)):
    if not file: raise gr.Error("Upload a CSV first.")
    gk = gk.strip() or os.getenv("GROQ_API_KEY","")
    mk = mk.strip() or os.getenv("MISTRAL_API_KEY","")
    gek = gek.strip() or os.getenv("GEMINI_API_KEY","")
    if not all([gk,mk,gek]): raise gr.Error("All 3 API keys required.")
    progress(0.05, desc="πŸ“₯ Loading CSV…")
    progress(0.1, desc="πŸ”¬ Embedding with SPECTER-2 (this takes a few minutes)…")
    r = run_pipeline(file.name, gk, mk, gek, int(n_trials))
    if r.get("error"): raise gr.Error(r["error"])
    progress(0.95, desc="πŸ“Š Building outputs…")
    td, interps = r["topic_data"], r.get("interpretations",{})
    disc, met = td["discipline"], td["metrics"]
    ar = r.get("agreement_rates",{})
    # ── Summary metrics (styled like reference) ──
    def _s(ok): return "βœ… PASS" if ok else "❌ FAIL"
    summary = (f"## Pipeline Complete β€” {disc['n_clusters']} clusters discovered\n\n"
        f"| Criterion | Value | Status |\n|---|---|---|\n"
        f"| Max cluster mass | {round(disc['max_mass_pct']*100,1)}% | {_s(disc['max_mass_ok'])} |\n"
        f"| Min cluster size | {disc['min_size']} | {_s(disc['min_size_ok'])} |\n"
        f"| Persistence (mean) | {round(met['persistence'],4)} | β€” |\n"
        f"| DBCV | {round(met['dbcv'],4)} | β€” |\n"
        f"| Stability ({3} seeds) | {round(met['stability'],4)} | β€” |\n\n"
        f"**Trials:** {td['n_trials_run']} (best #{td['best_trial']}) Β· "
        f"**Agreement:** Triple {ar.get('triple',0)}% Β· Two+ {ar.get('two_or_more',0)}%")
    # ── UMAP scatter ──
    u2d = np.array(td["umap_2d"])
    sdf = pd.DataFrame({"UMAP-1":u2d[:,0],"UMAP-2":u2d[:,1],
        "Cluster":[str(l) for l in td["labels"]],
        "Doc":[d[:60] for d in td["documents"]]})
    fig = px.scatter(sdf, x="UMAP-1", y="UMAP-2", color="Cluster",
        hover_data=["Doc"], opacity=0.75,
        title=f"2-D UMAP visualisation of SPECTER-2 embeddings")
    fig.update_layout(template="plotly_dark", height=500,
        paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
        font=dict(size=11))
    # ── Trial log ──
    tl = pd.DataFrame(td["trial_log"])
    tl_cols = [c for c in ["trial","discipline_pass","n_clusters","persistence",
        "dbcv","max_mass_pct","min_size","n_noise"] if c in tl.columns]
    tl_show = tl[tl_cols] if not tl.empty else pd.DataFrame()
    # ── Pareto front ──
    pfig = go.Figure()
    if not tl.empty:
        for passed, color, name in [(True,"#3dba7a","PASS"),(False,"#e04d4d","FAIL")]:
            sub = tl[tl["discipline_pass"]==passed]
            if not sub.empty:
                pfig.add_trace(go.Scatter(x=sub["max_mass_pct"],y=sub["persistence"],
                    mode="markers",marker=dict(size=8,color=color),name=name,
                    text=sub["trial"],hovertemplate="Trial %{text}<br>Mass: %{x:.0%}<br>Pers: %{y:.3f}"))
        pfig.add_vline(x=0.25, line_dash="dash", line_color="#5a6480",
            annotation_text="25% rule")
    pfig.update_layout(template="plotly_dark", height=400,
        paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
        title="Pareto front β€” Persistence vs Max cluster mass",
        xaxis_title="Max cluster mass (lower is better)",
        yaxis_title="Persistence (higher is better)", font=dict(size=11))
    # ── Cluster table ──
    rows = []
    for cid in sorted(interps.keys()):
        v = interps[cid]
        rows.append({"Cluster":cid,"Label":v["label"],"Agreement":v["agreement"],
            "Strong":v["strong"],"Weak":v["weak"],
            "Persistence":round(v.get("persistence",0),4),
            "Keyphrases":", ".join(v.get("keyphrases",[]))})
    cdf = pd.DataFrame(rows)
    # ── 4 separate sheets ──
    sheets = r.get("sheets",{})
    s1 = pd.DataFrame(sheets.get(1,[])); s2 = pd.DataFrame(sheets.get(2,[]))
    s3 = pd.DataFrame(sheets.get(3,[])); s4 = pd.DataFrame(sheets.get(4,[]))
    sp = r.get("sheet_paths",{})
    mdf = pd.DataFrame(r.get("mismatch_table",[]))
    progress(1.0, desc="βœ… Done!")
    dl_files = [f for f in
        [sp.get(1), sp.get(2), sp.get(3), sp.get(4), r.get("json_path")]
        if f is not None]
    return (summary, fig, pfig, tl_show, cdf, s1, s2, s3, s4,
            dl_files if dl_files else None, mdf)

# ── UI ───────────────────────────────────────────────────────────────────────
css = ".gradio-container{background:#0d1117!important;color:#c9d1d9!important}" \
      "footer{display:none!important}"
with gr.Blocks(theme=gr.themes.Base(primary_hue="blue",neutral_hue="slate"),
               css=css, title="SPECTER-2 Topic Analyzer") as demo:
    gr.Markdown("# πŸ“ SPECTER-2 Topic Analyzer")
    with gr.Row():
        with gr.Column(scale=1):
            file_in = gr.File(label="Upload Scopus CSV", file_types=[".csv"])
            preview_out = gr.Markdown("Upload a CSV to see stats.")
            groq_in = gr.Textbox(label="Groq API Key", type="password",
                placeholder="or set GROQ_API_KEY env var")
            mistral_in = gr.Textbox(label="Mistral API Key", type="password",
                placeholder="or set MISTRAL_API_KEY env var")
            gemini_in = gr.Textbox(label="Gemini API Key", type="password",
                placeholder="or set GEMINI_API_KEY env var")
            trials_in = gr.Slider(10,100,50,step=5,label="Optuna Trials")
            run_btn = gr.Button("β–Ά Run Full Pipeline", variant="primary", size="lg")
        with gr.Column(scale=3):
            with gr.Tabs():
                with gr.Tab("Summary"): summary_out = gr.Markdown()
                with gr.Tab("2-D UMAP"): scatter_out = gr.Plot()
                with gr.Tab("Pareto Front"): pareto_out = gr.Plot()
                with gr.Tab("Trial Log"): trial_out = gr.Dataframe()
                with gr.Tab("Clusters"): cluster_out = gr.Dataframe()
                with gr.Tab("Sheet 1 β€” Groq"): s1_out = gr.Dataframe()
                with gr.Tab("Sheet 2 β€” Mistral"): s2_out = gr.Dataframe()
                with gr.Tab("Sheet 3 β€” Gemini"): s3_out = gr.Dataframe()
                with gr.Tab("Sheet 4 β€” Consolidated"): s4_out = gr.Dataframe()
                with gr.Tab("RQ Mismatch"): mismatch_out = gr.Dataframe()
                with gr.Tab("Downloads"):
                    dl_out = gr.File(label="All sheet CSVs + topics.json",
                                     file_count="multiple")
    file_in.change(_preview, inputs=[file_in], outputs=[preview_out])
    run_btn.click(_run,
        inputs=[file_in, groq_in, mistral_in, gemini_in, trials_in],
        outputs=[summary_out, scatter_out, pareto_out, trial_out, cluster_out,
                 s1_out, s2_out, s3_out, s4_out, dl_out, mismatch_out])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)