| |
| """Minimal Streamlit interface for the EchoPilot ReAct agent.""" |
|
|
| from __future__ import annotations |
|
|
| import tempfile |
| import uuid |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import streamlit as st |
|
|
| from agents import get_intelligent_agent |
| from config import Config |
| from utils.video_utils import convert_video_to_h264 |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent |
|
|
|
|
| @st.cache_resource(show_spinner=False) |
| def load_agent(): |
| IntelligentAgent, _ = get_intelligent_agent() |
| return IntelligentAgent(device=Config.DEVICE) |
|
|
|
|
| def _persist_upload(upload) -> Tuple[Path, Path]: |
| """Write uploaded file to a temporary directory and return its path.""" |
| suffix = Path(upload.name or "input.mp4").suffix or ".mp4" |
| temp_dir = Path(tempfile.mkdtemp(prefix="echopilot_")) |
| video_path = temp_dir / f"input{suffix}" |
| with open(video_path, "wb") as handle: |
| handle.write(upload.getbuffer()) |
| return video_path, temp_dir |
|
|
|
|
| def _extract_key_metrics(response) -> List[Tuple[str, str]]: |
| metrics: List[Tuple[str, str]] = [] |
| results = response.execution_result.results or {} |
| tool_results: Dict[str, Dict] = results.get("tool_results") or {} |
| measurement = tool_results.get("echo_measurement_prediction") |
| if isinstance(measurement, dict) and measurement.get("status") == "success": |
| entries = measurement.get("measurements") or [] |
| if entries: |
| data = entries[0].get("measurements", {}) |
|
|
| def _format_metric(key: str, label: str, precision: int = 1): |
| info = data.get(key) |
| if not isinstance(info, dict): |
| return |
| value = info.get("value") |
| unit = info.get("unit", "") |
| if value is None: |
| return |
| try: |
| value_str = f"{float(value):.{precision}f}" |
| except (TypeError, ValueError): |
| value_str = str(value) |
| unit_str = f" {unit}".strip() |
| metrics.append((label, f"{value_str}{unit_str}")) |
|
|
| for key, label in [ |
| ("ejection_fraction", "Ejection Fraction"), |
| ("EF", "Ejection Fraction"), |
| ]: |
| if key in data: |
| _format_metric(key, label, precision=1) |
| break |
|
|
| if "pulmonary_artery_pressure_continuous" in data: |
| _format_metric("pulmonary_artery_pressure_continuous", "Pulmonary Artery Pressure", precision=1) |
| if "dilated_ivc" in data: |
| _format_metric("dilated_ivc", "IVC Diameter", precision=2) |
| return metrics |
|
|
|
|
| def main() -> None: |
| st.set_page_config(page_title="EchoPilot Agent", page_icon="🫀", layout="wide") |
| st.title("EchoPilot · Echocardiography Co-Pilot") |
| st.caption("Upload a study, ask a focused question, and EchoPilot will run the appropriate tools to answer.") |
|
|
| upload_col, info_col = st.columns([2, 1]) |
| with upload_col: |
| uploaded_video = st.file_uploader( |
| "Echo video file", |
| type=["mp4", "mov", "m4v", "avi", "wmv"], |
| help="Standard ultrasound formats are supported.", |
| ) |
| default_question = "Estimate the ejection fraction and note any major abnormalities." |
| query = st.text_area("Clinical question", value=default_question, height=120) |
| with info_col: |
| st.markdown("### How it works") |
| st.write( |
| "- EchoPilot uses a ReAct loop to decide which tools to call.\n" |
| "- It may segment chambers, compute EchoPrime measurements, or run disease classifiers.\n" |
| "- Results are summarized below; raw tool logs are hidden for clarity." |
| ) |
|
|
| response = None |
| display_video: Path | None = None |
|
|
| run_clicked = st.button("Run Analysis", type="primary", use_container_width=True, disabled=not uploaded_video or not query.strip()) |
| if run_clicked: |
| agent = load_agent() |
| video_path, temp_dir = _persist_upload(uploaded_video) |
| temp_display_dir = PROJECT_ROOT / "temp" |
| temp_display_dir.mkdir(parents=True, exist_ok=True) |
| display_target = temp_display_dir / f"display_{uuid.uuid4().hex}.mp4" |
| display_video = Path(convert_video_to_h264(str(video_path), str(display_target))) |
|
|
| with st.spinner("EchoPilot is analyzing the study..."): |
| response = agent.process_query(query.strip(), str(video_path)) |
|
|
| |
| if temp_dir.exists(): |
| for item in temp_dir.iterdir(): |
| item.unlink(missing_ok=True) |
| temp_dir.rmdir() |
|
|
| if response: |
| st.success("Analysis complete") |
| metrics = _extract_key_metrics(response) |
|
|
| container = st.container() |
| video_col, metrics_col = container.columns([2, 1]) |
| if display_video and display_video.exists(): |
| with video_col: |
| st.video(str(display_video)) |
| if metrics: |
| with metrics_col: |
| st.markdown("#### Key Measurements") |
| for label, value in metrics: |
| st.metric(label, value) |
|
|
| st.divider() |
| st.markdown("#### EchoPilot Response") |
| st.chat_message("user").write(query.strip()) |
| st.chat_message("assistant").write(response.response_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|