rkingzhong commited on
Commit
b5f497f
·
verified ·
1 Parent(s): 6de60c3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -3
README.md CHANGED
@@ -1,3 +1,56 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ github repo: [github.com/ximeiorg/ochw](https://github.com/ximeiorg/ochw)
6
+
7
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65090e91f6a2bf0be7b61f22/JE6RBeidskDuD5Zgj-r9a.png)
8
+ python 推理代码:
9
+
10
+ ```python
11
+ def get_labels():
12
+ labels = []
13
+ with open("data/label.txt", "r", encoding="utf-8") as f:
14
+ for line in f:
15
+ # line: ! 0
16
+ line = line.strip()
17
+ label = line.split("\t")[0]
18
+ labels.append(label)
19
+ return labels
20
+
21
+ if __name__ == "__main__":
22
+ model = HandwritingTrainer.load_from_checkpoint("checkpoint-epoch=32-val_loss=0.156.ckpt")
23
+ model.eval()
24
+ model = model.to("cuda")
25
+ img = Image.open("./testdata/hui.png")
26
+ img = img.convert("RGB")
27
+ img = img.resize((96,96))
28
+ rans = transforms.Compose([
29
+ transforms.Resize((96, 96)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.95], std=[0.2])
32
+ ])
33
+ img = trans(img)
34
+ img = img.unsqueeze(0)
35
+ img = img.to("cuda")
36
+ labels = get_labels()
37
+ with torch.no_grad():
38
+ output = model(img)
39
+ output = torch.nn.functional.softmax(output,dim=1)
40
+ # 获取top5的预测结果
41
+ top5_prob, top5_idx = torch.topk(output, 5)
42
+ top5_prob = top5_prob.cpu().numpy()
43
+ top5_idx = top5_idx.cpu().numpy()
44
+ for i in range(5):
45
+ idx = top5_idx[0][i]
46
+ print(f"Top {i+1} 预测标签: {labels[idx]}, 概率: {top5_prob[0][i]:.4f}")
47
+
48
+ ```
49
+
50
+ 得到的结果如下:
51
+ ```
52
+ Top 1 预测标签: 知, 概率: 0.9505
53
+ Top 2 预测标签: 勉, 概率: 0.0095
54
+ Top 3 预测标签: 贮, 概率: 0.0025
55
+ Top 4 预测标签: 处, 概率: 0.0025
56
+ Top 5 预测标签: ‰, 概率: 0.0025