qixun commited on
Commit
9b707b4
·
verified ·
1 Parent(s): fb8a941

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -55
README.md CHANGED
@@ -10,81 +10,105 @@ tags:
10
  使用方法如下:
11
 
12
 
13
- import torch
14
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
- import json
16
- import torch.nn.functional as F
17
- from zhconv import convert
18
- import re
19
 
20
- model_path = "qixun/qilv_classify"
21
 
22
- # 加载模型和分词器
23
- tokenizer = AutoTokenizer.from_pretrained(model_path)
24
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
25
 
26
- # 如果GPU可用,将模型移动到GPU
27
- #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- #model.to(device)
29
 
30
- # 加载标签映射关系,label_mapping.json需要根据本机情况修改
31
- with open("label_mapping.json", "r", encoding="utf-8") as f:
32
- label_mapping = json.load(f)
33
 
34
 
35
- def classify_text(text):
36
 
37
- text = convert(text, 'zh-cn')
38
- # 去掉空格和换行
39
- text = text.replace(" ", "").replace("\n", "")
40
 
41
- # 检查文本长度是否为56个字符
42
- if len(text) != 64:
43
- return "请输入一首带标点的七言律诗"
44
 
45
- unique_characters = set(re.findall(r'[\u4e00-\u9fff]', text))
46
- if len(unique_characters) < 30:
47
- return "请输入一首正常的七言律诗"
48
 
49
- # 准备输入数据
50
- inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
51
 
52
- # 将输入数据移动到GPU
53
- inputs = {key: value.to(device) for key, value in inputs.items()}
54
 
55
- # 模型推断
56
- with torch.no_grad():
57
- outputs = model(**inputs)
58
 
59
- # 获取预测结果
60
- logits = outputs.logits
61
 
62
- # 计算每个类别的概率
63
- probabilities = F.softmax(logits, dim=-1)
64
 
65
- # 获取概率最高的三个分类及其概率
66
- top_k = 3
67
- top_probs, top_indices = torch.topk(probabilities, top_k, dim=-1)
68
 
69
- # 将预测结果转换为标签并附上概率
70
- results = []
71
- for j in range(top_k):
72
- label = label_mapping[str(top_indices[0][j].item())]
73
- prob = top_probs[0][j].item()
74
- results.append((label, prob))
75
 
76
- # 将结果格式化为字符串
77
- result_str = "文本: {}\n".format(text)
78
- for label, prob in results:
79
- result_str += "分类: {}, 概率: {:.4f}\n".format(label, prob)
80
 
81
- return result_str
82
 
83
- # 示例调用
84
- text = "胎禽消息渺难知,小萼妆容故故迟。城郭渐随寒碧敛,湖山刚与晚阴宜,再来恐或成孤往,此去何由问所之。坐对空亭喧冻雀,可堪暝色向人垂。"
85
- result = classify_text(text)
86
- print(result)
87
 
88
 
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  使用方法如下:
11
 
12
 
13
+ import torch
14
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
+ import json
16
+ import torch.nn.functional as F
17
+ from zhconv import convert
18
+ import re
19
 
20
+ model_path = "qixun/qilv_classify"
21
 
22
+ # 加载模型和分词器
23
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
24
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
25
 
26
+ # 如果GPU可用,将模型移动到GPU
27
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ #model.to(device)
29
 
30
+ # 加载标签映射关系,label_mapping.json需要根据本机情况修改
31
+ with open("label_mapping.json", "r", encoding="utf-8") as f:
32
+ label_mapping = json.load(f)
33
 
34
 
35
+ def classify_text(text):
36
 
37
+ text = convert(text, 'zh-cn')
38
+ # 去掉空格和换行
39
+ text = text.replace(" ", "").replace("\n", "")
40
 
41
+ # 检查文本长度是否为56个字符
42
+ if len(text) != 64:
43
+ return "请输入一首带标点的七言律诗"
44
 
45
+ unique_characters = set(re.findall(r'[\u4e00-\u9fff]', text))
46
+ if len(unique_characters) < 30:
47
+ return "请输入一首正常的七言律诗"
48
 
49
+ # 准备输入数据
50
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
51
 
52
+ # 将输入数据移动到GPU
53
+ inputs = {key: value.to(device) for key, value in inputs.items()}
54
 
55
+ # 模型推断
56
+ with torch.no_grad():
57
+ outputs = model(**inputs)
58
 
59
+ # 获取预测结果
60
+ logits = outputs.logits
61
 
62
+ # 计算每个类别的概率
63
+ probabilities = F.softmax(logits, dim=-1)
64
 
65
+ # 获取概率最高的三个分类及其概率
66
+ top_k = 3
67
+ top_probs, top_indices = torch.topk(probabilities, top_k, dim=-1)
68
 
69
+ # 将预测结果转换为标签并附上概率
70
+ results = []
71
+ for j in range(top_k):
72
+ label = label_mapping[str(top_indices[0][j].item())]
73
+ prob = top_probs[0][j].item()
74
+ results.append((label, prob))
75
 
76
+ # 将结果格式化为字符串
77
+ result_str = "文本: {}\n".format(text)
78
+ for label, prob in results:
79
+ result_str += "分类: {}, 概率: {:.4f}\n".format(label, prob)
80
 
81
+ return result_str
82
 
83
+ # 示例调用
84
+ text = "胎禽消息渺难知,小萼妆容故故迟。城郭渐随寒碧敛,湖山刚与晚阴宜,再来恐或成孤往,此去何由问所之。坐对空亭喧冻雀,可堪暝色向人垂。"
85
+ result = classify_text(text)
86
+ print(result)
87
 
88
 
89
 
90
 
91
+
92
+
93
+ 也可以直接在huggingface里输入一首加标点为64字符的简体七言律诗进行测试,label_mapping.json内容为:
94
+
95
+ {
96
+ "0": "中唐",
97
+ "1": "乱码",
98
+ "2": "冲塔",
99
+ "3": "同光",
100
+ "4": "复兴",
101
+ "5": "实验",
102
+ "6": "晚唐",
103
+ "7": "江西",
104
+ "8": "浙",
105
+ "9": "浣花",
106
+ "10": "理学",
107
+ "11": "盛唐",
108
+ "12": "艳体",
109
+ "13": "诗界xx",
110
+ "14": "赣",
111
+ "15": "闽"
112
+ }
113
+
114
+ 大家自行转换。