maris-ai-master / backend-rust /src /api /autonomous.rs
MarisUK's picture
Maris AI model sync
f440f03 verified
use std::{convert::Infallible, time::Duration};
use axum::{
extract::Path,
http::StatusCode,
response::sse::{Event, KeepAlive, Sse},
Json,
};
use futures_util::stream;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::inference::python_bridge::PythonBridge;
use crate::inference::response_validator::{
AutonomousAgentRolePayload, AutonomousApprovalPayload, AutonomousCheckpointPayload,
AutonomousSessionPayload, AutonomousTaskPayload, AutonomousTimelineEventPayload,
};
const AUTONOMOUS_STREAM_INTERVAL_MS: u64 = 900;
struct AutonomousStreamState {
session_id: String,
active: bool,
is_first: bool,
}
#[derive(Deserialize)]
pub struct StartSessionRequest {
pub goal: String,
pub max_steps: Option<u32>,
pub persona_id: Option<String>,
pub session_id: Option<String>,
}
#[derive(Serialize)]
pub struct Task {
pub id: String,
pub description: String,
pub status: String,
pub result: Option<String>,
pub created_at: String,
pub tool: String,
pub depends_on: Vec<String>,
pub attempts: u32,
pub max_attempts: u32,
pub last_error: Option<String>,
}
#[derive(Serialize)]
pub struct StartSessionResponse {
pub session_id: String,
pub goal: String,
pub status: String,
pub tasks: Vec<Task>,
pub progress_percent: u32,
pub persona_id: String,
pub persona_title: String,
pub persona_summary: String,
pub events: Vec<TimelineEvent>,
pub checkpoints: Vec<Checkpoint>,
pub approvals: Vec<Approval>,
pub agent_roles: Vec<AgentRole>,
pub replay_cursor: u32,
pub resume_token: String,
pub failover_mode: String,
}
#[derive(Serialize)]
pub struct SessionStatus {
pub session_id: String,
pub status: String,
pub tasks: Vec<Task>,
pub progress_percent: u32,
pub persona_id: String,
pub persona_title: String,
pub persona_summary: String,
pub events: Vec<TimelineEvent>,
pub checkpoints: Vec<Checkpoint>,
pub approvals: Vec<Approval>,
pub agent_roles: Vec<AgentRole>,
pub replay_cursor: u32,
pub resume_token: String,
pub failover_mode: String,
}
#[derive(Serialize)]
pub struct TimelineEvent {
pub id: String,
pub event_type: String,
pub title: String,
pub detail: String,
pub agent_role: String,
pub level: String,
pub created_at: String,
pub task_id: Option<String>,
pub interruptible: bool,
}
#[derive(Serialize)]
pub struct Checkpoint {
pub id: String,
pub label: String,
pub status: String,
pub summary: String,
pub created_at: String,
pub task_id: Option<String>,
}
#[derive(Serialize)]
pub struct Approval {
pub id: String,
pub kind: String,
pub status: String,
pub title: String,
pub summary: String,
pub created_at: String,
pub task_id: Option<String>,
pub resolution_note: Option<String>,
}
#[derive(Serialize)]
pub struct AgentRole {
pub id: String,
pub title: String,
pub responsibility: String,
pub status: String,
}
fn map_task(task: AutonomousTaskPayload) -> Task {
Task {
id: task.id,
description: task.description,
status: task.status,
result: task.result,
created_at: task.created_at,
tool: task.tool,
depends_on: task.depends_on,
attempts: task.attempts,
max_attempts: task.max_attempts,
last_error: task.last_error,
}
}
fn map_session_status(session_id: String, result: AutonomousSessionPayload) -> SessionStatus {
SessionStatus {
session_id,
status: result.status,
progress_percent: result.progress_percent,
tasks: result.tasks.into_iter().map(map_task).collect(),
persona_id: result.persona_id,
persona_title: result.persona_title,
persona_summary: result.persona_summary,
events: result.events.into_iter().map(map_event).collect(),
checkpoints: result.checkpoints.into_iter().map(map_checkpoint).collect(),
approvals: result.approvals.into_iter().map(map_approval).collect(),
agent_roles: result.agent_roles.into_iter().map(map_agent_role).collect(),
replay_cursor: result.replay_cursor,
resume_token: result.resume_token,
failover_mode: result.failover_mode,
}
}
fn map_event(event: AutonomousTimelineEventPayload) -> TimelineEvent {
TimelineEvent {
id: event.id,
event_type: event.event_type,
title: event.title,
detail: event.detail,
agent_role: event.agent_role,
level: event.level,
created_at: event.created_at,
task_id: event.task_id,
interruptible: event.interruptible,
}
}
fn map_checkpoint(checkpoint: AutonomousCheckpointPayload) -> Checkpoint {
Checkpoint {
id: checkpoint.id,
label: checkpoint.label,
status: checkpoint.status,
summary: checkpoint.summary,
created_at: checkpoint.created_at,
task_id: checkpoint.task_id,
}
}
fn map_approval(approval: AutonomousApprovalPayload) -> Approval {
Approval {
id: approval.id,
kind: approval.kind,
status: approval.status,
title: approval.title,
summary: approval.summary,
created_at: approval.created_at,
task_id: approval.task_id,
resolution_note: approval.resolution_note,
}
}
fn map_agent_role(role: AutonomousAgentRolePayload) -> AgentRole {
AgentRole {
id: role.id,
title: role.title,
responsibility: role.responsibility,
status: role.status,
}
}
async fn fetch_status_payload(session_id: &str) -> Result<SessionStatus, String> {
let bridge = PythonBridge::new();
let payload = serde_json::json!({ "session_id": session_id });
bridge
.call::<AutonomousSessionPayload>("autonomous/status", &payload)
.await
.map(|result| map_session_status(session_id.to_string(), result))
.map_err(|err| err.to_string())
}
pub async fn start_session(
Json(req): Json<StartSessionRequest>,
) -> Result<Json<StartSessionResponse>, (StatusCode, Json<serde_json::Value>)> {
let bridge = PythonBridge::new();
let session_id = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
let payload = serde_json::json!({
"session_id": session_id,
"goal": req.goal,
"max_steps": req.max_steps.unwrap_or(10),
"persona_id": req.persona_id,
});
match bridge
.call::<AutonomousSessionPayload>("autonomous/start", &payload)
.await
{
Ok(result) => {
let tasks = result.tasks.into_iter().map(map_task).collect();
Ok(Json(StartSessionResponse {
session_id,
goal: req.goal,
status: result.status,
tasks,
progress_percent: result.progress_percent,
persona_id: result.persona_id,
persona_title: result.persona_title,
persona_summary: result.persona_summary,
events: result.events.into_iter().map(map_event).collect(),
checkpoints: result.checkpoints.into_iter().map(map_checkpoint).collect(),
approvals: result.approvals.into_iter().map(map_approval).collect(),
agent_roles: result.agent_roles.into_iter().map(map_agent_role).collect(),
replay_cursor: result.replay_cursor,
resume_token: result.resume_token,
failover_mode: result.failover_mode,
}))
}
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)),
}
}
pub async fn get_status(
Path(session_id): Path<String>,
) -> Result<Json<SessionStatus>, (StatusCode, Json<serde_json::Value>)> {
match fetch_status_payload(&session_id).await {
Ok(result) => Ok(Json(result)),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)),
}
}
pub async fn stream_status(
Path(session_id): Path<String>,
) -> Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>> {
let stream = stream::unfold(
AutonomousStreamState {
session_id,
active: true,
is_first: true,
},
|state| async move {
if !state.active {
return None;
}
if !state.is_first {
tokio::time::sleep(Duration::from_millis(AUTONOMOUS_STREAM_INTERVAL_MS)).await;
}
let next_state = match fetch_status_payload(&state.session_id).await {
Ok(status) => {
let is_terminal = matches!(status.status.as_str(), "completed" | "failed");
let event = Event::default()
.data(serde_json::to_string(&status).unwrap_or_else(|_| "{}".to_string()));
Some((
Ok(event),
AutonomousStreamState {
session_id: state.session_id,
active: !is_terminal,
is_first: false,
},
))
}
Err(error) => Some((
Ok(Event::default().event("maris_error").data(
serde_json::json!({
"error": error,
})
.to_string(),
)),
AutonomousStreamState {
session_id: state.session_id,
active: false,
is_first: false,
},
)),
};
next_state
},
);
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(10))
.text("keepalive"),
)
}