nllb-200-coreml-256 / example.py
cstr's picture
Upload NLLB-200 CoreML 256-token models with tokenizer
0dc7561 verified
import coremltools as ct
import numpy as np
from transformers import AutoTokenizer
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def translate_text(text, source_lang="eng_Latn", target_lang="deu_Latn"):
"""
Translate text using CoreML models
Args:
text: Text to translate (up to ~150-180 words)
source_lang: Source language code (e.g., "eng_Latn", "fra_Latn")
target_lang: Target language code (e.g., "deu_Latn", "spa_Latn")
Returns:
Translated text
"""
MAX_LEN = 256
# Load models (do this once, reuse for multiple translations)
encoder = ct.models.MLModel("NLLB_Encoder_256.mlpackage",
compute_units=ct.ComputeUnit.ALL)
decoder = ct.models.MLModel("NLLB_Decoder_256.mlpackage",
compute_units=ct.ComputeUnit.ALL)
# Load tokenizer from local directory
tokenizer = AutoTokenizer.from_pretrained("./tokenizer")
tokenizer.src_lang = source_lang
# Encode
inputs = tokenizer(text, return_tensors="np",
padding="max_length",
max_length=MAX_LEN,
truncation=True)
enc_outputs = encoder.predict({
"input_ids": inputs["input_ids"].astype(np.int32),
"attention_mask": inputs["attention_mask"].astype(np.int32)
})
encoder_hidden_states = enc_outputs[list(enc_outputs.keys())[0]]
# Decode
forced_bos = tokenizer.convert_tokens_to_ids(target_lang)
current_tokens = [2, forced_bos]
for i in range(MAX_LEN - 2):
decoder_input = np.full((1, MAX_LEN), tokenizer.pad_token_id, dtype=np.int32)
decoder_input[0, :len(current_tokens)] = current_tokens
dec_outputs = decoder.predict({
"decoder_input_ids": decoder_input,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": inputs["attention_mask"].astype(np.int32)
})
logits = dec_outputs[list(dec_outputs.keys())[0]]
next_token = int(np.argmax(logits[0, len(current_tokens) - 1, :]))
if next_token == 2:
break
current_tokens.append(next_token)
return tokenizer.decode(current_tokens[2:], skip_special_tokens=True)
if __name__ == "__main__":
# Example usage
text = "Hello, how are you today?"
translation = translate_text(text, source_lang="eng_Latn", target_lang="deu_Latn")
print(f"English: {text}")
print(f"German: {translation}")