POC2PROD / model_utils.py
maxcasado's picture
Update model_utils.py
bb37dca verified
# model_utils.py
import torch
import joblib
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_MODEL = "maxcasado/BERT_overflow"
@lru_cache(maxsize=4)
def load_model_and_tokenizer(model_name: str):
"""
Charge tokenizer, modèle, classes et type de problème (multi-label ou non)
pour un modèle donné (HF repo).
Résultat mis en cache pour éviter de recharger à chaque requête.
"""
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device)
model.eval()
# Tenter de charger mlb.joblib si présent dans le repo
classes = None
try:
mlb_path = hf_hub_download(model_name, "mlb.joblib")
mlb = joblib.load(mlb_path)
classes = list(mlb.classes_)
print(f"Loaded mlb.joblib for {model_name}")
except Exception:
# Fallback sur id2label si pas de mlb.joblib
id2label = getattr(model.config, "id2label", None)
if id2label:
classes = [id2label[i] for i in range(len(id2label))]
print(f"Using config.id2label for {model_name}")
else:
classes = [f"label_{i}" for i in range(model.config.num_labels)]
print(f"Using generic label_i for {model_name}")
# Déterminer si c'est du multi-label
problem_type = getattr(model.config, "problem_type", None)
multi_label = problem_type == "multi_label_classification"
# Par sécurité, si problem_type n'est pas précisé mais que le modèle a beaucoup de labels,
# on peut supposer du multi-label (heuristique).
if problem_type is None and model.config.num_labels > 2:
multi_label = True
return tokenizer, model, classes, multi_label
def _to_device(batch):
return {k: v.to(device) for k, v in batch.items()}
def predict_proba(text: str, top_k: int = 10, model_name: str | None = None):
"""
Prend un texte et renvoie les top_k tags + probabilités
pour le modèle spécifié (ou le modèle par défaut).
"""
if model_name is None:
model_name = DEFAULT_MODEL
tokenizer, model, classes, multi_label = load_model_and_tokenizer(model_name)
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256,
)
with torch.no_grad():
outputs = model(**_to_device(enc))
logits = outputs.logits[0]
if multi_label:
probs = torch.sigmoid(logits)
else:
probs = torch.softmax(logits, dim=-1)
probs = probs.cpu().numpy()
indices = probs.argsort()[::-1][:top_k]
return [
{
"label": classes[int(i)],
"score": float(probs[i]),
}
for i in indices
]