File size: 3,510 Bytes
fa3af31
 
 
 
 
 
 
 
 
 
 
 
 
48652bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffadfb5
48652bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3af31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
---
license: afl-3.0
language:
- en
metrics:
- accuracy
base_model:
- distilbert/distilbert-base-uncased
pipeline_tag: text-classification
tags:
- tarot
- question-detector
---
DistilBERT Question Detector Model
# DistilBERT 占卜问题检测模型

本项目提供了一个基于 `DistilBERT` 占卜问题检测模型,可用于判断输入文本是否为符合塔罗占卜的问题。  
## 📂 目录结构
model.safetensors: The trained model weights.
config.json: The configuration file for the model architecture.
tokenizer.json: The tokenizer configuration.
special_tokens_map.json: The special tokens configuration.
vocab.txt: The vocabulary file for the tokenizer.
---
## 🚀 快速开始  
### **1️⃣ 安装依赖**
请确保你的环境已安装 Python 3.8+,然后运行以下命令安装所需的依赖库:

pip install torch transformers fastapi uvicorn safetensors

### **2️⃣ 直接运行推理**
如果你想直接在本地测试模型,可以运行 inference.py:
python inference.py
示例代码(inference.py):
```python
    import torch
    from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

    # 1. 加载模型
    model_path = "./distilbert-question-detector"
    tokenizer = DistilBertTokenizer.from_pretrained(model_path)
    model = DistilBertForSequenceClassification.from_pretrained(model_path)
    model.eval()

    # 2. 进行推理
    text = "Is this a question?"
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits, dim=-1)

    predicted_class = torch.argmax(probabilities, dim=-1).item()

    print(f"Probabilities: {probabilities}")
    print(f"Predicted class: {predicted_class}")  # 1 代表是疑问句,0 代表不是
```
### **3️⃣ 运行 API**
你也可以使用 FastAPI 部署一个 HTTP 接口,允许其他应用通过 HTTP 请求访问模型。
uvicorn app:app --host 0.0.0.0 --port 8000
示例 API 代码(app.py):
```python
    from fastapi import FastAPI
    import torch
    from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

    app = FastAPI()

    # 加载模型
    model_path = "./distilbert-question-detector/checkpoint-5150"
    tokenizer = DistilBertTokenizer.from_pretrained(model_path)
    model = DistilBertForSequenceClassification.from_pretrained(model_path)
    model.eval()

    @app.post("/predict/")
    async def predict(text: str):
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = torch.nn.functional.softmax(logits, dim=-1)

        predicted_class = torch.argmax(probabilities, dim=-1).item()
        return {"text": text, "probabilities": probabilities.tolist(), "predicted_class": predicted_class}
```
API 运行后,可通过以下方式测试:
```sh
curl -X 'POST' \
  'http://127.0.0.1:8000/predict/' \
  -H 'Content-Type: application/json' \
  -d '{"text": "Is this a valid question?"}'
```
## 📌 结果说明
predicted_class: 0 代表输入文本是符合条件
predicted_class: 1 代表输入文本不符合条件
示例结果
```json
{
    "text": "Is this a valid question?",
    "probabilities": [[0.9266, 0.0734]],
    "predicted_class": 0
}
```