|
|
--- |
|
|
license: gpl-3.0 |
|
|
pipeline_tag: text-classification |
|
|
tags: |
|
|
- art |
|
|
|
|
|
widget: |
|
|
- text: "牛犊初生敢问天,为官一任史无前。钎锤巧构蓝图景,岩壁砺磨钢铁肩。玉汝于成堪大智,红旗永艳有群贤。十风五雨千秋业,铸就惊天动地篇。" |
|
|
- text: "胎禽消息渺难知,小萼妆容故故迟。城郭渐随寒碧敛,湖山刚与晚阴宜,再来恐或成孤往,此去何由问所之。坐对空亭喧冻雀,可堪暝色向人垂。" |
|
|
- text: "异域风吹残帜斜,呜呼水木不清华。未闻史载分赃制,时见官乘夺路槎。有术掠民腾物价,无能让土息胡笳。两朝竭力推经济,遍地催开血色花。" |
|
|
|
|
|
|
|
|
--- |
|
|
|
|
|
此模型的作用是对输入的简体七言律诗进行风格上的分类,详情见 https://mp.weixin.qq.com/s/P8FVCkI8-anDuLWQIAgs2w |
|
|
|
|
|
使用方法如下: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import json |
|
|
import torch.nn.functional as F |
|
|
from zhconv import convert |
|
|
import re |
|
|
|
|
|
model_path = "qixun/qilv_classify" |
|
|
|
|
|
# 加载模型和分词器 |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
|
|
|
# 如果GPU可用,将模型移动到GPU |
|
|
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
#model.to(device) |
|
|
|
|
|
# 加载标签映射关系,label_mapping.json需要根据本机情况修改 |
|
|
with open("label_mapping.json", "r", encoding="utf-8") as f: |
|
|
label_mapping = json.load(f) |
|
|
|
|
|
|
|
|
def classify_text(text): |
|
|
|
|
|
text = convert(text, 'zh-cn') |
|
|
# 去掉空格和换行 |
|
|
text = text.replace(" ", "").replace("\n", "") |
|
|
|
|
|
# 检查文本长度是否为56个字符 |
|
|
if len(text) != 64: |
|
|
return "请输入一首带标点的七言律诗" |
|
|
|
|
|
unique_characters = set(re.findall(r'[\u4e00-\u9fff]', text)) |
|
|
if len(unique_characters) < 30: |
|
|
return "请输入一首正常的七言律诗" |
|
|
|
|
|
# 准备输入数据 |
|
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512) |
|
|
|
|
|
# 如GPU可用,将输入数据移动到GPU |
|
|
#inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
|
|
|
# 模型推断 |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
# 获取预测结果 |
|
|
logits = outputs.logits |
|
|
|
|
|
# 计算每个类别的概率 |
|
|
probabilities = F.softmax(logits, dim=-1) |
|
|
|
|
|
# 获取概率最高的三个分类及其概率 |
|
|
top_k = 3 |
|
|
top_probs, top_indices = torch.topk(probabilities, top_k, dim=-1) |
|
|
|
|
|
# 将预测结果转换为标签并附上概率 |
|
|
results = [] |
|
|
for j in range(top_k): |
|
|
label = label_mapping[str(top_indices[0][j].item())] |
|
|
prob = top_probs[0][j].item() |
|
|
results.append((label, prob)) |
|
|
|
|
|
# 将结果格式化为字符串 |
|
|
result_str = "文本: {}\n".format(text) |
|
|
for label, prob in results: |
|
|
result_str += "分类: {}, 概率: {:.4f}\n".format(label, prob) |
|
|
|
|
|
return result_str |
|
|
|
|
|
# 示例调用 |
|
|
text = "胎禽消息渺难知,小萼妆容故故迟。城郭渐随寒碧敛,湖山刚与晚阴宜,再来恐或成孤往,此去何由问所之。坐对空亭喧冻雀,可堪暝色向人垂。" |
|
|
result = classify_text(text) |
|
|
print(result) |
|
|
``` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
也可以直接在huggingface里输入一首加标点为64字符的简体七言律诗进行测试,label_mapping.json内容为: |
|
|
|
|
|
{ |
|
|
"0": "中唐", |
|
|
"1": "乱码", |
|
|
"2": "冲塔", |
|
|
"3": "同光", |
|
|
"4": "复兴", |
|
|
"5": "实验", |
|
|
"6": "晚唐", |
|
|
"7": "江西", |
|
|
"8": "浙", |
|
|
"9": "浣花", |
|
|
"10": "理学", |
|
|
"11": "盛唐", |
|
|
"12": "艳体", |
|
|
"13": "诗界xx", |
|
|
"14": "赣", |
|
|
"15": "闽" |
|
|
} |
|
|
|
|
|
大家自行转换。 |