Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import itertools | |
| import os | |
| import csv | |
| from collections import defaultdict | |
| from six.moves import zip | |
| import io | |
| import wget | |
| import sys | |
| from subprocess import check_call, check_output | |
| # scripts and data locations | |
| CWD = os.getcwd() | |
| UTILS = f"{CWD}/utils" | |
| MOSES = f"{UTILS}/mosesdecoder" | |
| WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) | |
| if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): | |
| print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') | |
| sys.exit(-1) | |
| # please donwload mosesdecoder here: | |
| detok_cmd = f'{MOSES}/scripts/tokenizer/detokenizer.perl' | |
| def call(cmd): | |
| print(f"Executing: {cmd}") | |
| check_call(cmd, shell=True) | |
| class MultiLingualAlignedCorpusReader(object): | |
| """A class to read TED talk dataset | |
| """ | |
| def __init__(self, corpus_path, delimiter='\t', | |
| target_token=True, bilingual=True, corpus_type='file', | |
| lang_dict={'source': ['fr'], 'target': ['en']}, | |
| eval_lang_dict=None, zero_shot=False, | |
| detok=True, | |
| ): | |
| self.empty_line_flag = 'NULL' | |
| self.corpus_path = corpus_path | |
| self.delimiter = delimiter | |
| self.bilingual = bilingual | |
| self.lang_dict = lang_dict | |
| self.lang_set = set() | |
| self.target_token = target_token | |
| self.zero_shot = zero_shot | |
| self.eval_lang_dict = eval_lang_dict | |
| self.corpus_type = corpus_type | |
| self.detok = detok | |
| for list_ in self.lang_dict.values(): | |
| for lang in list_: | |
| self.lang_set.add(lang) | |
| self.data = dict() | |
| self.data['train'] = self.read_aligned_corpus(split_type='train') | |
| self.data['test'] = self.read_aligned_corpus(split_type='test') | |
| self.data['dev'] = self.read_aligned_corpus(split_type='dev') | |
| def read_data(self, file_loc_): | |
| data_list = list() | |
| with io.open(file_loc_, 'r', encoding='utf8') as fp: | |
| for line in fp: | |
| try: | |
| text = line.strip() | |
| except IndexError: | |
| text = self.empty_line_flag | |
| data_list.append(text) | |
| return data_list | |
| def filter_text(self, dict_): | |
| if self.target_token: | |
| field_index = 1 | |
| else: | |
| field_index = 0 | |
| data_dict = defaultdict(list) | |
| list1 = dict_['source'] | |
| list2 = dict_['target'] | |
| for sent1, sent2 in zip(list1, list2): | |
| try: | |
| src_sent = ' '.join(sent1.split()[field_index: ]) | |
| except IndexError: | |
| src_sent = 'NULL' | |
| if src_sent.find(self.empty_line_flag) != -1 or len(src_sent) == 0: | |
| continue | |
| elif sent2.find(self.empty_line_flag) != -1 or len(sent2) == 0: | |
| continue | |
| else: | |
| data_dict['source'].append(sent1) | |
| data_dict['target'].append(sent2) | |
| return data_dict | |
| def read_file(self, split_type, data_type): | |
| return self.data[split_type][data_type] | |
| def save_file(self, path_, split_type, data_type, lang): | |
| tok_file = tok_file_name(path_, lang) | |
| with io.open(tok_file, 'w', encoding='utf8') as fp: | |
| for line in self.data[split_type][data_type]: | |
| fp.write(line + '\n') | |
| if self.detok: | |
| de_tok(tok_file, lang) | |
| def add_target_token(self, list_, lang_id): | |
| new_list = list() | |
| token = '__' + lang_id + '__' | |
| for sent in list_: | |
| new_list.append(token + ' ' + sent) | |
| return new_list | |
| def read_from_single_file(self, path_, s_lang, t_lang): | |
| data_dict = defaultdict(list) | |
| with io.open(path_, 'r', encoding='utf8') as fp: | |
| reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) | |
| for row in reader: | |
| data_dict['source'].append(row[s_lang]) | |
| data_dict['target'].append(row[t_lang]) | |
| if self.target_token: | |
| text = self.add_target_token(data_dict['source'], t_lang) | |
| data_dict['source'] = text | |
| return data_dict['source'], data_dict['target'] | |
| def read_aligned_corpus(self, split_type='train'): | |
| data_dict = defaultdict(list) | |
| iterable = [] | |
| s_list = [] | |
| t_list = [] | |
| if self.zero_shot: | |
| if split_type == "train": | |
| iterable = zip(self.lang_dict['source'], self.lang_dict['target']) | |
| else: | |
| iterable = zip(self.eval_lang_dict['source'], self.eval_lang_dict['target']) | |
| elif self.bilingual: | |
| iterable = itertools.product(self.lang_dict['source'], self.lang_dict['target']) | |
| for s_lang, t_lang in iterable: | |
| if s_lang == t_lang: | |
| continue | |
| if self.corpus_type == 'file': | |
| split_type_file_path = os.path.join(self.corpus_path, | |
| "all_talks_{}.tsv".format(split_type)) | |
| s_list, t_list = self.read_from_single_file(split_type_file_path, | |
| s_lang=s_lang, | |
| t_lang=t_lang) | |
| data_dict['source'] += s_list | |
| data_dict['target'] += t_list | |
| new_data_dict = self.filter_text(data_dict) | |
| return new_data_dict | |
| def read_langs(corpus_path): | |
| split_type_file_path = os.path.join(corpus_path, 'extracted', | |
| "all_talks_dev.tsv") | |
| with io.open(split_type_file_path, 'r', encoding='utf8') as fp: | |
| reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) | |
| header = next(reader) | |
| return [k for k in header.keys() if k != 'talk_name'] | |
| def extra_english(corpus_path, split): | |
| split_type_file_path = os.path.join(corpus_path, | |
| f"all_talks_{split}.tsv") | |
| output_split_type_file_path = os.path.join(corpus_path, | |
| f"all_talks_{split}.en") | |
| with io.open(split_type_file_path, 'r', encoding='utf8') as fp, io.open(output_split_type_file_path, 'w', encoding='utf8') as fw: | |
| reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) | |
| for row in reader: | |
| line = row['en'] | |
| fw.write(line + '\n') | |
| de_tok(output_split_type_file_path, 'en') | |
| def tok_file_name(filename, lang): | |
| seps = filename.split('.') | |
| seps.insert(-1, 'tok') | |
| tok_file = '.'.join(seps) | |
| return tok_file | |
| def de_tok(tok_file, lang): | |
| # seps = tok_file.split('.') | |
| # seps.insert(-1, 'detok') | |
| # de_tok_file = '.'.join(seps) | |
| de_tok_file = tok_file.replace('.tok.', '.') | |
| cmd = 'perl {detok_cmd} -l {lang} < {tok_file} > {de_tok_file}'.format( | |
| detok_cmd=detok_cmd, tok_file=tok_file, | |
| de_tok_file=de_tok_file, lang=lang[:2]) | |
| call(cmd) | |
| def extra_bitex( | |
| ted_data_path, | |
| lsrc_lang, | |
| ltrg_lang, | |
| target_token, | |
| output_data_path, | |
| ): | |
| def get_ted_lang(lang): | |
| long_langs = ['pt-br', 'zh-cn', 'zh-tw', 'fr-ca'] | |
| if lang[:5] in long_langs: | |
| return lang[:5] | |
| elif lang[:4] =='calv': | |
| return lang[:5] | |
| elif lang in ['pt_BR', 'zh_CN', 'zh_TW', 'fr_CA']: | |
| return lang.lower().replace('_', '-') | |
| return lang[:2] | |
| src_lang = get_ted_lang(lsrc_lang) | |
| trg_lang = get_ted_lang(ltrg_lang) | |
| train_lang_dict={'source': [src_lang], 'target': [trg_lang]} | |
| eval_lang_dict = {'source': [src_lang], 'target': [trg_lang]} | |
| obj = MultiLingualAlignedCorpusReader(corpus_path=ted_data_path, | |
| lang_dict=train_lang_dict, | |
| target_token=target_token, | |
| corpus_type='file', | |
| eval_lang_dict=eval_lang_dict, | |
| zero_shot=False, | |
| bilingual=True) | |
| os.makedirs(output_data_path, exist_ok=True) | |
| lsrc_lang = lsrc_lang.replace('-', '_') | |
| ltrg_lang = ltrg_lang.replace('-', '_') | |
| obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", | |
| split_type='train', data_type='source', lang=src_lang) | |
| obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", | |
| split_type='train', data_type='target', lang=trg_lang) | |
| obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", | |
| split_type='test', data_type='source', lang=src_lang) | |
| obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", | |
| split_type='test', data_type='target', lang=trg_lang) | |
| obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", | |
| split_type='dev', data_type='source', lang=src_lang) | |
| obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", | |
| split_type='dev', data_type='target', lang=trg_lang) | |
| def bar_custom(current, total, width=80): | |
| print("Downloading: %d%% [%d / %d] Ks" % (current / total * 100, current / 1000, total / 1000), end='\r') | |
| def download_and_extract(download_to, extract_to): | |
| url = 'http://phontron.com/data/ted_talks.tar.gz' | |
| filename = f"{download_to}/ted_talks.tar.gz" | |
| if os.path.exists(filename): | |
| print(f'{filename} has already been downloaded so skip') | |
| else: | |
| filename = wget.download(url, filename, bar=bar_custom) | |
| if os.path.exists(f'{extract_to}/all_talks_train.tsv'): | |
| print(f'Already extracted so skip') | |
| else: | |
| extract_cmd = f'tar xzfv "{filename}" -C "{extract_to}"' | |
| call(extract_cmd) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--ted_data_path', type=str, default=WORKDIR_ROOT, required=False) | |
| parser.add_argument( | |
| '--direction-list', | |
| type=str, | |
| # default=None, | |
| #for ML50 | |
| default=( | |
| "bn_IN-en_XX,he_IL-en_XX,fa_IR-en_XX,id_ID-en_XX,sv_SE-en_XX,pt_XX-en_XX,ka_GE-en_XX,ka_GE-en_XX,th_TH-en_XX," | |
| "mr_IN-en_XX,hr_HR-en_XX,uk_UA-en_XX,az_AZ-en_XX,mk_MK-en_XX,gl_ES-en_XX,sl_SI-en_XX,mn_MN-en_XX," | |
| #non-english directions | |
| # "fr_XX-de_DE," # replaced with wmt20 | |
| # "ja_XX-ko_KR,es_XX-pt_XX,ru_RU-sv_SE,hi_IN-bn_IN,id_ID-ar_AR,cs_CZ-pl_PL,ar_AR-tr_TR" | |
| ), | |
| required=False) | |
| parser.add_argument('--target-token', action='store_true', default=False) | |
| parser.add_argument('--extract-all-english', action='store_true', default=False) | |
| args = parser.parse_args() | |
| import sys | |
| import json | |
| # TED Talks data directory | |
| ted_data_path = args.ted_data_path | |
| download_to = f'{ted_data_path}/downloads' | |
| extract_to = f'{ted_data_path}/extracted' | |
| #DESTDIR=${WORKDIR_ROOT}/ML50/raw/ | |
| output_path = f'{ted_data_path}/ML50/raw' | |
| os.makedirs(download_to, exist_ok=True) | |
| os.makedirs(extract_to, exist_ok=True) | |
| os.makedirs(output_path, exist_ok=True) | |
| download_and_extract(download_to, extract_to) | |
| if args.extract_all_english: | |
| for split in ['train', 'dev', 'test']: | |
| extra_english(ted_data_path, split) | |
| exit(0) | |
| if args.direction_list is not None: | |
| directions = args.direction_list.strip().split(',') | |
| directions = [tuple(d.strip().split('-', 1)) for d in directions if d] | |
| else: | |
| langs = read_langs(ted_data_path) | |
| # directions = [ | |
| # '{}.{}'.format(src, tgt) | |
| # for src in langs | |
| # for tgt in langs | |
| # if src < tgt | |
| # ] | |
| directions = [('en', tgt) for tgt in langs if tgt != 'en'] | |
| print(f'num directions={len(directions)}: {directions}') | |
| for src_lang, trg_lang in directions: | |
| print('--working on {}-{}'.format(src_lang, trg_lang)) | |
| extra_bitex( | |
| extract_to, | |
| src_lang, | |
| trg_lang, | |
| target_token=args.target_token, | |
| output_data_path=output_path | |
| ) | |