Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import json | |
| import plotly.io as pio | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from langchain_mistralai import ChatMistralAI | |
| from langgraph.prebuilt import create_react_agent | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from agent import SYSTEM_PROMPT, get_local_tools | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
| load_dotenv() | |
| OUTPUT_DIR = "outputs" | |
| CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints") | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| llm = ChatMistralAI(model="open-mistral-nemo", temperature=0, timeout=300, max_retries=5) | |
| agent = create_react_agent(model=llm, tools=get_local_tools(), prompt=SYSTEM_PROMPT, checkpointer=MemorySaver()) | |
| _msg_count = 0 | |
| _uploaded = {"path": ""} | |
| def _latest_output(): | |
| ord = {"summaries": 1, "labels": 2, "themes": 3, "taxonomy": 4, "comparison": 9, "narrative": 10} | |
| fs = glob.glob(f"{OUTPUT_DIR}/rq4_*.csv") + glob.glob(f"{CHECKPOINT_DIR}/rq4_*.json") | |
| scored = sorted([(sum(v * (k in f) for k, v in ord.items()), f) for f in fs], key=lambda x: x[0]) | |
| return [x[1] for x in scored] or [] | |
| def _build_progress(): | |
| ps = [ | |
| ("Load", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_summaries.json"))), | |
| ("Codes", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_labels.json"))), | |
| ("Themes", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_themes.json"))), | |
| ("PAJAIS", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_taxonomy_map.json"))), | |
| ("Report", bool(glob.glob(f"{OUTPUT_DIR}/rq4_comparison.csv"))), | |
| ] | |
| return " β ".join(f"{'β ' if d else 'β¬'} {n}" for n, d in ps) | |
| def respond(message, chat_history, uploaded_file): | |
| global _msg_count | |
| _msg_count += 1 | |
| _uploaded["path"] = uploaded_file or _uploaded.get("path", "") | |
| text = (message or "Analyze") + (f"\n[CSV: {_uploaded['path']}]" if _uploaded["path"] else "\n[No CSV]") | |
| chat_history.append({"role": "user", "content": message or "Analyze"}) | |
| chat_history.append({"role": "assistant", "content": "π¬ **Working...**"}) | |
| yield chat_history, "", _latest_output() | |
| res = agent.invoke({"messages": [("human", text)]}, config={"configurable": {"thread_id": "session"}}) | |
| chat_history[-1] = {"role": "assistant", "content": res["messages"][-1].content} | |
| yield chat_history, "", _latest_output() | |
| def _load_chart(name): | |
| if not name or not os.path.exists(os.path.join(OUTPUT_DIR, name)): return None | |
| return pio.from_json(open(os.path.join(OUTPUT_DIR, name)).read()) | |
| def _get_chart_choices(): | |
| return [os.path.basename(f) for f in sorted(glob.glob(f"{OUTPUT_DIR}/rq4_*.json"))] | |
| def _load_review_table(): | |
| ps = sorted(glob.glob(f"{CHECKPOINT_DIR}/rq4_*.json")) | |
| if not ps: return [[0, "No data", "", 0, 0, False, "", ""]] | |
| data = json.load(open(ps[-1])) | |
| return [[i, d.get("label", d.get("top_words", ""))[:60], d.get("nearest", [{}])[0].get("sentence", "")[:120], d.get("sentence_count", 0), d.get("paper_count", 0), True, "", ""] for i, d in enumerate(data)] | |
| def _show_papers_by_select(table_data, evt: gr.SelectData): | |
| idx = int(table_data.iloc[evt.index[0], 0]) if hasattr(table_data, 'iloc') else int(table_data[evt.index[0]][0]) | |
| fs = sorted(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_labels.json")) or sorted(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_summaries.json")) | |
| for f in fs: | |
| for t in json.load(open(f)): | |
| if t.get("topic_id") == idx: | |
| return f"Topic {idx}: {t.get('label', '')}\n\n" + "\n".join(f"- {p}" for p in t.get("paper_titles", [])) | |
| return "Not found" | |
| def _submit_review(table_data, chat_history): | |
| ls = [f"Topic {int(r[0])}: {'RENAME to '+r[6] if r[6] else ('APPROVE' if r[5] else 'REJECT')}" for r in table_data.values.tolist()] | |
| msg = "Review decisions:\n" + "\n".join(ls) | |
| chat_history.append({"role": "user", "content": "Submitted review"}) | |
| chat_history.append({"role": "assistant", "content": "π¬ **Processing...**"}) | |
| yield chat_history, _latest_output(), gr.update(), gr.update(), _build_progress() | |
| res = agent.invoke({"messages": [("human", msg)]}, config={"configurable": {"thread_id": "session"}}) | |
| chat_history[-1] = {"role": "assistant", "content": res["messages"][-1].content} | |
| yield chat_history, _latest_output(), gr.update(choices=_get_chart_choices()), _load_review_table(), _build_progress() | |
| CSS = """ | |
| .gradio-container { background: #0b0f19 !important; color: #f8fafc !important; } | |
| .sidebar { background: #111827 !important; border-right: 1px solid #1f2937 !important; } | |
| .header-text { font-family: 'Outfit', sans-serif; color: #ffffff !important; letter-spacing: -0.02em; } | |
| .tab-nav { border-bottom: 1px solid #1f2937 !important; background: transparent !important; } | |
| .chatbot-container { border-radius: 12px !important; border: 1px solid #1f2937 !important; overflow: hidden; } | |
| .primary-btn { background: #4f46e5 !important; color: #ffffff !important; border-radius: 8px !important; font-weight: 600 !important; } | |
| .secondary-btn { background: #1f2937 !important; color: #f8fafc !important; border: 1px solid #374151 !important; border-radius: 8px !important; } | |
| body, .gr-form, .gr-input, .gr-button, p, span, h1, h2, h3, h4, h5, h6, label, .gr-markdown { | |
| color: #f8fafc !important; | |
| } | |
| .primary-btn span, .primary-btn { | |
| color: #ffffff !important; | |
| } | |
| .sidebar span, .sidebar p, .sidebar h2, .sidebar label { | |
| color: #f8fafc !important; | |
| } | |
| /* Ensure inputs are dark but readable */ | |
| input, textarea, select { | |
| background-color: #1f2937 !important; | |
| color: #f8fafc !important; | |
| border: 1px solid #374151 !important; | |
| } | |
| """ | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="violet", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Outfit"), | |
| font_mono=gr.themes.GoogleFont("JetBrains Mono"), | |
| ).set( | |
| body_background_fill="#0b0f19", | |
| block_background_fill="#111827", | |
| block_title_text_weight="700", | |
| button_primary_background_fill="*primary_600", | |
| button_primary_text_color="white", | |
| body_text_color="#f8fafc", | |
| block_label_text_color="#94a3b8", | |
| ) | |
| with gr.Blocks(title="Thematic Analysis AI") as demo: | |
| with gr.Sidebar(label="Data Hub", open=True): | |
| gr.HTML("<h2 class='header-text'>π Resource Center</h2>") | |
| upload = gr.File(label="Dataset (Scopus CSV)", file_types=[".csv"], elem_id="file-upload") | |
| progress = gr.Markdown(value=_build_progress(), elem_id="progress-display") | |
| gr.HTML("<hr>") | |
| gr.Markdown("### π οΈ Configuration\nModel: `mistral-small-latest`\nPipeline: `BERTopic + Agglomerative`") | |
| gr.HTML("<h1 class='header-text' style='margin-bottom: 20px;'>π¬ Topic Modelling Agentic AI</h1>") | |
| with gr.Tabs(): | |
| with gr.Tab("π¬ Agent Chat"): | |
| chatbot = gr.Chatbot(height=450, show_label=False, elem_classes="chatbot-container") | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Ask the agent to analyze, group, or export...", show_label=False, scale=9) | |
| send = gr.Button("Send", variant="primary", scale=1, elem_classes="primary-btn") | |
| with gr.Tab("π Review & Refine"): | |
| gr.Markdown("### π Topic Validation Table\nReview the identified themes and rename or reject as needed.") | |
| table = gr.Dataframe(headers=["#", "Label", "Key Evidence", "Sents", "Papers", "Approve", "Rename", "Reasoning"], datatype=["number", "str", "str", "number", "number", "bool", "str", "str"], interactive=True) | |
| with gr.Row(): | |
| submit = gr.Button("Submit Review Decisions", variant="primary", scale=2, elem_classes="primary-btn") | |
| clear = gr.Button("Refresh Table", variant="secondary", scale=1, elem_classes="secondary-btn") | |
| papers = gr.Textbox(label="Full Context: Papers in Selected Topic", lines=6, interactive=False) | |
| with gr.Tab("π Visual Analytics"): | |
| gr.Markdown("### π Interactive Topic Visualizations") | |
| with gr.Row(): | |
| selector = gr.Dropdown(choices=[], label="Select Visualization Type", scale=7) | |
| refresh_viz = gr.Button("Refresh Charts", variant="secondary", scale=1) | |
| display = gr.Plot() | |
| with gr.Tab("π₯ Export Control"): | |
| gr.Markdown("### πΎ Final Outputs\nDownload generated papers, narratives, and comparison matrices.") | |
| download = gr.File(label="Available Exports", file_count="multiple") | |
| def respond_with_viz(m, h, u): | |
| g = respond(m, h, u) | |
| for hist, _, dl in g: | |
| cs = _get_chart_choices() | |
| yield hist, "", dl, gr.update(choices=cs, value=cs[-1] if cs else None), _load_chart(cs[-1]) if cs else None, _load_review_table(), _build_progress() | |
| def upload_handler(f, h): | |
| yield from respond_with_viz("Analyze CSV", h, f) | |
| msg.submit(respond_with_viz, [msg, chatbot, upload], [chatbot, msg, download, selector, display, table, progress]) | |
| send.click(respond_with_viz, [msg, chatbot, upload], [chatbot, msg, download, selector, display, table, progress]) | |
| selector.change(_load_chart, [selector], [display]) | |
| table.select(_show_papers_by_select, [table], [papers]) | |
| submit.click(_submit_review, [table, chatbot], [chatbot, download, selector, table, progress]) | |
| upload.change(upload_handler, [upload, chatbot], [chatbot, msg, download, selector, display, table, progress]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, theme=theme, css=CSS) | |