Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import onnxruntime as ort | |
| from transformers import MarianTokenizer | |
| import gradio as gr | |
| # Load tokenizer | |
| tokenizer_path = "./onnx_model" | |
| tokenizer = MarianTokenizer.from_pretrained(tokenizer_path) | |
| # Load ONNX model | |
| onnx_model_path = "./model.onnx" | |
| session = ort.InferenceSession(onnx_model_path) | |
| def translate(text, max_length=512): | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=max_length) | |
| input_ids = inputs["input_ids"].astype(np.int64) | |
| attention_mask = inputs["attention_mask"].astype(np.int64) | |
| # Initialize decoder | |
| batch_size = input_ids.shape[0] | |
| decoder_input_ids = np.full((batch_size, 1), tokenizer.pad_token_id, dtype=np.int64) | |
| eos_reached = np.zeros(batch_size, dtype=bool) | |
| outputs = [] | |
| for _ in range(max_length): | |
| onnx_outputs = session.run( | |
| None, | |
| { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "decoder_input_ids": decoder_input_ids, | |
| }, | |
| ) | |
| next_token_logits = onnx_outputs[0][:, -1, :] # Get last token predictions | |
| next_tokens = np.argmax(next_token_logits, axis=-1).reshape(-1, 1) # Select highest probability token | |
| # Append new tokens to decoder input | |
| decoder_input_ids = np.hstack([decoder_input_ids, next_tokens]) | |
| # Stop if all sentences have reached EOS | |
| eos_reached |= (next_tokens == tokenizer.eos_token_id).flatten() | |
| if eos_reached.all(): | |
| break | |
| # Decode output tokens | |
| translated_texts = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True) | |
| return translated_texts[0] | |
| # Gradio interface | |
| with gr.Blocks() as interface: | |
| input_text = gr.Textbox(label="Input Text") | |
| output_translation = gr.Textbox(label="Translation") | |
| translate_button = gr.Button("Translate") | |
| translate_button.click(fn=translate, inputs=[input_text], outputs=[output_translation]) | |
| if __name__ == "__main__": | |
| interface.launch() | |