Frenchizer commited on
Commit
0760540
·
verified ·
1 Parent(s): 80505e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -43
app.py CHANGED
@@ -4,58 +4,65 @@ from transformers import MarianTokenizer
4
  import gradio as gr
5
 
6
  # Load the tokenizer from the local folder
7
- model_path = "./onnx_model" # Path to the folder containing the model files
8
  tokenizer = MarianTokenizer.from_pretrained(model_path)
9
 
10
  # Load the ONNX model
11
  onnx_model_path = "./model.onnx"
12
  session = ort.InferenceSession(onnx_model_path)
13
 
14
- def translate_text(input_texts):
15
- # Tokenize input texts (batch processing)
16
- tokenized_input = tokenizer(
17
- input_texts, return_tensors="np", padding=True, truncation=True, max_length=512
18
- )
19
- input_ids = tokenized_input["input_ids"]
20
- attention_mask = tokenized_input["attention_mask"]
21
-
22
- # Define the decoder start token ID
23
- decoder_start_token_id = tokenizer.pad_token_id # Use pad_token_id as the decoder start token
24
- decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
25
-
26
- # Prepare inputs for ONNX model
27
- ort_inputs = {
28
- "input_ids": input_ids.astype(np.int64),
29
- "attention_mask": attention_mask.astype(np.int64),
30
- "decoder_input_ids": decoder_input_ids,
31
- }
32
-
33
- # Run inference using the ONNX model
34
- ort_outputs = session.run(None, ort_inputs)
35
- output_ids = ort_outputs[0] # Get the output token IDs
36
-
37
- # Debug: Inspect the structure of output_ids
38
- print("Output IDs shape:", output_ids.shape)
39
- print("Output IDs:", output_ids)
40
-
41
- # Ensure output_ids is in the correct format (2D array)
42
- if isinstance(output_ids, list):
43
- output_ids = np.array(output_ids) # Convert list to numpy array if necessary
44
- if output_ids.ndim > 2:
45
- output_ids = output_ids.squeeze(0) # Remove extra dimensions if present
46
-
47
- # Decode the output tokens
48
- translated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
49
- return translated_texts
 
 
50
 
51
  # Gradio interface
 
 
 
 
 
52
  interface = gr.Interface(
53
- fn=translate_text,
54
- inputs="text",
55
- outputs="text",
56
- title="Frenchizer Translation Model",
57
- description="Translate text with MarianMT ONNX model and encoding by batches."
58
  )
59
 
60
- # Launch the interface
61
  interface.launch()
 
4
  import gradio as gr
5
 
6
  # Load the tokenizer from the local folder
7
+ model_path = "./onnx_model" # Path to the folder containing the tokenizer files
8
  tokenizer = MarianTokenizer.from_pretrained(model_path)
9
 
10
  # Load the ONNX model
11
  onnx_model_path = "./model.onnx"
12
  session = ort.InferenceSession(onnx_model_path)
13
 
14
+ def translate_text(input_texts, max_length=512):
15
+ # Tokenize the input texts
16
+ inputs = tokenizer(input_texts, return_tensors="np", padding=True, truncation=True, max_length=max_length)
17
+ input_ids = inputs["input_ids"].astype(np.int64)
18
+ attention_mask = inputs["attention_mask"].astype(np.int64)
19
+
20
+ # Initialize variables for decoding
21
+ batch_size = input_ids.shape[0]
22
+ decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64) # Start with pad token
23
+
24
+ # Generate output tokens iteratively
25
+ for _ in range(max_length):
26
+ # Run the ONNX model
27
+ ort_outputs = session.run(
28
+ None,
29
+ {
30
+ "input_ids": input_ids,
31
+ "attention_mask": attention_mask,
32
+ "decoder_input_ids": decoder_input_ids,
33
+ },
34
+ )
35
+
36
+ # Get the next token logits (output of the ONNX model)
37
+ next_token_logits = ort_outputs[0][:, -1, :] # Shape: (batch_size, vocab_size)
38
+
39
+ # Greedy decoding: select the token with the highest probability
40
+ next_tokens = np.argmax(next_token_logits, axis=-1) # Shape: (batch_size,)
41
+
42
+ # Append the next tokens to the decoder input for the next iteration
43
+ decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1)
44
+
45
+ # Stop if all sequences have reached the EOS token
46
+ if all(tokenizer.eos_token_id in sequence for sequence in decoder_input_ids):
47
+ break
48
+
49
+ # Decode the output tokens to text
50
+ translations = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
51
+ return translations
52
 
53
  # Gradio interface
54
+ def gradio_translate(input_texts):
55
+ translations = translate_text(input_texts)
56
+ return translations
57
+
58
+ # Create the Gradio interface
59
  interface = gr.Interface(
60
+ fn=gradio_translate,
61
+ inputs=gr.Textbox(lines=2, placeholder="Enter text to translate...", label="Input Text"),
62
+ outputs=gr.Textbox(label="Translated Text"),
63
+ title="ONNX English to French Translation",
64
+ description="Translate English text to French using a MarianMT ONNX model.",
65
  )
66
 
67
+ # Launch the Gradio app
68
  interface.launch()