YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Wuhall/bert-base-chinese-cls
文本分类模型 | 支持目的识别/技术分析/研究内容分类v1.1.0 | BERT | FastAPI | Zero-Shot
分类标签
| Label | 类别 |
|---|---|
| 0 | 研究目的 |
| 1 | 研究内容 |
| 2 | 核心技术 |
| 3 | 其他 |
快速开始
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
class ZeroShotClassifier:
def __init__(self, model_path):
"""初始化分类器"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
# 标签映射
self.label_map = {
0: 0,
1: 1,
2: 2,
3: 3
}
def predict(self, text):
"""对单个文本进行预测"""
# 对文本进行编码
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# 进行预测
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1).item()
# 转换标签
predicted_label = self.label_map[predicted_class]
confidence = predictions[0][predicted_class].item()
return {
"label": predicted_label,
"confidence": confidence
}
def predict_batch(self, texts):
"""对多个文本进行批量预测"""
results = []
for text in texts:
result = self.predict(text)
results.append(result)
return results
# FastAPI部分
app = FastAPI(
title="零样本文本分类服务",
version="1.0.0",
docs_url="/docs"
)
class PredictRequest(BaseModel):
text: str
class PredictResponse(BaseModel):
label: int
confidence: float
# 加载模型(只加载一次)
classifier = ZeroShotClassifier("Wuhall/bert-base-chinese-cls")
@app.post("/predict", response_model=PredictResponse, summary="文本分类预测")
def predict_api(request: PredictRequest):
if not request.text:
raise HTTPException(status_code=400, detail="请提供text字段")
result = classifier.predict(request.text)
return result
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=False)
- Downloads last month
- 1
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support