| from __future__ import annotations |
|
|
| import os |
| import json |
| import urllib.request |
| import urllib.error |
| from dataclasses import dataclass |
| from typing import Any |
|
|
|
|
| @dataclass |
| class ModelBackendInfo: |
| name: str |
| details: dict[str, Any] |
|
|
|
|
| class SimpleBackend: |
| """Backend minimal pour POC. |
| |
| - En prod: remplacez par vLLM (in-process) ou un serveur vLLM OpenAI-compatible. |
| - Ici: backend "stub" pour permettre de tester l'API sans GPU. |
| """ |
|
|
| def __init__(self) -> None: |
| self.model_name_or_path = os.getenv("MODEL_NAME_OR_PATH", "") |
|
|
| def generate(self, prompt: str, max_tokens: int = 256) -> str: |
| return ( |
| "Je suis un POC de triage. Pour mieux évaluer, je dois poser quelques questions et " |
| "je ne remplace pas un avis médical.\n\n" |
| f"Résumé du message: {prompt[:200]}" |
| ) |
|
|
| def info(self) -> ModelBackendInfo: |
| return ModelBackendInfo( |
| name="stub", details={"model_name_or_path": self.model_name_or_path} |
| ) |
|
|
|
|
| class TransformersPeftBackend: |
| """Backend Transformers + PEFT (LoRA). |
| |
| Conçu pour être activé via variable d'env, et chargé *lazy* au 1er appel |
| pour éviter de casser la CI/Docker (où torch CUDA et/ou les poids ne sont |
| pas disponibles). |
| """ |
|
|
| def __init__( |
| self, |
| base_model_name_or_path: str, |
| adapter_name_or_path: str | None, |
| ) -> None: |
| self.base_model_name_or_path = base_model_name_or_path |
| self.adapter_name_or_path = adapter_name_or_path |
|
|
| self._tokenizer = None |
| self._model = None |
| self._device = None |
|
|
| def _lazy_init(self) -> None: |
| if self._model is not None: |
| return |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| self.base_model_name_or_path, |
| trust_remote_code=True, |
| use_fast=True, |
| ) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| self.base_model_name_or_path, |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True, |
| device_map="auto" if device == "cuda" else None, |
| ) |
|
|
| if self.adapter_name_or_path: |
| from peft import PeftModel |
|
|
| model = PeftModel.from_pretrained(model, self.adapter_name_or_path) |
|
|
| model.eval() |
| self._tokenizer = tokenizer |
| self._model = model |
| self._device = device |
|
|
| def _build_prompt(self, user_message: str, lang: str) -> str: |
| |
| system_fr = ( |
| "Tu es un assistant de triage médical (POC éducatif). " |
| "Tu dois être prudent, poser des questions, et proposer des étapes suivantes. " |
| "N'invente pas de diagnostic. En cas de signes graves, recommande les urgences." |
| ) |
| system_en = ( |
| "You are a medical triage assistant (educational POC). " |
| "Be cautious, ask questions, and propose next steps. " |
| "Do not invent a diagnosis. For red flags, recommend emergency care." |
| ) |
| system = system_fr if lang == "fr" else system_en |
|
|
| tok = self._tokenizer |
| if tok is not None and hasattr(tok, "apply_chat_template"): |
| messages = [ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": user_message}, |
| ] |
| return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
| return f"SYSTEM: {system}\nUSER: {user_message}\nASSISTANT:" |
|
|
| def generate(self, prompt: str, max_tokens: int = 256) -> str: |
| self._lazy_init() |
|
|
| import torch |
|
|
| assert self._model is not None |
| assert self._tokenizer is not None |
|
|
| device = self._device or ("cuda" if torch.cuda.is_available() else "cpu") |
| encoded = self._tokenizer(prompt, return_tensors="pt") |
| if device == "cuda": |
| encoded = {k: v.to(device) for k, v in encoded.items()} |
|
|
| input_len = int(encoded["input_ids"].shape[-1]) |
| with torch.inference_mode(): |
| out = self._model.generate( |
| **encoded, |
| max_new_tokens=max_tokens, |
| do_sample=True, |
| temperature=0.2, |
| top_p=0.9, |
| ) |
|
|
| new_tokens = out[0][input_len:] |
| text = self._tokenizer.decode(new_tokens, skip_special_tokens=True) |
| return text.strip() |
|
|
| def info(self) -> ModelBackendInfo: |
| device = self._device |
| if device is None: |
| try: |
| import torch |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| except Exception: |
| device = None |
| return ModelBackendInfo( |
| name="transformers-peft", |
| details={ |
| "base_model": self.base_model_name_or_path, |
| "adapter": self.adapter_name_or_path, |
| "device": device, |
| }, |
| ) |
|
|
|
|
| class VllmOpenAIBackend: |
| """Backend that calls a vLLM OpenAI-compatible server. |
| |
| Intended for cloud deployment where vLLM runs on Linux/GPU and this API |
| simply forwards generation requests. |
| |
| Required env vars: |
| - VLLM_BASE_URL (default: http://127.0.0.1:8000) |
| - VLLM_MODEL (default: BASE_MODEL_NAME_OR_PATH or Qwen/Qwen3-1.7B-Base) |
| Optional: |
| - VLLM_API_KEY (default: empty) |
| """ |
|
|
| def __init__(self, base_url: str, model: str, api_key: str | None = None) -> None: |
| self.base_url = base_url.rstrip("/") |
| self.model = model |
| self.api_key = api_key or "" |
|
|
| def _system_prompt(self) -> str: |
| |
| |
| return ( |
| "FR: Tu es un assistant de triage médical (POC éducatif). " |
| "Sois prudent, pose des questions, propose des étapes suivantes. " |
| "N'invente pas de diagnostic. En cas de signes graves, recommande les urgences.\n\n" |
| "EN: You are a medical triage assistant (educational POC). " |
| "Be cautious, ask questions, propose next steps. " |
| "Do not invent a diagnosis. For red flags, recommend emergency care." |
| ) |
|
|
| def generate(self, prompt: str, max_tokens: int = 256) -> str: |
| url = f"{self.base_url}/v1/chat/completions" |
| payload = { |
| "model": self.model, |
| "messages": [ |
| {"role": "system", "content": self._system_prompt()}, |
| {"role": "user", "content": prompt}, |
| ], |
| "max_tokens": max_tokens, |
| "temperature": 0.2, |
| "top_p": 0.9, |
| } |
|
|
| data = json.dumps(payload).encode("utf-8") |
| headers = {"Content-Type": "application/json"} |
| if self.api_key: |
| headers["Authorization"] = f"Bearer {self.api_key}" |
|
|
| req = urllib.request.Request(url, data=data, headers=headers, method="POST") |
| try: |
| with urllib.request.urlopen(req, timeout=120) as resp: |
| body = resp.read().decode("utf-8", errors="replace") |
| except urllib.error.HTTPError as e: |
| err = e.read().decode("utf-8", errors="replace") |
| raise RuntimeError(f"vLLM request failed: HTTP {e.code} - {err[:300]}") from e |
| except Exception as e: |
| raise RuntimeError(f"vLLM request failed: {type(e).__name__}: {e}") from e |
|
|
| try: |
| obj = json.loads(body) |
| return (obj["choices"][0]["message"]["content"] or "").strip() |
| except Exception as e: |
| raise RuntimeError(f"Unexpected vLLM response: {body[:300]}") from e |
|
|
| def info(self) -> ModelBackendInfo: |
| return ModelBackendInfo( |
| name="vllm-openai", |
| details={ |
| "base_url": self.base_url, |
| "model": self.model, |
| }, |
| ) |
|
|
|
|
| def _default_adapter_path() -> str | None: |
| |
| candidates = [ |
| "checkpoints/qwen3-1.7b-dpo_LONG_20260318_1657", |
| "checkpoints/qwen3-1.7b-dpo_from_sft_lowvram", |
| ] |
| for cand in candidates: |
| if os.path.exists(cand): |
| return cand |
| return None |
|
|
|
|
| def make_backend_from_env() -> SimpleBackend | TransformersPeftBackend | VllmOpenAIBackend: |
| backend = os.getenv("TRIAGE_BACKEND", "stub").strip().lower() |
| if backend in {"stub", "simple", "noop"}: |
| return SimpleBackend() |
|
|
| if backend in {"transformers", "peft", "transformers-peft"}: |
| base_model = os.getenv("BASE_MODEL_NAME_OR_PATH", "Qwen/Qwen3-1.7B-Base") |
| adapter = os.getenv("ADAPTER_NAME_OR_PATH") or _default_adapter_path() |
| return TransformersPeftBackend(base_model_name_or_path=base_model, adapter_name_or_path=adapter) |
|
|
| if backend in {"vllm", "vllm-openai", "openai"}: |
| base_url = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8000") |
| model = os.getenv("VLLM_MODEL") or os.getenv("BASE_MODEL_NAME_OR_PATH", "Qwen/Qwen3-1.7B-Base") |
| api_key = os.getenv("VLLM_API_KEY") |
| return VllmOpenAIBackend(base_url=base_url, model=model, api_key=api_key) |
|
|
| |
| return SimpleBackend() |
|
|