# model_utils.py import torch import joblib from functools import lru_cache from transformers import AutoTokenizer, AutoModelForSequenceClassification from huggingface_hub import hf_hub_download device = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEFAULT_MODEL = "maxcasado/BERT_overflow" @lru_cache(maxsize=4) def load_model_and_tokenizer(model_name: str): """ Charge tokenizer, modèle, classes et type de problème (multi-label ou non) pour un modèle donné (HF repo). Résultat mis en cache pour éviter de recharger à chaque requête. """ print(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.to(device) model.eval() # Tenter de charger mlb.joblib si présent dans le repo classes = None try: mlb_path = hf_hub_download(model_name, "mlb.joblib") mlb = joblib.load(mlb_path) classes = list(mlb.classes_) print(f"Loaded mlb.joblib for {model_name}") except Exception: # Fallback sur id2label si pas de mlb.joblib id2label = getattr(model.config, "id2label", None) if id2label: classes = [id2label[i] for i in range(len(id2label))] print(f"Using config.id2label for {model_name}") else: classes = [f"label_{i}" for i in range(model.config.num_labels)] print(f"Using generic label_i for {model_name}") # Déterminer si c'est du multi-label problem_type = getattr(model.config, "problem_type", None) multi_label = problem_type == "multi_label_classification" # Par sécurité, si problem_type n'est pas précisé mais que le modèle a beaucoup de labels, # on peut supposer du multi-label (heuristique). if problem_type is None and model.config.num_labels > 2: multi_label = True return tokenizer, model, classes, multi_label def _to_device(batch): return {k: v.to(device) for k, v in batch.items()} def predict_proba(text: str, top_k: int = 10, model_name: str | None = None): """ Prend un texte et renvoie les top_k tags + probabilités pour le modèle spécifié (ou le modèle par défaut). """ if model_name is None: model_name = DEFAULT_MODEL tokenizer, model, classes, multi_label = load_model_and_tokenizer(model_name) enc = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=256, ) with torch.no_grad(): outputs = model(**_to_device(enc)) logits = outputs.logits[0] if multi_label: probs = torch.sigmoid(logits) else: probs = torch.softmax(logits, dim=-1) probs = probs.cpu().numpy() indices = probs.argsort()[::-1][:top_k] return [ { "label": classes[int(i)], "score": float(probs[i]), } for i in indices ]