Solving the issue "AttributeError: 'NllbTokenizerFast' object has no attribute 'lang_code_to_id'"
ffa8d95
verified
| import torch | |
| from transformers import set_seed, pipeline | |
| from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import time | |
| ######### HELSINKI NLP ################## | |
| def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str: | |
| ''' | |
| Translate the text using HelsinkiNLP's Opus models for Mossi language. | |
| Parameters | |
| ---------- | |
| s: str | |
| The text | |
| src_iso: | |
| The ISO-3 code of the source language | |
| dest_iso: | |
| The ISO-3 code of the destination language | |
| Returns | |
| ---------- | |
| translation:str | |
| The translated text | |
| ''' | |
| # Ensure replicability | |
| set_seed(555) | |
| # Inference | |
| translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-{src_iso}-{dest_iso}") | |
| translation = translator(s)[0]['translation_text'] | |
| return translation | |
| ######### MASAKHANE ################## | |
| def translate_masakhane(s:str, src_iso:str, dest_iso:str)-> str: | |
| ''' | |
| Translate the text using Masakhane's M2M models for Mossi language. | |
| Parameters | |
| ---------- | |
| s: str | |
| The text | |
| src_iso: | |
| The ISO-3 code of the source language | |
| dest_iso: | |
| The ISO-3 code of the destination language | |
| Returns | |
| ---------- | |
| translation:str | |
| The translated text | |
| ''' | |
| # Ensure replicability | |
| set_seed(555) | |
| # Load model | |
| model = M2M100ForConditionalGeneration.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") | |
| tokenizer = M2M100Tokenizer.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") | |
| # Inference | |
| encoded = tokenizer(s, return_tensors="pt") | |
| generated_tokens = model.generate(**encoded) | |
| translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return translation | |
| ######### META ################## | |
| def translate_facebook(s:str, src_iso:str, dest_iso:str)-> str: | |
| ''' | |
| Translate the text using Meta's NLLB model for Mossi language. | |
| Parameters | |
| ---------- | |
| s: str | |
| The text | |
| src_iso: | |
| The ISO-3 code of the source language | |
| dest_iso: | |
| The ISO-3 code of the destination language | |
| Returns | |
| ---------- | |
| translation:str | |
| The translated text | |
| ''' | |
| # Ensure replicability | |
| set_seed(555) | |
| # Load model | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M",src_lang=f"{src_iso}_Latn") #use_auth_token=True, | |
| model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") #, use_auth_token=True) | |
| # Inference | |
| encoded = tokenizer(s, return_tensors="pt") | |
| translated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.convert_tokens_to_ids(f"{dest_iso}_Latn"), max_length=30) | |
| translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| return translation | |
| ######### ALL OF THE ABOVE ################## | |
| def translate(s, src_iso, dest_iso): | |
| ''' | |
| Translate the text using all available models (Meta, Masakhane, and Helsinki NLP where applicable). | |
| Parameters | |
| ---------- | |
| s: str | |
| The text | |
| src_iso: | |
| The ISO-3 code of the source language | |
| dest_iso: | |
| The ISO-3 code of the destination language | |
| Returns | |
| ---------- | |
| translation:str | |
| The translated text, concatenated over different models | |
| ''' | |
| # Ensure replicability | |
| start_time = time.time() | |
| # Translate with Meta NLLB | |
| translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso) | |
| # Check if the ISO pair is supported by another model and if so, add to translation | |
| iso_pair = f"{src_iso}-{dest_iso}" | |
| if iso_pair in ["mos-eng", 'eng-mos', 'fra-mos']: | |
| src_iso = src_iso.lower().replace("eng", "en").replace("fra", "fr") | |
| dest_iso = dest_iso.replace("eng", "en").replace("fra", "fr") | |
| translation+= f"\n\n\nHelsinkiNLP's Opus translation is:\n\n {translate_helsinki_nlp(s, src_iso, dest_iso)}" | |
| if iso_pair in ["mos-fra", "fra-mos"]: | |
| src_iso = src_iso.lower().replace("fra", "fr") | |
| dest_iso = dest_iso.replace("fra", "fr") | |
| translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso) | |
| print("Time elapsed: ", int(time.time() - start_time), " seconds") | |
| return translation | |