maxcasado commited on
Commit
730bdeb
·
verified ·
1 Parent(s): ca2fc92

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +7 -15
model_utils.py CHANGED
@@ -1,8 +1,8 @@
1
  # model_utils.py
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
- # 🔁 Mets ici le chemin ou le repo HF de ton modèle
6
  MODEL_NAME = "maxcasado/BERT_overflow"
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -12,10 +12,9 @@ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
12
  model.to(device)
13
  model.eval()
14
 
15
- # On essaie de récupérer les labels à partir de la config
16
- id2label = getattr(model.config, "id2label", None)
17
- if not id2label:
18
- id2label = {i: f"label_{i}" for i in range(model.config.num_labels)}
19
 
20
 
21
  def _to_device(batch):
@@ -24,8 +23,7 @@ def _to_device(batch):
24
 
25
  def predict_proba(text: str, top_k: int = 10):
26
  """
27
- Prend une question en entrée, renvoie les top_k tags avec leurs probas.
28
- Gère multi-class et multi-label.
29
  """
30
  enc = tokenizer(
31
  text,
@@ -38,21 +36,15 @@ def predict_proba(text: str, top_k: int = 10):
38
  with torch.no_grad():
39
  outputs = model(**_to_device(enc))
40
  logits = outputs.logits[0]
41
-
42
- # Heuristique : si problème multi-label
43
- if getattr(model.config, "problem_type", None) == "multi_label_classification":
44
- probs = torch.sigmoid(logits)
45
- else:
46
- probs = torch.softmax(logits, dim=-1)
47
 
48
  probs = probs.cpu().numpy()
49
 
50
- # indices triés par proba décroissante
51
  indices = probs.argsort()[::-1][:top_k]
52
 
53
  results = [
54
  {
55
- "label": id2label.get(int(i), f"label_{int(i)}"),
56
  "score": float(probs[i]),
57
  }
58
  for i in indices
 
1
  # model_utils.py
2
  import torch
3
+ import joblib
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
 
 
6
  MODEL_NAME = "maxcasado/BERT_overflow"
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
12
  model.to(device)
13
  model.eval()
14
 
15
+ # MultiLabelBinarizer pour récupérer les noms de tags
16
+ mlb = joblib.load("mlb.joblib") # le fichier présent dans ton repo modèle
17
+ classes = list(mlb.classes_) # index -> nom de tag
 
18
 
19
 
20
  def _to_device(batch):
 
23
 
24
  def predict_proba(text: str, top_k: int = 10):
25
  """
26
+ Multi-label : renvoie top_k tags avec proba (sigmoid).
 
27
  """
28
  enc = tokenizer(
29
  text,
 
36
  with torch.no_grad():
37
  outputs = model(**_to_device(enc))
38
  logits = outputs.logits[0]
39
+ probs = torch.sigmoid(logits) # multi-label !
 
 
 
 
 
40
 
41
  probs = probs.cpu().numpy()
42
 
 
43
  indices = probs.argsort()[::-1][:top_k]
44
 
45
  results = [
46
  {
47
+ "label": classes[int(i)],
48
  "score": float(probs[i]),
49
  }
50
  for i in indices