import json import uuid import gradio as gr from config import settings from ui.callbacks import ( new_thread_id, run_pipeline, resume_after_hitl, submit_feedback, update_settings, ) # ── Constants ───────────────────────────────────────────────────────────── LATEX_DELIMITERS = [ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\[", "right": "\\]", "display": True}, ] AGENT_STEPS = [ ("extract", "Extract"), ("guardrail", "Guardrail"), ("parse", "Parse"), ("route", "Route"), ("retrieve_context", "RAG"), ("retrieve_memory", "Memory"), ("solve", "Solve"), ("verify", "Verify"), ("explain", "Explain"), ("save_memory", "Save"), ] CSS = """\ /* ── Global ── */ .gradio-container { max-width: 100% !important; padding: 0 !important; background: #0f0f1a !important; font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; } /* ── Header ── */ .app-header { background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%); padding: 28px 40px; border-bottom: 1px solid rgba(255,255,255,0.06); margin-bottom: 0; } .app-header h1 { margin: 0 !important; font-size: 2em !important; font-weight: 700 !important; background: linear-gradient(135deg, #00d2ff, #7b2ff7); -webkit-background-clip: text; -webkit-text-fill-color: transparent; letter-spacing: -0.5px; } .app-header p { margin: 6px 0 0 !important; color: #8892b0 !important; font-size: 0.95em !important; } /* ── Main layout ── */ .main-container { padding: 24px 32px !important; gap: 24px !important; } .input-panel { background: #1a1a2e !important; border: 1px solid rgba(255,255,255,0.06) !important; border-radius: 16px !important; padding: 24px !important; min-width: 360px !important; max-width: 400px !important; } .output-panel { background: #1a1a2e !important; border: 1px solid rgba(255,255,255,0.06) !important; border-radius: 16px !important; padding: 24px !important; } /* ── Cards ── */ .card { background: rgba(255,255,255,0.03) !important; border: 1px solid rgba(255,255,255,0.06) !important; border-radius: 12px !important; padding: 16px !important; margin-bottom: 12px !important; } /* ── Pipeline progress ── */ .pipeline-bar { display: flex; gap: 4px; padding: 12px 16px; background: rgba(255,255,255,0.02); border-radius: 12px; border: 1px solid rgba(255,255,255,0.05); overflow-x: auto; margin-bottom: 16px; } /* ── Buttons ── */ .solve-btn { background: linear-gradient(135deg, #7b2ff7, #00d2ff) !important; border: none !important; color: white !important; font-weight: 600 !important; font-size: 1.05em !important; padding: 12px 24px !important; border-radius: 12px !important; transition: all 0.3s ease !important; text-transform: uppercase !important; letter-spacing: 0.5px !important; } .solve-btn:hover { transform: translateY(-1px) !important; box-shadow: 0 8px 25px rgba(123, 47, 247, 0.35) !important; } .new-btn { background: rgba(255,255,255,0.05) !important; border: 1px solid rgba(255,255,255,0.1) !important; color: #8892b0 !important; border-radius: 10px !important; } /* ── Tabs ── */ .output-tabs .tab-nav { background: rgba(255,255,255,0.02) !important; border-radius: 12px !important; padding: 4px !important; border: 1px solid rgba(255,255,255,0.05) !important; } .output-tabs .tab-nav button { border-radius: 8px !important; font-weight: 500 !important; font-size: 0.88em !important; padding: 8px 14px !important; transition: all 0.2s ease !important; } .output-tabs .tab-nav button.selected { background: linear-gradient(135deg, #7b2ff7, #00d2ff) !important; color: white !important; } /* ── Solution/Explanation content ── */ .solution-content { font-size: 1.05em !important; line-height: 1.8 !important; padding: 20px !important; min-height: 250px !important; color: #e6e6e6 !important; } .solution-content p, .solution-content li { margin-bottom: 8px !important; } /* ── Status badge ── */ .status-badge { min-height: 36px !important; } .status-badge .prose { padding: 8px 16px !important; border-radius: 20px !important; background: rgba(123, 47, 247, 0.1) !important; border: 1px solid rgba(123, 47, 247, 0.2) !important; display: inline-block !important; font-size: 0.9em !important; } /* ── Confidence indicator ── */ .confidence-box textarea { text-align: center !important; font-weight: 700 !important; font-size: 1.1em !important; border-radius: 10px !important; background: rgba(0, 210, 255, 0.08) !important; border: 1px solid rgba(0, 210, 255, 0.2) !important; color: #00d2ff !important; } /* ── Trace display ── */ .trace-content { font-family: 'JetBrains Mono', 'Fira Code', monospace !important; font-size: 0.88em !important; line-height: 1.6 !important; padding: 16px !important; } /* ── Settings accordion ── */ .settings-accordion { border: 1px solid rgba(255,255,255,0.06) !important; border-radius: 12px !important; background: rgba(255,255,255,0.02) !important; margin: 0 32px 0 32px !important; } /* ── HITL panel ── */ .hitl-panel { background: rgba(255, 165, 0, 0.05) !important; border: 1px solid rgba(255, 165, 0, 0.2) !important; border-radius: 12px !important; padding: 16px !important; } /* ── Feedback section ── */ .feedback-section { border-top: 1px solid rgba(255,255,255,0.06); padding-top: 16px; margin-top: 16px; } .feedback-btn-correct { background: rgba(0, 200, 83, 0.1) !important; border: 1px solid rgba(0, 200, 83, 0.3) !important; color: #00c853 !important; border-radius: 8px !important; } .feedback-btn-incorrect { background: rgba(255, 82, 82, 0.1) !important; border: 1px solid rgba(255, 82, 82, 0.3) !important; color: #ff5252 !important; border-radius: 8px !important; } /* ── Dataframe styling ── */ .context-table, .memory-table { border-radius: 10px !important; overflow: hidden !important; } /* ── Input radio ── */ .input-mode-radio label { border-radius: 8px !important; padding: 8px 16px !important; } /* ── Extracted text ── */ .extracted-text textarea { background: rgba(255,255,255,0.03) !important; border: 1px solid rgba(255,255,255,0.08) !important; border-radius: 10px !important; font-size: 0.92em !important; } /* ── Section labels ── */ .section-label { font-size: 0.78em !important; text-transform: uppercase !important; letter-spacing: 1.2px !important; color: #5a6080 !important; margin-bottom: 8px !important; font-weight: 600 !important; } .section-label p { margin: 0 !important; } /* ── Hide borders on dark theme ── */ .dark .block { border: none !important; background: transparent !important; } """ # ── Helpers ──────────────────────────────────────────────────────────────── def _format_trace(trace: list[dict]) -> str: if not trace: return "*Waiting for agent activity...*" AGENT_ICONS = { "extractor": "📥", "guardrail": "🛡️", "parser": "📝", "router": "🔀", "retriever": "📚", "memory_retriever": "🧠", "solver": "🧮", "verifier": "✅", "explainer": "💡", "memory_saver": "💾", } lines = [] for i, t in enumerate(trace): agent = t.get("agent", "?") action = t.get("action", "?") summary = t.get("summary", "") ts = t.get("timestamp", "")[-8:] icon = AGENT_ICONS.get(agent, "⚙️") step_num = i + 1 lines.append(f"**{step_num}.** {icon} `{agent}` → **{action}** — {summary} \n{ts}") return "\n\n".join(lines) def _format_chunks(chunks: list[dict]) -> list[list]: rows = [] for c in chunks: rows.append([c.get("source", ""), c.get("text", "")[:200], c.get("rrf_score", "")]) return rows or [["—", "No context retrieved", ""]] def _format_similar(similar: list[dict]) -> list[list]: rows = [] for s in similar: rows.append([ s.get("extracted_text", "")[:100], s.get("solution", "")[:100], s.get("similarity", ""), ]) return rows or [["—", "No similar past problems", ""]] def _format_chat(chat_history: list) -> list[dict]: """Convert internal chat history to Gradio Chatbot messages format.""" messages = [] for msg in chat_history: if isinstance(msg, dict): role = msg.get("role", "user") content = msg.get("content", "") else: role = "assistant" if getattr(msg, "type", "") == "ai" else "user" content = msg.content if hasattr(msg, "content") else str(msg) messages.append({"role": role, "content": content}) return messages def _build_pipeline_html(active_node: str = "", completed: set = None) -> str: """Build HTML for the pipeline progress bar.""" completed = completed or set() steps_html = [] for node_id, label in AGENT_STEPS: if node_id == active_node: color = "#7b2ff7" bg = "rgba(123, 47, 247, 0.2)" border = "rgba(123, 47, 247, 0.5)" text_color = "#c4a1ff" icon = "⏳" elif node_id in completed: color = "#00c853" bg = "rgba(0, 200, 83, 0.12)" border = "rgba(0, 200, 83, 0.3)" text_color = "#69f0ae" icon = "✓" else: color = "#3a3f58" bg = "rgba(255,255,255,0.02)" border = "rgba(255,255,255,0.06)" text_color = "#5a6080" icon = "" steps_html.append( f'
' f'{icon} {label}
' ) return f'
{" ".join(steps_html)}
' # ── Main solve function ────────────────────────────────────────────────── def solve(input_text, input_image, input_audio, input_mode, thread_state, chat_state): thread_id = thread_state or new_thread_id() chat_history = chat_state or [] final_state = {} trace = [] completed_nodes = set() for update in run_pipeline(input_text, input_image, input_audio, input_mode, thread_id, chat_history): node = update["node"] output = update["output"] if node == "error": yield ( f"**Error:** {output.get('error', 'Unknown error')}", _build_pipeline_html(), "", 0, _format_chunks([]), _format_similar([]), "", "", None, thread_id, chat_history, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", "", _format_chat(chat_history), ) return for k, v in output.items(): final_state[k] = v trace = final_state.get("agent_trace", trace) completed_nodes.add(node) # Check HITL interrupt if final_state.get("needs_human_review") and node in ("extract", "parse", "verify"): extracted = final_state.get("extracted_text", "") confidence = final_state.get("extraction_confidence", 0) hitl_type = "extraction" if node == "extract" else ("clarification" if node == "parse" else "verification") yield ( f"**Review Required** — {hitl_type.title()}", _build_pipeline_html(node, completed_nodes), extracted, confidence, _format_chunks(final_state.get("retrieved_chunks", [])), _format_similar(final_state.get("similar_past_problems", [])), "", "", None, thread_id, chat_history, gr.update(visible=True), gr.update(visible=True, value=extracted), gr.update(visible=True), gr.update(visible=True), _format_trace(trace), hitl_type, _format_chat(chat_history), ) return # Check guardrail rejection if final_state.get("is_valid_input") is False: reason = final_state.get("rejection_reason", "Input rejected by guardrail.") yield ( f"**Rejected:** {reason}", _build_pipeline_html(node, completed_nodes), "", 0, _format_chunks([]), _format_similar([]), "", "", None, thread_id, chat_history, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), _format_trace(trace), "", _format_chat(chat_history), ) return # Streaming update extracted = final_state.get("extracted_text", "") confidence = final_state.get("extraction_confidence", 0) solution = final_state.get("solution", "") explanation = final_state.get("explanation", "") diagram = final_state.get("diagram_path") or None final_conf = final_state.get("final_confidence", 0) # Find next pending node for progress display next_node = "" for nid, _ in AGENT_STEPS: if nid not in completed_nodes: next_node = nid break yield ( f"Processing...", _build_pipeline_html(next_node, completed_nodes), extracted, confidence, _format_chunks(final_state.get("retrieved_chunks", [])), _format_similar(final_state.get("similar_past_problems", [])), solution, explanation, diagram, thread_id, chat_history, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), _format_trace(trace), "", _format_chat(chat_history), ) # Final output extracted = final_state.get("extracted_text", "") confidence = final_state.get("extraction_confidence", 0) solution = final_state.get("solution", "") explanation = final_state.get("explanation", "") diagram = final_state.get("diagram_path") or None final_conf = final_state.get("final_confidence", 0) chat_history.append({"role": "user", "content": input_text or "[image/audio input]"}) chat_history.append({"role": "assistant", "content": explanation or solution}) conf_text = f"Confidence: {final_conf:.0%}" if final_conf else "" yield ( f"**Solved!** {conf_text}", _build_pipeline_html("", completed_nodes), extracted, confidence, _format_chunks(final_state.get("retrieved_chunks", [])), _format_similar(final_state.get("similar_past_problems", [])), solution, explanation, diagram, thread_id, chat_history, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), _format_trace(final_state.get("agent_trace", [])), "", _format_chat(chat_history), ) def handle_hitl_approve(hitl_text, thread_state, chat_state, hitl_type_state): thread_id = thread_state chat_history = chat_state or [] final_state = {} completed_nodes = set() for update in resume_after_hitl(thread_id, human_text=hitl_text, approved=True): node = update["node"] output = update["output"] if node == "error": yield ( f"**Error:** {output.get('error', 'Unknown error')}", _build_pipeline_html(), "", 0, _format_chunks([]), _format_similar([]), "", "", None, thread_id, chat_history, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", "", _format_chat(chat_history), ) return for k, v in output.items(): final_state[k] = v completed_nodes.add(node) extracted = final_state.get("extracted_text", "") confidence = final_state.get("extraction_confidence", 0) solution = final_state.get("solution", "") explanation = final_state.get("explanation", "") diagram = final_state.get("diagram_path") or None next_node = "" for nid, _ in AGENT_STEPS: if nid not in completed_nodes: next_node = nid break yield ( f"Resumed — processing...", _build_pipeline_html(next_node, completed_nodes), extracted, confidence, _format_chunks(final_state.get("retrieved_chunks", [])), _format_similar(final_state.get("similar_past_problems", [])), solution, explanation, diagram, thread_id, chat_history, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), _format_trace(final_state.get("agent_trace", [])), "", _format_chat(chat_history), ) # Final solution = final_state.get("solution", "") explanation = final_state.get("explanation", "") final_conf = final_state.get("final_confidence", 0) conf_text = f"Confidence: {final_conf:.0%}" if final_conf else "" chat_history.append({"role": "user", "content": "[resumed after review]"}) chat_history.append({"role": "assistant", "content": explanation or solution}) yield ( f"**Solved!** {conf_text}", _build_pipeline_html("", completed_nodes), final_state.get("extracted_text", ""), final_state.get("extraction_confidence", 0), _format_chunks(final_state.get("retrieved_chunks", [])), _format_similar(final_state.get("similar_past_problems", [])), solution, explanation, final_state.get("diagram_path") or None, thread_id, chat_history, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), _format_trace(final_state.get("agent_trace", [])), "", _format_chat(chat_history), ) def handle_feedback(feedback_type, comment, thread_state): result = submit_feedback(thread_state, feedback_type, comment) return result # ── Build UI ───────────────────────────────────────────────────────────── THEME = gr.themes.Base( primary_hue=gr.themes.colors.purple, secondary_hue=gr.themes.colors.cyan, neutral_hue=gr.themes.colors.slate, font=gr.themes.GoogleFont("Inter"), ).set( body_background_fill="#0f0f1a", body_background_fill_dark="#0f0f1a", block_background_fill="#1a1a2e", block_background_fill_dark="#1a1a2e", block_border_width="0px", input_background_fill="#12122a", input_background_fill_dark="#12122a", button_primary_background_fill="linear-gradient(135deg, #7b2ff7, #00d2ff)", button_primary_text_color="white", ) def build_ui(): blocks_kw = {"title": "Math Mentor"} if int(gr.__version__.split(".")[0]) < 6: blocks_kw.update(theme=THEME, css=CSS) with gr.Blocks(**blocks_kw) as demo: # ── Header ── with gr.Row(elem_classes="app-header"): gr.HTML( '
' '

