File size: 2,087 Bytes
c9d7167
 
65ed74c
c9d7167
 
7a6d259
 
f80fc89
fc36581
7a6d259
65ed74c
 
c9d7167
7a6d259
 
 
0760540
 
 
7a6d259
0760540
7a6d259
 
 
 
 
0760540
f80fc89
0760540
 
 
 
 
 
 
 
7a6d259
 
 
 
 
c4b718f
7a6d259
 
 
0760540
7a6d259
 
 
 
65ed74c
 
7a6d259
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import onnxruntime as ort
from transformers import MarianTokenizer
import gradio as gr

# Load tokenizer
tokenizer_path = "./onnx_model"
tokenizer = MarianTokenizer.from_pretrained(tokenizer_path)

# Load ONNX model
onnx_model_path = "./model.onnx"
session = ort.InferenceSession(onnx_model_path)

def translate(text, max_length=512):
    # Tokenize input
    inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=max_length)
    input_ids = inputs["input_ids"].astype(np.int64)
    attention_mask = inputs["attention_mask"].astype(np.int64)

    # Initialize decoder
    batch_size = input_ids.shape[0]
    decoder_input_ids = np.full((batch_size, 1), tokenizer.pad_token_id, dtype=np.int64)
    eos_reached = np.zeros(batch_size, dtype=bool)
    
    outputs = []
    
    for _ in range(max_length):
        onnx_outputs = session.run(
            None,
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "decoder_input_ids": decoder_input_ids,
            },
        )

        next_token_logits = onnx_outputs[0][:, -1, :]  # Get last token predictions
        next_tokens = np.argmax(next_token_logits, axis=-1).reshape(-1, 1)  # Select highest probability token
        
        # Append new tokens to decoder input
        decoder_input_ids = np.hstack([decoder_input_ids, next_tokens])

        # Stop if all sentences have reached EOS
        eos_reached |= (next_tokens == tokenizer.eos_token_id).flatten()
        if eos_reached.all():
            break
    
    # Decode output tokens
    translated_texts = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
    return translated_texts[0]

# Gradio interface
with gr.Blocks() as interface:
    input_text = gr.Textbox(label="Input Text")
    output_translation = gr.Textbox(label="Translation")
    translate_button = gr.Button("Translate")
    translate_button.click(fn=translate, inputs=[input_text], outputs=[output_translation])

if __name__ == "__main__":
    interface.launch()