p14-space / src /triage_llm /api /model.py
perachon's picture
Deploy CPU FastAPI stub
c9dcc3b
from __future__ import annotations
import os
import json
import urllib.request
import urllib.error
from dataclasses import dataclass
from typing import Any
@dataclass
class ModelBackendInfo:
name: str
details: dict[str, Any]
class SimpleBackend:
"""Backend minimal pour POC.
- En prod: remplacez par vLLM (in-process) ou un serveur vLLM OpenAI-compatible.
- Ici: backend "stub" pour permettre de tester l'API sans GPU.
"""
def __init__(self) -> None:
self.model_name_or_path = os.getenv("MODEL_NAME_OR_PATH", "")
def generate(self, prompt: str, max_tokens: int = 256) -> str:
return (
"Je suis un POC de triage. Pour mieux évaluer, je dois poser quelques questions et "
"je ne remplace pas un avis médical.\n\n"
f"Résumé du message: {prompt[:200]}"
)
def info(self) -> ModelBackendInfo:
return ModelBackendInfo(
name="stub", details={"model_name_or_path": self.model_name_or_path}
)
class TransformersPeftBackend:
"""Backend Transformers + PEFT (LoRA).
Conçu pour être activé via variable d'env, et chargé *lazy* au 1er appel
pour éviter de casser la CI/Docker (où torch CUDA et/ou les poids ne sont
pas disponibles).
"""
def __init__(
self,
base_model_name_or_path: str,
adapter_name_or_path: str | None,
) -> None:
self.base_model_name_or_path = base_model_name_or_path
self.adapter_name_or_path = adapter_name_or_path
self._tokenizer = None
self._model = None
self._device = None
def _lazy_init(self) -> None:
if self._model is not None:
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name_or_path,
trust_remote_code=True,
use_fast=True,
)
model = AutoModelForCausalLM.from_pretrained(
self.base_model_name_or_path,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map="auto" if device == "cuda" else None,
)
if self.adapter_name_or_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, self.adapter_name_or_path)
model.eval()
self._tokenizer = tokenizer
self._model = model
self._device = device
def _build_prompt(self, user_message: str, lang: str) -> str:
# Minimal prompt: keep behavior stable; we do not attempt strict JSON parsing here.
system_fr = (
"Tu es un assistant de triage médical (POC éducatif). "
"Tu dois être prudent, poser des questions, et proposer des étapes suivantes. "
"N'invente pas de diagnostic. En cas de signes graves, recommande les urgences."
)
system_en = (
"You are a medical triage assistant (educational POC). "
"Be cautious, ask questions, and propose next steps. "
"Do not invent a diagnosis. For red flags, recommend emergency care."
)
system = system_fr if lang == "fr" else system_en
tok = self._tokenizer
if tok is not None and hasattr(tok, "apply_chat_template"):
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user_message},
]
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return f"SYSTEM: {system}\nUSER: {user_message}\nASSISTANT:"
def generate(self, prompt: str, max_tokens: int = 256) -> str:
self._lazy_init()
import torch
assert self._model is not None
assert self._tokenizer is not None
device = self._device or ("cuda" if torch.cuda.is_available() else "cpu")
encoded = self._tokenizer(prompt, return_tensors="pt")
if device == "cuda":
encoded = {k: v.to(device) for k, v in encoded.items()}
input_len = int(encoded["input_ids"].shape[-1])
with torch.inference_mode():
out = self._model.generate(
**encoded,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.2,
top_p=0.9,
)
new_tokens = out[0][input_len:]
text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
return text.strip()
def info(self) -> ModelBackendInfo:
device = self._device
if device is None:
try:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
device = None
return ModelBackendInfo(
name="transformers-peft",
details={
"base_model": self.base_model_name_or_path,
"adapter": self.adapter_name_or_path,
"device": device,
},
)
class VllmOpenAIBackend:
"""Backend that calls a vLLM OpenAI-compatible server.
Intended for cloud deployment where vLLM runs on Linux/GPU and this API
simply forwards generation requests.
Required env vars:
- VLLM_BASE_URL (default: http://127.0.0.1:8000)
- VLLM_MODEL (default: BASE_MODEL_NAME_OR_PATH or Qwen/Qwen3-1.7B-Base)
Optional:
- VLLM_API_KEY (default: empty)
"""
def __init__(self, base_url: str, model: str, api_key: str | None = None) -> None:
self.base_url = base_url.rstrip("/")
self.model = model
self.api_key = api_key or ""
def _system_prompt(self) -> str:
# Keep it bilingual because the API contract currently calls `generate()`
# without passing `lang`.
return (
"FR: Tu es un assistant de triage médical (POC éducatif). "
"Sois prudent, pose des questions, propose des étapes suivantes. "
"N'invente pas de diagnostic. En cas de signes graves, recommande les urgences.\n\n"
"EN: You are a medical triage assistant (educational POC). "
"Be cautious, ask questions, propose next steps. "
"Do not invent a diagnosis. For red flags, recommend emergency care."
)
def generate(self, prompt: str, max_tokens: int = 256) -> str:
url = f"{self.base_url}/v1/chat/completions"
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": self._system_prompt()},
{"role": "user", "content": prompt},
],
"max_tokens": max_tokens,
"temperature": 0.2,
"top_p": 0.9,
}
data = json.dumps(payload).encode("utf-8")
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
try:
with urllib.request.urlopen(req, timeout=120) as resp:
body = resp.read().decode("utf-8", errors="replace")
except urllib.error.HTTPError as e:
err = e.read().decode("utf-8", errors="replace")
raise RuntimeError(f"vLLM request failed: HTTP {e.code} - {err[:300]}") from e
except Exception as e:
raise RuntimeError(f"vLLM request failed: {type(e).__name__}: {e}") from e
try:
obj = json.loads(body)
return (obj["choices"][0]["message"]["content"] or "").strip()
except Exception as e:
raise RuntimeError(f"Unexpected vLLM response: {body[:300]}") from e
def info(self) -> ModelBackendInfo:
return ModelBackendInfo(
name="vllm-openai",
details={
"base_url": self.base_url,
"model": self.model,
},
)
def _default_adapter_path() -> str | None:
# Prefer the latest known long-run adapter if present in this repo.
candidates = [
"checkpoints/qwen3-1.7b-dpo_LONG_20260318_1657",
"checkpoints/qwen3-1.7b-dpo_from_sft_lowvram",
]
for cand in candidates:
if os.path.exists(cand):
return cand
return None
def make_backend_from_env() -> SimpleBackend | TransformersPeftBackend | VllmOpenAIBackend:
backend = os.getenv("TRIAGE_BACKEND", "stub").strip().lower()
if backend in {"stub", "simple", "noop"}:
return SimpleBackend()
if backend in {"transformers", "peft", "transformers-peft"}:
base_model = os.getenv("BASE_MODEL_NAME_OR_PATH", "Qwen/Qwen3-1.7B-Base")
adapter = os.getenv("ADAPTER_NAME_OR_PATH") or _default_adapter_path()
return TransformersPeftBackend(base_model_name_or_path=base_model, adapter_name_or_path=adapter)
if backend in {"vllm", "vllm-openai", "openai"}:
base_url = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8000")
model = os.getenv("VLLM_MODEL") or os.getenv("BASE_MODEL_NAME_OR_PATH", "Qwen/Qwen3-1.7B-Base")
api_key = os.getenv("VLLM_API_KEY")
return VllmOpenAIBackend(base_url=base_url, model=model, api_key=api_key)
# Safe fallback
return SimpleBackend()