BHAVIKBANKER commited on
Commit
853e1a5
Β·
verified Β·
1 Parent(s): f69277f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -63
app.py CHANGED
@@ -1,12 +1,16 @@
1
- """app.py β€” Gradio UI entry point (<200 lines, Β§11)."""
2
- import os, json, tempfile, time
 
 
 
 
3
  import pandas as pd, numpy as np
4
  import gradio as gr
5
  import plotly.express as px
6
  import plotly.graph_objects as go
7
- from agent import run_pipeline
8
 
9
- # ── CSV preview on upload ────────────────────────────────────────────────────
10
  def _preview(file):
11
  if not file: return "Upload a Scopus CSV to begin."
12
  df = pd.read_csv(file.name)
@@ -19,53 +23,284 @@ def _preview(file):
19
  ok = "βœ…" if has_t and has_a and blanks_t < n and blanks_a < n else "❌"
20
  return (f"## {ok} CSV loaded β€” {n} entries\n\n"
21
  f"| Column | Present | Blank rows |\n|---|---|---|\n"
22
- f"| title | {'βœ…' if has_t else '❌'} | {blanks_t} |\n"
23
  f"| abstract | {'βœ…' if has_a else '❌'} | {blanks_a} |\n\n"
24
- f"**Usable papers:** {n - max(blanks_t,blanks_a)} / {n}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # ── Pipeline runner ──────────────────────────────────────────────────────────
27
- def _run(file, gk, mk, gek, n_trials, progress=gr.Progress(track_tqdm=True)):
 
28
  if not file: raise gr.Error("Upload a CSV first.")
29
- gk = gk.strip() or os.getenv("GROQ_API_KEY","")
30
- mk = mk.strip() or os.getenv("MISTRAL_API_KEY","")
31
  gek = gek.strip() or os.getenv("GEMINI_API_KEY","")
32
  if not all([gk,mk,gek]): raise gr.Error("All 3 API keys required.")
 
33
  progress(0.05, desc="πŸ“₯ Loading CSV…")
34
- progress(0.1, desc="πŸ”¬ Embedding with SPECTER-2 (this takes a few minutes)…")
35
- r = run_pipeline(file.name, gk, mk, gek, int(n_trials))
36
  if r.get("error"): raise gr.Error(r["error"])
37
- progress(0.95, desc="πŸ“Š Building outputs…")
38
- td, interps = r["topic_data"], r.get("interpretations",{})
39
- disc, met = td["discipline"], td["metrics"]
40
- ar = r.get("agreement_rates",{})
41
- # ── Summary metrics (styled like reference) ──
 
 
42
  def _s(ok): return "βœ… PASS" if ok else "❌ FAIL"
43
- summary = (f"## Pipeline Complete β€” {disc['n_clusters']} clusters discovered\n\n"
 
44
  f"| Criterion | Value | Status |\n|---|---|---|\n"
45
  f"| Max cluster mass | {round(disc['max_mass_pct']*100,1)}% | {_s(disc['max_mass_ok'])} |\n"
46
  f"| Min cluster size | {disc['min_size']} | {_s(disc['min_size_ok'])} |\n"
47
  f"| Persistence (mean) | {round(met['persistence'],4)} | β€” |\n"
48
  f"| DBCV | {round(met['dbcv'],4)} | β€” |\n"
49
- f"| Stability ({3} seeds) | {round(met['stability'],4)} | β€” |\n\n"
50
  f"**Trials:** {td['n_trials_run']} (best #{td['best_trial']}) Β· "
51
- f"**Agreement:** Triple {ar.get('triple',0)}% Β· Two+ {ar.get('two_or_more',0)}%")
52
- # ── UMAP scatter ──
 
 
53
  u2d = np.array(td["umap_2d"])
54
  sdf = pd.DataFrame({"UMAP-1":u2d[:,0],"UMAP-2":u2d[:,1],
55
  "Cluster":[str(l) for l in td["labels"]],
56
  "Doc":[d[:60] for d in td["documents"]]})
57
  fig = px.scatter(sdf, x="UMAP-1", y="UMAP-2", color="Cluster",
58
  hover_data=["Doc"], opacity=0.75,
59
- title=f"2-D UMAP visualisation of SPECTER-2 embeddings")
60
  fig.update_layout(template="plotly_dark", height=500,
61
- paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
62
- font=dict(size=11))
63
- # ── Trial log ──
64
  tl = pd.DataFrame(td["trial_log"])
65
  tl_cols = [c for c in ["trial","discipline_pass","n_clusters","persistence",
66
  "dbcv","max_mass_pct","min_size","n_noise"] if c in tl.columns]
67
  tl_show = tl[tl_cols] if not tl.empty else pd.DataFrame()
68
- # ── Pareto front ──
69
  pfig = go.Figure()
70
  if not tl.empty:
71
  for passed, color, name in [(True,"#3dba7a","PASS"),(False,"#e04d4d","FAIL")]:
@@ -74,73 +309,165 @@ def _run(file, gk, mk, gek, n_trials, progress=gr.Progress(track_tqdm=True)):
74
  pfig.add_trace(go.Scatter(x=sub["max_mass_pct"],y=sub["persistence"],
75
  mode="markers",marker=dict(size=8,color=color),name=name,
76
  text=sub["trial"],hovertemplate="Trial %{text}<br>Mass: %{x:.0%}<br>Pers: %{y:.3f}"))
77
- pfig.add_vline(x=0.25, line_dash="dash", line_color="#5a6480",
78
- annotation_text="25% rule")
79
- pfig.update_layout(template="plotly_dark", height=400,
80
- paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
81
  title="Pareto front β€” Persistence vs Max cluster mass",
82
- xaxis_title="Max cluster mass (lower is better)",
83
- yaxis_title="Persistence (higher is better)", font=dict(size=11))
84
- # ── Cluster table ──
85
- rows = []
86
  for cid in sorted(interps.keys()):
87
  v = interps[cid]
88
- rows.append({"Cluster":cid,"Label":v["label"],"Agreement":v["agreement"],
89
  "Strong":v["strong"],"Weak":v["weak"],
90
  "Persistence":round(v.get("persistence",0),4),
91
  "Keyphrases":", ".join(v.get("keyphrases",[]))})
92
- cdf = pd.DataFrame(rows)
93
- # ── 4 separate sheets ──
94
  sheets = r.get("sheets",{})
95
  s1 = pd.DataFrame(sheets.get(1,[])); s2 = pd.DataFrame(sheets.get(2,[]))
96
  s3 = pd.DataFrame(sheets.get(3,[])); s4 = pd.DataFrame(sheets.get(4,[]))
97
  sp = r.get("sheet_paths",{})
98
  mdf = pd.DataFrame(r.get("mismatch_table",[]))
 
 
 
 
 
 
 
 
 
 
 
 
99
  progress(1.0, desc="βœ… Done!")
100
- dl_files = [f for f in
101
- [sp.get(1), sp.get(2), sp.get(3), sp.get(4), r.get("json_path")]
102
- if f is not None]
103
- return (summary, fig, pfig, tl_show, cdf, s1, s2, s3, s4,
104
- dl_files if dl_files else None, mdf)
 
 
 
 
 
 
105
 
106
  # ── UI ─────────────────────��─────────────────────────────────────────────────
107
  css = ".gradio-container{background:#0d1117!important;color:#c9d1d9!important}" \
108
  "footer{display:none!important}"
109
- with gr.Blocks(theme=gr.themes.Base(primary_hue="blue",neutral_hue="slate"),
 
110
  css=css, title="SPECTER-2 Topic Analyzer") as demo:
111
  gr.Markdown("# πŸ“ SPECTER-2 Topic Analyzer")
 
112
  with gr.Row():
113
  with gr.Column(scale=1):
114
- file_in = gr.File(label="Upload Scopus CSV", file_types=[".csv"])
115
  preview_out = gr.Markdown("Upload a CSV to see stats.")
116
- groq_in = gr.Textbox(label="Groq API Key", type="password",
117
- placeholder="or set GROQ_API_KEY env var")
118
  mistral_in = gr.Textbox(label="Mistral API Key", type="password",
119
- placeholder="or set MISTRAL_API_KEY env var")
120
- gemini_in = gr.Textbox(label="Gemini API Key", type="password",
121
- placeholder="or set GEMINI_API_KEY env var")
122
- trials_in = gr.Slider(10,100,50,step=5,label="Optuna Trials")
123
- run_btn = gr.Button("β–Ά Run Full Pipeline", variant="primary", size="lg")
 
 
 
 
 
124
  with gr.Column(scale=3):
125
  with gr.Tabs():
126
- with gr.Tab("Summary"): summary_out = gr.Markdown()
127
- with gr.Tab("2-D UMAP"): scatter_out = gr.Plot()
128
- with gr.Tab("Pareto Front"): pareto_out = gr.Plot()
129
- with gr.Tab("Trial Log"): trial_out = gr.Dataframe()
130
- with gr.Tab("Clusters"): cluster_out = gr.Dataframe()
131
- with gr.Tab("Sheet 1 β€” Groq"): s1_out = gr.Dataframe()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  with gr.Tab("Sheet 2 β€” Mistral"): s2_out = gr.Dataframe()
133
- with gr.Tab("Sheet 3 β€” Gemini"): s3_out = gr.Dataframe()
134
  with gr.Tab("Sheet 4 β€” Consolidated"): s4_out = gr.Dataframe()
135
- with gr.Tab("RQ Mismatch"): mismatch_out = gr.Dataframe()
136
  with gr.Tab("Downloads"):
137
  dl_out = gr.File(label="All sheet CSVs + topics.json",
138
  file_count="multiple")
 
139
  file_in.change(_preview, inputs=[file_in], outputs=[preview_out])
140
- run_btn.click(_run,
141
- inputs=[file_in, groq_in, mistral_in, gemini_in, trials_in],
142
- outputs=[summary_out, scatter_out, pareto_out, trial_out, cluster_out,
143
- s1_out, s2_out, s3_out, s4_out, dl_out, mismatch_out])
 
 
 
 
 
 
 
 
 
 
144
 
145
  if __name__ == "__main__":
146
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ """app.py β€” Gradio UI entry point.
2
+ Tabs: Summary, UMAP, Pareto, Trial Log, Clusters, Top 3 Papers,
3
+ Methodology (3-LLM council + regex pipeline), Refinement Log,
4
+ Sheet 1-4, RQ Mismatch, Downloads.
5
+ """
6
+ import os, json
7
  import pandas as pd, numpy as np
8
  import gradio as gr
9
  import plotly.express as px
10
  import plotly.graph_objects as go
11
+ from agent import run_pipeline, METHODOLOGY_PATTERNS, TECHNIQUE_PATTERNS
12
 
13
+ # ── CSV preview ──────────────────────────────────────────────────────────────
14
  def _preview(file):
15
  if not file: return "Upload a Scopus CSV to begin."
16
  df = pd.read_csv(file.name)
 
23
  ok = "βœ…" if has_t and has_a and blanks_t < n and blanks_a < n else "❌"
24
  return (f"## {ok} CSV loaded β€” {n} entries\n\n"
25
  f"| Column | Present | Blank rows |\n|---|---|---|\n"
26
+ f"| title | {'βœ…' if has_t else '❌'} | {blanks_t} |\n"
27
  f"| abstract | {'βœ…' if has_a else '❌'} | {blanks_a} |\n\n"
28
+ f"**Usable papers:** {n - max(blanks_t, blanks_a)} / {n}")
29
+
30
+
31
+ # ── Helper builders ──────────────────────────────────────────────────────────
32
+ def _top_papers_df(top_papers: dict) -> pd.DataFrame:
33
+ rows = []
34
+ for cid in sorted(top_papers.keys()):
35
+ for p in top_papers[cid]:
36
+ rows.append({"Cluster": cid, "Label": p["cluster_label"],
37
+ "Rank": p["rank"], "Title": p["title"],
38
+ "Abstract Snippet": p["abstract_snippet"]})
39
+ return pd.DataFrame(rows)
40
+
41
+
42
+ def _methodology_summary_df(methodology_data: dict, interps: dict) -> pd.DataFrame:
43
+ rows = []
44
+ for cid in sorted(methodology_data.keys()):
45
+ md = methodology_data[cid]
46
+ label = interps.get(cid, {}).get("label", f"Cluster {cid}")
47
+ rows.append({
48
+ "Cluster": cid,
49
+ "Label": label,
50
+ "Dominant Method": md.get("dominant_method", "β€”"),
51
+ "Dominant Technique": md.get("dominant_technique", "β€”"),
52
+ "Empirical %": md.get("empirical_pct", 0),
53
+ "Theoretical %": md.get("theoretical_pct", 0),
54
+ "Mixed %": md.get("mixed_pct", 0),
55
+ "Methods (β‰₯2 LLMs)": ", ".join(
56
+ f"{m['name']} ({m['pct']}%, {m['agreement']})"
57
+ for m in md.get("methodologies", [])),
58
+ "Techniques (β‰₯2 LLMs)": ", ".join(
59
+ f"{t['name']} ({t['pct']}%, {t['agreement']})"
60
+ for t in md.get("techniques", [])),
61
+ "Regex Confirmed": ", ".join(md.get("regex_confirmed_consensus", [])) or "β€”",
62
+ "Regex Rejected": ", ".join(md.get("regex_rejected_consensus", [])) or "β€”",
63
+ })
64
+ return pd.DataFrame(rows)
65
+
66
+
67
+ def _extraction_pipeline_df(methodology_data: dict, interps: dict) -> pd.DataFrame:
68
+ """
69
+ One row per (cluster, method/technique) showing the full extraction trace:
70
+ which regex pattern fired, what text it matched, which LLMs confirmed it,
71
+ and whether it passed the β‰₯2-LLM gate.
72
+ """
73
+ rows = []
74
+ for cid in sorted(methodology_data.keys()):
75
+ md = methodology_data[cid]
76
+ label = interps.get(cid, {}).get("label", f"Cluster {cid}")
77
+ scan = md.get("regex_scan", {})
78
+
79
+ # Accepted items
80
+ for item in md.get("methodologies", []) + md.get("techniques", []):
81
+ name = item["name"]
82
+ # Find regex hits for this category name
83
+ regex_hits = scan.get("methods", {}).get(name, []) or \
84
+ scan.get("techniques", {}).get(name, [])
85
+ matched_text = ", ".join(
86
+ dict.fromkeys(h["match"] for h in regex_hits))[:80] if regex_hits else "β€”"
87
+ rows.append({
88
+ "Cluster": cid,
89
+ "Label": label,
90
+ "Item": name,
91
+ "Type": "Method" if item in md.get("methodologies",[]) else "Technique",
92
+ "Regex Match": matched_text,
93
+ "Regex Fired": "βœ…" if regex_hits else "❌",
94
+ "LLM Votes": item["llm_votes"],
95
+ "Agreement": item["agreement"],
96
+ "Avg Pct (%)": item["pct"],
97
+ "Evidence": item.get("evidence", "β€”"),
98
+ "Gate Passed": "οΏ½οΏ½ ACCEPTED",
99
+ })
100
+
101
+ # Rejected items (single LLM only)
102
+ for item in md.get("rejected_methods", []) + md.get("rejected_techniques", []):
103
+ name = item["name"]
104
+ regex_hits = scan.get("methods", {}).get(name, []) or \
105
+ scan.get("techniques", {}).get(name, [])
106
+ matched_text = ", ".join(
107
+ dict.fromkeys(h["match"] for h in regex_hits))[:80] if regex_hits else "β€”"
108
+ rows.append({
109
+ "Cluster": cid,
110
+ "Label": label,
111
+ "Item": name,
112
+ "Type": "Method" if item in md.get("rejected_methods",[]) else "Technique",
113
+ "Regex Match": matched_text,
114
+ "Regex Fired": "βœ…" if regex_hits else "❌",
115
+ "LLM Votes": item["llm_votes"],
116
+ "Agreement": item["agreement"],
117
+ "Avg Pct (%)": item["pct"],
118
+ "Evidence": item.get("evidence", "β€”"),
119
+ "Gate Passed": "❌ REJECTED (single LLM)",
120
+ })
121
+
122
+ return pd.DataFrame(rows) if rows else pd.DataFrame()
123
+
124
+
125
+ def _per_llm_methodology_df(methodology_data: dict, interps: dict) -> pd.DataFrame:
126
+ """Per-LLM raw methodology responses side-by-side."""
127
+ rows = []
128
+ for cid in sorted(methodology_data.keys()):
129
+ md = methodology_data[cid]
130
+ label = interps.get(cid, {}).get("label", f"Cluster {cid}")
131
+ raw = md.get("llm_raw", {})
132
+
133
+ def _fmt(r, key):
134
+ return " | ".join(
135
+ f"{i['name']} ({i.get('pct',0)}%)"
136
+ for i in r.get(key, [])
137
+ ) or "β€”"
138
+
139
+ rows.append({
140
+ "Cluster": cid,
141
+ "Label": label,
142
+ "Groq Methods": _fmt(raw.get("groq",{}), "methodologies"),
143
+ "Mistral Methods": _fmt(raw.get("mistral",{}), "methodologies"),
144
+ "Gemini Methods": _fmt(raw.get("gemini",{}), "methodologies"),
145
+ "Groq Techniques": _fmt(raw.get("groq",{}), "techniques"),
146
+ "Mistral Techniques": _fmt(raw.get("mistral",{}), "techniques"),
147
+ "Gemini Techniques": _fmt(raw.get("gemini",{}), "techniques"),
148
+ "Groq Emp/Theo/Mix": f"{raw.get('groq',{}).get('empirical_pct',0)}/"
149
+ f"{raw.get('groq',{}).get('theoretical_pct',0)}/"
150
+ f"{raw.get('groq',{}).get('mixed_pct',0)}",
151
+ "Mistral Emp/Theo/Mix":f"{raw.get('mistral',{}).get('empirical_pct',0)}/"
152
+ f"{raw.get('mistral',{}).get('theoretical_pct',0)}/"
153
+ f"{raw.get('mistral',{}).get('mixed_pct',0)}",
154
+ "Gemini Emp/Theo/Mix": f"{raw.get('gemini',{}).get('empirical_pct',0)}/"
155
+ f"{raw.get('gemini',{}).get('theoretical_pct',0)}/"
156
+ f"{raw.get('gemini',{}).get('mixed_pct',0)}",
157
+ })
158
+ return pd.DataFrame(rows)
159
+
160
+
161
+ def _regex_hits_df(methodology_data: dict, interps: dict) -> pd.DataFrame:
162
+ """
163
+ One row per (cluster, pattern, matched text) so the user can see exactly
164
+ which regex fired on which word in which paper.
165
+ """
166
+ rows = []
167
+ for cid in sorted(methodology_data.keys()):
168
+ md = methodology_data[cid]
169
+ label = interps.get(cid, {}).get("label", f"Cluster {cid}")
170
+ scan = md.get("regex_scan", {})
171
+
172
+ for category, hits in scan.get("methods", {}).items():
173
+ for h in hits:
174
+ rows.append({"Cluster": cid, "Label": label,
175
+ "Bank": "Methodology", "Pattern Category": category,
176
+ "Matched Text": h["match"], "Paper #": h["doc"],
177
+ "Char Span": f"{h['span'][0]}–{h['span'][1]}"})
178
+
179
+ for category, hits in scan.get("techniques", {}).items():
180
+ for h in hits:
181
+ rows.append({"Cluster": cid, "Label": label,
182
+ "Bank": "Technique", "Pattern Category": category,
183
+ "Matched Text": h["match"], "Paper #": h["doc"],
184
+ "Char Span": f"{h['span'][0]}–{h['span'][1]}"})
185
+
186
+ return pd.DataFrame(rows) if rows else pd.DataFrame()
187
+
188
+
189
+ def _methodology_bar_chart(methodology_data: dict, interps: dict) -> go.Figure:
190
+ labels_list, empirical, theoretical, mixed = [], [], [], []
191
+ for cid in sorted(methodology_data.keys()):
192
+ md = methodology_data[cid]
193
+ labels_list.append(interps.get(cid,{}).get("label", f"C{cid}")[:30])
194
+ empirical.append(md.get("empirical_pct", 0))
195
+ theoretical.append(md.get("theoretical_pct", 0))
196
+ mixed.append(md.get("mixed_pct", 0))
197
+
198
+ fig = go.Figure()
199
+ fig.add_trace(go.Bar(name="Empirical %", x=labels_list, y=empirical, marker_color="#3dba7a"))
200
+ fig.add_trace(go.Bar(name="Theoretical %", x=labels_list, y=theoretical, marker_color="#5b9cf6"))
201
+ fig.add_trace(go.Bar(name="Mixed %", x=labels_list, y=mixed, marker_color="#f5a623"))
202
+ fig.update_layout(
203
+ barmode="stack", template="plotly_dark", height=420,
204
+ paper_bgcolor="#0d1117", plot_bgcolor="#161b22",
205
+ title="Research Orientation per Cluster β€” Averaged across Groq + Mistral + Gemini",
206
+ xaxis_title="Cluster", yaxis_title="Percentage (%)",
207
+ font=dict(size=11), legend=dict(orientation="h", y=1.12),
208
+ xaxis_tickangle=-35,
209
+ )
210
+ return fig
211
+
212
+
213
+ def _regex_pattern_info() -> str:
214
+ m_list = "\n".join(f"- **{k}**: `{v.pattern}`" for k,v in METHODOLOGY_PATTERNS.items())
215
+ t_list = "\n".join(f"- **{k}**: `{v.pattern}`" for k,v in TECHNIQUE_PATTERNS.items())
216
+ return (
217
+ "### How Methodology Extraction Works\n\n"
218
+ "**Step 1 β€” Regex Pre-Scan** \n"
219
+ "Two compiled pattern banks (case-insensitive) are run against each representative abstract. "
220
+ "Every match is recorded with its exact character span, matched text, and paper number. "
221
+ "This produces ground-truth hints that are injected into the LLM prompt.\n\n"
222
+ "**Step 2 β€” 3-LLM Council** \n"
223
+ "Groq (llama-3.1-8b), Mistral (mistral-small), and Gemini (gemini-2.5-flash) each receive "
224
+ "the same prompt: the regex evidence + the full abstracts. Each LLM must confirm or reject "
225
+ "the regex hits and may add methods/techniques it finds in the text. "
226
+ "Each LLM also provides an evidence quote (≀15 words) for every item it names.\n\n"
227
+ "**Step 3 β€” Consolidation (β‰₯2-LLM gate)** \n"
228
+ "A method or technique only survives if at least 2 out of 3 LLMs named it. "
229
+ "Percentages are averaged across agreeing LLMs. Items named by only one LLM are marked "
230
+ "REJECTED and shown in the extraction pipeline table.\n\n"
231
+ "**Step 4 β€” Orientation Percentages** \n"
232
+ "Empirical / Theoretical / Mixed percentages are averaged across all 3 LLMs and shown "
233
+ "in the stacked bar chart above.\n\n"
234
+ "---\n\n"
235
+ "#### Methodology Pattern Bank\n" + m_list +
236
+ "\n\n#### Technique Pattern Bank\n" + t_list
237
+ )
238
+
239
+
240
+ def _refinement_df(refinement_log: list) -> pd.DataFrame:
241
+ if not refinement_log:
242
+ return pd.DataFrame(columns=["Cluster","Iteration","Old Label","New Label",
243
+ "Issues","Improvement","Hallucination Detected"])
244
+ return pd.DataFrame([{
245
+ "Cluster": r["cluster"],
246
+ "Iteration": r["iteration"],
247
+ "Old Label": r["old_label"],
248
+ "New Label": r["new_label"],
249
+ "Issues": "; ".join(r.get("issues",[])),
250
+ "Improvement": r["improvement_score"],
251
+ "Hallucination Detected":r["hallucination_detected"],
252
+ } for r in refinement_log])
253
+
254
 
255
  # ── Pipeline runner ──────────────────────────────────────────────────────────
256
+ def _run(file, gk, mk, gek, n_trials, n_optimize,
257
+ progress=gr.Progress(track_tqdm=True)):
258
  if not file: raise gr.Error("Upload a CSV first.")
259
+ gk = gk.strip() or os.getenv("GROQ_API_KEY","")
260
+ mk = mk.strip() or os.getenv("MISTRAL_API_KEY","")
261
  gek = gek.strip() or os.getenv("GEMINI_API_KEY","")
262
  if not all([gk,mk,gek]): raise gr.Error("All 3 API keys required.")
263
+
264
  progress(0.05, desc="πŸ“₯ Loading CSV…")
265
+ progress(0.10, desc="πŸ”¬ Embedding with SPECTER-2 (this takes a few minutes)…")
266
+ r = run_pipeline(file.name, gk, mk, gek, int(n_trials), int(n_optimize))
267
  if r.get("error"): raise gr.Error(r["error"])
268
+
269
+ progress(0.85, desc="πŸ“Š Building outputs…")
270
+ td, interps = r["topic_data"], r.get("interpretations", {})
271
+ disc, met = td["discipline"], td["metrics"]
272
+ ar = r.get("agreement_rates", {})
273
+ rl = r.get("refinement_log", [])
274
+
275
  def _s(ok): return "βœ… PASS" if ok else "❌ FAIL"
276
+ summary = (
277
+ f"## Pipeline Complete β€” {disc['n_clusters']} clusters discovered\n\n"
278
  f"| Criterion | Value | Status |\n|---|---|---|\n"
279
  f"| Max cluster mass | {round(disc['max_mass_pct']*100,1)}% | {_s(disc['max_mass_ok'])} |\n"
280
  f"| Min cluster size | {disc['min_size']} | {_s(disc['min_size_ok'])} |\n"
281
  f"| Persistence (mean) | {round(met['persistence'],4)} | β€” |\n"
282
  f"| DBCV | {round(met['dbcv'],4)} | β€” |\n"
283
+ f"| Stability (3 seeds) | {round(met['stability'],4)} | β€” |\n\n"
284
  f"**Trials:** {td['n_trials_run']} (best #{td['best_trial']}) Β· "
285
+ f"**Agreement:** Triple {ar.get('triple',0)}% Β· Two+ {ar.get('two_or_more',0)}% Β· "
286
+ f"**Optimization passes:** {n_optimize} Β· **Labels refined:** {len(rl)}"
287
+ )
288
+
289
  u2d = np.array(td["umap_2d"])
290
  sdf = pd.DataFrame({"UMAP-1":u2d[:,0],"UMAP-2":u2d[:,1],
291
  "Cluster":[str(l) for l in td["labels"]],
292
  "Doc":[d[:60] for d in td["documents"]]})
293
  fig = px.scatter(sdf, x="UMAP-1", y="UMAP-2", color="Cluster",
294
  hover_data=["Doc"], opacity=0.75,
295
+ title="2-D UMAP visualisation of SPECTER-2 embeddings")
296
  fig.update_layout(template="plotly_dark", height=500,
297
+ paper_bgcolor="#0d1117", plot_bgcolor="#161b22", font=dict(size=11))
298
+
 
299
  tl = pd.DataFrame(td["trial_log"])
300
  tl_cols = [c for c in ["trial","discipline_pass","n_clusters","persistence",
301
  "dbcv","max_mass_pct","min_size","n_noise"] if c in tl.columns]
302
  tl_show = tl[tl_cols] if not tl.empty else pd.DataFrame()
303
+
304
  pfig = go.Figure()
305
  if not tl.empty:
306
  for passed, color, name in [(True,"#3dba7a","PASS"),(False,"#e04d4d","FAIL")]:
 
309
  pfig.add_trace(go.Scatter(x=sub["max_mass_pct"],y=sub["persistence"],
310
  mode="markers",marker=dict(size=8,color=color),name=name,
311
  text=sub["trial"],hovertemplate="Trial %{text}<br>Mass: %{x:.0%}<br>Pers: %{y:.3f}"))
312
+ pfig.add_vline(x=0.25,line_dash="dash",line_color="#5a6480",annotation_text="25% rule")
313
+ pfig.update_layout(template="plotly_dark",height=400,
314
+ paper_bgcolor="#0d1117",plot_bgcolor="#161b22",
 
315
  title="Pareto front β€” Persistence vs Max cluster mass",
316
+ xaxis_title="Max cluster mass",yaxis_title="Persistence",font=dict(size=11))
317
+
318
+ cdf_rows = []
 
319
  for cid in sorted(interps.keys()):
320
  v = interps[cid]
321
+ cdf_rows.append({"Cluster":cid,"Label":v["label"],"Agreement":v["agreement"],
322
  "Strong":v["strong"],"Weak":v["weak"],
323
  "Persistence":round(v.get("persistence",0),4),
324
  "Keyphrases":", ".join(v.get("keyphrases",[]))})
325
+ cdf = pd.DataFrame(cdf_rows)
326
+
327
  sheets = r.get("sheets",{})
328
  s1 = pd.DataFrame(sheets.get(1,[])); s2 = pd.DataFrame(sheets.get(2,[]))
329
  s3 = pd.DataFrame(sheets.get(3,[])); s4 = pd.DataFrame(sheets.get(4,[]))
330
  sp = r.get("sheet_paths",{})
331
  mdf = pd.DataFrame(r.get("mismatch_table",[]))
332
+
333
+ md_data = r.get("methodology_data", {})
334
+
335
+ top_papers_df = _top_papers_df(r.get("top_papers", {}))
336
+ method_summary_df = _methodology_summary_df(md_data, interps)
337
+ method_chart = _methodology_bar_chart(md_data, interps)
338
+ extraction_df = _extraction_pipeline_df(md_data, interps)
339
+ per_llm_df = _per_llm_methodology_df(md_data, interps)
340
+ regex_hits_df = _regex_hits_df(md_data, interps)
341
+ pattern_info = _regex_pattern_info()
342
+ refine_df = _refinement_df(rl)
343
+
344
  progress(1.0, desc="βœ… Done!")
345
+ dl_files = [f for f in [sp.get(1),sp.get(2),sp.get(3),sp.get(4),r.get("json_path")] if f]
346
+
347
+ return (summary, fig, pfig, tl_show, cdf,
348
+ top_papers_df,
349
+ method_chart, method_summary_df, extraction_df, per_llm_df,
350
+ regex_hits_df, pattern_info,
351
+ refine_df,
352
+ s1, s2, s3, s4,
353
+ dl_files if dl_files else None,
354
+ mdf)
355
+
356
 
357
  # ── UI ─────────────────────��─────────────────────────────────────────────────
358
  css = ".gradio-container{background:#0d1117!important;color:#c9d1d9!important}" \
359
  "footer{display:none!important}"
360
+
361
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate"),
362
  css=css, title="SPECTER-2 Topic Analyzer") as demo:
