File size: 3,011 Bytes
cda0729 730bdeb bb37dca cda0729 ab80a46 cda0729 bb37dca cda0729 bb37dca cda0729 bb37dca cda0729 bb37dca cda0729 bb37dca cda0729 bb37dca cda0729 ab80a46 cda0729 730bdeb cda0729 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# model_utils.py
import torch
import joblib
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_MODEL = "maxcasado/BERT_overflow"
@lru_cache(maxsize=4)
def load_model_and_tokenizer(model_name: str):
"""
Charge tokenizer, modèle, classes et type de problème (multi-label ou non)
pour un modèle donné (HF repo).
Résultat mis en cache pour éviter de recharger à chaque requête.
"""
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device)
model.eval()
# Tenter de charger mlb.joblib si présent dans le repo
classes = None
try:
mlb_path = hf_hub_download(model_name, "mlb.joblib")
mlb = joblib.load(mlb_path)
classes = list(mlb.classes_)
print(f"Loaded mlb.joblib for {model_name}")
except Exception:
# Fallback sur id2label si pas de mlb.joblib
id2label = getattr(model.config, "id2label", None)
if id2label:
classes = [id2label[i] for i in range(len(id2label))]
print(f"Using config.id2label for {model_name}")
else:
classes = [f"label_{i}" for i in range(model.config.num_labels)]
print(f"Using generic label_i for {model_name}")
# Déterminer si c'est du multi-label
problem_type = getattr(model.config, "problem_type", None)
multi_label = problem_type == "multi_label_classification"
# Par sécurité, si problem_type n'est pas précisé mais que le modèle a beaucoup de labels,
# on peut supposer du multi-label (heuristique).
if problem_type is None and model.config.num_labels > 2:
multi_label = True
return tokenizer, model, classes, multi_label
def _to_device(batch):
return {k: v.to(device) for k, v in batch.items()}
def predict_proba(text: str, top_k: int = 10, model_name: str | None = None):
"""
Prend un texte et renvoie les top_k tags + probabilités
pour le modèle spécifié (ou le modèle par défaut).
"""
if model_name is None:
model_name = DEFAULT_MODEL
tokenizer, model, classes, multi_label = load_model_and_tokenizer(model_name)
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256,
)
with torch.no_grad():
outputs = model(**_to_device(enc))
logits = outputs.logits[0]
if multi_label:
probs = torch.sigmoid(logits)
else:
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
indices = probs.argsort()[::-1][:top_k]
return [
{
"label": classes[int(i)],
"score": float(probs[i]),
}
for i in indices
]
|