maxcasado commited on
Commit
bb37dca
·
verified ·
1 Parent(s): fc989ef

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +57 -17
model_utils.py CHANGED
@@ -1,35 +1,71 @@
1
  # model_utils.py
2
  import torch
3
  import joblib
 
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  from huggingface_hub import hf_hub_download
6
 
7
- MODEL_NAME = "maxcasado/BERT_overflow"
8
-
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
- print("Loading tokenizer and model...")
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
14
- model.to(device)
15
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- print("Loading MultiLabelBinarizer (mlb.joblib)...")
18
- mlb_path = hf_hub_download(MODEL_NAME, "mlb.joblib")
19
- mlb = joblib.load(mlb_path)
20
- classes = list(mlb.classes_) # index -> tag name
 
 
 
 
 
 
21
 
22
 
23
  def _to_device(batch):
24
  return {k: v.to(device) for k, v in batch.items()}
25
 
26
 
27
- def predict_proba(text: str, top_k: int = 10):
28
  """
29
- Multi-label prediction:
30
- - entrée : texte de la question
31
- - sortie : top_k tags avec leurs probabilités (sigmoid)
32
  """
 
 
 
 
 
33
  enc = tokenizer(
34
  text,
35
  return_tensors="pt",
@@ -40,8 +76,12 @@ def predict_proba(text: str, top_k: int = 10):
40
 
41
  with torch.no_grad():
42
  outputs = model(**_to_device(enc))
43
- logits = outputs.logits[0] # shape [num_labels]
44
- probs = torch.sigmoid(logits) # multi-label
 
 
 
 
45
 
46
  probs = probs.cpu().numpy()
47
  indices = probs.argsort()[::-1][:top_k]
 
1
  # model_utils.py
2
  import torch
3
  import joblib
4
+ from functools import lru_cache
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from huggingface_hub import hf_hub_download
7
 
 
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
+ DEFAULT_MODEL = "maxcasado/BERT_overflow"
11
+
12
+
13
+ @lru_cache(maxsize=4)
14
+ def load_model_and_tokenizer(model_name: str):
15
+ """
16
+ Charge tokenizer, modèle, classes et type de problème (multi-label ou non)
17
+ pour un modèle donné (HF repo).
18
+ Résultat mis en cache pour éviter de recharger à chaque requête.
19
+ """
20
+ print(f"Loading model: {model_name}")
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
23
+ model.to(device)
24
+ model.eval()
25
+
26
+ # Tenter de charger mlb.joblib si présent dans le repo
27
+ classes = None
28
+ try:
29
+ mlb_path = hf_hub_download(model_name, "mlb.joblib")
30
+ mlb = joblib.load(mlb_path)
31
+ classes = list(mlb.classes_)
32
+ print(f"Loaded mlb.joblib for {model_name}")
33
+ except Exception:
34
+ # Fallback sur id2label si pas de mlb.joblib
35
+ id2label = getattr(model.config, "id2label", None)
36
+ if id2label:
37
+ classes = [id2label[i] for i in range(len(id2label))]
38
+ print(f"Using config.id2label for {model_name}")
39
+ else:
40
+ classes = [f"label_{i}" for i in range(model.config.num_labels)]
41
+ print(f"Using generic label_i for {model_name}")
42
 
43
+ # Déterminer si c'est du multi-label
44
+ problem_type = getattr(model.config, "problem_type", None)
45
+ multi_label = problem_type == "multi_label_classification"
46
+
47
+ # Par sécurité, si problem_type n'est pas précisé mais que le modèle a beaucoup de labels,
48
+ # on peut supposer du multi-label (heuristique).
49
+ if problem_type is None and model.config.num_labels > 2:
50
+ multi_label = True
51
+
52
+ return tokenizer, model, classes, multi_label
53
 
54
 
55
  def _to_device(batch):
56
  return {k: v.to(device) for k, v in batch.items()}
57
 
58
 
59
+ def predict_proba(text: str, top_k: int = 10, model_name: str | None = None):
60
  """
61
+ Prend un texte et renvoie les top_k tags + probabilités
62
+ pour le modèle spécifié (ou le modèle par défaut).
 
63
  """
64
+ if model_name is None:
65
+ model_name = DEFAULT_MODEL
66
+
67
+ tokenizer, model, classes, multi_label = load_model_and_tokenizer(model_name)
68
+
69
  enc = tokenizer(
70
  text,
71
  return_tensors="pt",
 
76
 
77
  with torch.no_grad():
78
  outputs = model(**_to_device(enc))
79
+ logits = outputs.logits[0]
80
+
81
+ if multi_label:
82
+ probs = torch.sigmoid(logits)
83
+ else:
84
+ probs = torch.softmax(logits, dim=-1)
85
 
86
  probs = probs.cpu().numpy()
87
  indices = probs.argsort()[::-1][:top_k]