|
|
|
|
|
import torch |
|
|
import joblib |
|
|
from functools import lru_cache |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
DEFAULT_MODEL = "maxcasado/BERT_overflow" |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=4) |
|
|
def load_model_and_tokenizer(model_name: str): |
|
|
""" |
|
|
Charge tokenizer, modèle, classes et type de problème (multi-label ou non) |
|
|
pour un modèle donné (HF repo). |
|
|
Résultat mis en cache pour éviter de recharger à chaque requête. |
|
|
""" |
|
|
print(f"Loading model: {model_name}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
classes = None |
|
|
try: |
|
|
mlb_path = hf_hub_download(model_name, "mlb.joblib") |
|
|
mlb = joblib.load(mlb_path) |
|
|
classes = list(mlb.classes_) |
|
|
print(f"Loaded mlb.joblib for {model_name}") |
|
|
except Exception: |
|
|
|
|
|
id2label = getattr(model.config, "id2label", None) |
|
|
if id2label: |
|
|
classes = [id2label[i] for i in range(len(id2label))] |
|
|
print(f"Using config.id2label for {model_name}") |
|
|
else: |
|
|
classes = [f"label_{i}" for i in range(model.config.num_labels)] |
|
|
print(f"Using generic label_i for {model_name}") |
|
|
|
|
|
|
|
|
problem_type = getattr(model.config, "problem_type", None) |
|
|
multi_label = problem_type == "multi_label_classification" |
|
|
|
|
|
|
|
|
|
|
|
if problem_type is None and model.config.num_labels > 2: |
|
|
multi_label = True |
|
|
|
|
|
return tokenizer, model, classes, multi_label |
|
|
|
|
|
|
|
|
def _to_device(batch): |
|
|
return {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
def predict_proba(text: str, top_k: int = 10, model_name: str | None = None): |
|
|
""" |
|
|
Prend un texte et renvoie les top_k tags + probabilités |
|
|
pour le modèle spécifié (ou le modèle par défaut). |
|
|
""" |
|
|
if model_name is None: |
|
|
model_name = DEFAULT_MODEL |
|
|
|
|
|
tokenizer, model, classes, multi_label = load_model_and_tokenizer(model_name) |
|
|
|
|
|
enc = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=256, |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**_to_device(enc)) |
|
|
logits = outputs.logits[0] |
|
|
|
|
|
if multi_label: |
|
|
probs = torch.sigmoid(logits) |
|
|
else: |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
probs = probs.cpu().numpy() |
|
|
indices = probs.argsort()[::-1][:top_k] |
|
|
|
|
|
return [ |
|
|
{ |
|
|
"label": classes[int(i)], |
|
|
"score": float(probs[i]), |
|
|
} |
|
|
for i in indices |
|
|
] |
|
|
|