| |
| |
| |
| |
|
|
| from collections import Counter |
| import logging |
| import sys |
|
|
| LOG = logging.getLogger(__name__) |
|
|
| BOS = "<s>" |
| EOS = "</s>" |
| UNK = "<unk>" |
|
|
|
|
| def replace_tags(tokens, tags, vocab): |
| for i, t in enumerate(tokens): |
| if t not in vocab: |
| if i < len(tags): |
| tokens[i] = tags[i] |
| else: |
| print "Error: missing tags for index i:", i |
| print ' '.join(tokens) |
| print ' '.join(tags) |
| tokens[i] = UNK |
|
|
|
|
| def replace_unks(tokens, vocab): |
| for i, t in enumerate(tokens): |
| if t not in vocab: |
| tokens[i] = UNK |
|
|
|
|
| def numberize(line, m, n, svocab, tvocab): |
| line = line.split() |
| source_words = line[:2 * m + 1] |
| target_words = line[-n:] |
|
|
| line = ' '.join([str(svocab[item]) for item in source_words]) + ' ' |
| line += ' '.join([str(tvocab[item]) for item in target_words]) + '\n' |
|
|
| return line |
|
|
|
|
| def get_ngrams(corpus_stem, align_file, tagged_stem, svocab, tvocab, slang, |
| tlang, m, n, ofh): |
| """ |
| m - source context |
| n - target context |
| |
| returns set of tags used |
| """ |
| tags = Counter() |
| sfh = open(corpus_stem + "." + slang) |
| tfh = open(corpus_stem + "." + tlang) |
| afh = open(align_file) |
| fhs = [sfh, tfh, afh] |
| if tagged_stem: |
| fhs.append(open(tagged_stem + "." + slang)) |
| fhs.append(open(tagged_stem + "." + tlang)) |
|
|
| count = 0 |
| ngrams = 0 |
| LOG.info("Extracting ngrams") |
| for lines in zip(*fhs): |
| stokens = lines[0][:-1].split() |
| ttokens = lines[1][:-1].split() |
| stokens.append(EOS) |
| ttokens.append(EOS) |
| if tagged_stem: |
| stags = lines[3][:-1].split() |
| ttags = lines[4][:-1].split() |
| stags.append(EOS) |
| ttags.append(EOS) |
| tags.update(stags) |
| tags.update(ttags) |
| replace_tags(stokens, stags, svocab) |
| replace_tags(ttokens, ttags, tvocab) |
| else: |
| replace_unks(stokens, svocab) |
| replace_unks(ttokens, tvocab) |
| |
| |
| target_aligns = [[] for t in range(len(ttokens))] |
| for atoken in lines[2][:-1].split(): |
| spos, tpos = atoken.split("-") |
| spos, tpos = int(spos), int(tpos) |
| target_aligns[tpos].append(spos) |
|
|
| |
| target_aligns[-1] = [len(stokens) - 1] |
|
|
| for tpos, spos_list in enumerate(target_aligns): |
| |
| if not spos_list: |
| |
| |
| rpos = tpos + 1 |
| lpos = tpos - 1 |
| while rpos < len(ttokens) or lpos >= 0: |
| if rpos < len(ttokens) and target_aligns[rpos]: |
| spos_list = target_aligns[rpos] |
| break |
| if lpos >= 0 and target_aligns[lpos]: |
| spos_list = target_aligns[lpos] |
| break |
| rpos += 1 |
| lpos -= 1 |
|
|
| if not spos_list: |
| raise Exception( |
| "No alignments in sentence \nSRC: " + |
| lines[0][:-1] + "\nTGT: " + lines[1][:-1]) |
| midpos = (len(spos_list) - 1) / 2 |
| spos = sorted(spos_list)[midpos] |
|
|
| |
| for i in range(max(0, m - spos)): |
| print>>ofh, BOS, |
| |
| print>>ofh, " ".join( |
| [s for s in stokens[max(0, spos - m):spos + m + 1]]), |
| for i in range(max(0, spos + m + 1 - len(stokens))): |
| print>>ofh, EOS, |
| for i in range(max(0, n - (tpos + 1))): |
| print>>ofh, BOS, |
| print>>ofh, " ".join( |
| [t for t in ttokens[max(0, tpos + 1 - n):tpos + 1]]), |
| print>>ofh |
| ngrams += 1 |
|
|
| count += 1 |
| if count % 1000 == 0: |
| sys.stderr.write(".") |
| if count % 50000 == 0: |
| sys.stderr.write(" [%d]\n" % count) |
| ofh.close() |
| sys.stderr.write("\n") |
| LOG.info("Extracted %d ngrams" % ngrams) |
| return tags |
|
|