| |
| |
| |
| |
|
|
| import os |
| from collections import Counter |
|
|
| import torch |
| from fairseq.file_io import PathManager |
| from fairseq.tokenizer import tokenize_line |
| from typing import List, Dict |
|
|
|
|
| def safe_readline(f): |
| pos = f.tell() |
| while True: |
| try: |
| return f.readline() |
| except UnicodeDecodeError: |
| pos -= 1 |
| f.seek(pos) |
|
|
|
|
| class Binarizer: |
| @staticmethod |
| def binarize( |
| filename, |
| dict, |
| consumer, |
| tokenize=tokenize_line, |
| append_eos=True, |
| reverse_order=False, |
| offset=0, |
| end=-1, |
| already_numberized=False, |
| ) -> Dict[str, int]: |
| nseq, ntok = 0, 0 |
| replaced = Counter() |
|
|
| def replaced_consumer(word, idx): |
| if idx == dict.unk_index and word != dict.unk_word: |
| replaced.update([word]) |
|
|
| with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: |
| f.seek(offset) |
| |
| line = safe_readline(f) |
| while line: |
| |
| |
| |
| |
| |
| |
| if end > 0 and f.tell() > end and f.tell() < end + 2 ** 32: |
| break |
| if already_numberized: |
| id_strings = line.strip().split() |
| id_list = [int(id_string) for id_string in id_strings] |
| if reverse_order: |
| id_list.reverse() |
| if append_eos: |
| id_list.append(dict.eos()) |
| ids = torch.IntTensor(id_list) |
| else: |
| ids = dict.encode_line( |
| line=line, |
| line_tokenizer=tokenize, |
| add_if_not_exist=False, |
| consumer=replaced_consumer, |
| append_eos=append_eos, |
| reverse_order=reverse_order, |
| ) |
| nseq += 1 |
| ntok += len(ids) |
| consumer(ids) |
| line = f.readline() |
| return { |
| "nseq": nseq, |
| "nunk": sum(replaced.values()), |
| "ntok": ntok, |
| "replaced": replaced, |
| } |
|
|
| @staticmethod |
| def binarize_alignments( |
| filename, alignment_parser, consumer, offset=0, end=-1 |
| ) -> Dict[str, int]: |
| nseq = 0 |
|
|
| with open(PathManager.get_local_path(filename), "r") as f: |
| f.seek(offset) |
| line = safe_readline(f) |
| while line: |
| if end > 0 and f.tell() > end: |
| break |
| ids = alignment_parser(line) |
| nseq += 1 |
| consumer(ids) |
| line = f.readline() |
| return {"nseq": nseq} |
|
|
| @staticmethod |
| def find_offsets(filename, num_chunks) -> List[int]: |
| with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: |
| size = os.fstat(f.fileno()).st_size |
| chunk_size = size // num_chunks |
| offsets = [0 for _ in range(num_chunks + 1)] |
| for i in range(1, num_chunks): |
| f.seek(chunk_size * i) |
| safe_readline(f) |
| offsets[i] = f.tell() |
| return offsets |
|
|