File size: 20,805 Bytes
d596074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./ngram_entropy_pruning.py \
-threshold 1e-8 \
-lm download/lm/4gram.arpa \
-write-lm download/lm/4gram_pruned_1e8.arpa
This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
in the same way as SRILM.
"""
import argparse
import gzip
import logging
import math
import re
from collections import OrderedDict, defaultdict
from enum import Enum, unique
from io import StringIO
parser = argparse.ArgumentParser(
description="""
Prune an n-gram language model based on the relative entropy
between the original and the pruned model, based on Andreas Stolcke's paper.
An n-gram entry is removed, if the removal causes (training set) perplexity
of the model to increase by less than threshold relative.
The command takes an arpa file and a pruning threshold as input,
and outputs a pruned arpa file.
"""
)
parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
parser.add_argument(
"-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
)
parser.add_argument(
"-minorder",
type=int,
default=1,
help="The minorder parameter limits pruning to ngrams of that length and above.",
)
parser.add_argument(
"-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
)
parser.add_argument(
"-verbose",
type=int,
default=2,
choices=[0, 1, 2, 3, 4, 5],
help="Verbose level, where 0 is most noisy; 5 is most silent",
)
args = parser.parse_args()
default_encoding = args.encoding
logging.basicConfig(
format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
level=args.verbose * 10,
)
class Context(dict):
"""
This class stores data for a context h.
It behaves like a python dict object, except that it has several
additional attributes.
"""
def __init__(self):
super().__init__()
self.log_bo = None
class Arpa:
"""
This is a class that implement the data structure of an APRA LM.
It (as well as some other classes) is modified based on the library
by Stefan Fischer:
https://github.com/sfischer13/python-arpa
"""
UNK = "<unk>"
SOS = "<s>"
EOS = "</s>"
FLOAT_NDIGITS = 7
base = 10
@staticmethod
def _check_input(my_input):
if not my_input:
raise ValueError
elif isinstance(my_input, tuple):
return my_input
elif isinstance(my_input, list):
return tuple(my_input)
elif isinstance(my_input, str):
return tuple(my_input.strip().split(" "))
else:
raise ValueError
@staticmethod
def _check_word(input_word):
if not isinstance(input_word, str):
raise ValueError
if " " in input_word:
raise ValueError
def _replace_unks(self, words):
return tuple((w if w in self else self._unk) for w in words)
def __init__(self, path=None, encoding=None, unk=None):
self._counts = OrderedDict()
self._ngrams = (
OrderedDict()
) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
self._vocabulary = set()
if unk is None:
self._unk = self.UNK
if path is not None:
self.loadf(path, encoding)
def __contains__(self, ngram):
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
def contains_word(self, word):
self._check_word(word)
return word in self._vocabulary
def add_count(self, order, count):
self._counts[order] = count
self._ngrams[order - 1] = defaultdict(Context)
def update_counts(self):
for order in range(1, self.order() + 1):
count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
if count > 0:
self._counts[order] = count
def add_entry(self, ngram, p, bo=None, order=None):
# Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
# Note that p and bo here are in fact in the log domain (self.base = 10)
h_context = self._ngrams[len(h)][h]
h_context[w] = p
if bo is not None:
self._ngrams[len(ngram)][ngram].log_bo = bo
for word in ngram:
self._vocabulary.add(word)
def counts(self):
return sorted(self._counts.items())
def order(self):
return max(self._counts.keys(), default=None)
def vocabulary(self, sort=True):
if sort:
return sorted(self._vocabulary)
else:
return self._vocabulary
def _entries(self, order):
return (
self._entry(h, w)
for h, wlist in self._ngrams[order - 1].items()
for w in wlist
)
def _entry(self, h, w):
# return the entry for the ngram (h, w)
ngram = h + (w,)
log_p = self._ngrams[len(h)][h][w]
log_bo = self._log_bo(ngram)
if log_bo is not None:
return (
round(log_p, self.FLOAT_NDIGITS),
ngram,
round(log_bo, self.FLOAT_NDIGITS),
)
else:
return round(log_p, self.FLOAT_NDIGITS), ngram
def _log_bo(self, ngram):
if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
return self._ngrams[len(ngram)][ngram].log_bo
else:
return None
def _log_p(self, ngram):
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
return self._ngrams[len(h)][h][w]
else:
return None
def log_p_raw(self, ngram):
log_p = self._log_p(ngram)
if log_p is not None:
return log_p
else:
if len(ngram) == 1:
raise KeyError
else:
log_bo = self._log_bo(ngram[:-1])
if log_bo is None:
log_bo = 0
return log_bo + self.log_p_raw(ngram[1:])
def log_joint_prob(self, sequence):
# Compute the joint prob of the sequence based on the chain rule
# Note that sequence should be a tuple of strings
#
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
log_joint_p = 0
seq = sequence
while len(seq) > 0:
log_joint_p += self.log_p_raw(seq)
seq = seq[:-1]
# If we're computing the marginal probability of the unigram
# <s> context we have to look up </s> instead since the former
# has prob = 0.
if len(seq) == 1 and seq[0] == self.SOS:
seq = (self.EOS,)
return log_joint_p
def set_new_context(self, h):
old_context = self._ngrams[len(h)][h]
self._ngrams[len(h)][h] = Context()
return old_context
def log_p(self, ngram):
words = self._check_input(ngram)
if self._unk:
words = self._replace_unks(words)
return self.log_p_raw(words)
def log_s(self, sentence, sos=SOS, eos=EOS):
words = self._check_input(sentence)
if self._unk:
words = self._replace_unks(words)
if sos:
words = (sos,) + words
if eos:
words = words + (eos,)
result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
if sos:
result = result - self.log_p_raw(words[:1])
return result
def p(self, ngram):
return self.base ** self.log_p(ngram)
def s(self, sentence):
return self.base ** self.log_s(sentence)
def write(self, fp):
fp.write("\n\\data\\\n")
for order, count in self.counts():
fp.write("ngram {}={}\n".format(order, count))
fp.write("\n")
for order, _ in self.counts():
fp.write("\\{}-grams:\n".format(order))
for e in self._entries(order):
prob = e[0]
ngram = " ".join(e[1])
if len(e) == 2:
fp.write("{}\t{}\n".format(prob, ngram))
elif len(e) == 3:
backoff = e[2]
fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
else:
raise ValueError
fp.write("\n")
fp.write("\\end\\\n")
class ArpaParser:
"""
This is a class that implement a parser of an arpa file
"""
@unique
class State(Enum):
DATA = 1
COUNT = 2
HEADER = 3
ENTRY = 4
re_count = re.compile(r"^ngram (\d+)=(\d+)$")
re_header = re.compile(r"^\\(\d+)-grams:$")
re_entry = re.compile(
"^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
"\t"
"(\\S+( \\S+)*)"
"(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
)
def _parse(self, fp):
self._result = []
self._state = self.State.DATA
self._tmp_model = None
self._tmp_order = None
for line in fp:
line = line.strip()
if self._state == self.State.DATA:
self._data(line)
elif self._state == self.State.COUNT:
self._count(line)
elif self._state == self.State.HEADER:
self._header(line)
elif self._state == self.State.ENTRY:
self._entry(line)
if self._state != self.State.DATA:
raise Exception(line)
return self._result
def _data(self, line):
if line == "\\data\\":
self._state = self.State.COUNT
self._tmp_model = Arpa()
else:
pass # skip comment line
def _count(self, line):
match = self.re_count.match(line)
if match:
order = match.group(1)
count = match.group(2)
self._tmp_model.add_count(int(order), int(count))
elif not line:
self._state = self.State.HEADER # there are no counts
else:
raise Exception(line)
def _header(self, line):
match = self.re_header.match(line)
if match:
self._state = self.State.ENTRY
self._tmp_order = int(match.group(1))
elif line == "\\end\\":
self._result.append(self._tmp_model)
self._state = self.State.DATA
self._tmp_model = None
self._tmp_order = None
elif not line:
pass # skip empty line
else:
raise Exception(line)
def _entry(self, line):
match = self.re_entry.match(line)
if match:
p = self._float_or_int(match.group(1))
ngram = tuple(match.group(4).split(" "))
bo_match = match.group(7)
bo = self._float_or_int(bo_match) if bo_match else None
self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
elif not line:
self._state = self.State.HEADER # last entry
else:
raise Exception(line)
@staticmethod
def _float_or_int(s):
f = float(s)
i = int(f)
if str(i) == s: # don't drop trailing ".0"
return i
else:
return f
def load(self, fp):
"""Deserialize fp (a file-like object) to a Python object."""
return self._parse(fp)
def loadf(self, path, encoding=None):
"""Deserialize path (.arpa, .gz) to a Python object."""
path = str(path)
if path.endswith(".gz"):
with gzip.open(path, mode="rt", encoding=encoding) as f:
return self.load(f)
else:
with open(path, mode="rt", encoding=encoding) as f:
return self.load(f)
def loads(self, s):
"""Deserialize s (a str) to a Python object."""
with StringIO(s) as f:
return self.load(f)
def dump(self, obj, fp):
"""Serialize obj to fp (a file-like object) in ARPA format."""
obj.write(fp)
def dumpf(self, obj, path, encoding=None):
"""Serialize obj to path in ARPA format (.arpa, .gz)."""
path = str(path)
if path.endswith(".gz"):
with gzip.open(path, mode="wt", encoding=encoding) as f:
return self.dump(obj, f)
else:
with open(path, mode="wt", encoding=encoding) as f:
self.dump(obj, f)
def dumps(self, obj):
"""Serialize obj to an ARPA formatted str."""
with StringIO() as f:
self.dump(obj, f)
return f.getvalue()
def add_log_p(prev_log_sum, log_p, base):
return math.log(base**log_p + base**prev_log_sum, base)
def compute_numerator_denominator(lm, h):
log_sum_seen_h = -math.inf
log_sum_seen_h_lower = -math.inf
base = lm.base
for w, log_p in lm._ngrams[len(h)][h].items():
log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
ngram = h + (w,)
log_p_lower = lm.log_p_raw(ngram[1:])
log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
numerator = 1.0 - base**log_sum_seen_h
denominator = 1.0 - base**log_sum_seen_h_lower
return numerator, denominator
def prune(lm, threshold, minorder):
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
for i in range(
lm.order(), max(minorder - 1, 1), -1
): # i is the order of the ngram (h, w)
logging.info("processing %d-grams ..." % i)
count_pruned_ngrams = 0
h_dict = lm._ngrams[i - 1]
for h in list(h_dict.keys()):
# old backoff weight, BOW(h)
log_bow = lm._log_bo(h)
if log_bow is None:
log_bow = 0
# Compute numerator and denominator of the backoff weight,
# so that we can quickly compute the BOW adjustment due to
# leaving out one prob.
numerator, denominator = compute_numerator_denominator(lm, h)
# assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
# Compute the marginal probability of the context, P(h)
h_log_p = lm.log_joint_prob(h)
all_pruned = True
pruned_w_set = set()
for w, log_p in h_dict[h].items():
ngram = h + (w,)
# lower-order estimate for ngramProb, P(w|h')
backoff_prob = lm.log_p_raw(ngram[1:])
# Compute BOW after removing ngram, BOW'(h)
new_log_bow = math.log(
numerator + lm.base**log_p, lm.base
) - math.log(denominator + lm.base**backoff_prob, lm.base)
# Compute change in entropy due to removal of ngram
delta_prob = backoff_prob + new_log_bow - log_p
delta_entropy = -(lm.base**h_log_p) * (
(lm.base**log_p) * delta_prob
+ numerator * (new_log_bow - log_bow)
)
# compute relative change in model (training set) perplexity
perp_change = lm.base**delta_entropy - 1.0
pruned = threshold > 0 and perp_change < threshold
# Make sure we don't prune ngrams whose backoff nodes are needed
if (
pruned
and len(ngram) in lm._ngrams
and len(lm._ngrams[len(ngram)][ngram]) > 0
):
pruned = False
logging.debug(
"CONTEXT "
+ str(h)
+ " WORD "
+ w
+ " CONTEXTPROB %f " % h_log_p
+ " OLDPROB %f " % log_p
+ " NEWPROB %f " % (backoff_prob + new_log_bow)
+ " DELTA-H %f " % delta_entropy
+ " DELTA-LOGP %f " % delta_prob
+ " PPL-CHANGE %f " % perp_change
+ " PRUNED "
+ str(pruned)
)
if pruned:
pruned_w_set.add(w)
count_pruned_ngrams += 1
else:
all_pruned = False
# If we removed all ngrams for this context we can
# remove the context itself, but only if the present
# context is not a prefix to a longer one.
if all_pruned and len(pruned_w_set) == len(h_dict[h]):
del h_dict[
h
] # this context h is no longer needed, as its ngram prob is stored at its own context h'
elif len(pruned_w_set) > 0:
# The pruning for this context h is actually done here
old_context = lm.set_new_context(h)
for w, p_w in old_context.items():
if w not in pruned_w_set:
lm.add_entry(
h + (w,), p_w
) # the entry hw is stored at the context h
# We need to recompute the back-off weight, but
# this can only be done after completing the pruning
# of the lower-order ngrams.
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
# recompute backoff weights
for i in range(
max(minorder - 1, 1) + 1, lm.order() + 1
): # be careful of this order: from low- to high-order
for h in lm._ngrams[i - 1]:
numerator, denominator = compute_numerator_denominator(lm, h)
new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
lm._ngrams[len(h)][h].log_bo = new_log_bow
# update counts
lm.update_counts()
return
def check_h_is_valid(lm, h):
sum_under_h = sum(
[lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
)
if abs(sum_under_h - 1.0) > 1e-6:
logging.info("warning: %s %f" % (str(h), sum_under_h))
return False
else:
return True
def validate_lm(lm):
# sanity check if the conditional probability sums to one under each context h
for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
logging.info("validating %d-grams ..." % i)
h_dict = lm._ngrams[i - 1]
for h in h_dict.keys():
check_h_is_valid(lm, h)
def compare_two_apras(path1, path2):
pass
if __name__ == "__main__":
# load an arpa file
logging.info("Loading the arpa file from %s" % args.lm)
parser = ArpaParser()
models = parser.loadf(args.lm, encoding=default_encoding)
lm = models[0] # ARPA files may contain several models.
logging.info("Stats before pruning:")
for i, cnt in lm.counts():
logging.info("ngram %d=%d" % (i, cnt))
# prune it, the language model will be modified in-place
logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
prune(lm, args.threshold, args.minorder)
# validate_lm(lm)
# write the arpa language model to a file
logging.info("Stats after pruning:")
for i, cnt in lm.counts():
logging.info("ngram %d=%d" % (i, cnt))
logging.info("Saving the pruned arpa file to %s" % args.write_lm)
parser.dumpf(lm, args.write_lm, encoding=default_encoding)
logging.info("Done.")
|