Spaces:
Sleeping
Sleeping
feat: 將預測任務包裝成協程
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from contextlib import asynccontextmanager
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 4 |
from fastapi import FastAPI
|
| 5 |
from fastapi.params import Body
|
| 6 |
from transformers import (
|
|
@@ -49,24 +50,30 @@ async def predict(instance: FastAPI, text: str = Body(..., embed=True)):
|
|
| 49 |
:param text: 待分類的文本
|
| 50 |
:return: 預測結果,包括文本、預測類別和置信度
|
| 51 |
"""
|
| 52 |
-
# 獲取對象
|
| 53 |
-
tokenizer, model = instance.state.tokenizer, instance.state.model
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
outputs = model(**inputs)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
return {
|
| 67 |
"text": text,
|
| 68 |
-
"label":
|
| 69 |
-
"confidence":
|
| 70 |
}
|
| 71 |
|
| 72 |
|
|
|
|
| 1 |
from contextlib import asynccontextmanager
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
from anyio.to_thread import run_sync
|
| 5 |
from fastapi import FastAPI
|
| 6 |
from fastapi.params import Body
|
| 7 |
from transformers import (
|
|
|
|
| 50 |
:param text: 待分類的文本
|
| 51 |
:return: 預測結果,包括文本、預測類別和置信度
|
| 52 |
"""
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
def _inference():
|
| 55 |
+
# 獲取對象
|
| 56 |
+
tokenizer, model = instance.state.tokenizer, instance.state.model
|
| 57 |
|
| 58 |
+
# 分詞處理
|
| 59 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
|
|
|
|
| 60 |
|
| 61 |
+
# 推理
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
outputs = model(**inputs)
|
| 64 |
+
|
| 65 |
+
# 處理輸出
|
| 66 |
+
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 67 |
+
confidences, classes = torch.max(predictions, dim=-1)
|
| 68 |
+
|
| 69 |
+
return classes.item(), confidences.item()
|
| 70 |
+
|
| 71 |
+
label, confidence = await run_sync(_inference)
|
| 72 |
|
| 73 |
return {
|
| 74 |
"text": text,
|
| 75 |
+
"label": label,
|
| 76 |
+
"confidence": confidence
|
| 77 |
}
|
| 78 |
|
| 79 |
|