File size: 3,563 Bytes
74f2b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::time::Duration;
use serde_json::{json, Value};
use tokio::time::sleep;
use crate::{errors::AppError, state::AppState};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InferenceOutcome {
    pub text: String,
    pub model: String,
    pub device: String,
    pub transport: String,
    pub degraded_reason: Option<String>,
}

pub async fn generate(state: &AppState, redacted_prompt: &str) -> Result<InferenceOutcome, AppError> {
    let request_body = json!({
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "messages": [
            { "role": "system", "content": "You are a HIPAA-aware emergency triage assistant. Produce concise clinical guidance, not a diagnosis." },
            { "role": "user", "content": redacted_prompt }
        ],
        "temperature": 0.2,
        "max_tokens": 320,
        "stream": false
    });

    let endpoint = state.settings.vllm_url.clone();
    let api_key = state.settings.vllm_api_key.clone();

    let mut delay = Duration::from_millis(250);
    let mut last_error = String::new();
    for attempt in 1..=3u8 {
        let mut builder = state.client.post(&endpoint).json(&request_body);
        builder = builder.header("ngrok-skip-browser-warning", "true");
        if let Some(key) = api_key.as_ref() {
            builder = builder.bearer_auth(key);
        }
        let response = builder.send().await;
        match response {
            Ok(resp) if resp.status().is_success() => {
                let value: Value = resp.json().await?;
                let text = extract_text(&value)
                    .unwrap_or_else(|| "Degraded response: model returned an empty payload.".to_string());
                return Ok(InferenceOutcome {
                    text,
                    model: "Qwen2.5-7B-Instruct".to_string(),
                    device: "AMD MI300X (ROCm/HIP)".to_string(),
                    transport: "vLLM OpenAI-compatible endpoint".to_string(),
                    degraded_reason: None,
                });
            }
            Ok(resp) => {
                let status = resp.status();
                let body = resp.text().await.unwrap_or_default();
                last_error = format!("attempt {attempt}: HTTP {status} {body}");
            }
            Err(err) => {
                last_error = format!("attempt {attempt}: {err}");
            }
        }
        sleep(delay).await;
        delay = delay.saturating_mul(2);
    }

    let fallback = format!(
        "Degraded triage mode: the MI300X inference backend was unavailable after 3 attempts. The gateway stayed deterministic and safe, redaction completed locally, and the clinician can still review the case. Last transport error: {last_error}"
    );

    Ok(InferenceOutcome {
        text: fallback,
        model: "Qwen2.5-7B-Instruct".to_string(),
        device: "CPU fallback".to_string(),
        transport: "safe local fallback".to_string(),
        degraded_reason: Some(last_error),
    })
}

fn extract_text(value: &Value) -> Option<String> {
    if let Some(choice) = value.get("choices").and_then(|v| v.get(0)) {
        if let Some(content) = choice
            .get("message")
            .and_then(|m| m.get("content"))
            .and_then(|c| c.as_str())
        {
            return Some(content.trim().to_string());
        }
        if let Some(text) = choice.get("text").and_then(|t| t.as_str()) {
            return Some(text.trim().to_string());
        }
    }
    value.get("output_text").and_then(|v| v.as_str()).map(|s| s.trim().to_string())
}