rustvital-amd / src /handlers.rs
brainworm2024's picture
Final live AMD GPU integration, audit fix
74f2b46
use axum::{
extract::{ws::WebSocketUpgrade, Form, Json, Multipart, Query, State},
http::StatusCode,
response::{Html, IntoResponse, Response, Sse},
};
use axum::response::sse::Event;
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use std::convert::Infallible;
use std::pin::Pin;
use futures::stream::Stream;
use crate::{
errors::AppError,
models::{FineTuneResponse, MedicalApiResponse, PatientId, StatusResponse, TriageForm, TriageRecord, TriageRequest},
pipeline::{medical, orchestrator::run_pipeline},
state::AppState,
ui,
};
pub async fn index() -> Html<String> {
Html(include_str!("../static/index.html").to_string())
}
pub async fn health() -> impl IntoResponse {
(StatusCode::OK, "healthy")
}
pub async fn status(State(state): State<AppState>) -> Json<StatusResponse> {
let device = if state.settings.vllm_url.contains("127.0.0.1") {
"CPU-safe fallback / local development".to_string()
} else {
"AMD MI300X / vLLM endpoint".to_string()
};
Json(StatusResponse {
status: "running".to_string(),
device,
model: "Qwen2.5-7B-Instruct".to_string(),
last_record_id: state.latest_record_id().await,
triage_count: state.records().await.len(),
})
}
#[derive(Debug, Deserialize)]
pub struct TriageQuery {
pub patient_id: String,
pub reason: String,
pub note: String,
pub consent_hash: String,
}
pub async fn triage_stream(
State(state): State<AppState>,
Query(query): Query<TriageQuery>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, Infallible>> + Send>>> {
let request = match triage_request_from_query(query) {
Ok(request) => request,
Err(err) => {
let event = Event::default().data(serde_json::json!({"kind":"error","message":err}).to_string());
let stream = Box::pin(tokio_stream::once(Ok(event)));
return Sse::new(stream);
}
};
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Event, Infallible>>(32);
let mut bus_rx = state.agent_bus.subscribe();
tokio::spawn(async move {
let _ = run_pipeline(&state, request, None).await;
});
tokio::spawn(async move {
loop {
match bus_rx.recv().await {
Ok(event) => {
let payload = serde_json::to_string(&event).unwrap_or_else(|_| "{}".to_string());
if tx.send(Ok(Event::default().data(payload))).await.is_err() {
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
});
let stream = Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx));
Sse::new(stream)
}
pub async fn triage_json(
State(state): State<AppState>,
Json(payload): Json<TriageRequest>,
) -> Result<Json<crate::ui::TriageViewModel>, (StatusCode, String)> {
let view_model = execute_pipeline(state, payload, None).await?;
Ok(Json(view_model))
}
pub async fn triage_html(
State(state): State<AppState>,
Form(payload): Form<TriageForm>,
) -> Result<Html<String>, (StatusCode, String)> {
let request = triage_request_from_form(payload).map_err(bad_request)?;
let view_model = execute_pipeline(state, request, None).await?;
Ok(Html(ui::render_triage_fragment(&view_model)))
}
pub async fn dicom_redact(
mut multipart: Multipart,
) -> Result<Html<String>, (StatusCode, String)> {
let mut filename = "upload.dcm".to_string();
let mut bytes = Vec::new();
while let Some(field) = multipart.next_field().await.map_err(internal_error)? {
if field.name().unwrap_or_default() == "dicom" {
if let Some(file_name) = field.file_name() {
filename = file_name.to_string();
}
bytes = field.bytes().await.map_err(internal_error)?.to_vec();
}
}
if bytes.is_empty() {
return Err((StatusCode::BAD_REQUEST, "No DICOM file received".to_string()));
}
let result = crate::pipeline::dicom::redact_dicom(&filename, &bytes);
Ok(Html(ui::render_dicom_fragment(&result)))
}
pub async fn enrich_medical(
State(state): State<AppState>,
Query(query): Query<TriageQuery>,
) -> Result<Json<MedicalApiResponse>, (StatusCode, String)> {
let request = triage_request_from_query(query).map_err(bad_request)?;
let enrichment = medical::enrich(&state, &request.note, &request.reason)
.await
.map_err(|err| (StatusCode::BAD_GATEWAY, err.to_string()))?;
Ok(Json(MedicalApiResponse {
query: enrichment.query.clone(),
hits: enrichment.pubmed_hits.clone(),
summary: format!("{} PubMed results gathered safely after triage.", enrichment.pubmed_hits.len()),
}))
}
pub async fn dashboard(State(state): State<AppState>) -> Html<String> {
Html(ui::render_dashboard_fragment(&state.records().await))
}
pub async fn history_json(State(state): State<AppState>) -> Json<Vec<TriageRecord>> {
Json(state.records().await)
}
pub async fn agents_ws(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
ws.on_upgrade(move |socket| async move {
if let Err(err) = run_ws(socket, state).await {
tracing::warn!("agent websocket ended: {}", err);
}
})
}
pub async fn trigger_federated_tune(State(state): State<AppState>) -> Json<FineTuneResponse> {
let job_id = format!("lora-mi300x-{}", uuid::Uuid::now_v7());
let record_hint = state.latest_record_id().await.unwrap_or_else(|| "bootstrap".to_string());
let command = format!(
"python -m peft.lora_train --model Qwen/Qwen2.5-7B-Instruct --dataset redacted-clinical-corpus --gpus 1 --precision bf16 --device mi300x --cid {record_hint}"
);
tracing::info!("federated fine-tune job scheduled: {}", command);
Json(FineTuneResponse {
status: "submitted".to_string(),
job_id,
command,
model: "Qwen2.5-7B-Instruct + LoRA adapter".to_string(),
dataset_cid: record_hint,
})
}
pub async fn federation_round(State(state): State<AppState>) -> Json<serde_json::Value> {
Json(serde_json::json!({
"round": state.records().await.len(),
"active_nodes": ["General Hospital A", "Regional Medical Center", "University Clinic"],
"latest_record_id": state.latest_record_id().await,
"privacy": "redacted only; no PHI leaves the gateway",
}))
}
async fn execute_pipeline(
state: AppState,
request: TriageRequest,
dicom: Option<(String, Vec<u8>)>,
) -> Result<crate::ui::TriageViewModel, (StatusCode, String)> {
let outcome = run_pipeline(&state, request, dicom).await.map_err(internal_pipeline_error)?;
Ok(crate::ui::TriageViewModel::from_outcome(outcome))
}
fn triage_request_from_form(payload: TriageForm) -> Result<TriageRequest, String> {
Ok(TriageRequest {
patient_id: PatientId::new(payload.patient_id).map_err(|e| format!("patient_id: {e}"))?,
reason: payload.reason.trim().to_string(),
note: payload.note,
consent_hash: crate::models::ConsentHash::new(payload.consent_hash).map_err(|e| format!("consent_hash: {e}"))?,
})
}
fn triage_request_from_query(query: TriageQuery) -> Result<TriageRequest, String> {
Ok(TriageRequest {
patient_id: PatientId::new(query.patient_id).map_err(|e| format!("patient_id: {e}"))?,
reason: query.reason.trim().to_string(),
note: query.note,
consent_hash: crate::models::ConsentHash::new(query.consent_hash).map_err(|e| format!("consent_hash: {e}"))?,
})
}
fn bad_request(err: String) -> (StatusCode, String) {
(StatusCode::BAD_REQUEST, err)
}
fn internal_error<E: std::fmt::Display>(err: E) -> (StatusCode, String) {
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}
fn internal_pipeline_error(err: AppError) -> (StatusCode, String) {
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}
async fn run_ws(socket: axum::extract::ws::WebSocket, state: AppState) -> Result<(), AppError> {
let (mut sender, mut receiver) = socket.split();
let mut rx = state.agent_bus.subscribe();
loop {
tokio::select! {
maybe_event = rx.recv() => {
match maybe_event {
Ok(event) => {
let payload = serde_json::to_string(&event)?;
sender.send(axum::extract::ws::Message::Text(payload.into())).await?;
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
maybe_msg = receiver.next() => {
match maybe_msg {
Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
Some(Ok(axum::extract::ws::Message::Ping(_))) => {}
Some(Ok(_)) => {}
Some(Err(_)) => break,
}
}
}
}
Ok(())
}