brainworm2024's picture
Final live AMD GPU integration, audit fix
74f2b46
use std::time::Duration;
use serde_json::{json, Value};
use tokio::time::sleep;
use crate::{errors::AppError, state::AppState};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InferenceOutcome {
pub text: String,
pub model: String,
pub device: String,
pub transport: String,
pub degraded_reason: Option<String>,
}
pub async fn generate(state: &AppState, redacted_prompt: &str) -> Result<InferenceOutcome, AppError> {
let request_body = json!({
"model": "Qwen/Qwen2.5-7B-Instruct",
"messages": [
{ "role": "system", "content": "You are a HIPAA-aware emergency triage assistant. Produce concise clinical guidance, not a diagnosis." },
{ "role": "user", "content": redacted_prompt }
],
"temperature": 0.2,
"max_tokens": 320,
"stream": false
});
let endpoint = state.settings.vllm_url.clone();
let api_key = state.settings.vllm_api_key.clone();
let mut delay = Duration::from_millis(250);
let mut last_error = String::new();
for attempt in 1..=3u8 {
let mut builder = state.client.post(&endpoint).json(&request_body);
builder = builder.header("ngrok-skip-browser-warning", "true");
if let Some(key) = api_key.as_ref() {
builder = builder.bearer_auth(key);
}
let response = builder.send().await;
match response {
Ok(resp) if resp.status().is_success() => {
let value: Value = resp.json().await?;
let text = extract_text(&value)
.unwrap_or_else(|| "Degraded response: model returned an empty payload.".to_string());
return Ok(InferenceOutcome {
text,
model: "Qwen2.5-7B-Instruct".to_string(),
device: "AMD MI300X (ROCm/HIP)".to_string(),
transport: "vLLM OpenAI-compatible endpoint".to_string(),
degraded_reason: None,
});
}
Ok(resp) => {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
last_error = format!("attempt {attempt}: HTTP {status} {body}");
}
Err(err) => {
last_error = format!("attempt {attempt}: {err}");
}
}
sleep(delay).await;
delay = delay.saturating_mul(2);
}
let fallback = format!(
"Degraded triage mode: the MI300X inference backend was unavailable after 3 attempts. The gateway stayed deterministic and safe, redaction completed locally, and the clinician can still review the case. Last transport error: {last_error}"
);
Ok(InferenceOutcome {
text: fallback,
model: "Qwen2.5-7B-Instruct".to_string(),
device: "CPU fallback".to_string(),
transport: "safe local fallback".to_string(),
degraded_reason: Some(last_error),
})
}
fn extract_text(value: &Value) -> Option<String> {
if let Some(choice) = value.get("choices").and_then(|v| v.get(0)) {
if let Some(content) = choice
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
{
return Some(content.trim().to_string());
}
if let Some(text) = choice.get("text").and_then(|t| t.as_str()) {
return Some(text.trim().to_string());
}
}
value.get("output_text").and_then(|v| v.as_str()).map(|s| s.trim().to_string())
}