File size: 2,146 Bytes
9d7cae5
 
10e1dd3
 
 
c89c192
 
 
 
 
ce0e702
2af4cbe
e674c9e
 
 
7eb63ce
e674c9e
 
 
2af4cbe
e674c9e
 
2af4cbe
 
e674c9e
 
 
 
 
 
 
 
 
 
7eb63ce
 
 
1105bc7
8d184e5
7eb63ce
 
 
 
 
8d184e5
67a51e8
9baba9f
 
c89c192
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
---
license: apache-2.0
metrics:
- accuracy
pipeline_tag: image-classification
language:
- zh
tags:
- aversarial attack
- Chinese text
---
# Siamese CNN
复现[Argot: Generating Adversarial Readable Chinese Texts IJCAI 2020](https://www.ijcai.org/Proceedings/2020/351) 字形变换的相似结构汉字筛选
## 介绍

基于CNN架构,采用孪生网络训练方式,对输入的汉字对进行编码并计算其欧式距离作为汉字字形相似度度量

![image/png](https://cdn-uploads.huggingface.co/production/uploads/637cd8039a5217b88b72b71c/e5w-gARTB_WgV3Ixg0K-Y.png)

## 架构
三层Conv2D 大小(Input_channel, output_channel, filter_size)= (3,64,8),(64,128,8),(128,128,8)  
每层卷积层后添加MaxPool(2)  
lr = 0.002
## 数据集
汉字来源:https://github.com/zzboy/chinese  
采用pygame生成图片数据,默认采用黑体字体,图片大小为200*200  
上述汉字每行作为相似字符,按照7:3划分数据集  
并参考https://github.com/avilash/pytorch-siamese-triplet 生成三元组训练数据、测试数据,实际训练、测试时采用50000对、10000对三元组数据
## 评估
loss = MarginRankingLoss(margin=1)
|   0% of margin    |    20% of margin             |      50% of margin     | loss |epoch|   
| :--------------- | :---------------------- | :---|:---|:-----|    
|0.9012                  |0.7998   | 0.5700| 0.4674| 10|    
  
0% of margin 相当于准确率  
## 使用
采用Pytorch加载,一般用加载一个CNN模型就可以使用,注意删除state_dict中的字典名字  
```py
  model_dict = torch.load('./checkpoint.pth')['state_dict']
  model_dict_mod = {}
  for key, value in model_dict.items():
    new_key = '.'.join(key.split('.')[1:])
    model_dict_mod[new_key] = value
  self.model.load_state_dict(model_dict_mod)
```
## 文件介绍
 ` prepare_data.py ` 生成数据集,将汉字转换为图片,默认黑体字体,也可以用别的,从C://Windows/Fonts Windows系统上下载  
`character`文件表示训练数据,同一个子文件夹表示其中的汉字是相似的,不同的子文件夹表示汉字不相似  
`train.py`为训练文件