rahull30's picture
downloadable result csv's
2595b94
"""
app.py — Gradio web application for SPECTER2-based scientific topic modelling.
Pipeline:
CSV Upload → Preprocessing → SPECTER2 Embeddings → UMAP → HDBSCAN →
Top Papers → LLM Label Generation (3 approaches) → AI Council →
TCCM Classification → KeyBERT Keywords → Results
PARALLELIZATION:
Per-cluster processing (labeling + AI Council + TCCM + keywords) is
executed in a ThreadPoolExecutor(max_workers=10), reducing the label
generation phase from ~60 min sequential to ~5-8 min parallel.
"""
import os
import io
import sys
import traceback
import numpy as np
import pandas as pd
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from concurrent.futures import ThreadPoolExecutor, as_completed
# Local imports
from utils import (
load_env, build_paper_results, build_cluster_summary,
print_metrics_report, build_metrics_summary, build_council_summary
)
from preprocessing import load_and_preprocess
from embedding import load_or_generate_embeddings
from clustering import auto_cluster, get_top_papers, compute_silhouette, compute_cluster_coherence
from labeling import generate_all_labels
from ai_council import run_council, compute_label_confidence
from tccm_classifier import run_tccm_for_all_clusters, classify_tccm, extract_keywords
load_env()
# ─── PER-CLUSTER WORKER ──────────────────────────────────────────────────────
def _process_cluster(cid, papers, labels, df, np_labels):
"""
Worker function executed in parallel for each cluster.
Runs: generate_all_labels → run_council → compute_label_confidence
→ classify_tccm → extract_keywords
Returns (cid, cluster_result, tccm_result)
"""
try:
# Labels (3 approaches) — each approach calls LLM once
candidates = generate_all_labels(cid, papers)
# AI Council — 3 candidates × 3 agents = 9 LLM calls, all parallel inside
council = run_council(cid, candidates, papers)
label_conf = compute_label_confidence(council)
n_papers = int(np.sum(np_labels == cid))
cluster_result = {
**council,
"label_confidence": label_conf,
"n_papers": n_papers,
}
# TCCM classification
tccm = classify_tccm(cid, papers)
# KeyBERT keywords from clean texts of this cluster
mask = np_labels == cid
clean_texts = df[mask]["combined_text_clean"].tolist()
keywords = extract_keywords(clean_texts)
tccm_result = {**tccm, "keywords": keywords}
return cid, cluster_result, tccm_result
except Exception as e:
tb = traceback.format_exc()
print(f"[Worker] Cluster {cid} FAILED: {e}\n{tb}")
# Return safe fallback values so the pipeline doesn't crash
return cid, {
"final_label": f"Cluster {cid}",
"winning_approach": "error",
"candidates": {},
"justification": f"Error: {e}",
"label_confidence": 0.0,
"n_papers": int(np.sum(np_labels == cid)),
}, {
"theory": "Not specified", "context": "Not specified",
"characteristics": "Not specified", "methodology": "Not specified",
"keywords": [],
}
# ─── PIPELINE ────────────────────────────────────────────────────────────────
def run_full_pipeline(csv_file, progress=gr.Progress(track_tqdm=True)):
"""Main pipeline function called by Gradio."""
try:
# ── Step 1: Preprocessing
progress(0.05, desc="🔍 Preprocessing CSV...")
df, preprocess_stats = load_and_preprocess(csv_file.name)
# ── Step 2: Embeddings
progress(0.15, desc="🧬 Generating SPECTER2 embeddings (may take a few minutes)...")
embeddings = load_or_generate_embeddings(df, batch_size=64)
# ── Step 3+4: UMAP + HDBSCAN (with strict 15 clusters and noise absorption)
progress(0.38, desc="📐 Running UMAP + HDBSCAN (targeting exactly 15 clusters)...")
reduced_nd, reduced_2d, labels, probs = auto_cluster(embeddings)
# ── Step 5: Top Papers
progress(0.52, desc="📄 Selecting top papers per cluster...")
top_papers = get_top_papers(df, reduced_nd, labels, probs)
# ── Metrics
progress(0.56, desc="📊 Computing research metrics...")
silhouette = compute_silhouette(reduced_nd, labels)
coherence = compute_cluster_coherence(embeddings, labels)
# ── Step 6+7+8: Labeling + AI Council + TCCM — ALL IN PARALLEL
cluster_ids = sorted(top_papers.keys())
n_total = len(cluster_ids)
progress(0.58, desc=f"🤖 Labeling & classifying {n_total} clusters in parallel...")
cluster_results: dict = {}
tccm_results: dict = {}
completed = 0
with ThreadPoolExecutor(max_workers=3) as executor:
futures = {
executor.submit(
_process_cluster,
cid, top_papers[cid], labels, df, labels
): cid
for cid in cluster_ids
}
for future in as_completed(futures):
cid_done = futures[future]
try:
cid, cluster_result, tccm_result = future.result()
cluster_results[cid] = cluster_result
tccm_results[cid] = tccm_result
except Exception as e:
print(f"[Pipeline] Unexpected error for cluster {cid_done}: {e}")
completed += 1
pct = 0.58 + 0.37 * (completed / max(n_total, 1))
progress(pct, desc=f"✅ Cluster {completed}/{n_total} done...")
# ── Step 9: Build outputs
progress(0.97, desc="📋 Compiling results...")
paper_df = build_paper_results(df, labels, cluster_results)
cluster_df = build_cluster_summary(
cluster_results, top_papers, coherence, silhouette, tccm_results
)
metrics_df = build_metrics_summary(silhouette, coherence, cluster_results, labels)
council_df = build_council_summary(cluster_results)
print_metrics_report(silhouette, coherence, cluster_results, labels)
# ── Scatter plot
fig = _make_scatter(df, reduced_2d, labels, cluster_results)
# ── Dataset Overview
overview_md = _build_overview_md(preprocess_stats)
# ── Metrics string (keep for UI but add DF for download)
avg_coherence = float(np.mean(list(coherence.values()))) if coherence else 0
avg_confidence = float(np.mean([
r.get("label_confidence", 0) for r in cluster_results.values()
])) if cluster_results else 0
n_noise = int(np.sum(labels == -1))
noise_pct = 100 * n_noise / max(len(labels), 1)
metrics_md = (
f"### 📊 Research Metrics\n"
f"| Metric | Value |\n|---|---|\n"
f"| Total Clusters | **{len(cluster_results)}** |\n"
f"| Total Papers | **{len(df)}** |\n"
f"| Noise Points | **{n_noise} ({noise_pct:.1f}%)** |\n"
f"| Silhouette Score | **{silhouette:.4f}** |\n"
f"| Avg Cluster Coherence | **{avg_coherence:.4f}** |\n"
f"| Avg Label Confidence | **{avg_confidence:.4f}** |\n"
)
# ── Council comparison table
council_md = _build_council_md(cluster_results)
# ── Save CSV files to disk
paper_df.to_csv("paper_results.csv", index=False)
cluster_df.to_csv("cluster_summary.csv", index=False)
metrics_df.to_csv("metrics_summary.csv", index=False)
council_df.to_csv("council_scores.csv", index=False)
# Cluster options for filtering
cids = sorted([int(c) for c in cluster_results.keys()])
cluster_choices = ["All Clusters"] + [f"Cluster {c}" for c in cids]
progress(1.0, desc="✅ Done! (Results saved to project folder)")
return (
cluster_df,
paper_df,
fig,
metrics_md,
overview_md,
council_md,
gr.update(choices=cluster_choices, value="All Clusters"),
gr.update(value="✅ **Pipeline complete.** Results saved as CSV files in the project folder.", visible=True),
gr.update(value="cluster_summary.csv", interactive=True),
gr.update(value="paper_results.csv", interactive=True),
gr.update(value="metrics_summary.csv", interactive=True),
gr.update(value="council_scores.csv", interactive=True),
)
except Exception as e:
tb = traceback.format_exc()
print(f"[Pipeline Error] {tb}")
raise gr.Error(f"Pipeline failed: {str(e)}\n\nDetails:\n{tb}")
# ─── HELPER BUILDERS ─────────────────────────────────────────────────────────
def _build_overview_md(stats: dict) -> str:
"""Build a markdown table summarising dataset preprocessing statistics."""
total = stats.get("total", 0)
missing_abs = stats.get("missing_abstracts", 0)
dupes = stats.get("duplicates_removed", 0)
final = stats.get("final_count", 0)
cleaned = total - final - dupes
return (
f"### 📂 Dataset Overview\n"
f"| Stage | Count |\n|---|---|\n"
f"| Papers in CSV | **{total}** |\n"
f"| Missing abstracts | **{missing_abs}** |\n"
f"| Duplicates removed | **{dupes}** |\n"
f"| Short / invalid texts removed | **{max(0, cleaned)}** |\n"
f"| **Papers used for analysis** | **{final}** |\n"
)
def _build_council_md(cluster_results: dict) -> str:
"""Build a markdown comparison table of AI Council scores per cluster."""
if not cluster_results:
return ""
rows = []
for cid, result in sorted(cluster_results.items()):
candidates = result.get("candidates", {})
winner = result.get("winning_approach", "")
for approach, eval_data in candidates.items():
sc = eval_data.get("scores", {})
is_winner = "✅" if approach == winner else ""
rows.append({
"Cluster": cid,
"Approach": approach,
"Label (truncated)": eval_data.get("label", "")[:45],
"Semantic": f"{sc.get('semantic', 0):.2f}",
"Keyword": f"{sc.get('keyword', 0):.2f}",
"Clarity": f"{sc.get('clarity', 0):.2f}",
"Final": f"{sc.get('final', 0):.3f}",
"Winner": is_winner,
})
if not rows:
return ""
lines = ["### 🏛️ AI Council Score Comparison\n"]
lines.append("| Cluster | Approach | Label | Semantic | Keyword | Clarity | Final | Winner |")
lines.append("|---|---|---|---|---|---|---|---|")
for r in rows:
lines.append(
f"| {r['Cluster']} | {r['Approach']} | {r['Label (truncated)']} "
f"| {r['Semantic']} | {r['Keyword']} | {r['Clarity']} | {r['Final']} | {r['Winner']} |"
)
return "\n".join(lines)
def _make_scatter(df, reduced_2d, labels, cluster_results):
"""Create a Plotly 2D scatter plot with cluster colors."""
n = len(df)
cluster_labels_list = []
for i in range(n):
cid = int(labels[i])
if cid == -1:
cluster_labels_list.append("Noise")
elif cid in cluster_results:
cluster_labels_list.append(f"[{cid}] {cluster_results[cid]['final_label'][:40]}")
else:
cluster_labels_list.append(f"Cluster {cid}")
plot_df = pd.DataFrame({
"x": reduced_2d[:, 0],
"y": reduced_2d[:, 1],
"cluster": cluster_labels_list,
"title": df["Title"].str[:80],
})
noise_mask = plot_df["cluster"] == "Noise"
fig = go.Figure()
non_noise = plot_df[~noise_mask]
cluster_names = sorted(non_noise["cluster"].unique())
colors = px.colors.qualitative.Alphabet + px.colors.qualitative.Dark24
for i, cname in enumerate(cluster_names):
cdata = non_noise[non_noise["cluster"] == cname]
fig.add_trace(go.Scatter(
x=cdata["x"], y=cdata["y"],
mode="markers",
name=cname,
text=cdata["title"],
hovertemplate="%{text}<extra>%{fullData.name}</extra>",
marker=dict(size=5, color=colors[i % len(colors)], opacity=0.75),
))
if noise_mask.any():
ndata = plot_df[noise_mask]
fig.add_trace(go.Scatter(
x=ndata["x"], y=ndata["y"],
mode="markers",
name="Noise",
text=ndata["title"],
hovertemplate="%{text}<extra>Noise</extra>",
marker=dict(size=3, color="#aaaaaa", opacity=0.4),
))
fig.update_layout(
title="UMAP 2D Projection — Colored by Cluster",
xaxis_title="UMAP Dimension 1",
yaxis_title="UMAP Dimension 2",
legend=dict(font=dict(size=10), itemsizing="constant", orientation="v", x=1.01),
height=620,
plot_bgcolor="#0f1117",
paper_bgcolor="#0f1117",
font=dict(color="#e0e0e0"),
xaxis=dict(gridcolor="#2a2a3a", zeroline=False),
yaxis=dict(gridcolor="#2a2a3a", zeroline=False),
margin=dict(l=40, r=200, t=50, b=40),
)
return fig
# ─── DOWNLOAD HANDLER ────────────────────────────────────────────────────────
def filter_papers_by_cluster(paper_df_raw: pd.DataFrame, cluster_choice: str):
"""Filter the paper results table by cluster selection for the UI."""
if not isinstance(paper_df_raw, pd.DataFrame):
return paper_df_raw
if cluster_choice == "All Clusters" or not cluster_choice:
return paper_df_raw
try:
cid = int(cluster_choice.replace("Cluster ", ""))
return paper_df_raw[paper_df_raw["Cluster_ID"] == cid]
except:
return paper_df_raw
# ─── GRADIO UI ───────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:ital,wght@0,400;0,700;1,400&family=DM+Serif+Display:ital@0;1&display=swap');
:root {
--bg-deep: #f8fafc;
--bg-panel: #ffffff;
--bg-card: #ffffff;
--border: #e2e8f0;
--accent: #2563eb;
--accent2: #059669;
--accent3: #dc2626;
--text: #1e293b;
--text-muted: #64748b;
--font-mono: 'Space Mono', monospace;
--font-serif: 'DM Serif Display', serif;
}
body, .gradio-container {
background: var(--bg-deep) !important;
font-family: var(--font-mono) !important;
color: var(--text) !important;
}
.main-title {
font-family: var(--font-serif) !important;
font-size: 2.8rem !important;
font-weight: 400 !important;
color: #1e293b !important;
text-align: center;
margin: 1.5rem 0 0.2rem;
letter-spacing: -0.02em;
line-height: 1.1;
}
.subtitle {
font-family: var(--font-mono) !important;
font-size: 0.78rem !important;
color: var(--text-muted) !important;
text-align: center;
letter-spacing: 0.15em;
text-transform: uppercase;
margin-bottom: 2rem;
}
.pipeline-badge {
display: inline-block;
background: #f1f5f9;
border: 1px solid var(--border);
border-radius: 6px;
padding: 0.6rem 1.2rem;
font-size: 0.7rem;
color: var(--accent);
letter-spacing: 0.1em;
text-align: center;
margin: 0.3rem;
}
label, .label-wrap {
font-family: var(--font-mono) !important;
font-size: 0.75rem !important;
color: var(--accent) !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
}
button.primary {
background: linear-gradient(135deg, #2563eb, #1d4ed8) !important;
border: none !important;
font-family: var(--font-mono) !important;
font-size: 0.85rem !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
padding: 0.8rem 2rem !important;
border-radius: 4px !important;
color: white !important;
transition: all 0.2s !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
}
button.primary:hover {
background: linear-gradient(135deg, #1d4ed8, #1e40af) !important;
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05) !important;
transform: translateY(-1px) !important;
}
.block, .panel, .gr-box {
background: var(--bg-panel) !important;
border: 1px solid var(--border) !important;
border-radius: 8px !important;
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06) !important;
}
.tab-nav button {
font-family: var(--font-mono) !important;
font-size: 0.75rem !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
color: var(--text-muted) !important;
background: transparent !important;
border: none !important;
border-bottom: 2px solid transparent !important;
transition: all 0.2s !important;
}
.tab-nav button.selected {
color: var(--accent) !important;
border-bottom: 2px solid var(--accent) !important;
}
.metrics-box {
background: #f8fafc !important;
border: 1px solid var(--border) !important;
border-radius: 8px !important;
padding: 1rem !important;
}
table {
font-family: var(--font-mono) !important;
font-size: 0.78rem !important;
color: var(--text) !important;
}
thead th {
color: var(--accent) !important;
text-transform: uppercase !important;
letter-spacing: 0.08em !important;
font-size: 0.7rem !important;
border-bottom: 1px solid var(--border) !important;
background: #f1f5f9 !important;
}
.hint-text {
font-family: var(--font-mono) !important;
font-size: 0.72rem !important;
color: var(--text-muted) !important;
line-height: 1.6 !important;
}
.status-ok {
color: var(--accent2);
font-size: 0.75rem;
font-family: var(--font-mono);
}
"""
HEADER_HTML = """
<div style="text-align:center; padding: 1rem 0 0.5rem;">
<div class="main-title">Scientific Topic Modelling</div>
<div class="subtitle">SPECTER2 · UMAP · HDBSCAN · AI Council · TCCM</div>
<div style="display:flex; flex-wrap:wrap; justify-content:center; gap:0.3rem; margin:1rem 0;">
<div class="pipeline-badge">① SPECTER2 Embeddings</div>
<div class="pipeline-badge">② UMAP Reduction</div>
<div class="pipeline-badge">③ HDBSCAN (15 clusters)</div>
<div class="pipeline-badge">④ LLM Label Generation</div>
<div class="pipeline-badge">⑤ AI Council Scoring</div>
<div class="pipeline-badge">⑥ TCCM Classification</div>
</div>
</div>
"""
INSTRUCTIONS_MD = """
### How to use
1. **Prepare your CSV** — Scopus export format with columns: `Title`, `Abstract`, `DOI`
2. **Set API keys** — Add `GROQ_API_KEY` to your `.env` file
3. **Upload & Run** — Click *Run Pipeline* and wait for results (~10-15 min)
4. **Explore** — Browse cluster labels, top papers, UMAP plot, AI Council scores, TCCM, and keywords
### Requirements
- Minimum **50 papers** recommended
- For best results: **200–5000 papers**
- First run downloads SPECTER2 model (~440 MB) — subsequent runs use cache
### Output Tabs
- **📋 Cluster Summary** — Final labels, TCCM, keywords, AI Council scores per cluster
- **📄 Paper Results** — Every paper with its assigned cluster and label
- **🗺️ UMAP Plot** — Interactive 2D scatter with hover tooltips
- **📊 Metrics** — Silhouette score, cluster coherence, label confidence
- **🏛️ AI Council** — Per-label score breakdown for all candidates
- **📂 Dataset Overview** — Preprocessing statistics
"""
def build_app():
with gr.Blocks() as demo:
gr.HTML(HEADER_HTML)
with gr.Row():
with gr.Column(scale=1):
gr.HTML('<div class="hint-text">' + INSTRUCTIONS_MD.replace("\n", "<br>") + '</div>')
with gr.Column(scale=1):
csv_input = gr.File(
label="Upload Scopus CSV",
file_types=[".csv"],
type="filepath",
)
run_btn = gr.Button("▶ Run Full Pipeline", variant="primary", size="lg")
status_box = gr.Markdown("", visible=False, elem_classes=["status-ok"])
with gr.Tabs():
with gr.Tab("📋 Cluster Summary"):
cluster_dl_btn = gr.DownloadButton("📥 Download Cluster Summary CSV", interactive=False)
cluster_table = gr.DataFrame(
label="Cluster Results",
wrap=True,
interactive=False,
buttons=["copy", "fullscreen"],
)
with gr.Tab("📄 Paper Results"):
paper_dl_btn = gr.DownloadButton("📥 Download Paper Results CSV", interactive=False)
cluster_filter = gr.Dropdown(
label="Filter by Cluster",
choices=["All Clusters"],
value="All Clusters",
)
paper_table = gr.DataFrame(
label="Per-Paper Results",
wrap=True,
interactive=False,
buttons=["copy", "fullscreen"],
)
with gr.Tab("🗺️ UMAP Plot"):
scatter_plot = gr.Plot(label="2D UMAP Projection")
with gr.Tab("📊 Research Metrics"):
metrics_dl_btn = gr.DownloadButton("📥 Download Metrics Summary CSV", interactive=False)
metrics_md = gr.Markdown("")
with gr.Tab("🏛️ AI Council"):
council_dl_btn = gr.DownloadButton("📥 Download AI Council Scores CSV", interactive=False)
council_md = gr.Markdown("")
with gr.Tab("📂 Dataset Overview"):
overview_md = gr.Markdown("")
# ── EVENT HANDLERS ──────────────────────────────────────────────────
run_btn.click(
fn=run_full_pipeline,
inputs=[csv_input],
outputs=[
cluster_table, paper_table, scatter_plot, metrics_md,
overview_md, council_md,
cluster_filter,
status_box,
cluster_dl_btn, paper_dl_btn, metrics_dl_btn, council_dl_btn
],
)
# Filtering logic
cluster_filter.change(
fn=filter_papers_by_cluster,
inputs=[paper_table, cluster_filter],
outputs=[paper_table]
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
css=CSS,
)