Ibou17's picture
Initial deployment — Radiology Agent
3a08315
"""
Chargement des modèles HuggingFace.
- GPU disponible → quantification 4-bit (BitsAndBytes)
- CPU uniquement → float32, TinyLlama ou Mistral léger
"""
from __future__ import annotations
import logging
import os
from typing import Optional, Tuple
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from config.config import CFG
LOGGER = logging.getLogger(__name__)
MODEL_REGISTRY = {
"biomistral": "BioMistral/BioMistral-7B",
"mistral": "mistralai/Mistral-7B-Instruct-v0.2",
"tiny": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
_model_cache: dict = {}
def _get_bnb_config():
"""Config 4-bit pour GPU (T4 / A100)."""
from transformers import BitsAndBytesConfig
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
def load_model(
model_key: str = "biomistral",
quantize: bool = True,
cache_dir: Optional[str] = None,
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""
Charge (ou retourne depuis cache) un modèle HuggingFace.
Gère automatiquement CPU (pas de quantification) et GPU (4-bit).
"""
global _model_cache
if model_key in _model_cache:
LOGGER.info("Modèle '%s' servi depuis cache mémoire.", model_key)
return _model_cache[model_key]
if cache_dir is None:
cache_dir = str(CFG.models_cache)
model_id = MODEL_REGISTRY.get(model_key, model_key)
has_gpu = torch.cuda.is_available()
LOGGER.info("Chargement modèle: %s | GPU=%s | quantize=%s", model_id, has_gpu, quantize)
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
cache_dir=cache_dir,
use_fast=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Modèle
if has_gpu and quantize:
bnb_cfg = _get_bnb_config()
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_cfg,
device_map="auto",
cache_dir=cache_dir,
trust_remote_code=True,
)
elif has_gpu:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
cache_dir=cache_dir,
trust_remote_code=True,
)
else:
# CPU : float32, pas de quantification
LOGGER.warning("Pas de GPU — chargement CPU (lent, recommandé : model_key='tiny')")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="cpu",
cache_dir=cache_dir,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model.eval()
_model_cache[model_key] = (model, tokenizer)
LOGGER.info("✅ Modèle '%s' chargé.", model_key)
return model, tokenizer
def generate(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 200,
temperature: float = 0.1,
) -> str:
"""Génère une réponse textuelle depuis un prompt."""
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature if temperature > 0 else 1.0,
do_sample=(temperature > 0),
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output_ids[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()