| """ |
| OncoAgent β Conversational Clinical Copilot (UI Module). |
| |
| Provides a ChatGPT-style Gradio interface for the multi-agent |
| oncological triage system. Uses LangGraph streaming to show |
| real-time agent progress and prevent UI freezing. |
| """ |
|
|
| import os |
| import time |
| import random |
| import logging |
| import gradio as gr |
| from typing import Dict, Any, List, Tuple, Optional, Generator |
| from dotenv import load_dotenv |
|
|
| from agents.graph import build_oncoagent_graph |
| from ui.styles import CSS, FONTS_LINK |
|
|
| load_dotenv() |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| ICONS: Dict[str, str] = { |
| "check": '<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="#10b981" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round"><polyline points="20 6 9 17 4 12"/></svg>', |
| "alert": '<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="#f87171" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="15" y1="9" x2="9" y2="15"/><line x1="9" y1="9" x2="15" y2="15"/></svg>', |
| } |
|
|
| |
| NODE_LABELS: Dict[str, str] = { |
| "router": "Routing case", |
| "ingestion": "Extracting clinical entities", |
| "corrective_rag": "Retrieving NCCN/ESMO guidelines", |
| "specialist": "Generating clinical recommendation", |
| "critic": "Validating medical safety", |
| "hitl_gate": "Assessing acuity level", |
| "formatter": "Formatting final report", |
| "fallback": "Generating safe fallback response", |
| } |
|
|
| |
| |
| |
| agent_graph = build_oncoagent_graph() |
|
|
|
|
| def generate_patient_id() -> str: |
| """Generate a randomized patient session identifier.""" |
| return f"PT-{random.randint(1000, 9999)}" |
|
|
|
|
| |
| |
| |
| def stream_triage( |
| clinical_text: str, |
| patient_id: str, |
| tier_override: str, |
| ) -> Generator[Tuple[str, str, Dict[str, str]], None, None]: |
| """Stream through LangGraph nodes, yielding progress and final result. |
| |
| Args: |
| clinical_text: Raw clinical notes from the user. |
| patient_id: Session identifier for memory isolation. |
| tier_override: Model tier selection (auto / 9b / 27b). |
| |
| Yields: |
| Tuples of (node_name, progress_markdown, partial_state). |
| """ |
| if not clinical_text.strip(): |
| yield ("done", "Please enter a clinical case.", {}) |
| return |
| if not patient_id.strip(): |
| patient_id = "PT-UNKNOWN" |
|
|
| input_state: Dict[str, Any] = { |
| "clinical_text": clinical_text, |
| "messages": [("user", clinical_text)], |
| "manual_override": tier_override if tier_override != "auto" else None, |
| "errors": [], |
| } |
| config: Dict[str, Any] = { |
| "configurable": {"thread_id": patient_id}, |
| } |
|
|
| accumulated_state: Dict[str, Any] = {} |
|
|
| try: |
| for event in agent_graph.stream( |
| input_state, config=config, stream_mode="updates" |
| ): |
| for node_name, node_output in event.items(): |
| label = NODE_LABELS.get(node_name, node_name) |
| yield (node_name, f"**{label}**...", node_output) |
| if isinstance(node_output, dict): |
| accumulated_state.update(node_output) |
| except Exception as e: |
| logger.error("Graph streaming error: %s", e, exc_info=True) |
| yield ("error", f"Error: {str(e)}", {}) |
| return |
|
|
| yield ("done", "Complete", accumulated_state) |
|
|
|
|
| def format_final_response(state: Dict[str, Any]) -> str: |
| """Format the accumulated state into a readable clinical response.""" |
| recommendation: str = state.get( |
| "formatted_recommendation", |
| state.get("clinical_recommendation", "No recommendation generated."), |
| ) |
| safety_status: str = state.get("safety_status", "Unknown") |
| is_safe: bool = state.get("is_safe", False) |
| critic_feedback = state.get("critic_feedback", []) |
|
|
| if is_safe: |
| badge = f"<span class='badge-safe'>{ICONS['check']} Clinically Safe</span>" |
| else: |
| badge = f"<span class='badge-unsafe'>{ICONS['alert']} Review Required</span>" |
|
|
| md = f"### Decision Status: {badge}\n\n" |
| md += f"{recommendation}\n\n---\n" |
| md += f"**Safety Audit:** {safety_status}\n" |
|
|
| if critic_feedback: |
| if isinstance(critic_feedback, list): |
| items = critic_feedback |
| else: |
| items = [str(critic_feedback)] |
| md += "\n<div class='critic-card'><strong>Critic Iterations:</strong><br/>" |
| md += "<br/>".join([f"β {fb}" for fb in items]) |
| md += "</div>" |
|
|
| return md |
|
|
|
|
| def extract_evidence(state: Dict[str, Any]) -> Tuple[str, str, str]: |
| """Extract evidence tabs content from state.""" |
| sources: List[str] = state.get("rag_sources", []) |
| graph_ctx: List[str] = state.get("graph_rag_context", []) |
| api_ctx: List[str] = state.get("api_evidence_context", []) |
|
|
| sources_md = ( |
| "### Medical Guidelines (NCCN / ESMO)\n\n" + "\n".join(sources) |
| if sources |
| else "No guideline sources retrieved." |
| ) |
| graph_md = ( |
| "### Clinical Knowledge Graph\n\n" |
| + "\n".join([f"- {item}" for item in graph_ctx]) |
| if graph_ctx |
| else "No graph relations extracted." |
| ) |
| api_md = ( |
| "### Real-Time Evidence (CIViC & ClinicalTrials)\n\n" |
| + "\n".join([f"- {item}" for item in api_ctx]) |
| if api_ctx |
| else "No real-time API evidence found." |
| ) |
| return sources_md, graph_md, api_md |
|
|
|
|
| |
| |
| |
| theme = gr.themes.Soft( |
| primary_hue="sky", |
| secondary_hue="slate", |
| neutral_hue="slate", |
| font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"], |
| ) |
|
|
| |
| |
| |
| with gr.Blocks(title="OncoAgent β Clinical Triage") as demo: |
|
|
| |
| gr.HTML(FONTS_LINK) |
| gr.HTML( |
| "<div class='header-bar'>" |
| "<span class='brand-name'>OncoAgent</span>" |
| "<span class='hw-badge'>AMD Instinct MI300X</span>" |
| "</div>" |
| ) |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1, min_width=280, elem_classes="sidebar-column"): |
| |
| with gr.Column(elem_classes="card"): |
| gr.HTML("<div class='section-title'>Session</div>") |
| patient_id_input = gr.Textbox( |
| label="Patient ID", |
| value=generate_patient_id, |
| interactive=True, |
| info="Unique session for memory isolation", |
| ) |
| tier_override_input = gr.Dropdown( |
| label="Model Tier", |
| choices=["auto", "9b", "27b"], |
| value="auto", |
| info="Auto-routes based on case complexity", |
| ) |
| new_session_btn = gr.Button("β» New Session", variant="secondary", size="sm") |
|
|
| |
| with gr.Row(): |
| with gr.Column(elem_classes="kpi-tile", min_width=100): |
| gr.HTML( |
| "<div class='kpi-label'>Confidence</div>" |
| "<div class='kpi-value' id='kpi-confidence'>β</div>" |
| ) |
| confidence_val = gr.Label(label="Confidence", visible=False) |
| with gr.Column(elem_classes="kpi-tile", min_width=100): |
| gr.HTML( |
| "<div class='kpi-label'>Sources</div>" |
| "<div class='kpi-value' id='kpi-sources'>β</div>" |
| ) |
| sources_val = gr.Label(label="Sources", visible=False) |
|
|
| |
| with gr.Tabs(elem_classes="card"): |
| with gr.Tab("Guidelines"): |
| output_sources = gr.Markdown( |
| "NCCN and ESMO guideline evidence will appear here." |
| ) |
| with gr.Tab("Knowledge Graph"): |
| output_graph = gr.Markdown( |
| "Knowledge graph connections will appear here." |
| ) |
| with gr.Tab("API Evidence"): |
| output_api = gr.Markdown( |
| "Real-time data from CIViC and ClinicalTrials.gov." |
| ) |
|
|
| |
| with gr.Column(elem_classes="card"): |
| gr.HTML("<div class='section-title'>System Status</div>") |
| status_box = gr.Markdown( |
| "<div class='status-bar'>System ready.</div>", |
| elem_id="status-box", |
| ) |
|
|
| |
| with gr.Column(scale=3): |
| with gr.Column(elem_classes="card", min_width=600): |
| chatbot = gr.Chatbot( |
| label="OncoAgent", |
| show_label=False, |
| elem_classes="gr-chatbot", |
| height=620, |
| ) |
| case_input = gr.Textbox( |
| placeholder="Describe the clinical case or ask a follow-up question...", |
| show_label=False, |
| container=False, |
| submit_btn="β", |
| elem_classes="chat-input-integrated" |
| ) |
|
|
| |
| def process_and_stream( |
| history: List[Dict[str, str]], text: str, pid: str, tier: str, |
| ): |
| """Stream triage results to UI, updating step-by-step.""" |
| if not text.strip(): |
| yield ( |
| history, "", "β", "β", "", "", "", |
| "<div class='status-bar'>System ready.</div>", |
| ) |
| return |
|
|
| history = history + [ |
| {"role": "user", "content": text}, |
| {"role": "assistant", "content": ""}, |
| ] |
|
|
| |
| yield ( |
| history, "", "β", "β", |
| "Retrieving NCCN/ESMO guidelines...", |
| "Building knowledge graph...", |
| "Querying real-time evidence...", |
| "<div class='status-bar'>Processing triage via LangGraph...</div>", |
| ) |
|
|
| accumulated: Dict[str, Any] = {} |
| for node_name, progress, node_output in stream_triage(text, pid, tier): |
| if isinstance(node_output, dict): |
| accumulated.update(node_output) |
|
|
| if node_name == "done": |
| break |
| if node_name == "error": |
| history[-1]["content"] = f"**Error:** {progress}" |
| yield ( |
| history, "", "β", "β", "", "", "", |
| f"<div class='status-bar'>{progress}</div>", |
| ) |
| return |
|
|
| label = NODE_LABELS.get(node_name, node_name) |
| status_html = f"<div class='status-bar'><span class='node-step active'>{label}</span></div>" |
| history[-1]["content"] = f"*Processing: {label}...*" |
| yield ( |
| history, "", "β", "β", |
| "Retrieving NCCN/ESMO guidelines...", |
| "Building knowledge graph...", |
| "Querying real-time evidence...", |
| status_html, |
| ) |
|
|
| |
| final_md = format_final_response(accumulated) |
| history[-1]["content"] = final_md |
|
|
| sources_md, graph_md, api_md = extract_evidence(accumulated) |
|
|
| conf = accumulated.get("rag_confidence", 0.0) |
| src_count = len(accumulated.get("rag_sources", [])) |
|
|
| yield ( |
| history, |
| "", |
| f"{conf * 100:.1f}%" if conf else "β", |
| str(src_count) if src_count else "β", |
| sources_md, |
| graph_md, |
| api_md, |
| f"<div class='status-bar'>Triage completed for {pid}</div>", |
| ) |
|
|
| outputs = [ |
| chatbot, case_input, confidence_val, sources_val, |
| output_sources, output_graph, output_api, status_box, |
| ] |
| inputs = [chatbot, case_input, patient_id_input, tier_override_input] |
|
|
| case_input.submit(fn=process_and_stream, inputs=inputs, outputs=outputs) |
|
|
| new_session_btn.click( |
| lambda: ( |
| [], "", generate_patient_id(), "auto", "β", "β", |
| "", "", "", |
| "<div class='status-bar'>System ready.</div>", |
| ), |
| outputs=[ |
| chatbot, case_input, patient_id_input, tier_override_input, |
| confidence_val, sources_val, |
| output_sources, output_graph, output_api, status_box, |
| ], |
| ) |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| theme=theme, |
| css=CSS, |
| ) |
|
|