Daksh C Jain
fix: change mistral model to open-mistral-nemo to circumvent 503 errors
7299c66
import os
import glob
import json
import plotly.io as pio
import gradio as gr
from dotenv import load_dotenv
from langchain_mistralai import ChatMistralAI
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from agent import SYSTEM_PROMPT, get_local_tools
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
load_dotenv()
OUTPUT_DIR = "outputs"
CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
llm = ChatMistralAI(model="open-mistral-nemo", temperature=0, timeout=300, max_retries=5)
agent = create_react_agent(model=llm, tools=get_local_tools(), prompt=SYSTEM_PROMPT, checkpointer=MemorySaver())
_msg_count = 0
_uploaded = {"path": ""}
def _latest_output():
ord = {"summaries": 1, "labels": 2, "themes": 3, "taxonomy": 4, "comparison": 9, "narrative": 10}
fs = glob.glob(f"{OUTPUT_DIR}/rq4_*.csv") + glob.glob(f"{CHECKPOINT_DIR}/rq4_*.json")
scored = sorted([(sum(v * (k in f) for k, v in ord.items()), f) for f in fs], key=lambda x: x[0])
return [x[1] for x in scored] or []
def _build_progress():
ps = [
("Load", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_summaries.json"))),
("Codes", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_labels.json"))),
("Themes", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_themes.json"))),
("PAJAIS", bool(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_taxonomy_map.json"))),
("Report", bool(glob.glob(f"{OUTPUT_DIR}/rq4_comparison.csv"))),
]
return " β†’ ".join(f"{'βœ…' if d else '⬜'} {n}" for n, d in ps)
def respond(message, chat_history, uploaded_file):
global _msg_count
_msg_count += 1
_uploaded["path"] = uploaded_file or _uploaded.get("path", "")
text = (message or "Analyze") + (f"\n[CSV: {_uploaded['path']}]" if _uploaded["path"] else "\n[No CSV]")
chat_history.append({"role": "user", "content": message or "Analyze"})
chat_history.append({"role": "assistant", "content": "πŸ”¬ **Working...**"})
yield chat_history, "", _latest_output()
res = agent.invoke({"messages": [("human", text)]}, config={"configurable": {"thread_id": "session"}})
chat_history[-1] = {"role": "assistant", "content": res["messages"][-1].content}
yield chat_history, "", _latest_output()
def _load_chart(name):
if not name or not os.path.exists(os.path.join(OUTPUT_DIR, name)): return None
return pio.from_json(open(os.path.join(OUTPUT_DIR, name)).read())
def _get_chart_choices():
return [os.path.basename(f) for f in sorted(glob.glob(f"{OUTPUT_DIR}/rq4_*.json"))]
def _load_review_table():
ps = sorted(glob.glob(f"{CHECKPOINT_DIR}/rq4_*.json"))
if not ps: return [[0, "No data", "", 0, 0, False, "", ""]]
data = json.load(open(ps[-1]))
return [[i, d.get("label", d.get("top_words", ""))[:60], d.get("nearest", [{}])[0].get("sentence", "")[:120], d.get("sentence_count", 0), d.get("paper_count", 0), True, "", ""] for i, d in enumerate(data)]
def _show_papers_by_select(table_data, evt: gr.SelectData):
idx = int(table_data.iloc[evt.index[0], 0]) if hasattr(table_data, 'iloc') else int(table_data[evt.index[0]][0])
fs = sorted(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_labels.json")) or sorted(glob.glob(f"{CHECKPOINT_DIR}/rq4_*_summaries.json"))
for f in fs:
for t in json.load(open(f)):
if t.get("topic_id") == idx:
return f"Topic {idx}: {t.get('label', '')}\n\n" + "\n".join(f"- {p}" for p in t.get("paper_titles", []))
return "Not found"
def _submit_review(table_data, chat_history):
ls = [f"Topic {int(r[0])}: {'RENAME to '+r[6] if r[6] else ('APPROVE' if r[5] else 'REJECT')}" for r in table_data.values.tolist()]
msg = "Review decisions:\n" + "\n".join(ls)
chat_history.append({"role": "user", "content": "Submitted review"})
chat_history.append({"role": "assistant", "content": "πŸ”¬ **Processing...**"})
yield chat_history, _latest_output(), gr.update(), gr.update(), _build_progress()
res = agent.invoke({"messages": [("human", msg)]}, config={"configurable": {"thread_id": "session"}})
chat_history[-1] = {"role": "assistant", "content": res["messages"][-1].content}
yield chat_history, _latest_output(), gr.update(choices=_get_chart_choices()), _load_review_table(), _build_progress()
CSS = """
.gradio-container { background: #0b0f19 !important; color: #f8fafc !important; }
.sidebar { background: #111827 !important; border-right: 1px solid #1f2937 !important; }
.header-text { font-family: 'Outfit', sans-serif; color: #ffffff !important; letter-spacing: -0.02em; }
.tab-nav { border-bottom: 1px solid #1f2937 !important; background: transparent !important; }
.chatbot-container { border-radius: 12px !important; border: 1px solid #1f2937 !important; overflow: hidden; }
.primary-btn { background: #4f46e5 !important; color: #ffffff !important; border-radius: 8px !important; font-weight: 600 !important; }
.secondary-btn { background: #1f2937 !important; color: #f8fafc !important; border: 1px solid #374151 !important; border-radius: 8px !important; }
body, .gr-form, .gr-input, .gr-button, p, span, h1, h2, h3, h4, h5, h6, label, .gr-markdown {
color: #f8fafc !important;
}
.primary-btn span, .primary-btn {
color: #ffffff !important;
}
.sidebar span, .sidebar p, .sidebar h2, .sidebar label {
color: #f8fafc !important;
}
/* Ensure inputs are dark but readable */
input, textarea, select {
background-color: #1f2937 !important;
color: #f8fafc !important;
border: 1px solid #374151 !important;
}
"""
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="violet",
neutral_hue="slate",
font=gr.themes.GoogleFont("Outfit"),
font_mono=gr.themes.GoogleFont("JetBrains Mono"),
).set(
body_background_fill="#0b0f19",
block_background_fill="#111827",
block_title_text_weight="700",
button_primary_background_fill="*primary_600",
button_primary_text_color="white",
body_text_color="#f8fafc",
block_label_text_color="#94a3b8",
)
with gr.Blocks(title="Thematic Analysis AI") as demo:
with gr.Sidebar(label="Data Hub", open=True):
gr.HTML("<h2 class='header-text'>πŸ“ Resource Center</h2>")
upload = gr.File(label="Dataset (Scopus CSV)", file_types=[".csv"], elem_id="file-upload")
progress = gr.Markdown(value=_build_progress(), elem_id="progress-display")
gr.HTML("<hr>")
gr.Markdown("### πŸ› οΈ Configuration\nModel: `mistral-small-latest`\nPipeline: `BERTopic + Agglomerative`")
gr.HTML("<h1 class='header-text' style='margin-bottom: 20px;'>πŸ”¬ Topic Modelling Agentic AI</h1>")
with gr.Tabs():
with gr.Tab("πŸ’¬ Agent Chat"):
chatbot = gr.Chatbot(height=450, show_label=False, elem_classes="chatbot-container")
with gr.Row():
msg = gr.Textbox(placeholder="Ask the agent to analyze, group, or export...", show_label=False, scale=9)
send = gr.Button("Send", variant="primary", scale=1, elem_classes="primary-btn")
with gr.Tab("πŸ“‹ Review & Refine"):
gr.Markdown("### πŸ” Topic Validation Table\nReview the identified themes and rename or reject as needed.")
table = gr.Dataframe(headers=["#", "Label", "Key Evidence", "Sents", "Papers", "Approve", "Rename", "Reasoning"], datatype=["number", "str", "str", "number", "number", "bool", "str", "str"], interactive=True)
with gr.Row():
submit = gr.Button("Submit Review Decisions", variant="primary", scale=2, elem_classes="primary-btn")
clear = gr.Button("Refresh Table", variant="secondary", scale=1, elem_classes="secondary-btn")
papers = gr.Textbox(label="Full Context: Papers in Selected Topic", lines=6, interactive=False)
with gr.Tab("πŸ“Š Visual Analytics"):
gr.Markdown("### πŸ“ˆ Interactive Topic Visualizations")
with gr.Row():
selector = gr.Dropdown(choices=[], label="Select Visualization Type", scale=7)
refresh_viz = gr.Button("Refresh Charts", variant="secondary", scale=1)
display = gr.Plot()
with gr.Tab("πŸ“₯ Export Control"):
gr.Markdown("### πŸ’Ύ Final Outputs\nDownload generated papers, narratives, and comparison matrices.")
download = gr.File(label="Available Exports", file_count="multiple")
def respond_with_viz(m, h, u):
g = respond(m, h, u)
for hist, _, dl in g:
cs = _get_chart_choices()
yield hist, "", dl, gr.update(choices=cs, value=cs[-1] if cs else None), _load_chart(cs[-1]) if cs else None, _load_review_table(), _build_progress()
def upload_handler(f, h):
yield from respond_with_viz("Analyze CSV", h, f)
msg.submit(respond_with_viz, [msg, chatbot, upload], [chatbot, msg, download, selector, display, table, progress])
send.click(respond_with_viz, [msg, chatbot, upload], [chatbot, msg, download, selector, display, table, progress])
selector.change(_load_chart, [selector], [display])
table.select(_show_papers_by_select, [table], [papers])
submit.click(_submit_review, [table, chatbot], [chatbot, download, selector, table, progress])
upload.change(upload_handler, [upload, chatbot], [chatbot, msg, download, selector, display, table, progress])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, theme=theme, css=CSS)