Math Mentor

' '

AI-Powered JEE Math Solver — Multi-Agent RAG Pipeline with Human-in-the-Loop

' '
' ) # Hidden state thread_state = gr.State(value=new_thread_id()) chat_state = gr.State(value=[]) hitl_type_state = gr.State(value="") # ── Settings ── with gr.Accordion( "Settings" + (" — NOT CONFIGURED" if not settings.is_llm_configured else ""), open=not settings.is_llm_configured, elem_classes="settings-accordion", ): gr.Markdown( "Configure your LLM endpoint. Supports any OpenAI-compatible API " "(Ollama, Together, OpenRouter, etc.)", elem_classes="section-label", ) with gr.Row(): settings_url = gr.Textbox( label="Base URL", placeholder="e.g. http://localhost:11434/v1", value="", scale=3, ) settings_model = gr.Textbox( label="Model", placeholder="e.g. llama3, gpt-4o", value="", scale=2, ) settings_key = gr.Textbox( label="API Key", placeholder="Leave empty if not needed", value="", type="password", scale=2, ) settings_btn = gr.Button("Save", scale=1, variant="secondary") settings_status = gr.Textbox(label="Status", interactive=False, visible=True) settings_btn.click( update_settings, inputs=[settings_url, settings_model, settings_key], outputs=[settings_status], ) # ── Pipeline progress bar ── with gr.Row(visible=True): pipeline_progress = gr.HTML( value=_build_pipeline_html(), elem_classes="pipeline-bar", ) # ── Main layout ── with gr.Row(elem_classes="main-container", equal_height=False): # ── Left: Input panel ── with gr.Column(scale=1, min_width=360, elem_classes="input-panel"): gr.Markdown("INPUT", elem_classes="section-label") input_mode = gr.Radio( ["Text", "Image", "Audio"], label="Input Mode", value="Text", elem_classes="input-mode-radio", ) text_input = gr.Textbox( label="Math Problem", placeholder="e.g. Find the derivative of x^3 + 2x^2 - 5x + 1\n\nOr ask a follow-up question...", lines=4, visible=True, ) image_input = gr.Image( label="Upload Problem Image", type="filepath", visible=False, ) audio_input = gr.Audio( label="Record or Upload Audio", type="filepath", visible=False, ) gr.Examples( examples=[ ["Solve x^2 - 5x + 6 = 0"], ["Find the derivative of sin(x) * e^x"], ["Evaluate the integral of x^2 from 0 to 3"], ["If P(A) = 0.6 and P(B|A) = 0.3, find P(A and B)"], ["Find eigenvalues of [[2,1],[1,2]]"], ["Plot y = x^3 - 3x + 1"], ], inputs=[text_input], label="Try an example", ) with gr.Row(): solve_btn = gr.Button( "Solve", variant="primary", size="lg", elem_classes="solve-btn", scale=3, ) new_btn = gr.Button( "New", size="lg", elem_classes="new-btn", scale=1, ) # Extracted text display extracted_text = gr.Textbox( label="Extracted Text", interactive=False, lines=2, visible=True, elem_classes="extracted-text", ) confidence_display = gr.Number( label="OCR/ASR Confidence", precision=3, visible=True, ) # ── HITL panel ── with gr.Group(elem_classes="hitl-panel", visible=False) as hitl_group: gr.Markdown("HUMAN REVIEW REQUIRED", elem_classes="section-label") hitl_textbox = gr.Textbox( label="Review & Edit", lines=3, visible=False, interactive=True, ) with gr.Row(): hitl_approve = gr.Button("Approve", visible=False, variant="primary") hitl_reject = gr.Button("Reject", visible=False, variant="stop") # ── Feedback ── with gr.Group(elem_classes="feedback-section"): gr.Markdown("FEEDBACK", elem_classes="section-label") with gr.Row(): correct_btn = gr.Button( "Correct", size="sm", elem_classes="feedback-btn-correct", ) incorrect_btn = gr.Button( "Incorrect", size="sm", elem_classes="feedback-btn-incorrect", ) feedback_comment = gr.Textbox( label="Comment (optional)", lines=1, ) feedback_status = gr.Textbox(label="", interactive=False) # ── Right: Output panel ── with gr.Column(scale=4, elem_classes="output-panel"): # Status row with gr.Row(): status_display = gr.Markdown( "Ready to solve.", elem_classes="status-badge", ) # Output tabs with gr.Tabs(elem_classes="output-tabs"): with gr.Tab("Solution"): solution_display = gr.Markdown( label="Solution", latex_delimiters=LATEX_DELIMITERS, elem_classes="solution-content", ) with gr.Tab("Explanation"): explanation_display = gr.Markdown( label="Explanation", latex_delimiters=LATEX_DELIMITERS, elem_classes="solution-content", ) with gr.Tab("Diagram"): diagram_display = gr.Image(label="Diagram", visible=True) with gr.Tab("Agent Trace"): trace_display = gr.Markdown( "*Waiting for agent activity...*", latex_delimiters=LATEX_DELIMITERS, elem_classes="trace-content", ) with gr.Tab("Retrieved Context"): context_table = gr.Dataframe( headers=["Source", "Content", "Score"], label="Retrieved Chunks", elem_classes="context-table", ) with gr.Tab("Memory"): memory_table = gr.Dataframe( headers=["Question", "Answer", "Similarity"], label="Similar Past Problems", elem_classes="memory-table", ) with gr.Tab("Chat History"): chatbot_kwargs = {"label": "Conversation", "height": 400} if int(gr.__version__.split(".")[0]) < 6: chatbot_kwargs["type"] = "messages" chat_display = gr.Chatbot(**chatbot_kwargs) # ── Event handlers ── def toggle_inputs(mode): return ( gr.update(visible=mode == "Text"), gr.update(visible=mode == "Image"), gr.update(visible=mode == "Audio"), ) input_mode.change( toggle_inputs, inputs=[input_mode], outputs=[text_input, image_input, audio_input], ) # Solve solve_btn.click( solve, inputs=[text_input, image_input, audio_input, input_mode, thread_state, chat_state], outputs=[ status_display, pipeline_progress, extracted_text, confidence_display, context_table, memory_table, solution_display, explanation_display, diagram_display, thread_state, chat_state, hitl_group, hitl_textbox, hitl_approve, hitl_reject, trace_display, hitl_type_state, chat_display, ], ) # New problem def reset(): return ( new_thread_id(), [], "Ready to solve.", _build_pipeline_html(), "", 0, [["—", "No context retrieved", ""]], [["—", "No similar past problems", ""]], "", "", None, "*Waiting for agent activity...*", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), [], ) new_btn.click( reset, outputs=[ thread_state, chat_state, status_display, pipeline_progress, extracted_text, confidence_display, context_table, memory_table, solution_display, explanation_display, diagram_display, trace_display, hitl_group, hitl_textbox, hitl_approve, hitl_reject, chat_display, ], ) # HITL approve hitl_approve.click( handle_hitl_approve, inputs=[hitl_textbox, thread_state, chat_state, hitl_type_state], outputs=[ status_display, pipeline_progress, extracted_text, confidence_display, context_table, memory_table, solution_display, explanation_display, diagram_display, thread_state, chat_state, hitl_group, hitl_textbox, hitl_approve, hitl_reject, trace_display, hitl_type_state, chat_display, ], ) # Feedback correct_btn.click( lambda c, t: handle_feedback("correct", c, t), inputs=[feedback_comment, thread_state], outputs=[feedback_status], ) incorrect_btn.click( lambda c, t: handle_feedback("incorrect", c, t), inputs=[feedback_comment, thread_state], outputs=[feedback_status], ) return demo # ── Entry point ────────────────────────────────────────────────────────── if __name__ == "__main__": demo = build_ui() launch_kw = {"server_name": "0.0.0.0", "server_port": 7860} if int(gr.__version__.split(".")[0]) >= 6: launch_kw.update(theme=THEME, css=CSS) demo.launch(**launch_kw)