raptorkwok commited on
Commit
ddeed21
·
1 Parent(s): fe6f409

change tokenization strategy to PyCantonese

Browse files
Files changed (2) hide show
  1. chinesemeteor.py +9 -2
  2. requirements.txt +2 -1
chinesemeteor.py CHANGED
@@ -43,6 +43,7 @@ from nltk import word_tokenize
43
  import nltk
44
  import evaluate
45
  import re
 
46
 
47
  # Download once
48
  nltk.download("wordnet", quiet=True)
@@ -151,6 +152,10 @@ class ChineseMETEOR(evaluate.Metric):
151
  nltk.download('punkt_tab', quiet=True)
152
  # CwnGraph auto-downloads on first use
153
 
 
 
 
 
154
  def _compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
155
  pred_seg = [" ".join(jieba.cut(p.strip())) for p in predictions]
156
  ref_seg = [" ".join(jieba.cut(r.strip())) for r in references]
@@ -218,8 +223,10 @@ class ChineseMETEOR(evaluate.Metric):
218
 
219
  scores = [
220
  meteor_score.single_meteor_score(
221
- word_tokenize(ref),
222
- word_tokenize(hyp),
 
 
223
  wordnet=chinese_wn
224
  )
225
  for ref, hyp in zip(ref_seg, pred_seg)
 
43
  import nltk
44
  import evaluate
45
  import re
46
+ import pycantonese
47
 
48
  # Download once
49
  nltk.download("wordnet", quiet=True)
 
152
  nltk.download('punkt_tab', quiet=True)
153
  # CwnGraph auto-downloads on first use
154
 
155
+ def _tokenize_chinese(self, sentence):
156
+ """Tokenize Chinese sentence using PyCantonese"""
157
+ return pycantonese.segment(sentence)
158
+
159
  def _compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
160
  pred_seg = [" ".join(jieba.cut(p.strip())) for p in predictions]
161
  ref_seg = [" ".join(jieba.cut(r.strip())) for r in references]
 
223
 
224
  scores = [
225
  meteor_score.single_meteor_score(
226
+ #word_tokenize(ref),
227
+ self._tokenize_chinese(ref),
228
+ #word_tokenize(hyp),
229
+ self._tokenize_chinese(hyp),
230
  wordnet=chinese_wn
231
  )
232
  for ref, hyp in zip(ref_seg, pred_seg)
requirements.txt CHANGED
@@ -2,4 +2,5 @@ evaluate>=0.4.1
2
  jieba_fast
3
  CwnGraph>=0.3.0
4
  nltk>=3.8
5
- numpy
 
 
2
  jieba_fast
3
  CwnGraph>=0.3.0
4
  nltk>=3.8
5
+ numpy
6
+ pycantonese