| use std::time::Duration; |
| use serde_json::{json, Value}; |
| use tokio::time::sleep; |
| use crate::{errors::AppError, state::AppState}; |
|
|
| #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] |
| pub struct InferenceOutcome { |
| pub text: String, |
| pub model: String, |
| pub device: String, |
| pub transport: String, |
| pub degraded_reason: Option<String>, |
| } |
|
|
| pub async fn generate(state: &AppState, redacted_prompt: &str) -> Result<InferenceOutcome, AppError> { |
| let request_body = json!({ |
| "model": "Qwen/Qwen2.5-7B-Instruct", |
| "messages": [ |
| { "role": "system", "content": "You are a HIPAA-aware emergency triage assistant. Produce concise clinical guidance, not a diagnosis." }, |
| { "role": "user", "content": redacted_prompt } |
| ], |
| "temperature": 0.2, |
| "max_tokens": 320, |
| "stream": false |
| }); |
|
|
| let endpoint = state.settings.vllm_url.clone(); |
| let api_key = state.settings.vllm_api_key.clone(); |
|
|
| let mut delay = Duration::from_millis(250); |
| let mut last_error = String::new(); |
| for attempt in 1..=3u8 { |
| let mut builder = state.client.post(&endpoint).json(&request_body); |
| builder = builder.header("ngrok-skip-browser-warning", "true"); |
| if let Some(key) = api_key.as_ref() { |
| builder = builder.bearer_auth(key); |
| } |
| let response = builder.send().await; |
| match response { |
| Ok(resp) if resp.status().is_success() => { |
| let value: Value = resp.json().await?; |
| let text = extract_text(&value) |
| .unwrap_or_else(|| "Degraded response: model returned an empty payload.".to_string()); |
| return Ok(InferenceOutcome { |
| text, |
| model: "Qwen2.5-7B-Instruct".to_string(), |
| device: "AMD MI300X (ROCm/HIP)".to_string(), |
| transport: "vLLM OpenAI-compatible endpoint".to_string(), |
| degraded_reason: None, |
| }); |
| } |
| Ok(resp) => { |
| let status = resp.status(); |
| let body = resp.text().await.unwrap_or_default(); |
| last_error = format!("attempt {attempt}: HTTP {status} {body}"); |
| } |
| Err(err) => { |
| last_error = format!("attempt {attempt}: {err}"); |
| } |
| } |
| sleep(delay).await; |
| delay = delay.saturating_mul(2); |
| } |
|
|
| let fallback = format!( |
| "Degraded triage mode: the MI300X inference backend was unavailable after 3 attempts. The gateway stayed deterministic and safe, redaction completed locally, and the clinician can still review the case. Last transport error: {last_error}" |
| ); |
|
|
| Ok(InferenceOutcome { |
| text: fallback, |
| model: "Qwen2.5-7B-Instruct".to_string(), |
| device: "CPU fallback".to_string(), |
| transport: "safe local fallback".to_string(), |
| degraded_reason: Some(last_error), |
| }) |
| } |
|
|
| fn extract_text(value: &Value) -> Option<String> { |
| if let Some(choice) = value.get("choices").and_then(|v| v.get(0)) { |
| if let Some(content) = choice |
| .get("message") |
| .and_then(|m| m.get("content")) |
| .and_then(|c| c.as_str()) |
| { |
| return Some(content.trim().to_string()); |
| } |
| if let Some(text) = choice.get("text").and_then(|t| t.as_str()) { |
| return Some(text.trim().to_string()); |
| } |
| } |
| value.get("output_text").and_then(|v| v.as_str()).map(|s| s.trim().to_string()) |
| } |
|
|