|
|
import coremltools as ct |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer |
|
|
import os |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def translate_text(text, source_lang="eng_Latn", target_lang="deu_Latn"): |
|
|
""" |
|
|
Translate text using CoreML models |
|
|
|
|
|
Args: |
|
|
text: Text to translate |
|
|
source_lang: Source language code (e.g., "eng_Latn", "fra_Latn") |
|
|
target_lang: Target language code (e.g., "deu_Latn", "spa_Latn") |
|
|
|
|
|
Returns: |
|
|
Translated text |
|
|
""" |
|
|
MAX_LEN = 128 |
|
|
|
|
|
|
|
|
encoder = ct.models.MLModel("NLLB_Encoder_128.mlpackage", |
|
|
compute_units=ct.ComputeUnit.ALL) |
|
|
decoder = ct.models.MLModel("NLLB_Decoder_128.mlpackage", |
|
|
compute_units=ct.ComputeUnit.ALL) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./tokenizer") |
|
|
tokenizer.src_lang = source_lang |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="np", |
|
|
padding="max_length", |
|
|
max_length=MAX_LEN, |
|
|
truncation=True) |
|
|
|
|
|
enc_outputs = encoder.predict({ |
|
|
"input_ids": inputs["input_ids"].astype(np.int32), |
|
|
"attention_mask": inputs["attention_mask"].astype(np.int32) |
|
|
}) |
|
|
|
|
|
encoder_hidden_states = enc_outputs[list(enc_outputs.keys())[0]] |
|
|
|
|
|
|
|
|
forced_bos = tokenizer.convert_tokens_to_ids(target_lang) |
|
|
current_tokens = [2, forced_bos] |
|
|
|
|
|
for i in range(MAX_LEN - 2): |
|
|
decoder_input = np.full((1, MAX_LEN), tokenizer.pad_token_id, dtype=np.int32) |
|
|
decoder_input[0, :len(current_tokens)] = current_tokens |
|
|
|
|
|
dec_outputs = decoder.predict({ |
|
|
"decoder_input_ids": decoder_input, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"encoder_attention_mask": inputs["attention_mask"].astype(np.int32) |
|
|
}) |
|
|
|
|
|
logits = dec_outputs[list(dec_outputs.keys())[0]] |
|
|
next_token = int(np.argmax(logits[0, len(current_tokens) - 1, :])) |
|
|
|
|
|
if next_token == 2: |
|
|
break |
|
|
|
|
|
current_tokens.append(next_token) |
|
|
|
|
|
return tokenizer.decode(current_tokens[2:], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
text = "Hello, how are you today?" |
|
|
translation = translate_text(text, source_lang="eng_Latn", target_lang="deu_Latn") |
|
|
print(f"English: {text}") |
|
|
print(f"German: {translation}") |
|
|
|