|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Usage: |
|
|
./ngram_entropy_pruning.py \ |
|
|
-threshold 1e-8 \ |
|
|
-lm download/lm/4gram.arpa \ |
|
|
-write-lm download/lm/4gram_pruned_1e8.arpa |
|
|
|
|
|
This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. |
|
|
This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' |
|
|
in the same way as SRILM. |
|
|
""" |
|
|
|
|
|
|
|
|
import argparse |
|
|
import gzip |
|
|
import logging |
|
|
import math |
|
|
import re |
|
|
from collections import OrderedDict, defaultdict |
|
|
from enum import Enum, unique |
|
|
from io import StringIO |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description=""" |
|
|
Prune an n-gram language model based on the relative entropy |
|
|
between the original and the pruned model, based on Andreas Stolcke's paper. |
|
|
An n-gram entry is removed, if the removal causes (training set) perplexity |
|
|
of the model to increase by less than threshold relative. |
|
|
|
|
|
The command takes an arpa file and a pruning threshold as input, |
|
|
and outputs a pruned arpa file. |
|
|
""" |
|
|
) |
|
|
parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") |
|
|
parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") |
|
|
parser.add_argument( |
|
|
"-write-lm", type=str, default=None, help="Path to output arpa file after pruning" |
|
|
) |
|
|
parser.add_argument( |
|
|
"-minorder", |
|
|
type=int, |
|
|
default=1, |
|
|
help="The minorder parameter limits pruning to ngrams of that length and above.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-encoding", type=str, default="utf-8", help="Encoding of the arpa file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"-verbose", |
|
|
type=int, |
|
|
default=2, |
|
|
choices=[0, 1, 2, 3, 4, 5], |
|
|
help="Verbose level, where 0 is most noisy; 5 is most silent", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
default_encoding = args.encoding |
|
|
logging.basicConfig( |
|
|
format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", |
|
|
level=args.verbose * 10, |
|
|
) |
|
|
|
|
|
|
|
|
class Context(dict): |
|
|
""" |
|
|
This class stores data for a context h. |
|
|
It behaves like a python dict object, except that it has several |
|
|
additional attributes. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.log_bo = None |
|
|
|
|
|
|
|
|
class Arpa: |
|
|
""" |
|
|
This is a class that implement the data structure of an APRA LM. |
|
|
It (as well as some other classes) is modified based on the library |
|
|
by Stefan Fischer: |
|
|
https://github.com/sfischer13/python-arpa |
|
|
""" |
|
|
|
|
|
UNK = "<unk>" |
|
|
SOS = "<s>" |
|
|
EOS = "</s>" |
|
|
FLOAT_NDIGITS = 7 |
|
|
base = 10 |
|
|
|
|
|
@staticmethod |
|
|
def _check_input(my_input): |
|
|
if not my_input: |
|
|
raise ValueError |
|
|
elif isinstance(my_input, tuple): |
|
|
return my_input |
|
|
elif isinstance(my_input, list): |
|
|
return tuple(my_input) |
|
|
elif isinstance(my_input, str): |
|
|
return tuple(my_input.strip().split(" ")) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
@staticmethod |
|
|
def _check_word(input_word): |
|
|
if not isinstance(input_word, str): |
|
|
raise ValueError |
|
|
if " " in input_word: |
|
|
raise ValueError |
|
|
|
|
|
def _replace_unks(self, words): |
|
|
return tuple((w if w in self else self._unk) for w in words) |
|
|
|
|
|
def __init__(self, path=None, encoding=None, unk=None): |
|
|
self._counts = OrderedDict() |
|
|
self._ngrams = ( |
|
|
OrderedDict() |
|
|
) |
|
|
self._vocabulary = set() |
|
|
if unk is None: |
|
|
self._unk = self.UNK |
|
|
|
|
|
if path is not None: |
|
|
self.loadf(path, encoding) |
|
|
|
|
|
def __contains__(self, ngram): |
|
|
h = ngram[:-1] |
|
|
w = ngram[-1] |
|
|
return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] |
|
|
|
|
|
def contains_word(self, word): |
|
|
self._check_word(word) |
|
|
return word in self._vocabulary |
|
|
|
|
|
def add_count(self, order, count): |
|
|
self._counts[order] = count |
|
|
self._ngrams[order - 1] = defaultdict(Context) |
|
|
|
|
|
def update_counts(self): |
|
|
for order in range(1, self.order() + 1): |
|
|
count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) |
|
|
if count > 0: |
|
|
self._counts[order] = count |
|
|
|
|
|
def add_entry(self, ngram, p, bo=None, order=None): |
|
|
|
|
|
h = ngram[:-1] |
|
|
w = ngram[-1] |
|
|
|
|
|
|
|
|
h_context = self._ngrams[len(h)][h] |
|
|
h_context[w] = p |
|
|
if bo is not None: |
|
|
self._ngrams[len(ngram)][ngram].log_bo = bo |
|
|
|
|
|
for word in ngram: |
|
|
self._vocabulary.add(word) |
|
|
|
|
|
def counts(self): |
|
|
return sorted(self._counts.items()) |
|
|
|
|
|
def order(self): |
|
|
return max(self._counts.keys(), default=None) |
|
|
|
|
|
def vocabulary(self, sort=True): |
|
|
if sort: |
|
|
return sorted(self._vocabulary) |
|
|
else: |
|
|
return self._vocabulary |
|
|
|
|
|
def _entries(self, order): |
|
|
return ( |
|
|
self._entry(h, w) |
|
|
for h, wlist in self._ngrams[order - 1].items() |
|
|
for w in wlist |
|
|
) |
|
|
|
|
|
def _entry(self, h, w): |
|
|
|
|
|
ngram = h + (w,) |
|
|
log_p = self._ngrams[len(h)][h][w] |
|
|
log_bo = self._log_bo(ngram) |
|
|
if log_bo is not None: |
|
|
return ( |
|
|
round(log_p, self.FLOAT_NDIGITS), |
|
|
ngram, |
|
|
round(log_bo, self.FLOAT_NDIGITS), |
|
|
) |
|
|
else: |
|
|
return round(log_p, self.FLOAT_NDIGITS), ngram |
|
|
|
|
|
def _log_bo(self, ngram): |
|
|
if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: |
|
|
return self._ngrams[len(ngram)][ngram].log_bo |
|
|
else: |
|
|
return None |
|
|
|
|
|
def _log_p(self, ngram): |
|
|
h = ngram[:-1] |
|
|
w = ngram[-1] |
|
|
if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: |
|
|
return self._ngrams[len(h)][h][w] |
|
|
else: |
|
|
return None |
|
|
|
|
|
def log_p_raw(self, ngram): |
|
|
log_p = self._log_p(ngram) |
|
|
if log_p is not None: |
|
|
return log_p |
|
|
else: |
|
|
if len(ngram) == 1: |
|
|
raise KeyError |
|
|
else: |
|
|
log_bo = self._log_bo(ngram[:-1]) |
|
|
if log_bo is None: |
|
|
log_bo = 0 |
|
|
return log_bo + self.log_p_raw(ngram[1:]) |
|
|
|
|
|
def log_joint_prob(self, sequence): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_joint_p = 0 |
|
|
seq = sequence |
|
|
while len(seq) > 0: |
|
|
log_joint_p += self.log_p_raw(seq) |
|
|
seq = seq[:-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(seq) == 1 and seq[0] == self.SOS: |
|
|
seq = (self.EOS,) |
|
|
|
|
|
return log_joint_p |
|
|
|
|
|
def set_new_context(self, h): |
|
|
old_context = self._ngrams[len(h)][h] |
|
|
self._ngrams[len(h)][h] = Context() |
|
|
return old_context |
|
|
|
|
|
def log_p(self, ngram): |
|
|
words = self._check_input(ngram) |
|
|
if self._unk: |
|
|
words = self._replace_unks(words) |
|
|
return self.log_p_raw(words) |
|
|
|
|
|
def log_s(self, sentence, sos=SOS, eos=EOS): |
|
|
words = self._check_input(sentence) |
|
|
if self._unk: |
|
|
words = self._replace_unks(words) |
|
|
if sos: |
|
|
words = (sos,) + words |
|
|
if eos: |
|
|
words = words + (eos,) |
|
|
result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) |
|
|
if sos: |
|
|
result = result - self.log_p_raw(words[:1]) |
|
|
return result |
|
|
|
|
|
def p(self, ngram): |
|
|
return self.base ** self.log_p(ngram) |
|
|
|
|
|
def s(self, sentence): |
|
|
return self.base ** self.log_s(sentence) |
|
|
|
|
|
def write(self, fp): |
|
|
fp.write("\n\\data\\\n") |
|
|
for order, count in self.counts(): |
|
|
fp.write("ngram {}={}\n".format(order, count)) |
|
|
fp.write("\n") |
|
|
for order, _ in self.counts(): |
|
|
fp.write("\\{}-grams:\n".format(order)) |
|
|
for e in self._entries(order): |
|
|
prob = e[0] |
|
|
ngram = " ".join(e[1]) |
|
|
if len(e) == 2: |
|
|
fp.write("{}\t{}\n".format(prob, ngram)) |
|
|
elif len(e) == 3: |
|
|
backoff = e[2] |
|
|
fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) |
|
|
else: |
|
|
raise ValueError |
|
|
fp.write("\n") |
|
|
fp.write("\\end\\\n") |
|
|
|
|
|
|
|
|
class ArpaParser: |
|
|
""" |
|
|
This is a class that implement a parser of an arpa file |
|
|
""" |
|
|
|
|
|
@unique |
|
|
class State(Enum): |
|
|
DATA = 1 |
|
|
COUNT = 2 |
|
|
HEADER = 3 |
|
|
ENTRY = 4 |
|
|
|
|
|
re_count = re.compile(r"^ngram (\d+)=(\d+)$") |
|
|
re_header = re.compile(r"^\\(\d+)-grams:$") |
|
|
re_entry = re.compile( |
|
|
"^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" |
|
|
"\t" |
|
|
"(\\S+( \\S+)*)" |
|
|
"(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" |
|
|
) |
|
|
|
|
|
def _parse(self, fp): |
|
|
self._result = [] |
|
|
self._state = self.State.DATA |
|
|
self._tmp_model = None |
|
|
self._tmp_order = None |
|
|
for line in fp: |
|
|
line = line.strip() |
|
|
if self._state == self.State.DATA: |
|
|
self._data(line) |
|
|
elif self._state == self.State.COUNT: |
|
|
self._count(line) |
|
|
elif self._state == self.State.HEADER: |
|
|
self._header(line) |
|
|
elif self._state == self.State.ENTRY: |
|
|
self._entry(line) |
|
|
if self._state != self.State.DATA: |
|
|
raise Exception(line) |
|
|
return self._result |
|
|
|
|
|
def _data(self, line): |
|
|
if line == "\\data\\": |
|
|
self._state = self.State.COUNT |
|
|
self._tmp_model = Arpa() |
|
|
else: |
|
|
pass |
|
|
|
|
|
def _count(self, line): |
|
|
match = self.re_count.match(line) |
|
|
if match: |
|
|
order = match.group(1) |
|
|
count = match.group(2) |
|
|
self._tmp_model.add_count(int(order), int(count)) |
|
|
elif not line: |
|
|
self._state = self.State.HEADER |
|
|
else: |
|
|
raise Exception(line) |
|
|
|
|
|
def _header(self, line): |
|
|
match = self.re_header.match(line) |
|
|
if match: |
|
|
self._state = self.State.ENTRY |
|
|
self._tmp_order = int(match.group(1)) |
|
|
elif line == "\\end\\": |
|
|
self._result.append(self._tmp_model) |
|
|
self._state = self.State.DATA |
|
|
self._tmp_model = None |
|
|
self._tmp_order = None |
|
|
elif not line: |
|
|
pass |
|
|
else: |
|
|
raise Exception(line) |
|
|
|
|
|
def _entry(self, line): |
|
|
match = self.re_entry.match(line) |
|
|
if match: |
|
|
p = self._float_or_int(match.group(1)) |
|
|
ngram = tuple(match.group(4).split(" ")) |
|
|
bo_match = match.group(7) |
|
|
bo = self._float_or_int(bo_match) if bo_match else None |
|
|
self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) |
|
|
elif not line: |
|
|
self._state = self.State.HEADER |
|
|
else: |
|
|
raise Exception(line) |
|
|
|
|
|
@staticmethod |
|
|
def _float_or_int(s): |
|
|
f = float(s) |
|
|
i = int(f) |
|
|
if str(i) == s: |
|
|
return i |
|
|
else: |
|
|
return f |
|
|
|
|
|
def load(self, fp): |
|
|
"""Deserialize fp (a file-like object) to a Python object.""" |
|
|
return self._parse(fp) |
|
|
|
|
|
def loadf(self, path, encoding=None): |
|
|
"""Deserialize path (.arpa, .gz) to a Python object.""" |
|
|
path = str(path) |
|
|
if path.endswith(".gz"): |
|
|
with gzip.open(path, mode="rt", encoding=encoding) as f: |
|
|
return self.load(f) |
|
|
else: |
|
|
with open(path, mode="rt", encoding=encoding) as f: |
|
|
return self.load(f) |
|
|
|
|
|
def loads(self, s): |
|
|
"""Deserialize s (a str) to a Python object.""" |
|
|
with StringIO(s) as f: |
|
|
return self.load(f) |
|
|
|
|
|
def dump(self, obj, fp): |
|
|
"""Serialize obj to fp (a file-like object) in ARPA format.""" |
|
|
obj.write(fp) |
|
|
|
|
|
def dumpf(self, obj, path, encoding=None): |
|
|
"""Serialize obj to path in ARPA format (.arpa, .gz).""" |
|
|
path = str(path) |
|
|
if path.endswith(".gz"): |
|
|
with gzip.open(path, mode="wt", encoding=encoding) as f: |
|
|
return self.dump(obj, f) |
|
|
else: |
|
|
with open(path, mode="wt", encoding=encoding) as f: |
|
|
self.dump(obj, f) |
|
|
|
|
|
def dumps(self, obj): |
|
|
"""Serialize obj to an ARPA formatted str.""" |
|
|
with StringIO() as f: |
|
|
self.dump(obj, f) |
|
|
return f.getvalue() |
|
|
|
|
|
|
|
|
def add_log_p(prev_log_sum, log_p, base): |
|
|
return math.log(base**log_p + base**prev_log_sum, base) |
|
|
|
|
|
|
|
|
def compute_numerator_denominator(lm, h): |
|
|
log_sum_seen_h = -math.inf |
|
|
log_sum_seen_h_lower = -math.inf |
|
|
base = lm.base |
|
|
for w, log_p in lm._ngrams[len(h)][h].items(): |
|
|
log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) |
|
|
|
|
|
ngram = h + (w,) |
|
|
log_p_lower = lm.log_p_raw(ngram[1:]) |
|
|
log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) |
|
|
|
|
|
numerator = 1.0 - base**log_sum_seen_h |
|
|
denominator = 1.0 - base**log_sum_seen_h_lower |
|
|
return numerator, denominator |
|
|
|
|
|
|
|
|
def prune(lm, threshold, minorder): |
|
|
|
|
|
|
|
|
|
|
|
for i in range( |
|
|
lm.order(), max(minorder - 1, 1), -1 |
|
|
): |
|
|
logging.info("processing %d-grams ..." % i) |
|
|
count_pruned_ngrams = 0 |
|
|
|
|
|
h_dict = lm._ngrams[i - 1] |
|
|
for h in list(h_dict.keys()): |
|
|
|
|
|
log_bow = lm._log_bo(h) |
|
|
if log_bow is None: |
|
|
log_bow = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
numerator, denominator = compute_numerator_denominator(lm, h) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h_log_p = lm.log_joint_prob(h) |
|
|
|
|
|
all_pruned = True |
|
|
pruned_w_set = set() |
|
|
|
|
|
for w, log_p in h_dict[h].items(): |
|
|
ngram = h + (w,) |
|
|
|
|
|
|
|
|
backoff_prob = lm.log_p_raw(ngram[1:]) |
|
|
|
|
|
|
|
|
new_log_bow = math.log( |
|
|
numerator + lm.base**log_p, lm.base |
|
|
) - math.log(denominator + lm.base**backoff_prob, lm.base) |
|
|
|
|
|
|
|
|
delta_prob = backoff_prob + new_log_bow - log_p |
|
|
delta_entropy = -(lm.base**h_log_p) * ( |
|
|
(lm.base**log_p) * delta_prob |
|
|
+ numerator * (new_log_bow - log_bow) |
|
|
) |
|
|
|
|
|
|
|
|
perp_change = lm.base**delta_entropy - 1.0 |
|
|
|
|
|
pruned = threshold > 0 and perp_change < threshold |
|
|
|
|
|
|
|
|
if ( |
|
|
pruned |
|
|
and len(ngram) in lm._ngrams |
|
|
and len(lm._ngrams[len(ngram)][ngram]) > 0 |
|
|
): |
|
|
pruned = False |
|
|
|
|
|
logging.debug( |
|
|
"CONTEXT " |
|
|
+ str(h) |
|
|
+ " WORD " |
|
|
+ w |
|
|
+ " CONTEXTPROB %f " % h_log_p |
|
|
+ " OLDPROB %f " % log_p |
|
|
+ " NEWPROB %f " % (backoff_prob + new_log_bow) |
|
|
+ " DELTA-H %f " % delta_entropy |
|
|
+ " DELTA-LOGP %f " % delta_prob |
|
|
+ " PPL-CHANGE %f " % perp_change |
|
|
+ " PRUNED " |
|
|
+ str(pruned) |
|
|
) |
|
|
|
|
|
if pruned: |
|
|
pruned_w_set.add(w) |
|
|
count_pruned_ngrams += 1 |
|
|
else: |
|
|
all_pruned = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if all_pruned and len(pruned_w_set) == len(h_dict[h]): |
|
|
del h_dict[ |
|
|
h |
|
|
] |
|
|
elif len(pruned_w_set) > 0: |
|
|
|
|
|
old_context = lm.set_new_context(h) |
|
|
|
|
|
for w, p_w in old_context.items(): |
|
|
if w not in pruned_w_set: |
|
|
lm.add_entry( |
|
|
h + (w,), p_w |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) |
|
|
|
|
|
|
|
|
for i in range( |
|
|
max(minorder - 1, 1) + 1, lm.order() + 1 |
|
|
): |
|
|
for h in lm._ngrams[i - 1]: |
|
|
numerator, denominator = compute_numerator_denominator(lm, h) |
|
|
new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) |
|
|
lm._ngrams[len(h)][h].log_bo = new_log_bow |
|
|
|
|
|
|
|
|
lm.update_counts() |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
def check_h_is_valid(lm, h): |
|
|
sum_under_h = sum( |
|
|
[lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] |
|
|
) |
|
|
if abs(sum_under_h - 1.0) > 1e-6: |
|
|
logging.info("warning: %s %f" % (str(h), sum_under_h)) |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
def validate_lm(lm): |
|
|
|
|
|
for i in range(lm.order(), 0, -1): |
|
|
logging.info("validating %d-grams ..." % i) |
|
|
h_dict = lm._ngrams[i - 1] |
|
|
for h in h_dict.keys(): |
|
|
check_h_is_valid(lm, h) |
|
|
|
|
|
|
|
|
def compare_two_apras(path1, path2): |
|
|
pass |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
logging.info("Loading the arpa file from %s" % args.lm) |
|
|
parser = ArpaParser() |
|
|
models = parser.loadf(args.lm, encoding=default_encoding) |
|
|
lm = models[0] |
|
|
logging.info("Stats before pruning:") |
|
|
for i, cnt in lm.counts(): |
|
|
logging.info("ngram %d=%d" % (i, cnt)) |
|
|
|
|
|
|
|
|
logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) |
|
|
prune(lm, args.threshold, args.minorder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("Stats after pruning:") |
|
|
for i, cnt in lm.counts(): |
|
|
logging.info("ngram %d=%d" % (i, cnt)) |
|
|
logging.info("Saving the pruned arpa file to %s" % args.write_lm) |
|
|
parser.dumpf(lm, args.write_lm, encoding=default_encoding) |
|
|
logging.info("Done.") |
|
|
|