Upload 4 files
Browse files- hmm_model.pkl.bz2 +3 -0
- hmm_model_large.pkl.bz2 +3 -0
- py2hz.py +144 -0
- readme.md +58 -0
hmm_model.pkl.bz2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfdce3ba8dba202798b09e74a34950e54c5c9014dbdd5d89395f466ecd053f8b
|
| 3 |
+
size 2343356
|
hmm_model_large.pkl.bz2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d5214166cb693789749314a5bfc362fd317b24ecf9ae339f4fc8ea78c7c5cc6
|
| 3 |
+
size 3808875
|
py2hz.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# _*_ coding:utf-8 _*_
|
| 2 |
+
"""
|
| 3 |
+
@Version : 1.0.0
|
| 4 |
+
@Time : 2024年12月27日
|
| 5 |
+
@Author : DuYu (@duyu09, 202103180009@stu.qlu.edu.cn)
|
| 6 |
+
@File : py2hz.py
|
| 7 |
+
@Describe : 基于隐马尔可夫模型(HMM)的拼音转汉字程序。
|
| 8 |
+
@Copyright: Copyright (c) 2024 DuYu (No.202103180009), Faculty of Computer Science & Technology, Qilu University of Technology (Shandong Academy of Sciences).
|
| 9 |
+
@Note : 训练集csv文件,要求第一列为由汉语拼音构成的句子,第二列为由汉字构成的句子。
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
import bz2
|
| 14 |
+
import pickle
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from hmmlearn import hmm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# 1. 数据预处理:加载CSV数据集
|
| 21 |
+
def load_dataset(file_path):
|
| 22 |
+
data = pd.read_csv(file_path)
|
| 23 |
+
sentences = data.iloc[:, 0].tolist() # 第一列:汉字句子
|
| 24 |
+
pinyins = data.iloc[:, 1].tolist() # 第二列:拼音句子
|
| 25 |
+
return sentences, pinyins
|
| 26 |
+
|
| 27 |
+
# 分词函数,确保英文单词保持完整
|
| 28 |
+
def segment_sentence(sentence):
|
| 29 |
+
tokens = re.findall(r'[a-zA-Z]+|[一-鿿]', sentence) # 使用正则表达式分割句子,确保英文单词保持完整
|
| 30 |
+
return tokens
|
| 31 |
+
|
| 32 |
+
# 2. 构建字典和状态观测集合
|
| 33 |
+
def build_vocab(sentences, pinyins):
|
| 34 |
+
hanzi_set = set()
|
| 35 |
+
pinyin_set = set()
|
| 36 |
+
|
| 37 |
+
for sentence, pinyin in zip(sentences, pinyins):
|
| 38 |
+
hanzi_set.update(segment_sentence(sentence))
|
| 39 |
+
pinyin_set.update(pinyin.split())
|
| 40 |
+
|
| 41 |
+
hanzi_list = list(hanzi_set)
|
| 42 |
+
pinyin_list = list(pinyin_set)
|
| 43 |
+
|
| 44 |
+
hanzi2id = {h: i for i, h in enumerate(hanzi_list)}
|
| 45 |
+
id2hanzi = {i: h for i, h in enumerate(hanzi_list)}
|
| 46 |
+
pinyin2id = {p: i for i, p in enumerate(pinyin_list)}
|
| 47 |
+
id2pinyin = {i: p for i, p in enumerate(pinyin_list)}
|
| 48 |
+
|
| 49 |
+
return hanzi2id, id2hanzi, pinyin2id, id2pinyin
|
| 50 |
+
|
| 51 |
+
# 3. 模型训练
|
| 52 |
+
def train_hmm(sentences, pinyins, hanzi2id, pinyin2id):
|
| 53 |
+
n_states = len(hanzi2id)
|
| 54 |
+
n_observations = len(pinyin2id)
|
| 55 |
+
|
| 56 |
+
model = hmm.MultinomialHMM(n_components=n_states, n_iter=100, tol=1e-4)
|
| 57 |
+
|
| 58 |
+
# 统计初始状态概率、转移概率和发射概率
|
| 59 |
+
start_prob = np.zeros(n_states)
|
| 60 |
+
trans_prob = np.zeros((n_states, n_states))
|
| 61 |
+
emit_prob = np.zeros((n_states, n_observations))
|
| 62 |
+
|
| 63 |
+
for sentence, pinyin in zip(sentences, pinyins):
|
| 64 |
+
# print(sentence, pinyin)
|
| 65 |
+
hanzi_seq = [hanzi2id[h] for h in segment_sentence(sentence)]
|
| 66 |
+
pinyin_seq = [pinyin2id[p] for p in pinyin.split()]
|
| 67 |
+
|
| 68 |
+
# 初始状态概率
|
| 69 |
+
start_prob[hanzi_seq[0]] += 1
|
| 70 |
+
|
| 71 |
+
# 转移概率
|
| 72 |
+
for i in range(len(hanzi_seq) - 1):
|
| 73 |
+
trans_prob[hanzi_seq[i], hanzi_seq[i + 1]] += 1
|
| 74 |
+
|
| 75 |
+
# 发射概率
|
| 76 |
+
for h, p in zip(hanzi_seq, pinyin_seq):
|
| 77 |
+
emit_prob[h, p] += 1
|
| 78 |
+
|
| 79 |
+
# 确保矩阵行和为1,并处理全零行
|
| 80 |
+
if start_prob.sum() == 0:
|
| 81 |
+
start_prob += 1
|
| 82 |
+
start_prob /= start_prob.sum()
|
| 83 |
+
|
| 84 |
+
row_sums = trans_prob.sum(axis=1, keepdims=True)
|
| 85 |
+
zero_rows = (row_sums == 0).flatten() # 修复索引错误
|
| 86 |
+
trans_prob[zero_rows, :] = 1.0 / n_states # 用均匀分布填充全零行
|
| 87 |
+
trans_prob /= trans_prob.sum(axis=1, keepdims=True)
|
| 88 |
+
|
| 89 |
+
emit_sums = emit_prob.sum(axis=1, keepdims=True)
|
| 90 |
+
zero_emit_rows = (emit_sums == 0).flatten()
|
| 91 |
+
emit_prob[zero_emit_rows, :] = 1.0 / n_observations # 均匀填充
|
| 92 |
+
emit_prob /= emit_prob.sum(axis=1, keepdims=True)
|
| 93 |
+
|
| 94 |
+
model.startprob_ = start_prob
|
| 95 |
+
model.transmat_ = trans_prob
|
| 96 |
+
model.emissionprob_ = emit_prob
|
| 97 |
+
|
| 98 |
+
return model
|
| 99 |
+
|
| 100 |
+
# 4. 保存和加载模型
|
| 101 |
+
def save_model(model, filepath):
|
| 102 |
+
with bz2.BZ2File(filepath, 'wb') as f:
|
| 103 |
+
pickle.dump(model, f)
|
| 104 |
+
|
| 105 |
+
def load_model(filepath):
|
| 106 |
+
with bz2.BZ2File(filepath, 'rb') as f:
|
| 107 |
+
return pickle.load(f)
|
| 108 |
+
|
| 109 |
+
def predict(model, pinyin_seq, pinyin2id, id2hanzi):
|
| 110 |
+
obs_seq = np.zeros((len(pinyin_seq), len(pinyin2id))) # 转换观测序列为 one-hot 格式
|
| 111 |
+
|
| 112 |
+
for t, p in enumerate(pinyin_seq):
|
| 113 |
+
if p in pinyin2id:
|
| 114 |
+
obs_seq[t, pinyin2id[p]] = 1
|
| 115 |
+
else:
|
| 116 |
+
obs_seq[t, 0] = 1 # 未知拼音默认处理
|
| 117 |
+
|
| 118 |
+
# 解码预测
|
| 119 |
+
model.n_trials = 3 # 运行3次
|
| 120 |
+
log_prob, state_seq = model.decode(obs_seq, algorithm='viterbi')
|
| 121 |
+
result = ''.join([id2hanzi[s] for s in state_seq])
|
| 122 |
+
return result
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2'):
|
| 126 |
+
sentences, pinyins = load_dataset(dataset_path) # 加载数据集
|
| 127 |
+
hanzi2id, id2hanzi, pinyin2id, id2pinyin = build_vocab(sentences, pinyins) # 构建字典
|
| 128 |
+
model = train_hmm(sentences, pinyins, hanzi2id, pinyin2id) # 训练模型
|
| 129 |
+
model.pinyin2id = pinyin2id
|
| 130 |
+
model.id2hanzi = id2hanzi
|
| 131 |
+
model.hanzi2id = hanzi2id
|
| 132 |
+
model.id2pinyin = id2pinyin
|
| 133 |
+
save_model(model, model_path) # 保存模型
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def pred(model_path='hmm_model.pkl.bz2', pinyin_str='ce4 shi4'):
|
| 137 |
+
model = load_model(model_path) # 加载模型
|
| 138 |
+
pinyin_list = pinyin_str.split()
|
| 139 |
+
result = predict(model, pinyin_list, model.pinyin2id, model.id2hanzi)
|
| 140 |
+
print('预测结果:', result)
|
| 141 |
+
|
| 142 |
+
if __name__ == '__main__':
|
| 143 |
+
# train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2')
|
| 144 |
+
pred(model_path='hmm_model.pkl.bz2', pinyin_str='hong2 yan2 bo2 ming4')
|
readme.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## 基于预训练隐马尔可夫模型的汉语拼音序列转汉字语句序列程序
|
| 2 |
+
|
| 3 |
+
**Chinese Pinyin (Hanyu pinyin) to Chinese Character (Hanzi) Conversion Program Based on Pretrained Hidden Markov Model (HMM)**
|
| 4 |
+
|
| 5 |
+
### 项目原理
|
| 6 |
+
|
| 7 |
+
本项目基于 **隐马尔可夫模型 (HMM)** 实现汉语拼音序列到汉字序列的转换。HMM是一种概率模型,假设观察序列(拼音)由隐藏状态序列(汉字)生成,并通过状态转移和发射概率描述序列关系。模型训练时,程序首先加载训练数据,提取拼音和汉字构建词汇表,并统计初始状态、状态转移和发射概率矩阵。训练过程中,HMM使用**最大似然估计**优化这些概率,以捕捉拼音与汉字的映射关系。解码阶段,利用**维特比算法 (Viterbi Algorithm)** 寻找最可能的汉字序列作为输出结果。本项目适合处理语言序列建模和序列标注等的任务。
|
| 8 |
+
|
| 9 |
+
### 数据集准备
|
| 10 |
+
|
| 11 |
+
- 需要`CSV`格式文件,其应包含两列,要求第一列为由汉语拼音构成的句子,第二列为由汉字构成的句子。
|
| 12 |
+
- 数据集示例:
|
| 13 |
+
|
| 14 |
+
| 第一列 (汉字语句) | 第二列 (拼音语句) |
|
| 15 |
+
| ----- | ----- |
|
| 16 |
+
| 我们试试看! | wo3 men shi4 shi4 kan4 ! |
|
| 17 |
+
| 我该去睡觉了。 | wo3 gai1 qu4 shui4 jiao4 le 。 |
|
| 18 |
+
| 你在干什么啊? | ni3 zai4 gan4 shen2 me a ? |
|
| 19 |
+
| 这是什么啊? | zhe4 shi4 shen2 me a ? |
|
| 20 |
+
| 我会尽量不打扰你复习。 | wo3 hui4 jin3 liang4 bu4 da3 rao3 ni3 fu4 xi2 。 |
|
| 21 |
+
|
| 22 |
+
### 训练和推理方法
|
| 23 |
+
|
| 24 |
+
修改`py2hz.py`的主函数代码以运行。我们已开源了基于多领域文本的预训练的模型权重`hmm_model.pkl.bz2`和`hmm_model_large.pkl.bz2`,可以直接使用。`hmm_model.pkl.bz2`规模稍小,可满足日常汉语的转换需求,其解压缩后约为800MB左右;`hmm_model_large.pkl.bz2`覆盖了几乎所有汉字的读音,并在规模更大的语料库上进行训练,其解压缩后约为4.5GB左右。
|
| 25 |
+
|
| 26 |
+
若需自行训练则取消train函数的注释,并修改函数参数。训练完成后模型将会被压缩保存,原因是模型中存在非常稀疏的大矩阵,适合压缩存储。
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
# dataset_path:数据集路径
|
| 30 |
+
# model_path:模型保存路径
|
| 31 |
+
# pinyin_str:待解析的拼音语句。
|
| 32 |
+
if __name__ == '__main__':
|
| 33 |
+
train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2')
|
| 34 |
+
pred(model_path='hmm_model.pkl.bz2', pinyin_str='hong2 yan2 bo2 ming4')
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### 预训练模型效果
|
| 38 |
+
|
| 39 |
+
下表展示了预训练模型`hmm_model.pkl.bz2`的使用效果。
|
| 40 |
+
|
| 41 |
+
| 输入 | 输出 |
|
| 42 |
+
| ----- | ----- |
|
| 43 |
+
| hong2 yan2 bo2 ming4 | 红颜薄命 |
|
| 44 |
+
| guo2 jia1 chao1 suan4 ji3 nan2 zhong1 xin1 | 国家超算济南中心 |
|
| 45 |
+
| liu3 an4 hua1 ming2 you4 yi1 cun1 | 柳暗花明又一村 |
|
| 46 |
+
| gu3 zhi4 shu1 song1 zheng4 | 骨质疏松症 |
|
| 47 |
+
| xi1 an1 dian4 zi3 ke1 ji4 da4 xue2 | 西安电子科技大学 |
|
| 48 |
+
| ye4 mian4 zhi4 huan4 suan4 fa3 | 页面置换算法 |
|
| 49 |
+
|
| 50 |
+
### 作者声明及访客统计
|
| 51 |
+
|
| 52 |
+
Author: Du Yu (202103180009@stu.qlu.edu.cn),
|
| 53 |
+
Faculty of Computer Science and Technology, Qilu University of Technology (Shandong Academy of Sciences).
|
| 54 |
+
|
| 55 |
+
<div><b>Number of Total Visits (All of Duyu09's GitHub Projects): </b><br><img src="https://profile-counter.glitch.me/duyu09/count.svg" /></div>
|
| 56 |
+
|
| 57 |
+
<div><b>Number of Total Visits (py2hz): </b>
|
| 58 |
+
<br><img src="https://profile-counter.glitch.me/py2hz/count.svg" /></div>
|