| use std::{convert::Infallible, time::Duration}; |
|
|
| use axum::{ |
| extract::Path, |
| http::StatusCode, |
| response::sse::{Event, KeepAlive, Sse}, |
| Json, |
| }; |
| use futures_util::stream; |
| use serde::{Deserialize, Serialize}; |
| use uuid::Uuid; |
|
|
| use crate::inference::python_bridge::PythonBridge; |
| use crate::inference::response_validator::{ |
| AutonomousAgentRolePayload, AutonomousApprovalPayload, AutonomousCheckpointPayload, |
| AutonomousSessionPayload, AutonomousTaskPayload, AutonomousTimelineEventPayload, |
| }; |
|
|
| const AUTONOMOUS_STREAM_INTERVAL_MS: u64 = 900; |
|
|
| struct AutonomousStreamState { |
| session_id: String, |
| active: bool, |
| is_first: bool, |
| } |
|
|
| #[derive(Deserialize)] |
| pub struct StartSessionRequest { |
| pub goal: String, |
| pub max_steps: Option<u32>, |
| pub persona_id: Option<String>, |
| pub session_id: Option<String>, |
| } |
|
|
| #[derive(Serialize)] |
| pub struct Task { |
| pub id: String, |
| pub description: String, |
| pub status: String, |
| pub result: Option<String>, |
| pub created_at: String, |
| pub tool: String, |
| pub depends_on: Vec<String>, |
| pub attempts: u32, |
| pub max_attempts: u32, |
| pub last_error: Option<String>, |
| } |
|
|
| #[derive(Serialize)] |
| pub struct StartSessionResponse { |
| pub session_id: String, |
| pub goal: String, |
| pub status: String, |
| pub tasks: Vec<Task>, |
| pub progress_percent: u32, |
| pub persona_id: String, |
| pub persona_title: String, |
| pub persona_summary: String, |
| pub events: Vec<TimelineEvent>, |
| pub checkpoints: Vec<Checkpoint>, |
| pub approvals: Vec<Approval>, |
| pub agent_roles: Vec<AgentRole>, |
| pub replay_cursor: u32, |
| pub resume_token: String, |
| pub failover_mode: String, |
| } |
|
|
| #[derive(Serialize)] |
| pub struct SessionStatus { |
| pub session_id: String, |
| pub status: String, |
| pub tasks: Vec<Task>, |
| pub progress_percent: u32, |
| pub persona_id: String, |
| pub persona_title: String, |
| pub persona_summary: String, |
| pub events: Vec<TimelineEvent>, |
| pub checkpoints: Vec<Checkpoint>, |
| pub approvals: Vec<Approval>, |
| pub agent_roles: Vec<AgentRole>, |
| pub replay_cursor: u32, |
| pub resume_token: String, |
| pub failover_mode: String, |
| } |
|
|
| #[derive(Serialize)] |
| pub struct TimelineEvent { |
| pub id: String, |
| pub event_type: String, |
| pub title: String, |
| pub detail: String, |
| pub agent_role: String, |
| pub level: String, |
| pub created_at: String, |
| pub task_id: Option<String>, |
| pub interruptible: bool, |
| } |
|
|
| #[derive(Serialize)] |
| pub struct Checkpoint { |
| pub id: String, |
| pub label: String, |
| pub status: String, |
| pub summary: String, |
| pub created_at: String, |
| pub task_id: Option<String>, |
| } |
|
|
| #[derive(Serialize)] |
| pub struct Approval { |
| pub id: String, |
| pub kind: String, |
| pub status: String, |
| pub title: String, |
| pub summary: String, |
| pub created_at: String, |
| pub task_id: Option<String>, |
| pub resolution_note: Option<String>, |
| } |
|
|
| #[derive(Serialize)] |
| pub struct AgentRole { |
| pub id: String, |
| pub title: String, |
| pub responsibility: String, |
| pub status: String, |
| } |
|
|
| fn map_task(task: AutonomousTaskPayload) -> Task { |
| Task { |
| id: task.id, |
| description: task.description, |
| status: task.status, |
| result: task.result, |
| created_at: task.created_at, |
| tool: task.tool, |
| depends_on: task.depends_on, |
| attempts: task.attempts, |
| max_attempts: task.max_attempts, |
| last_error: task.last_error, |
| } |
| } |
|
|
| fn map_session_status(session_id: String, result: AutonomousSessionPayload) -> SessionStatus { |
| SessionStatus { |
| session_id, |
| status: result.status, |
| progress_percent: result.progress_percent, |
| tasks: result.tasks.into_iter().map(map_task).collect(), |
| persona_id: result.persona_id, |
| persona_title: result.persona_title, |
| persona_summary: result.persona_summary, |
| events: result.events.into_iter().map(map_event).collect(), |
| checkpoints: result.checkpoints.into_iter().map(map_checkpoint).collect(), |
| approvals: result.approvals.into_iter().map(map_approval).collect(), |
| agent_roles: result.agent_roles.into_iter().map(map_agent_role).collect(), |
| replay_cursor: result.replay_cursor, |
| resume_token: result.resume_token, |
| failover_mode: result.failover_mode, |
| } |
| } |
|
|
| fn map_event(event: AutonomousTimelineEventPayload) -> TimelineEvent { |
| TimelineEvent { |
| id: event.id, |
| event_type: event.event_type, |
| title: event.title, |
| detail: event.detail, |
| agent_role: event.agent_role, |
| level: event.level, |
| created_at: event.created_at, |
| task_id: event.task_id, |
| interruptible: event.interruptible, |
| } |
| } |
|
|
| fn map_checkpoint(checkpoint: AutonomousCheckpointPayload) -> Checkpoint { |
| Checkpoint { |
| id: checkpoint.id, |
| label: checkpoint.label, |
| status: checkpoint.status, |
| summary: checkpoint.summary, |
| created_at: checkpoint.created_at, |
| task_id: checkpoint.task_id, |
| } |
| } |
|
|
| fn map_approval(approval: AutonomousApprovalPayload) -> Approval { |
| Approval { |
| id: approval.id, |
| kind: approval.kind, |
| status: approval.status, |
| title: approval.title, |
| summary: approval.summary, |
| created_at: approval.created_at, |
| task_id: approval.task_id, |
| resolution_note: approval.resolution_note, |
| } |
| } |
|
|
| fn map_agent_role(role: AutonomousAgentRolePayload) -> AgentRole { |
| AgentRole { |
| id: role.id, |
| title: role.title, |
| responsibility: role.responsibility, |
| status: role.status, |
| } |
| } |
|
|
| async fn fetch_status_payload(session_id: &str) -> Result<SessionStatus, String> { |
| let bridge = PythonBridge::new(); |
| let payload = serde_json::json!({ "session_id": session_id }); |
|
|
| bridge |
| .call::<AutonomousSessionPayload>("autonomous/status", &payload) |
| .await |
| .map(|result| map_session_status(session_id.to_string(), result)) |
| .map_err(|err| err.to_string()) |
| } |
|
|
| pub async fn start_session( |
| Json(req): Json<StartSessionRequest>, |
| ) -> Result<Json<StartSessionResponse>, (StatusCode, Json<serde_json::Value>)> { |
| let bridge = PythonBridge::new(); |
| let session_id = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string()); |
|
|
| let payload = serde_json::json!({ |
| "session_id": session_id, |
| "goal": req.goal, |
| "max_steps": req.max_steps.unwrap_or(10), |
| "persona_id": req.persona_id, |
| }); |
|
|
| match bridge |
| .call::<AutonomousSessionPayload>("autonomous/start", &payload) |
| .await |
| { |
| Ok(result) => { |
| let tasks = result.tasks.into_iter().map(map_task).collect(); |
|
|
| Ok(Json(StartSessionResponse { |
| session_id, |
| goal: req.goal, |
| status: result.status, |
| tasks, |
| progress_percent: result.progress_percent, |
| persona_id: result.persona_id, |
| persona_title: result.persona_title, |
| persona_summary: result.persona_summary, |
| events: result.events.into_iter().map(map_event).collect(), |
| checkpoints: result.checkpoints.into_iter().map(map_checkpoint).collect(), |
| approvals: result.approvals.into_iter().map(map_approval).collect(), |
| agent_roles: result.agent_roles.into_iter().map(map_agent_role).collect(), |
| replay_cursor: result.replay_cursor, |
| resume_token: result.resume_token, |
| failover_mode: result.failover_mode, |
| })) |
| } |
| Err(e) => Err(( |
| StatusCode::INTERNAL_SERVER_ERROR, |
| Json(serde_json::json!({ "error": e.to_string() })), |
| )), |
| } |
| } |
|
|
| pub async fn get_status( |
| Path(session_id): Path<String>, |
| ) -> Result<Json<SessionStatus>, (StatusCode, Json<serde_json::Value>)> { |
| match fetch_status_payload(&session_id).await { |
| Ok(result) => Ok(Json(result)), |
| Err(e) => Err(( |
| StatusCode::INTERNAL_SERVER_ERROR, |
| Json(serde_json::json!({ "error": e.to_string() })), |
| )), |
| } |
| } |
|
|
| pub async fn stream_status( |
| Path(session_id): Path<String>, |
| ) -> Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>> { |
| let stream = stream::unfold( |
| AutonomousStreamState { |
| session_id, |
| active: true, |
| is_first: true, |
| }, |
| |state| async move { |
| if !state.active { |
| return None; |
| } |
|
|
| if !state.is_first { |
| tokio::time::sleep(Duration::from_millis(AUTONOMOUS_STREAM_INTERVAL_MS)).await; |
| } |
|
|
| let next_state = match fetch_status_payload(&state.session_id).await { |
| Ok(status) => { |
| let is_terminal = matches!(status.status.as_str(), "completed" | "failed"); |
| let event = Event::default() |
| .data(serde_json::to_string(&status).unwrap_or_else(|_| "{}".to_string())); |
| Some(( |
| Ok(event), |
| AutonomousStreamState { |
| session_id: state.session_id, |
| active: !is_terminal, |
| is_first: false, |
| }, |
| )) |
| } |
| Err(error) => Some(( |
| Ok(Event::default().event("maris_error").data( |
| serde_json::json!({ |
| "error": error, |
| }) |
| .to_string(), |
| )), |
| AutonomousStreamState { |
| session_id: state.session_id, |
| active: false, |
| is_first: false, |
| }, |
| )), |
| }; |
|
|
| next_state |
| }, |
| ); |
|
|
| Sse::new(stream).keep_alive( |
| KeepAlive::new() |
| .interval(Duration::from_secs(10)) |
| .text("keepalive"), |
| ) |
| } |
|
|