maxcasado commited on
Commit
ab80a46
·
verified ·
1 Parent(s): 2286c50

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +12 -9
model_utils.py CHANGED
@@ -2,19 +2,22 @@
2
  import torch
3
  import joblib
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
5
 
6
  MODEL_NAME = "maxcasado/BERT_overflow"
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
12
  model.to(device)
13
  model.eval()
14
 
15
- # MultiLabelBinarizer pour récupérer les noms de tags
16
- mlb = joblib.load("mlb.joblib") # le fichier présent dans ton repo modèle
17
- classes = list(mlb.classes_) # index -> nom de tag
 
18
 
19
 
20
  def _to_device(batch):
@@ -23,7 +26,9 @@ def _to_device(batch):
23
 
24
  def predict_proba(text: str, top_k: int = 10):
25
  """
26
- Multi-label : renvoie top_k tags avec proba (sigmoid).
 
 
27
  """
28
  enc = tokenizer(
29
  text,
@@ -35,18 +40,16 @@ def predict_proba(text: str, top_k: int = 10):
35
 
36
  with torch.no_grad():
37
  outputs = model(**_to_device(enc))
38
- logits = outputs.logits[0]
39
- probs = torch.sigmoid(logits) # multi-label !
40
 
41
  probs = probs.cpu().numpy()
42
-
43
  indices = probs.argsort()[::-1][:top_k]
44
 
45
- results = [
46
  {
47
  "label": classes[int(i)],
48
  "score": float(probs[i]),
49
  }
50
  for i in indices
51
  ]
52
- return results
 
2
  import torch
3
  import joblib
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ from huggingface_hub import hf_hub_download
6
 
7
  MODEL_NAME = "maxcasado/BERT_overflow"
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ print("Loading tokenizer and model...")
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
14
  model.to(device)
15
  model.eval()
16
 
17
+ print("Loading MultiLabelBinarizer (mlb.joblib)...")
18
+ mlb_path = hf_hub_download(MODEL_NAME, "mlb.joblib")
19
+ mlb = joblib.load(mlb_path)
20
+ classes = list(mlb.classes_) # index -> tag name
21
 
22
 
23
  def _to_device(batch):
 
26
 
27
  def predict_proba(text: str, top_k: int = 10):
28
  """
29
+ Multi-label prediction:
30
+ - entrée : texte de la question
31
+ - sortie : top_k tags avec leurs probabilités (sigmoid)
32
  """
33
  enc = tokenizer(
34
  text,
 
40
 
41
  with torch.no_grad():
42
  outputs = model(**_to_device(enc))
43
+ logits = outputs.logits[0] # shape [num_labels]
44
+ probs = torch.sigmoid(logits) # multi-label
45
 
46
  probs = probs.cpu().numpy()
 
47
  indices = probs.argsort()[::-1][:top_k]
48
 
49
+ return [
50
  {
51
  "label": classes[int(i)],
52
  "score": float(probs[i]),
53
  }
54
  for i in indices
55
  ]