S-Dreamer commited on
Commit
c1430da
·
verified ·
1 Parent(s): b839a2b

Update model_inference.py

Browse files
Files changed (1) hide show
  1. model_inference.py +43 -16
model_inference.py CHANGED
@@ -1,30 +1,57 @@
1
- # model_inference.py
 
 
 
 
 
 
 
 
2
 
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
- import torch
5
 
6
  class ThreatModel:
7
  """
8
  Wraps a transformer classifier for threat categorization.
 
 
 
9
  """
10
 
11
- def __init__(self, model_path="bert-base-chinese", device=None):
12
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
14
- self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
15
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def predict(self, text):
18
- inputs = self.tokenizer(
19
  text,
20
  return_tensors="pt",
21
  truncation=True,
22
  padding=True
23
- ).to(self.device)
24
 
25
- with torch.no_grad():
26
- outputs = self.model(**inputs)
27
- logits = outputs.logits
28
- probs = torch.softmax(logits, dim=-1).cpu().tolist()[0]
29
 
30
- return probs # list of probabilities per class
 
1
+ from typing import List, Optional
2
+
3
+ try:
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ import torch
6
+ except ImportError:
7
+ AutoModelForSequenceClassification = None # type: ignore
8
+ AutoTokenizer = None # type: ignore
9
+ torch = None # type: ignore
10
 
 
 
11
 
12
  class ThreatModel:
13
  """
14
  Wraps a transformer classifier for threat categorization.
15
+
16
+ If `transformers` or `torch` are not installed, this class will gracefully
17
+ degrade and simply return empty probability lists instead of crashing.
18
  """
19
 
20
+ def __init__(self, model_path: str = "bert-base-chinese", device: Optional[str] = None):
21
+ self.available = AutoModelForSequenceClassification is not None and torch is not None
22
+ self.model = None
23
+ self.tokenizer = None
24
+ self.device = "cpu"
25
+
26
+ if not self.available:
27
+ # No transformers / torch in the environment; operate in dummy mode.
28
+ return
29
+
30
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") # type: ignore[attr-defined]
31
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path) # type: ignore[call-arg]
32
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path) # type: ignore[call-arg]
33
+ self.model.to(self.device) # type: ignore[union-attr]
34
+
35
+ def predict_proba(self, text: str) -> List[float]:
36
+ """
37
+ Return a list of probabilities per class.
38
+
39
+ If the model is not available (e.g. transformers not installed),
40
+ returns an empty list and lets the caller decide how to handle it.
41
+ """
42
+ if not self.available or self.model is None or self.tokenizer is None:
43
+ return []
44
 
45
+ inputs = self.tokenizer( # type: ignore[union-attr]
 
46
  text,
47
  return_tensors="pt",
48
  truncation=True,
49
  padding=True
50
+ ).to(self.device) # type: ignore[union-attr]
51
 
52
+ with torch.no_grad(): # type: ignore[union-attr]
53
+ outputs = self.model(**inputs) # type: ignore[operator]
54
+ logits = outputs.logits # type: ignore[union-attr]
55
+ probs = torch.softmax(logits, dim=-1).cpu().tolist()[0] # type: ignore[union-attr]
56
 
57
+ return probs