363
  gr.Markdown("# πŸ“ SPECTER-2 Topic Analyzer")
364
+
365
  with gr.Row():
366
  with gr.Column(scale=1):
367
+ file_in = gr.File(label="Upload Scopus CSV", file_types=[".csv"])
368
  preview_out = gr.Markdown("Upload a CSV to see stats.")
369
+ groq_in = gr.Textbox(label="Groq API Key", type="password",
370
+ placeholder="or set GROQ_API_KEY env var")
371
  mistral_in = gr.Textbox(label="Mistral API Key", type="password",
372
+ placeholder="or set MISTRAL_API_KEY env var")
373
+ gemini_in = gr.Textbox(label="Gemini API Key", type="password",
374
+ placeholder="or set GEMINI_API_KEY env var")
375
+ trials_in = gr.Slider(10, 100, 50, step=5, label="Optuna Trials")
376
+ optimize_in = gr.Slider(1, 5, 1, step=1,
377
+ label="πŸ” Optimization Passes",
378
+ info="Each pass: LLM critic audits labels for hallucinations. "
379
+ "1 = disabled. 2–5 = progressive refinement.")
380
+ run_btn = gr.Button("β–Ά Run Full Pipeline", variant="primary", size="lg")
381
+
382
  with gr.Column(scale=3):
