use axum::{ extract::{ws::WebSocketUpgrade, Form, Json, Multipart, Query, State}, http::StatusCode, response::{Html, IntoResponse, Response, Sse}, }; use axum::response::sse::Event; use futures::{SinkExt, StreamExt}; use serde::Deserialize; use std::convert::Infallible; use std::pin::Pin; use futures::stream::Stream; use crate::{ errors::AppError, models::{FineTuneResponse, MedicalApiResponse, PatientId, StatusResponse, TriageForm, TriageRecord, TriageRequest}, pipeline::{medical, orchestrator::run_pipeline}, state::AppState, ui, }; pub async fn index() -> Html { Html(include_str!("../static/index.html").to_string()) } pub async fn health() -> impl IntoResponse { (StatusCode::OK, "healthy") } pub async fn status(State(state): State) -> Json { let device = if state.settings.vllm_url.contains("127.0.0.1") { "CPU-safe fallback / local development".to_string() } else { "AMD MI300X / vLLM endpoint".to_string() }; Json(StatusResponse { status: "running".to_string(), device, model: "Qwen2.5-7B-Instruct".to_string(), last_record_id: state.latest_record_id().await, triage_count: state.records().await.len(), }) } #[derive(Debug, Deserialize)] pub struct TriageQuery { pub patient_id: String, pub reason: String, pub note: String, pub consent_hash: String, } pub async fn triage_stream( State(state): State, Query(query): Query, ) -> Sse> + Send>>> { let request = match triage_request_from_query(query) { Ok(request) => request, Err(err) => { let event = Event::default().data(serde_json::json!({"kind":"error","message":err}).to_string()); let stream = Box::pin(tokio_stream::once(Ok(event))); return Sse::new(stream); } }; let (tx, rx) = tokio::sync::mpsc::channel::>(32); let mut bus_rx = state.agent_bus.subscribe(); tokio::spawn(async move { let _ = run_pipeline(&state, request, None).await; }); tokio::spawn(async move { loop { match bus_rx.recv().await { Ok(event) => { let payload = serde_json::to_string(&event).unwrap_or_else(|_| "{}".to_string()); if tx.send(Ok(Event::default().data(payload))).await.is_err() { break; } } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, Err(tokio::sync::broadcast::error::RecvError::Closed) => break, } } }); let stream = Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)); Sse::new(stream) } pub async fn triage_json( State(state): State, Json(payload): Json, ) -> Result, (StatusCode, String)> { let view_model = execute_pipeline(state, payload, None).await?; Ok(Json(view_model)) } pub async fn triage_html( State(state): State, Form(payload): Form, ) -> Result, (StatusCode, String)> { let request = triage_request_from_form(payload).map_err(bad_request)?; let view_model = execute_pipeline(state, request, None).await?; Ok(Html(ui::render_triage_fragment(&view_model))) } pub async fn dicom_redact( mut multipart: Multipart, ) -> Result, (StatusCode, String)> { let mut filename = "upload.dcm".to_string(); let mut bytes = Vec::new(); while let Some(field) = multipart.next_field().await.map_err(internal_error)? { if field.name().unwrap_or_default() == "dicom" { if let Some(file_name) = field.file_name() { filename = file_name.to_string(); } bytes = field.bytes().await.map_err(internal_error)?.to_vec(); } } if bytes.is_empty() { return Err((StatusCode::BAD_REQUEST, "No DICOM file received".to_string())); } let result = crate::pipeline::dicom::redact_dicom(&filename, &bytes); Ok(Html(ui::render_dicom_fragment(&result))) } pub async fn enrich_medical( State(state): State, Query(query): Query, ) -> Result, (StatusCode, String)> { let request = triage_request_from_query(query).map_err(bad_request)?; let enrichment = medical::enrich(&state, &request.note, &request.reason) .await .map_err(|err| (StatusCode::BAD_GATEWAY, err.to_string()))?; Ok(Json(MedicalApiResponse { query: enrichment.query.clone(), hits: enrichment.pubmed_hits.clone(), summary: format!("{} PubMed results gathered safely after triage.", enrichment.pubmed_hits.len()), })) } pub async fn dashboard(State(state): State) -> Html { Html(ui::render_dashboard_fragment(&state.records().await)) } pub async fn history_json(State(state): State) -> Json> { Json(state.records().await) } pub async fn agents_ws(ws: WebSocketUpgrade, State(state): State) -> Response { ws.on_upgrade(move |socket| async move { if let Err(err) = run_ws(socket, state).await { tracing::warn!("agent websocket ended: {}", err); } }) } pub async fn trigger_federated_tune(State(state): State) -> Json { let job_id = format!("lora-mi300x-{}", uuid::Uuid::now_v7()); let record_hint = state.latest_record_id().await.unwrap_or_else(|| "bootstrap".to_string()); let command = format!( "python -m peft.lora_train --model Qwen/Qwen2.5-7B-Instruct --dataset redacted-clinical-corpus --gpus 1 --precision bf16 --device mi300x --cid {record_hint}" ); tracing::info!("federated fine-tune job scheduled: {}", command); Json(FineTuneResponse { status: "submitted".to_string(), job_id, command, model: "Qwen2.5-7B-Instruct + LoRA adapter".to_string(), dataset_cid: record_hint, }) } pub async fn federation_round(State(state): State) -> Json { Json(serde_json::json!({ "round": state.records().await.len(), "active_nodes": ["General Hospital A", "Regional Medical Center", "University Clinic"], "latest_record_id": state.latest_record_id().await, "privacy": "redacted only; no PHI leaves the gateway", })) } async fn execute_pipeline( state: AppState, request: TriageRequest, dicom: Option<(String, Vec)>, ) -> Result { let outcome = run_pipeline(&state, request, dicom).await.map_err(internal_pipeline_error)?; Ok(crate::ui::TriageViewModel::from_outcome(outcome)) } fn triage_request_from_form(payload: TriageForm) -> Result { Ok(TriageRequest { patient_id: PatientId::new(payload.patient_id).map_err(|e| format!("patient_id: {e}"))?, reason: payload.reason.trim().to_string(), note: payload.note, consent_hash: crate::models::ConsentHash::new(payload.consent_hash).map_err(|e| format!("consent_hash: {e}"))?, }) } fn triage_request_from_query(query: TriageQuery) -> Result { Ok(TriageRequest { patient_id: PatientId::new(query.patient_id).map_err(|e| format!("patient_id: {e}"))?, reason: query.reason.trim().to_string(), note: query.note, consent_hash: crate::models::ConsentHash::new(query.consent_hash).map_err(|e| format!("consent_hash: {e}"))?, }) } fn bad_request(err: String) -> (StatusCode, String) { (StatusCode::BAD_REQUEST, err) } fn internal_error(err: E) -> (StatusCode, String) { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } fn internal_pipeline_error(err: AppError) -> (StatusCode, String) { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } async fn run_ws(socket: axum::extract::ws::WebSocket, state: AppState) -> Result<(), AppError> { let (mut sender, mut receiver) = socket.split(); let mut rx = state.agent_bus.subscribe(); loop { tokio::select! { maybe_event = rx.recv() => { match maybe_event { Ok(event) => { let payload = serde_json::to_string(&event)?; sender.send(axum::extract::ws::Message::Text(payload.into())).await?; } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, Err(tokio::sync::broadcast::error::RecvError::Closed) => break, } } maybe_msg = receiver.next() => { match maybe_msg { Some(Ok(axum::extract::ws::Message::Close(_))) | None => break, Some(Ok(axum::extract::ws::Message::Ping(_))) => {} Some(Ok(_)) => {} Some(Err(_)) => break, } } } } Ok(()) }