| --- |
| license: apache-2.0 |
| --- |
| ONNX format of voxreality/src_ctx_aware_nllb_600M model |
|
|
| Model inference example: |
|
|
| ```python |
| from optimum.onnxruntime import ORTModelForSeq2SeqLM |
| from transformers import AutoTokenizer,pipeline |
| |
| model_path = 'voxreality/src_ctx_aware_nllb_600M_onnx' |
| model = ORTModelForSeq2SeqLM.from_pretrained(model_path) |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
| max_length = 100 |
| src_lang = 'eng_Latn' |
| tgt_lang = 'deu_Latn' |
| context_text = 'This is an optional context sentence.' |
| sentence_text = 'Text to be translated.' |
| |
| # If the context is provided |
| input_text = f'{context_text} {tokenizer.sep_token} {sentence_text}' |
| # If no context is provided, you can use just the sentence_text as input |
| # input_text = sentence_text |
| |
| tokenizer.src_lang = src_lang |
| |
| inputs = tokenizer(input_text, return_tensors='pt') |
| |
| input = inputs.to('cpu') |
| |
| forced_bos_token_id = tokenizer.lang_code_to_id[tgt_lang] |
| |
| output = model.generate( |
| **inputs, |
| forced_bos_token_id=forced_bos_token_id, |
| max_length=max_length |
| ) |
| |
| output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0] |
| |
| print(output_text) |
| ``` |