AlexTANG-JX commited on
Commit
48652bc
·
verified ·
1 Parent(s): 95dc670

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +99 -3
README.md CHANGED
@@ -1,3 +1,99 @@
1
- ---
2
- license: unknown
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DistilBERT Question Detector Model
2
+ # DistilBERT 占卜问题检测模型
3
+
4
+ 本项目提供了一个基于 `DistilBERT` 占卜问题检测模型,可用于判断输入文本是否为符合塔罗占卜的问题。
5
+ ## 📂 目录结构
6
+ model.safetensors: The trained model weights.
7
+ config.json: The configuration file for the model architecture.
8
+ tokenizer.json: The tokenizer configuration.
9
+ special_tokens_map.json: The special tokens configuration.
10
+ vocab.txt: The vocabulary file for the tokenizer.
11
+ ---
12
+ ## 🚀 快速开始
13
+ ### **1️⃣ 安装依赖**
14
+ 请确保你的环境已安装 Python 3.8+,然后运行以下命令安装所需的依赖库:
15
+
16
+ pip install torch transformers fastapi uvicorn safetensors
17
+
18
+ ### **2️⃣ 直接运行推理**
19
+ 如果你想直接在本地测试模型,可以运行 inference.py:
20
+ python inference.py
21
+ 示例代码(inference.py):
22
+ ```python
23
+ import torch
24
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
25
+
26
+ # 1. 加载模型
27
+ model_path = "./distilbert-question-detector/checkpoint-5150"
28
+ tokenizer = DistilBertTokenizer.from_pretrained(model_path)
29
+ model = DistilBertForSequenceClassification.from_pretrained(model_path)
30
+ model.eval()
31
+
32
+ # 2. 进行推理
33
+ text = "Is this a question?"
34
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
35
+
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+ logits = outputs.logits
39
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
40
+
41
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
42
+
43
+ print(f"Probabilities: {probabilities}")
44
+ print(f"Predicted class: {predicted_class}") # 1 代表是疑问句,0 代表不是
45
+ ```
46
+ ### **3️⃣ 运行 API**
47
+ 你也可以使用 FastAPI 部署一个 HTTP 接口,允许其他应用通过 HTTP 请求访问模型。
48
+ uvicorn app:app --host 0.0.0.0 --port 8000
49
+ 示例 API 代码(app.py):
50
+ ```python
51
+ from fastapi import FastAPI
52
+ import torch
53
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
54
+
55
+ app = FastAPI()
56
+
57
+ # 加载模型
58
+ model_path = "./distilbert-question-detector/checkpoint-5150"
59
+ tokenizer = DistilBertTokenizer.from_pretrained(model_path)
60
+ model = DistilBertForSequenceClassification.from_pretrained(model_path)
61
+ model.eval()
62
+
63
+ @app.post("/predict/")
64
+ async def predict(text: str):
65
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
66
+
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ logits = outputs.logits
70
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
71
+
72
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
73
+ return {"text": text, "probabilities": probabilities.tolist(), "predicted_class": predicted_class}
74
+ ```
75
+ API 运行后,可通过以下方式测试:
76
+ ```sh
77
+ curl -X 'POST' \
78
+ 'http://127.0.0.1:8000/predict/' \
79
+ -H 'Content-Type: application/json' \
80
+ -d '{"text": "Is this a valid question?"}'
81
+ ```
82
+ ## 📌 结果说明
83
+ predicted_class: 0 代表输入文本是符合条件
84
+ predicted_class: 1 代表输入文本不符合条件
85
+ 示例结果
86
+ ```json
87
+ {
88
+ "text": "Is this a valid question?",
89
+ "probabilities": [[0.9266, 0.0734]],
90
+ "predicted_class": 0
91
+ }
92
+ ```
93
+ ## 🔧 其他部署方案(可选)
94
+ 如果你希望将模型部署到云端,可以选择:
95
+
96
+ Hugging Face Hub: 上传 model.safetensors 到 🤗 Hugging Face
97
+ AWS SageMaker: 使用 Amazon SageMaker 进行云端推理
98
+ Docker 部署: 将 FastAPI 端点封装到 Docker 容器中
99
+