bztxb commited on
Commit
9e62f27
·
verified ·
1 Parent(s): 9881e4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -10,6 +10,10 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
  model.eval()
12
 
 
 
 
 
13
  # 标签列表
14
  labels_list = ['上供', '中人', '中央亞', '中央行政', '中央軍', '主副食', '交通', '人事', '人文敎育',
15
  '人物', '任免', '住生活', '佛敎', '保健', '倉庫', '倫理', '倭', '儀式', '儒學', '元',
@@ -33,10 +37,20 @@ labels_list = ['上供', '中人', '中央亞', '中央行政', '中央軍', '
33
  # 推理函数
34
  @torch.no_grad()
35
  def classify(text: str, threshold: float = 0.5):
 
36
  enc = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
 
 
37
  logits = model(**enc).logits.sigmoid() # 获取预测的概率
 
 
38
  scores = logits[0].tolist()
39
  preds = [labels_list[i] for i, score in enumerate(scores) if score >= threshold]
 
 
40
  return preds, {labels_list[i]: score for i, score in enumerate(scores)}
41
 
42
  # 创建 Gradio 接口
 
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
  model.eval()
12
 
13
+ # 确定设备(CPU 或 GPU)
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model.to(device) # 将模型加载到正确的设备
16
+
17
  # 标签列表
18
  labels_list = ['上供', '中人', '中央亞', '中央行政', '中央軍', '主副食', '交通', '人事', '人文敎育',
19
  '人物', '任免', '住生活', '佛敎', '保健', '倉庫', '倫理', '倭', '儀式', '儒學', '元',
 
37
  # 推理函数
38
  @torch.no_grad()
39
  def classify(text: str, threshold: float = 0.5):
40
+ # 将输入文本转化为模型需要的格式
41
  enc = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
42
+
43
+ # 将输入数据移动到正确的设备(GPU/CPU)
44
+ enc = {key: value.to(device) for key, value in enc.items()}
45
+
46
+ # 获取模型的预测值
47
  logits = model(**enc).logits.sigmoid() # 获取预测的概率
48
+
49
+ # 获取预测结果
50
  scores = logits[0].tolist()
51
  preds = [labels_list[i] for i, score in enumerate(scores) if score >= threshold]
52
+
53
+ # 返回预测标签及得分
54
  return preds, {labels_list[i]: score for i, score in enumerate(scores)}
55
 
56
  # 创建 Gradio 接口