maxcasado commited on
Commit
cda0729
·
verified ·
1 Parent(s): fc94431

Create model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +60 -0
model_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_utils.py
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+ # 🔁 Mets ici le chemin ou le repo HF de ton modèle
6
+ MODEL_NAME = "ton-username/stackoverflow-tags-bert"
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
12
+ model.to(device)
13
+ model.eval()
14
+
15
+ # On essaie de récupérer les labels à partir de la config
16
+ id2label = getattr(model.config, "id2label", None)
17
+ if not id2label:
18
+ id2label = {i: f"label_{i}" for i in range(model.config.num_labels)}
19
+
20
+
21
+ def _to_device(batch):
22
+ return {k: v.to(device) for k, v in batch.items()}
23
+
24
+
25
+ def predict_proba(text: str, top_k: int = 10):
26
+ """
27
+ Prend une question en entrée, renvoie les top_k tags avec leurs probas.
28
+ Gère multi-class et multi-label.
29
+ """
30
+ enc = tokenizer(
31
+ text,
32
+ return_tensors="pt",
33
+ truncation=True,
34
+ padding=True,
35
+ max_length=256,
36
+ )
37
+
38
+ with torch.no_grad():
39
+ outputs = model(**_to_device(enc))
40
+ logits = outputs.logits[0]
41
+
42
+ # Heuristique : si problème multi-label
43
+ if getattr(model.config, "problem_type", None) == "multi_label_classification":
44
+ probs = torch.sigmoid(logits)
45
+ else:
46
+ probs = torch.softmax(logits, dim=-1)
47
+
48
+ probs = probs.cpu().numpy()
49
+
50
+ # indices triés par proba décroissante
51
+ indices = probs.argsort()[::-1][:top_k]
52
+
53
+ results = [
54
+ {
55
+ "label": id2label.get(int(i), f"label_{int(i)}"),
56
+ "score": float(probs[i]),
57
+ }
58
+ for i in indices
59
+ ]
60
+ return results