| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import fileinput |
| |
|
| | from tqdm import tqdm |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description=( |
| | "Extract back-translations from the stdout of fairseq-generate. " |
| | "If there are multiply hypotheses for a source, we only keep the first one. " |
| | ) |
| | ) |
| | parser.add_argument("--output", required=True, help="output prefix") |
| | parser.add_argument( |
| | "--srclang", required=True, help="source language (extracted from H-* lines)" |
| | ) |
| | parser.add_argument( |
| | "--tgtlang", required=True, help="target language (extracted from S-* lines)" |
| | ) |
| | parser.add_argument("--minlen", type=int, help="min length filter") |
| | parser.add_argument("--maxlen", type=int, help="max length filter") |
| | parser.add_argument("--ratio", type=float, help="ratio filter") |
| | parser.add_argument("files", nargs="*", help="input files") |
| | args = parser.parse_args() |
| |
|
| | def validate(src, tgt): |
| | srclen = len(src.split(" ")) if src != "" else 0 |
| | tgtlen = len(tgt.split(" ")) if tgt != "" else 0 |
| | if ( |
| | (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) |
| | or ( |
| | args.maxlen is not None |
| | and (srclen > args.maxlen or tgtlen > args.maxlen) |
| | ) |
| | or ( |
| | args.ratio is not None |
| | and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio) |
| | ) |
| | ): |
| | return False |
| | return True |
| |
|
| | def safe_index(toks, index, default): |
| | try: |
| | return toks[index] |
| | except IndexError: |
| | return default |
| |
|
| | with open(args.output + "." + args.srclang, "w") as src_h, open( |
| | args.output + "." + args.tgtlang, "w" |
| | ) as tgt_h: |
| | for line in tqdm(fileinput.input(args.files)): |
| | if line.startswith("S-"): |
| | tgt = safe_index(line.rstrip().split("\t"), 1, "") |
| | elif line.startswith("H-"): |
| | if tgt is not None: |
| | src = safe_index(line.rstrip().split("\t"), 2, "") |
| | if validate(src, tgt): |
| | print(src, file=src_h) |
| | print(tgt, file=tgt_h) |
| | tgt = None |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|