| import os
|
| import re
|
| import sys
|
| from tqdm import tqdm
|
| import torch
|
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
| langs_supported = {
|
| "eng_Latn": "en",
|
| "ben_Beng": "bn",
|
| "guj_Gujr": "gu",
|
| "hin_Deva": "hi",
|
| "kan_Knda": "kn",
|
| "mal_Mlym": "ml",
|
| "mar_Deva": "mr",
|
| "npi_Deva": "ne",
|
| "ory_Orya": "or",
|
| "pan_Guru": "pa",
|
| "snd_Arab": "sd",
|
| "tam_Taml": "ta",
|
| "urd_Arab": "ur",
|
| }
|
|
|
|
|
| def predict(batch, tokenizer, model, bos_token_id):
|
| encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
|
| generated_tokens = model.generate(
|
| **encoded_batch,
|
| num_beams=5,
|
| max_length=256,
|
| min_length=0,
|
| forced_bos_token_id=bos_token_id,
|
| )
|
| hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
| return hypothesis
|
|
|
|
|
| def main(devtest_data_dir, batch_size):
|
|
|
| model_name = "facebook/m2m100-12B-last-ckpt"
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| model.eval()
|
|
|
|
|
| for pair in sorted(os.listdir(devtest_data_dir)):
|
| if "-" not in pair:
|
| continue
|
|
|
| src_lang, tgt_lang = pair.split("-")
|
|
|
|
|
| if (
|
| src_lang not in langs_supported.keys()
|
| or tgt_lang not in langs_supported.keys()
|
| ):
|
| print(f"Skipping {src_lang}-{tgt_lang} ...")
|
| continue
|
|
|
|
|
|
|
|
|
| print(f"Evaluating {src_lang}-{tgt_lang} ...")
|
|
|
| infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
|
| outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.m2m100")
|
|
|
| with open(infname, "r") as f:
|
| src_sents = f.read().split("\n")
|
|
|
| add_new_line = False
|
| if src_sents[-1] == "":
|
| add_new_line = True
|
| src_sents = src_sents[:-1]
|
|
|
|
|
| tokenizer.src_lang = langs_supported[src_lang]
|
|
|
|
|
| hypothesis = []
|
| for i in tqdm(range(0, len(src_sents), batch_size)):
|
| start, end = i, int(min(len(src_sents), i + batch_size))
|
| batch = src_sents[start:end]
|
| bos_token_id = tokenizer.lang_code_to_id[langs_supported[tgt_lang]]
|
| hypothesis += predict(batch, tokenizer, model, bos_token_id)
|
|
|
| assert len(hypothesis) == len(src_sents)
|
|
|
| hypothesis = [
|
| re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
| for x in hypothesis
|
| ]
|
| if add_new_line:
|
| hypothesis = hypothesis
|
|
|
| with open(outfname, "w") as f:
|
| f.write("\n".join(hypothesis))
|
|
|
|
|
|
|
|
|
| infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
|
| outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.m2m100")
|
|
|
| with open(infname, "r") as f:
|
| src_sents = f.read().split("\n")
|
|
|
| add_new_line = False
|
| if src_sents[-1] == "":
|
| add_new_line = True
|
| src_sents = src_sents[:-1]
|
|
|
|
|
| tokenizer.src_lang = langs_supported[tgt_lang]
|
|
|
|
|
| hypothesis = []
|
| for i in tqdm(range(0, len(src_sents), batch_size)):
|
| start, end = i, int(min(len(src_sents), i + batch_size))
|
| batch = src_sents[start:end]
|
| bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]]
|
| hypothesis += predict(batch, tokenizer, model, bos_token_id)
|
|
|
| assert len(hypothesis) == len(src_sents)
|
|
|
| hypothesis = [
|
| re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
| for x in hypothesis
|
| ]
|
| if add_new_line:
|
| hypothesis = hypothesis
|
|
|
| with open(outfname, "w") as f:
|
| f.write("\n".join(hypothesis))
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| devtest_data_dir = sys.argv[1]
|
| batch_size = int(sys.argv[2])
|
|
|
| if not torch.cuda.is_available():
|
| print("No GPU available")
|
| sys.exit(1)
|
|
|
| main(devtest_data_dir, batch_size)
|
|
|