ynyg commited on
Commit
7327f2e
·
verified ·
1 Parent(s): 869883c

feat: 新增label映射

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -16,6 +16,9 @@ MODEL_PATH = "models/Unified_Prompt_Guard"
16
  # 設備
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
 
 
19
 
20
  def load_model() -> tuple[XLMRobertaTokenizer, XLMRobertaForSequenceClassification]:
21
  """加載模型"""
@@ -72,8 +75,8 @@ async def predict(request: Request, text: str = Body(..., embed=True)):
72
 
73
  return {
74
  "text": text,
75
- "label": label,
76
- "confidence": confidence
77
  }
78
 
79
 
 
16
  # 設備
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
+ # 标签映射
20
+ LABEL_MAP = {0: "safe", "unsafe": 1}
21
+
22
 
23
  def load_model() -> tuple[XLMRobertaTokenizer, XLMRobertaForSequenceClassification]:
24
  """加載模型"""
 
75
 
76
  return {
77
  "text": text,
78
+ "label": LABEL_MAP.get(label),
79
+ "confidence": confidence,
80
  }
81
 
82