|
|
--- |
|
|
license: afl-3.0 |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
base_model: |
|
|
- distilbert/distilbert-base-uncased |
|
|
pipeline_tag: text-classification |
|
|
tags: |
|
|
- tarot |
|
|
- question-detector |
|
|
--- |
|
|
DistilBERT Question Detector Model |
|
|
# DistilBERT 占卜问题检测模型 |
|
|
|
|
|
本项目提供了一个基于 `DistilBERT` 占卜问题检测模型,可用于判断输入文本是否为符合塔罗占卜的问题。 |
|
|
## 📂 目录结构 |
|
|
model.safetensors: The trained model weights. |
|
|
config.json: The configuration file for the model architecture. |
|
|
tokenizer.json: The tokenizer configuration. |
|
|
special_tokens_map.json: The special tokens configuration. |
|
|
vocab.txt: The vocabulary file for the tokenizer. |
|
|
--- |
|
|
## 🚀 快速开始 |
|
|
### **1️⃣ 安装依赖** |
|
|
请确保你的环境已安装 Python 3.8+,然后运行以下命令安装所需的依赖库: |
|
|
|
|
|
pip install torch transformers fastapi uvicorn safetensors |
|
|
|
|
|
### **2️⃣ 直接运行推理** |
|
|
如果你想直接在本地测试模型,可以运行 inference.py: |
|
|
python inference.py |
|
|
示例代码(inference.py): |
|
|
```python |
|
|
import torch |
|
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
|
|
|
|
# 1. 加载模型 |
|
|
model_path = "./distilbert-question-detector" |
|
|
tokenizer = DistilBertTokenizer.from_pretrained(model_path) |
|
|
model = DistilBertForSequenceClassification.from_pretrained(model_path) |
|
|
model.eval() |
|
|
|
|
|
# 2. 进行推理 |
|
|
text = "Is this a question?" |
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
predicted_class = torch.argmax(probabilities, dim=-1).item() |
|
|
|
|
|
print(f"Probabilities: {probabilities}") |
|
|
print(f"Predicted class: {predicted_class}") # 1 代表是疑问句,0 代表不是 |
|
|
``` |
|
|
### **3️⃣ 运行 API** |
|
|
你也可以使用 FastAPI 部署一个 HTTP 接口,允许其他应用通过 HTTP 请求访问模型。 |
|
|
uvicorn app:app --host 0.0.0.0 --port 8000 |
|
|
示例 API 代码(app.py): |
|
|
```python |
|
|
from fastapi import FastAPI |
|
|
import torch |
|
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
# 加载模型 |
|
|
model_path = "./distilbert-question-detector/checkpoint-5150" |
|
|
tokenizer = DistilBertTokenizer.from_pretrained(model_path) |
|
|
model = DistilBertForSequenceClassification.from_pretrained(model_path) |
|
|
model.eval() |
|
|
|
|
|
@app.post("/predict/") |
|
|
async def predict(text: str): |
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
predicted_class = torch.argmax(probabilities, dim=-1).item() |
|
|
return {"text": text, "probabilities": probabilities.tolist(), "predicted_class": predicted_class} |
|
|
``` |
|
|
API 运行后,可通过以下方式测试: |
|
|
```sh |
|
|
curl -X 'POST' \ |
|
|
'http://127.0.0.1:8000/predict/' \ |
|
|
-H 'Content-Type: application/json' \ |
|
|
-d '{"text": "Is this a valid question?"}' |
|
|
``` |
|
|
## 📌 结果说明 |
|
|
predicted_class: 0 代表输入文本是符合条件 |
|
|
predicted_class: 1 代表输入文本不符合条件 |
|
|
示例结果 |
|
|
```json |
|
|
{ |
|
|
"text": "Is this a valid question?", |
|
|
"probabilities": [[0.9266, 0.0734]], |
|
|
"predicted_class": 0 |
|
|
} |
|
|
``` |