ynyg commited on
Commit
9965cc0
·
verified ·
1 Parent(s): 0160edf

feat: 將預測任務包裝成協程

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from contextlib import asynccontextmanager
2
 
3
  import torch
 
4
  from fastapi import FastAPI
5
  from fastapi.params import Body
6
  from transformers import (
@@ -49,24 +50,30 @@ async def predict(instance: FastAPI, text: str = Body(..., embed=True)):
49
  :param text: 待分類的文本
50
  :return: 預測結果,包括文本、預測類別和置信度
51
  """
52
- # 獲取對象
53
- tokenizer, model = instance.state.tokenizer, instance.state.model
54
 
55
- # 分詞處理
56
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
 
57
 
58
- #
59
- with torch.no_grad():
60
- outputs = model(**inputs)
61
 
62
- # 輸出
63
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
64
- confidences, classes = torch.max(predictions, dim=-1)
 
 
 
 
 
 
 
 
65
 
66
  return {
67
  "text": text,
68
- "label": classes.item(),
69
- "confidence": confidences.item()
70
  }
71
 
72
 
 
1
  from contextlib import asynccontextmanager
2
 
3
  import torch
4
+ from anyio.to_thread import run_sync
5
  from fastapi import FastAPI
6
  from fastapi.params import Body
7
  from transformers import (
 
50
  :param text: 待分類的文本
51
  :return: 預測結果,包括文本、預測類別和置信度
52
  """
 
 
53
 
54
+ def _inference():
55
+ # 獲取對象
56
+ tokenizer, model = instance.state.tokenizer, instance.state.model
57
 
58
+ # 分詞處
59
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
 
60
 
61
+ #
62
+ with torch.no_grad():
63
+ outputs = model(**inputs)
64
+
65
+ # 處理輸出
66
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
67
+ confidences, classes = torch.max(predictions, dim=-1)
68
+
69
+ return classes.item(), confidences.item()
70
+
71
+ label, confidence = await run_sync(_inference)
72
 
73
  return {
74
  "text": text,
75
+ "label": label,
76
+ "confidence": confidence
77
  }
78
 
79