Spaces:
Sleeping
Sleeping
| from contextlib import asynccontextmanager | |
| import torch | |
| from anyio.to_thread import run_sync | |
| from fastapi import FastAPI, Request | |
| from fastapi.params import Body | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| XLMRobertaTokenizer, | |
| XLMRobertaForSequenceClassification | |
| ) | |
| # 模型路徑 | |
| MODEL_PATH = "models/Unified_Prompt_Guard" | |
| # 設備 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 标签映射 | |
| LABEL_MAP = {0: "safe", 1: "unsafe"} | |
| def load_model() -> tuple[XLMRobertaTokenizer, XLMRobertaForSequenceClassification]: | |
| """加載模型""" | |
| # 加載分詞器 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True) | |
| # 加載模型 | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH, local_files_only=True) | |
| model.to(device) | |
| model.eval() | |
| return tokenizer, model | |
| async def lifespan(instance: FastAPI): | |
| """ | |
| FastAPI 應用程序的生命周期管理器。 | |
| :param instance: FastAPI 應用程序實例 | |
| """ | |
| # 加載模型 | |
| instance.state.tokenizer, instance.state.model = load_model() | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| async def predict(request: Request, text: str = Body(..., embed=True)): | |
| """ | |
| 使用預訓練的模型進行文本分類預測。 | |
| :param instance: FastAPI 應用程序實例 | |
| :param text: 待分類的文本 | |
| :return: 預測結果,包括文本、預測類別和置信度 | |
| """ | |
| def _inference(): | |
| # 獲取對象 | |
| tokenizer, model = request.app.state.tokenizer, request.app.state.model | |
| # 分詞處理 | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
| # 推理 | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # 處理輸出 | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| confidences, classes = torch.max(predictions, dim=-1) | |
| return classes.item(), confidences.item() | |
| label, confidence = await run_sync(_inference) | |
| return { | |
| "text": text, | |
| "label": LABEL_MAP.get(label), | |
| "confidence": confidence, | |
| } | |
| def greet_json(): | |
| """ | |
| 返回一個 JSON 格式的歡迎訊息。 | |
| """ | |
| return {"Hello": "World!"} | |
| if __name__ == '__main__': | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000) | |