File size: 3,011 Bytes
cda0729
 
730bdeb
bb37dca
cda0729
ab80a46
cda0729
 
 
bb37dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cda0729
bb37dca
 
 
 
 
 
 
 
 
 
cda0729
 
 
 
 
 
bb37dca
cda0729
bb37dca
 
cda0729
bb37dca
 
 
 
 
cda0729
 
 
 
 
 
 
 
 
 
bb37dca
 
 
 
 
 
cda0729
 
 
 
ab80a46
cda0729
730bdeb
cda0729
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# model_utils.py
import torch
import joblib
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DEFAULT_MODEL = "maxcasado/BERT_overflow"


@lru_cache(maxsize=4)
def load_model_and_tokenizer(model_name: str):
    """
    Charge tokenizer, modèle, classes et type de problème (multi-label ou non)
    pour un modèle donné (HF repo).
    Résultat mis en cache pour éviter de recharger à chaque requête.
    """
    print(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    model.to(device)
    model.eval()

    # Tenter de charger mlb.joblib si présent dans le repo
    classes = None
    try:
        mlb_path = hf_hub_download(model_name, "mlb.joblib")
        mlb = joblib.load(mlb_path)
        classes = list(mlb.classes_)
        print(f"Loaded mlb.joblib for {model_name}")
    except Exception:
        # Fallback sur id2label si pas de mlb.joblib
        id2label = getattr(model.config, "id2label", None)
        if id2label:
            classes = [id2label[i] for i in range(len(id2label))]
            print(f"Using config.id2label for {model_name}")
        else:
            classes = [f"label_{i}" for i in range(model.config.num_labels)]
            print(f"Using generic label_i for {model_name}")

    # Déterminer si c'est du multi-label
    problem_type = getattr(model.config, "problem_type", None)
    multi_label = problem_type == "multi_label_classification"

    # Par sécurité, si problem_type n'est pas précisé mais que le modèle a beaucoup de labels,
    # on peut supposer du multi-label (heuristique).
    if problem_type is None and model.config.num_labels > 2:
        multi_label = True

    return tokenizer, model, classes, multi_label


def _to_device(batch):
    return {k: v.to(device) for k, v in batch.items()}


def predict_proba(text: str, top_k: int = 10, model_name: str | None = None):
    """
    Prend un texte et renvoie les top_k tags + probabilités
    pour le modèle spécifié (ou le modèle par défaut).
    """
    if model_name is None:
        model_name = DEFAULT_MODEL

    tokenizer, model, classes, multi_label = load_model_and_tokenizer(model_name)

    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=256,
    )

    with torch.no_grad():
        outputs = model(**_to_device(enc))
        logits = outputs.logits[0]

        if multi_label:
            probs = torch.sigmoid(logits)
        else:
            probs = torch.softmax(logits, dim=-1)

    probs = probs.cpu().numpy()
    indices = probs.argsort()[::-1][:top_k]

    return [
        {
            "label": classes[int(i)],
            "score": float(probs[i]),
        }
        for i in indices
    ]