383
  with gr.Tabs():
384
+
385
+ with gr.Tab("Summary"):
386
+ summary_out = gr.Markdown()
387
+
388
+ with gr.Tab("2-D UMAP"):
389
+ scatter_out = gr.Plot()
390
+
391
+ with gr.Tab("Pareto Front"):
392
+ pareto_out = gr.Plot()
393
+
394
+ with gr.Tab("Trial Log"):
395
+ trial_out = gr.Dataframe()
396
+
397
+ with gr.Tab("Clusters"):
398
+ cluster_out = gr.Dataframe()
399
+
400
+ with gr.Tab("πŸ—ž Top 3 Papers"):
401
+ gr.Markdown("### Top 3 Representative Papers per Cluster\n"
402
+ "Ranked by cosine similarity to the cluster centroid "
403
+ "in SPECTER-2 embedding space.")
404
+ top_papers_out = gr.Dataframe(
405
+ headers=["Cluster","Label","Rank","Title","Abstract Snippet"],
406
+ wrap=True)
407
+
408
+ with gr.Tab("πŸ”¬ Methodology β€” Summary"):
409
+ gr.Markdown("### Consolidated Methodology Results\n"
410
+ "Only items agreed by **β‰₯ 2 out of 3 LLMs** (Groq + Mistral + Gemini) "
411
+ "appear here. Percentages averaged across agreeing LLMs.")
412
+ method_chart_out = gr.Plot()
413
+ method_summary_out = gr.Dataframe(wrap=True)
414
+
415
+ with gr.Tab("βš™ Methodology β€” Extraction Pipeline"):
416
+ gr.Markdown("### Full Extraction Trace\n"
417
+ "One row per method/technique showing: which regex pattern fired, "
418
+ "the exact matched text, how many LLMs agreed, and whether it "
419
+ "passed the β‰₯2-LLM gate.")
420
+ extraction_out = gr.Dataframe(wrap=True)
421
+
422
+ with gr.Tab("πŸ€– Methodology β€” Per-LLM Votes"):
423
+ gr.Markdown("### Raw Per-LLM Methodology Responses\n"
424
+ "Side-by-side view of what each LLM independently extracted "
425
+ "before consolidation.")
426
+ per_llm_out = gr.Dataframe(wrap=True)
427
+
428
+ with gr.Tab("πŸ” Regex Hits"):
429
+ gr.Markdown("### Regex Pattern Matches\n"
430
+ "Every regex match with its exact character span, matched text, "
431
+ "and which paper (1–3) it came from. This is the ground-truth "
432
+ "evidence fed to all 3 LLMs.")
433
+ regex_hits_out = gr.Dataframe(wrap=True)
434
+ regex_info_out = gr.Markdown()
435
+
436
+ with gr.Tab("πŸ” Refinement Log"):
437
+ gr.Markdown("### Optimization Refinement Log\n"
438
+ "Changes made by the Groq critic per optimization pass. "
439
+ "A label is only changed when improvement_score > 0.15 "
440
+ "OR hallucination was detected, AND the new label passes "
441
+ "the keyphrase grounding check.")
442
+ refine_out = gr.Dataframe(
443
+ headers=["Cluster","Iteration","Old Label","New Label",
444
+ "Issues","Improvement","Hallucination Detected"],
445
+ wrap=True)
446
+
447
+ with gr.Tab("Sheet 1 β€” Groq"): s1_out = gr.Dataframe()
448
  with gr.Tab("Sheet 2 β€” Mistral"): s2_out = gr.Dataframe()
449
+ with gr.Tab("Sheet 3 β€” Gemini"): s3_out = gr.Dataframe()
450
  with gr.Tab("Sheet 4 β€” Consolidated"): s4_out = gr.Dataframe()
451
+ with gr.Tab("RQ Mismatch"): mismatch_out = gr.Dataframe()
452
  with gr.Tab("Downloads"):
453
  dl_out = gr.File(label="All sheet CSVs + topics.json",
454
  file_count="multiple")
455
+
456
  file_in.change(_preview, inputs=[file_in], outputs=[preview_out])
457
+
458
+ run_btn.click(
459
+ _run,
460
+ inputs=[file_in, groq_in, mistral_in, gemini_in, trials_in, optimize_in],
461
+ outputs=[
462
+ summary_out, scatter_out, pareto_out, trial_out, cluster_out,
463
+ top_papers_out,
464
+ method_chart_out, method_summary_out, extraction_out, per_llm_out,
465
+ regex_hits_out, regex_info_out,
466
+ refine_out,
467
+ s1_out, s2_out, s3_out, s4_out,
468
+ dl_out, mismatch_out,
469
+ ],
470
+ )
471
 
472
  if __name__ == "__main__":
473
+ demo.launch(server_name="0.0.0.0", server_port=7860)