| |
|
|
| from libcpp.string cimport string |
| from libcpp.vector cimport vector |
| from itertools import chain |
| import os |
| import cython |
| cimport cdictree |
| cimport condiskpt |
| from math import log |
|
|
| cpdef int fsign(float x): |
| """Simply returns the sign of float x (zero is assumed +), it's defined here just so one gains a little bit with static typing""" |
| return 1 if x >= 0 else -1 |
|
|
| cdef bytes as_str(data): |
| if isinstance(data, bytes): |
| return data |
| elif isinstance(data, unicode): |
| return data.encode('UTF-8') |
| raise TypeError('Cannot convert %s to string' % type(data)) |
|
|
| cdef class Production(object): |
| """ |
| General class that represents a context-free production or a flat contiguous phrase. |
| Note: we can't extend from tuple yet (Cython 0.17.1 does not support it), so a few protocols are implemented so that |
| it feels like Production is a tuple. |
| """ |
|
|
| cdef readonly bytes lhs |
| cdef readonly tuple rhs |
|
|
| def __init__(self, rhs, lhs = None): |
| """ |
| :rhs right-hand side of the production (or the flat contiguous phrase) - sequence of strings |
| :lhs left-hand side nonterminal (or None in the case of flat contiguous phrases) |
| """ |
| self.rhs = tuple(rhs) |
| self.lhs = lhs |
|
|
| def __len__(self): |
| return len(self.rhs) |
|
|
| def __getitem__(self, key): |
| if 0 <= key < len(self.rhs): |
| return self.rhs[key] |
| else: |
| return IndexError, 'Index %s out of range' % str(key) |
|
|
| def __iter__(self): |
| for x in self.rhs: |
| yield x |
|
|
| def __contains__(self, item): |
| return item in self.rhs |
|
|
| def __reversed__(self): |
| return reversed(self.rhs) |
|
|
| def __hash__(self): |
| return hash(self.rhs) |
|
|
| def __str__(self): |
| if self.lhs: |
| return '%s -> %s' % (self.lhs, ' '.join(self.rhs)) |
| else: |
| return ' '.join(self.rhs) |
|
|
| def __repr__(self): |
| return repr(self.as_tuple()) |
|
|
| def as_tuple(self, lhs_first = False): |
| """ |
| Returns a tuple (lhs) + rhs or rhs + (lhs) depending on the flag 'lhs_first'. |
| """ |
| if self.lhs: |
| if lhs_first: |
| return tuple([self.lhs]) + self.rhs |
| else: |
| return self.rhs + tuple([self.lhs]) |
| else: |
| return self.rhs |
|
|
| def __richcmp__(self, other, op): |
| """ |
| The comparison uses 'as_tuple()', therefore in the CFG case, the lhs will be part of the production and it will be placed in the end |
| (just to keep with Moses convention which has mostly to do with sorting for scoring on disk). |
| """ |
| x = self.as_tuple() |
| y = other.as_tuple() |
| if op == 0: |
| return x < y |
| elif op == 1: |
| return x <= y |
| elif op == 2: |
| return x == y |
| elif op == 3: |
| return x != y |
| elif op == 4: |
| return x > y |
| elif op == 5: |
| return x >= y |
|
|
| cdef class Alignment(list): |
| """ |
| This represents a list of alignment points (pairs of integers). |
| It should inherit from tuple, but that is not yet supported in Cython (as for Cython 0.17.1). |
| """ |
|
|
| def __init__(self, alignment): |
| if type(alignment) is str: |
| pairs = [] |
| for point in alignment.split(): |
| s, t = point.split('-') |
| pairs.append((int(s), int(t))) |
| super(Alignment, self).__init__(pairs) |
| elif type(alignment) in [list, tuple]: |
| super(Alignment, self).__init__(alignment) |
| else: |
| ValueError, 'Cannot figure out pairs from: %s' % type(alignment) |
|
|
| def __str__(self): |
| return ' '.join('%d-%d' % (s, t) for s, t in self) |
|
|
| cdef class FValues(list): |
| """ |
| This represents a list of feature values (floats). |
| It should inherit from tuple, but that is not yet supported in Cython (as for Cython 0.17.1). |
| """ |
|
|
| def __init__(self, values): |
| super(FValues, self).__init__(values) |
|
|
| def __str__(self): |
| return ' '.join(str(x) for x in self) |
|
|
| cdef class TargetProduction(Production): |
| """This class specializes production making it the target side of a translation rule. |
| On top of lhs and rhs it comes with alignment information a tuple of real-valued features. |
| """ |
| cdef readonly Alignment alignment |
| cdef readonly FValues scores |
|
|
| def __init__(self, rhs, scores, alignment = [], lhs = None): |
| """ |
| :rhs right-hand side tokens (sequence of terminals and nonterminals) |
| :scores tuple of real-valued features |
| :alignment tuple of pairs of 0-based integers |
| :lhs left-hand side nonterminal (None in phrase-based) |
| """ |
| super(TargetProduction, self).__init__(rhs, lhs) |
| self.scores = FValues(scores) |
| self.alignment = Alignment(alignment) |
|
|
| @staticmethod |
| def desc(x, y, key = lambda r: r.scores[0]): |
| """Returns the sign of key(y) - key(x). |
| Can only be used if scores is not an empty vector as |
| keys defaults to scores[0]""" |
| return fsign(key(y) - key(x)) |
|
|
| def __str__(self): |
| """Returns a string such as: <words> ||| <scores> [||| word-alignment info]""" |
| if self.lhs: |
| lhs = [self.lhs] |
| else: |
| lhs = [] |
| return ' ||| '.join((' '.join(chain(self.rhs, lhs)), |
| str(self.scores), |
| str(self.alignment))) |
|
|
| def __repr__(self): |
| return repr((repr(self.rhs), repr(self.lhs), repr(self.scores), repr(self.alignment))) |
|
|
| cdef class QueryResult(list): |
|
|
| cdef readonly Production source |
|
|
| def __init__(self, source, targets = []): |
| super(QueryResult, self).__init__(targets) |
| self.source = source |
|
|
|
|
| cdef class DictionaryTree(object): |
|
|
| @classmethod |
| def canLoad(cls, path, bint wa = False): |
| """Whether or not the path represents a valid table for that class.""" |
| raise NotImplementedError |
|
|
| def query(self, line, converter = None, cmp = None, key = None): |
| """ |
| Returns a list of target productions that translate a given source production |
| :line query (string) |
| :converter applies a transformation to the score (function) |
| :cmp define it to get a sorted list (design it compatible with your converter) |
| :key defines the key of the comparison |
| :return QueryResult |
| """ |
| raise NotImplementedError |
|
|
| cdef class PhraseDictionaryTree(DictionaryTree): |
| """This class encapsulates a Moses::PhraseDictionaryTree for operations over |
| binary phrase tables.""" |
|
|
| cdef cdictree.PhraseDictionaryTree* tree |
| cdef readonly bytes path |
| cdef readonly unsigned nscores |
| cdef readonly bint wa |
| cdef readonly bytes delimiters |
| cdef readonly unsigned tableLimit |
|
|
| def __cinit__(self, bytes path, unsigned tableLimit = 20, unsigned nscores = 5, bint wa = False, delimiters = ' \t'): |
| """ |
| :path stem of the table, e.g europarl.fr-en is the stem for europar.fr-en.binphr.* |
| :tableLimit maximum translations per source (defaults to 20 - use zero to impose no limit) |
| :wa whether or not it has word-alignment information |
| :delimiters for tokenization (defaults to space and tab) |
| """ |
|
|
| if not PhraseDictionaryTree.canLoad(path, wa): |
| raise ValueError, "'%s' doesn't seem a valid binary table." % path |
| self.path = path |
| self.tableLimit = tableLimit |
| self.nscores = nscores |
| self.wa = wa |
| self.delimiters = delimiters |
| self.tree = new cdictree.PhraseDictionaryTree() |
| self.tree.NeedAlignmentInfo(wa) |
| self.tree.Read(path) |
|
|
| def __dealloc__(self): |
| del self.tree |
|
|
| @classmethod |
| def canLoad(cls, stem, bint wa = False): |
| """This sanity check was added to the constructor, but you can access it from outside this class |
| to determine whether or not you are providing a valid stem to BinaryPhraseTable.""" |
| if wa: |
| return os.path.isfile(stem + ".binphr.idx") \ |
| and os.path.isfile(stem + ".binphr.srctree.wa") \ |
| and os.path.isfile(stem + ".binphr.srcvoc") \ |
| and os.path.isfile(stem + ".binphr.tgtdata.wa") \ |
| and os.path.isfile(stem + ".binphr.tgtvoc") |
| else: |
| return os.path.isfile(stem + ".binphr.idx") \ |
| and os.path.isfile(stem + ".binphr.srctree") \ |
| and os.path.isfile(stem + ".binphr.srcvoc") \ |
| and os.path.isfile(stem + ".binphr.tgtdata") \ |
| and os.path.isfile(stem + ".binphr.tgtvoc") |
|
|
| cdef TargetProduction getTargetProduction(self, cdictree.StringTgtCand& cand, wa = None, converter = None): |
| """Converts a StringTgtCandidate (c++ object) and possibly a word-alignment info (string) to a TargetProduction (python object).""" |
| cdef list words = [cand.tokens[i].c_str() for i in xrange(cand.tokens.size())] |
| cdef list scores = [score for score in cand.scores] if converter is None else [converter(score) for score in cand.scores] |
| return TargetProduction(words, scores, wa) |
|
|
| def query(self, line, converter = lambda x: log(x), cmp = lambda x, y: fsign(y.scores[2] - x.scores[2]), key = None): |
| """ |
| Returns a list of target productions that translate a given source production |
| :line query (string) |
| :converter applies a transformation to the score (function) - defaults to the natural log (since by default binary phrase-tables store probabilities) |
| :cmp define it to get a sorted list - defaults to sorting by t(e|f) (since by default binary phrase-tables are not sorted) |
| :key defines the key of the comparison - defauls to none |
| :return QueryResult |
| """ |
| cdef bytes text = as_str(line) |
| cdef vector[string] fphrase = cdictree.Tokenize(text, self.delimiters) |
| cdef vector[cdictree.StringTgtCand]* rv = new vector[cdictree.StringTgtCand]() |
| cdef vector[string]* wa = NULL |
| cdef Production source = Production(f.c_str() for f in fphrase) |
| cdef QueryResult results = QueryResult(source) |
|
|
| if not self.wa: |
| self.tree.GetTargetCandidates(fphrase, rv[0]) |
| results.extend([self.getTargetProduction(candidate, None, converter) for candidate in rv[0]]) |
| else: |
| wa = new vector[string]() |
| self.tree.GetTargetCandidates(fphrase, rv[0], wa[0]) |
| results.extend([self.getTargetProduction(rv[0][i], wa[0][i].c_str(), converter) for i in range(rv.size())]) |
| del wa |
| del rv |
| if cmp: |
| results.sort(cmp=cmp, key=key) |
| if self.tableLimit > 0: |
| return QueryResult(source, results[0:self.tableLimit]) |
| else: |
| return results |
|
|
| cdef class OnDiskWrapper(DictionaryTree): |
|
|
| cdef condiskpt.OnDiskWrapper *wrapper |
| cdef condiskpt.OnDiskQuery *finder |
| cdef readonly bytes delimiters |
| cdef readonly unsigned tableLimit |
|
|
| def __cinit__(self, bytes path, unsigned tableLimit = 20, delimiters = ' \t'): |
| self.delimiters = delimiters |
| self.tableLimit = tableLimit |
| self.wrapper = new condiskpt.OnDiskWrapper() |
| self.wrapper.BeginLoad(string(path)) |
| self.finder = new condiskpt.OnDiskQuery(self.wrapper[0]) |
|
|
| @classmethod |
| def canLoad(cls, stem, bint wa = False): |
| return os.path.isfile(stem + "/Misc.dat") \ |
| and os.path.isfile(stem + "/Source.dat") \ |
| and os.path.isfile(stem + "/TargetColl.dat") \ |
| and os.path.isfile(stem + "/TargetInd.dat") \ |
| and os.path.isfile(stem + "/Vocab.dat") |
|
|
| cdef Production getSourceProduction(self, vector[string] ftokens): |
| cdef list tokens = [f.c_str() for f in ftokens] |
| return Production(tokens[:-1], tokens[-1]) |
|
|
| def query(self, line, converter = None, cmp = None, key = None): |
| """ |
| Returns a list of target productions that translate a given source production |
| :line query (string) |
| :converter applies a transformation to the score (function) - defaults to None (since by default OnDiskWrapper store the ln(prob)) |
| :cmp define it to get a sorted list - defaults to None (since by default OnDiskWrapper is already sorted) |
| :key defines the key of the comparison - defauls to none |
| :return QueryResult |
| """ |
| cdef bytes text = as_str(line) |
| cdef vector[string] ftokens = cdictree.Tokenize(text, self.delimiters) |
| cdef condiskpt.PhraseNode *node = <condiskpt.PhraseNode *>self.finder.Query(ftokens) |
| if node == NULL: |
| return [] |
| cdef Production source = self.getSourceProduction(ftokens) |
| cdef condiskpt.TargetPhraseCollection ephrases = node.GetTargetPhraseCollection(self.tableLimit, self.wrapper[0])[0] |
| cdef condiskpt.Vocab vocab = self.wrapper.GetVocab() |
| cdef condiskpt.TargetPhrase ephr |
| cdef condiskpt.Word e |
| cdef unsigned i, j |
| cdef QueryResult results = QueryResult(source) |
| for i in xrange(ephrases.GetSize()): |
| ephr = ephrases.GetTargetPhrase(i) |
| words = [ephr.GetWord(j).GetString(vocab).c_str() for j in xrange(ephr.GetSize())] |
| if converter is None: |
| results.append(TargetProduction(words[:-1], ephr.GetScores(), ephr.GetAlign(), words[-1])) |
| else: |
| scores = tuple(ephr.GetScores()) |
| results.append(TargetProduction(words[:-1], (converter(score) for score in scores), ephr.GetAlign(), words[-1])) |
| if cmp: |
| results.sort(cmp=cmp, key=key) |
| return results |
|
|
| def load(path, nscores, limit): |
| """Finds out the correct implementation depending on the content of 'path' and returns the appropriate dictionary tree.""" |
| if PhraseDictionaryTree.canLoad(path, False): |
| return PhraseDictionaryTree(path, limit, nscores, False) |
| elif PhraseDictionaryTree.canLoad(path, True): |
| return PhraseDictionaryTree(path, limit, nscores, True) |
| elif OnDiskWrapper.canLoad(path): |
| return OnDiskWrapper(path, limit) |
| else: |
| raise ValueError, '%s does not seem to be a valid table' % path |
|
|