Spaces:
Configuration error
Configuration error
| """ | |
| app.py β Gradio Blocks UI for the BERTopic Thematic Analysis Agent. | |
| Sections: (1) Data Input, (2) Agent Conversation, (3) Results | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.io as pio | |
| from agent import agent | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| THREAD_ID = str(uuid.uuid4()) | |
| AGENT_CONFIG = { | |
| "configurable": {"thread_id": THREAD_ID}, | |
| "recursion_limit": 100, | |
| } | |
| REVIEW_COLUMNS = [ | |
| "#", | |
| "Topic Label", | |
| "Top Evidence", | |
| "Sentences", | |
| "Papers", | |
| "Approve", | |
| "Rename To", | |
| "Reasoning", | |
| ] | |
| PHASE_LABELS = [ | |
| ("Phase 1", "Familiarisation"), | |
| ("Phase 2", "Initial Codes"), | |
| ("Phase 3", "Themes"), | |
| ("Phase 4", "Saturation"), | |
| ("Phase 5", "Naming"), | |
| ("Phase 5.5", "PAJAIS"), | |
| ("Phase 6", "Report"), | |
| ] | |
| CHART_OPTIONS = [ | |
| "Bar β Top 20 Topics", | |
| "Treemap β Topic Distribution", | |
| "Scatter β Cluster PCA", | |
| "Heatmap β Topic Similarity", | |
| ] | |
| _CHART_KEYS = ["bar_top20", "treemap", "scatter_pca", "heatmap"] | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _phase_bar_html(active_index: int) -> str: | |
| steps_html = "" | |
| for i, (code, name) in enumerate(PHASE_LABELS): | |
| if i < active_index: | |
| state, bg, fg = "done", "#10b981", "#ffffff" | |
| elif i == active_index: | |
| state, bg, fg = "active", "#6366f1", "#ffffff" | |
| else: | |
| state, bg, fg = "pending", "#e5e7eb", "#6b7280" | |
| steps_html += ( | |
| f'<div style="display:flex;flex-direction:column;align-items:center;gap:4px;flex:1;">' | |
| f'<div style="width:32px;height:32px;border-radius:50%;background:{bg};' | |
| f'color:{fg};display:flex;align-items:center;justify-content:center;' | |
| f'font-size:11px;font-weight:600;">{i+1}</div>' | |
| f'<span style="font-size:10px;color:#374151;text-align:center;line-height:1.2;">' | |
| f'{code}<br>{name}</span>' | |
| f'</div>' | |
| ) | |
| if i < len(PHASE_LABELS) - 1: | |
| line_bg = "#10b981" if i < active_index else "#e5e7eb" | |
| steps_html += ( | |
| f'<div style="flex:1;height:2px;background:{line_bg};margin-top:16px;' | |
| f'max-width:40px;"></div>' | |
| ) | |
| return ( | |
| f'<div style="padding:16px 8px;background:#f9fafb;border-radius:12px;' | |
| f'border:1px solid #e5e7eb;margin-bottom:8px;">' | |
| f'<div style="display:flex;align-items:flex-start;justify-content:space-between;">' | |
| f'{steps_html}</div></div>' | |
| ) | |
| def _empty_review_df() -> pd.DataFrame: | |
| return pd.DataFrame(columns=REVIEW_COLUMNS) | |
| def _load_charts() -> dict: | |
| p = Path("charts.json") | |
| return json.loads(p.read_text()) if p.exists() else {} | |
| def _call_agent(message: str, history: list): | |
| result = agent.invoke( | |
| {"messages": [{"role": "user", "content": message}]}, | |
| config=AGENT_CONFIG, | |
| ) | |
| ai_msg = result["messages"][-1].content | |
| updated_history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": ai_msg}, | |
| ] | |
| return updated_history, "" | |
| def _submit_review( | |
| review_df: pd.DataFrame, | |
| history: list, | |
| ) -> tuple[list, str, pd.DataFrame]: | |
| """Read table edits, serialise to JSON, send to agent.""" | |
| approved = review_df[ | |
| review_df["Approve"].astype(str).str.lower() == "yes" | |
| ] if not review_df.empty else review_df | |
| groups = {} | |
| for _, row in approved.iterrows(): | |
| theme_name = str( | |
| row.get("Rename To") | |
| or row.get("Topic Label") | |
| or f"Theme_{row['#']}" | |
| ) | |
| topic_id = int(row["#"]) if str(row["#"]).isdigit() else 0 | |
| groups.setdefault(theme_name, []).append(topic_id) | |
| groups_list = [ | |
| {"theme_name": k, "topic_ids": v} | |
| for k, v in groups.items() | |
| ] | |
| summary = ( | |
| f"Review submitted. Approved topics: {len(approved)}.\n" | |
| f"Groups formed: {len(groups_list)}.\n\n" | |
| f"{json.dumps(groups_list, indent=2)}\n\n" | |
| f"Please consolidate these groups into themes." | |
| ) | |
| updated_history, _ = _call_agent(summary, history) | |
| return updated_history, "", review_df | |
| def _upload_csv(file_obj): | |
| if file_obj is None: | |
| return "", "No file uploaded." | |
| # π₯ CLEAR OLD FILES | |
| files_to_clear = [ | |
| "labelled_topics.json", | |
| "summaries.json", | |
| "taxonomy_mapping.json", | |
| "comparison.csv", | |
| "report.txt" | |
| ] | |
| list(map(lambda f: os.remove(f) if os.path.exists(f) else None, files_to_clear)) | |
| path = file_obj.name | |
| return path, f"β File ready: `{path}`" | |
| def _start_analysis(csv_path: str, history: list) -> tuple[list, str, str]: | |
| if not csv_path: | |
| return history, "", "β οΈ Please upload a CSV first." | |
| msg = ( | |
| f"I have uploaded a Scopus CSV at: {csv_path}\n" | |
| f"Please begin Phase 1 β Familiarisation. Load the CSV, report statistics, " | |
| f"and STOP after Phase 1." | |
| ) | |
| updated_history, _ = _call_agent(msg, history) | |
| phase_html = _phase_bar_html(0) | |
| return updated_history, "", phase_html | |
| def _send_message(user_msg: str, history: list, phase_html: str) -> tuple[list, str, str]: | |
| if not user_msg.strip(): | |
| return history, "", phase_html | |
| updated_history, _ = _call_agent(user_msg, history) | |
| last_ai = updated_history[-1]["content"] if updated_history else "" | |
| new_phase = _detect_phase(last_ai, phase_html) | |
| return updated_history, "", new_phase | |
| def _detect_phase(ai_text: str, current_html: str) -> str: | |
| phase_map = { | |
| "phase 1": 0, "phase 2": 1, "phase 3": 2, | |
| "phase 4": 3, "phase 5.5": 5, "phase 5": 4, "phase 6": 6, | |
| } | |
| lower = ai_text.lower() | |
| detected = current_html | |
| for key, idx in sorted(phase_map.items(), key=lambda x: -len(x[0])): | |
| if f"{key} complete" in lower or f"beginning {key}" in lower or f"starting {key}" in lower: | |
| detected = _phase_bar_html(idx) | |
| break | |
| return detected | |
| def _get_chart_plot(chart_name: str): | |
| charts = _load_charts() | |
| key_map = dict(zip(CHART_OPTIONS, _CHART_KEYS)) | |
| key = key_map.get(chart_name, "") | |
| payload = charts.get(key, "") | |
| if not payload or str(payload).lstrip().startswith("<"): | |
| return None | |
| return pio.from_json(payload) | |
| def _get_download_files() -> list[str]: | |
| candidates = [ | |
| "comparison_abstract_vs_title.csv", | |
| "narrative.md", | |
| "topics.json", | |
| "labelled_topics.json", | |
| "themes.json", | |
| "taxonomy_mapping.json", | |
| "summaries.json", | |
| ] | |
| return list(filter(lambda p: Path(p).exists(), candidates)) | |
| def _refresh_review_table() -> pd.DataFrame: | |
| p = Path("labelled_topics.json") | |
| if not p.exists(): | |
| return _empty_review_df() | |
| topics = json.loads(p.read_text()) | |
| rows = list(map( | |
| lambda t: { | |
| "#": t["topic_id"], | |
| "Topic Label": t.get("label", f"Topic {t['topic_id']}"), | |
| "Top Evidence": " | ".join(t.get("top_sentences", [])[:2]), | |
| "Sentences": t.get("sentence_count", 0), | |
| "Papers": "", | |
| "Approve": "Yes", | |
| "Rename To": "", | |
| "Reasoning": t.get("reasoning", ""), | |
| }, | |
| topics[:100], | |
| )) | |
| return pd.DataFrame(rows) | |
| def _refresh_downloads() -> list[str]: | |
| return _get_download_files() or None | |
| # --------------------------------------------------------------------------- | |
| # Build UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks( | |
| title="BERTopic Thematic Analysis Agent", | |
| ) as demo: | |
| # ---- State ---- | |
| csv_path_state = gr.State("") | |
| # ---- Header ---- | |
| gr.HTML( | |
| '<div style="padding:24px 0 8px;">' | |
| '<h1 style="font-size:1.6rem;font-weight:600;margin:0;color:#1e1b4b;">' | |
| 'π BERTopic Thematic Analysis Agent</h1>' | |
| '<p style="color:#6b7280;margin:4px 0 0;font-size:0.95rem;">' | |
| 'Braun & Clarke (2006) Β· Six-Phase Pipeline Β· PAJAIS Taxonomy</p>' | |
| '</div>' | |
| ) | |
| # ---- Phase Progress Bar ---- | |
| phase_bar = gr.HTML(value=_phase_bar_html(-1), label="Phase Progress") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SECTION 1 β Data Input | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Group(): | |
| gr.Markdown("## 1 Β· Data Input") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| file_upload = gr.File( | |
| label="Upload Scopus CSV", | |
| file_types=[".csv"], | |
| type="filepath", | |
| ) | |
| file_status = gr.Markdown("_No file uploaded._") | |
| with gr.Column(scale=1): | |
| run_config = gr.Radio( | |
| choices=["abstract", "title"], | |
| value="abstract", | |
| label="Run Config (field to cluster)", | |
| ) | |
| start_btn = gr.Button("βΆ Start Analysis", variant="primary", size="lg") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SECTION 2 β Agent Conversation | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Group(): | |
| gr.Markdown("## 2 Β· Agent Conversation") | |
| chatbot = gr.Chatbot( | |
| label="Thematic Analysis Agent" | |
| ) | |
| with gr.Row(): | |
| chat_input = gr.Textbox( | |
| placeholder="Type a message or instruction⦠(e.g. 'proceed to Phase 2')", | |
| label="", | |
| scale=5, | |
| show_label=False, | |
| lines=1, | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SECTION 3 β Results | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Group(): | |
| gr.Markdown("## 3 Β· Results") | |
| with gr.Tabs(): | |
| # --- Tab 1: Review Table --- | |
| with gr.TabItem("π Review Table"): | |
| with gr.Row(): | |
| refresh_table_btn = gr.Button("π Refresh Table", size="sm") | |
| review_table = gr.Dataframe( | |
| value=_empty_review_df(), | |
| headers=REVIEW_COLUMNS, | |
| datatype=[ | |
| "number", "str", "str", "number", | |
| "str", "str", "str", "str", | |
| ], | |
| column_count=(8, "fixed"), | |
| interactive=True, | |
| wrap=True, | |
| label="Topic Review Table (edit Approve / Rename To / Reasoning)" | |
| ) | |
| submit_review_btn = gr.Button( | |
| "β Submit Review", variant="primary", size="lg" | |
| ) | |
| # --- Tab 2: Charts --- | |
| with gr.TabItem("π Charts"): | |
| chart_dropdown = gr.Dropdown( | |
| choices=CHART_OPTIONS, | |
| value=CHART_OPTIONS[0], | |
| label="Select Chart", | |
| interactive=True, | |
| ) | |
| chart_display = gr.Plot(label="Chart") | |
| # --- Tab 3: Download --- | |
| with gr.TabItem("β¬ Download"): | |
| refresh_dl_btn = gr.Button("π Refresh Files", size="sm") | |
| download_files = gr.File( | |
| label="Download Analysis Outputs", | |
| file_count="multiple", | |
| interactive=False, | |
| value=None, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Event wiring | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Upload CSV β store path | |
| file_upload.change( | |
| fn=_upload_csv, | |
| inputs=[file_upload], | |
| outputs=[csv_path_state, file_status], | |
| ) | |
| # Start analysis button | |
| start_btn.click( | |
| fn=_start_analysis, | |
| inputs=[csv_path_state, chatbot], | |
| outputs=[chatbot, chat_input, phase_bar], | |
| ) | |
| # Send message (button) | |
| send_btn.click( | |
| fn=_send_message, | |
| inputs=[chat_input, chatbot, phase_bar], | |
| outputs=[chatbot, chat_input, phase_bar], | |
| ) | |
| # Send message (Enter key) | |
| chat_input.submit( | |
| fn=_send_message, | |
| inputs=[chat_input, chatbot, phase_bar], | |
| outputs=[chatbot, chat_input, phase_bar], | |
| ) | |
| # Submit review table | |
| submit_review_btn.click( | |
| fn=_submit_review, | |
| inputs=[review_table, chatbot], | |
| outputs=[chatbot, chat_input, review_table], | |
| ) | |
| # Refresh review table | |
| refresh_table_btn.click( | |
| fn=_refresh_review_table, | |
| inputs=[], | |
| outputs=[review_table], | |
| ) | |
| # Chart dropdown | |
| chart_dropdown.change( | |
| fn=_get_chart_plot, | |
| inputs=[chart_dropdown], | |
| outputs=[chart_display], | |
| ) | |
| # Refresh downloads | |
| refresh_dl_btn.click( | |
| fn=_refresh_downloads, | |
| inputs=[], | |
| outputs=[download_files], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Launch | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| theme=gr.themes.Soft(primary_hue="indigo"), | |
| ) |