| use axum::{ |
| extract::{ws::WebSocketUpgrade, Form, Json, Multipart, Query, State}, |
| http::StatusCode, |
| response::{Html, IntoResponse, Response, Sse}, |
| }; |
| use axum::response::sse::Event; |
| use futures::{SinkExt, StreamExt}; |
| use serde::Deserialize; |
| use std::convert::Infallible; |
| use std::pin::Pin; |
| use futures::stream::Stream; |
|
|
| use crate::{ |
| errors::AppError, |
| models::{FineTuneResponse, MedicalApiResponse, PatientId, StatusResponse, TriageForm, TriageRecord, TriageRequest}, |
| pipeline::{medical, orchestrator::run_pipeline}, |
| state::AppState, |
| ui, |
| }; |
|
|
| pub async fn index() -> Html<String> { |
| Html(include_str!("../static/index.html").to_string()) |
| } |
|
|
| pub async fn health() -> impl IntoResponse { |
| (StatusCode::OK, "healthy") |
| } |
|
|
| pub async fn status(State(state): State<AppState>) -> Json<StatusResponse> { |
| let device = if state.settings.vllm_url.contains("127.0.0.1") { |
| "CPU-safe fallback / local development".to_string() |
| } else { |
| "AMD MI300X / vLLM endpoint".to_string() |
| }; |
| Json(StatusResponse { |
| status: "running".to_string(), |
| device, |
| model: "Qwen2.5-7B-Instruct".to_string(), |
| last_record_id: state.latest_record_id().await, |
| triage_count: state.records().await.len(), |
| }) |
| } |
|
|
| #[derive(Debug, Deserialize)] |
| pub struct TriageQuery { |
| pub patient_id: String, |
| pub reason: String, |
| pub note: String, |
| pub consent_hash: String, |
| } |
|
|
| pub async fn triage_stream( |
| State(state): State<AppState>, |
| Query(query): Query<TriageQuery>, |
| ) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, Infallible>> + Send>>> { |
| let request = match triage_request_from_query(query) { |
| Ok(request) => request, |
| Err(err) => { |
| let event = Event::default().data(serde_json::json!({"kind":"error","message":err}).to_string()); |
| let stream = Box::pin(tokio_stream::once(Ok(event))); |
| return Sse::new(stream); |
| } |
| }; |
|
|
| let (tx, rx) = tokio::sync::mpsc::channel::<Result<Event, Infallible>>(32); |
| let mut bus_rx = state.agent_bus.subscribe(); |
| tokio::spawn(async move { |
| let _ = run_pipeline(&state, request, None).await; |
| }); |
|
|
| tokio::spawn(async move { |
| loop { |
| match bus_rx.recv().await { |
| Ok(event) => { |
| let payload = serde_json::to_string(&event).unwrap_or_else(|_| "{}".to_string()); |
| if tx.send(Ok(Event::default().data(payload))).await.is_err() { |
| break; |
| } |
| } |
| Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, |
| Err(tokio::sync::broadcast::error::RecvError::Closed) => break, |
| } |
| } |
| }); |
|
|
| let stream = Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)); |
| Sse::new(stream) |
| } |
|
|
| pub async fn triage_json( |
| State(state): State<AppState>, |
| Json(payload): Json<TriageRequest>, |
| ) -> Result<Json<crate::ui::TriageViewModel>, (StatusCode, String)> { |
| let view_model = execute_pipeline(state, payload, None).await?; |
| Ok(Json(view_model)) |
| } |
|
|
| pub async fn triage_html( |
| State(state): State<AppState>, |
| Form(payload): Form<TriageForm>, |
| ) -> Result<Html<String>, (StatusCode, String)> { |
| let request = triage_request_from_form(payload).map_err(bad_request)?; |
| let view_model = execute_pipeline(state, request, None).await?; |
| Ok(Html(ui::render_triage_fragment(&view_model))) |
| } |
|
|
| pub async fn dicom_redact( |
| mut multipart: Multipart, |
| ) -> Result<Html<String>, (StatusCode, String)> { |
| let mut filename = "upload.dcm".to_string(); |
| let mut bytes = Vec::new(); |
|
|
| while let Some(field) = multipart.next_field().await.map_err(internal_error)? { |
| if field.name().unwrap_or_default() == "dicom" { |
| if let Some(file_name) = field.file_name() { |
| filename = file_name.to_string(); |
| } |
| bytes = field.bytes().await.map_err(internal_error)?.to_vec(); |
| } |
| } |
|
|
| if bytes.is_empty() { |
| return Err((StatusCode::BAD_REQUEST, "No DICOM file received".to_string())); |
| } |
|
|
| let result = crate::pipeline::dicom::redact_dicom(&filename, &bytes); |
| Ok(Html(ui::render_dicom_fragment(&result))) |
| } |
|
|
| pub async fn enrich_medical( |
| State(state): State<AppState>, |
| Query(query): Query<TriageQuery>, |
| ) -> Result<Json<MedicalApiResponse>, (StatusCode, String)> { |
| let request = triage_request_from_query(query).map_err(bad_request)?; |
| let enrichment = medical::enrich(&state, &request.note, &request.reason) |
| .await |
| .map_err(|err| (StatusCode::BAD_GATEWAY, err.to_string()))?; |
| Ok(Json(MedicalApiResponse { |
| query: enrichment.query.clone(), |
| hits: enrichment.pubmed_hits.clone(), |
| summary: format!("{} PubMed results gathered safely after triage.", enrichment.pubmed_hits.len()), |
| })) |
| } |
|
|
| pub async fn dashboard(State(state): State<AppState>) -> Html<String> { |
| Html(ui::render_dashboard_fragment(&state.records().await)) |
| } |
|
|
| pub async fn history_json(State(state): State<AppState>) -> Json<Vec<TriageRecord>> { |
| Json(state.records().await) |
| } |
|
|
| pub async fn agents_ws(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response { |
| ws.on_upgrade(move |socket| async move { |
| if let Err(err) = run_ws(socket, state).await { |
| tracing::warn!("agent websocket ended: {}", err); |
| } |
| }) |
| } |
|
|
| pub async fn trigger_federated_tune(State(state): State<AppState>) -> Json<FineTuneResponse> { |
| let job_id = format!("lora-mi300x-{}", uuid::Uuid::now_v7()); |
| let record_hint = state.latest_record_id().await.unwrap_or_else(|| "bootstrap".to_string()); |
| let command = format!( |
| "python -m peft.lora_train --model Qwen/Qwen2.5-7B-Instruct --dataset redacted-clinical-corpus --gpus 1 --precision bf16 --device mi300x --cid {record_hint}" |
| ); |
| tracing::info!("federated fine-tune job scheduled: {}", command); |
| Json(FineTuneResponse { |
| status: "submitted".to_string(), |
| job_id, |
| command, |
| model: "Qwen2.5-7B-Instruct + LoRA adapter".to_string(), |
| dataset_cid: record_hint, |
| }) |
| } |
|
|
| pub async fn federation_round(State(state): State<AppState>) -> Json<serde_json::Value> { |
| Json(serde_json::json!({ |
| "round": state.records().await.len(), |
| "active_nodes": ["General Hospital A", "Regional Medical Center", "University Clinic"], |
| "latest_record_id": state.latest_record_id().await, |
| "privacy": "redacted only; no PHI leaves the gateway", |
| })) |
| } |
|
|
| async fn execute_pipeline( |
| state: AppState, |
| request: TriageRequest, |
| dicom: Option<(String, Vec<u8>)>, |
| ) -> Result<crate::ui::TriageViewModel, (StatusCode, String)> { |
| let outcome = run_pipeline(&state, request, dicom).await.map_err(internal_pipeline_error)?; |
| Ok(crate::ui::TriageViewModel::from_outcome(outcome)) |
| } |
|
|
| fn triage_request_from_form(payload: TriageForm) -> Result<TriageRequest, String> { |
| Ok(TriageRequest { |
| patient_id: PatientId::new(payload.patient_id).map_err(|e| format!("patient_id: {e}"))?, |
| reason: payload.reason.trim().to_string(), |
| note: payload.note, |
| consent_hash: crate::models::ConsentHash::new(payload.consent_hash).map_err(|e| format!("consent_hash: {e}"))?, |
| }) |
| } |
|
|
| fn triage_request_from_query(query: TriageQuery) -> Result<TriageRequest, String> { |
| Ok(TriageRequest { |
| patient_id: PatientId::new(query.patient_id).map_err(|e| format!("patient_id: {e}"))?, |
| reason: query.reason.trim().to_string(), |
| note: query.note, |
| consent_hash: crate::models::ConsentHash::new(query.consent_hash).map_err(|e| format!("consent_hash: {e}"))?, |
| }) |
| } |
|
|
| fn bad_request(err: String) -> (StatusCode, String) { |
| (StatusCode::BAD_REQUEST, err) |
| } |
|
|
| fn internal_error<E: std::fmt::Display>(err: E) -> (StatusCode, String) { |
| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) |
| } |
|
|
| fn internal_pipeline_error(err: AppError) -> (StatusCode, String) { |
| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) |
| } |
|
|
| async fn run_ws(socket: axum::extract::ws::WebSocket, state: AppState) -> Result<(), AppError> { |
| let (mut sender, mut receiver) = socket.split(); |
| let mut rx = state.agent_bus.subscribe(); |
|
|
| loop { |
| tokio::select! { |
| maybe_event = rx.recv() => { |
| match maybe_event { |
| Ok(event) => { |
| let payload = serde_json::to_string(&event)?; |
| sender.send(axum::extract::ws::Message::Text(payload.into())).await?; |
| } |
| Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, |
| Err(tokio::sync::broadcast::error::RecvError::Closed) => break, |
| } |
| } |
| maybe_msg = receiver.next() => { |
| match maybe_msg { |
| Some(Ok(axum::extract::ws::Message::Close(_))) | None => break, |
| Some(Ok(axum::extract::ws::Message::Ping(_))) => {} |
| Some(Ok(_)) => {} |
| Some(Err(_)) => break, |
| } |
| } |
| } |
| } |
|
|
| Ok(()) |
| } |
|
|