b3x0m commited on
Commit
7f8af52
·
verified ·
1 Parent(s): f2feb8e

Create onnx_infer.py

Browse files
Files changed (1) hide show
  1. onnx_infer.py +115 -0
onnx_infer.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from transformers import AutoTokenizer
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+ tokenizer = AutoTokenizer.from_pretrained("b3x0m/xomdich-tokenizer")
9
+
10
+ onnx_model_path = "hyper-xomdich.onnx"
11
+ so = ort.SessionOptions()
12
+ so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
13
+ so.enable_mem_pattern = True
14
+ so.enable_cpu_mem_arena = True
15
+ session = ort.InferenceSession(onnx_model_path, so, providers=['CPUExecutionProvider', 'CUDAExecutionProvider'])
16
+
17
+ def prepare_input(text, tokenizer, max_length=512):
18
+ text = text + '。' if '。' not in text else text
19
+ encoder_inputs = tokenizer(
20
+ f"{text}</s>",
21
+ return_tensors="pt",
22
+ padding=True,
23
+ truncation=True,
24
+ max_length=max_length,
25
+ add_special_tokens=True
26
+ )
27
+
28
+ decoder_input = torch.tensor([[tokenizer.bos_token_id]])
29
+ src_len = encoder_inputs["input_ids"].size(1)
30
+ tgt_len = decoder_input.size(1)
31
+
32
+ encoder_attention_mask = encoder_inputs["attention_mask"].unsqueeze(1).unsqueeze(1)
33
+ encoder_attention_mask = encoder_attention_mask.expand(-1, -1, src_len, src_len)
34
+ encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
35
+
36
+ decoder_attention_mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1).bool()
37
+ decoder_attention_mask = decoder_attention_mask.unsqueeze(0).unsqueeze(0)
38
+ decoder_attention_mask = decoder_attention_mask.float() * -10000.0
39
+
40
+ cross_attention_mask = encoder_inputs["attention_mask"].unsqueeze(1).unsqueeze(1)
41
+ cross_attention_mask = cross_attention_mask.expand(-1, -1, tgt_len, src_len)
42
+ cross_attention_mask = (1.0 - cross_attention_mask) * -10000.0
43
+
44
+ inputs = {
45
+ "input_ids": encoder_inputs["input_ids"].numpy().astype(np.int64),
46
+ "attention_mask": encoder_attention_mask.numpy().astype(np.float32),
47
+ "decoder_input_ids": decoder_input.numpy().astype(np.int64),
48
+ "decoder_attention_mask": decoder_attention_mask.numpy().astype(np.float32),
49
+ "cross_attention_mask": cross_attention_mask.numpy().astype(np.float32)
50
+ }
51
+ return inputs, decoder_input, src_len
52
+
53
+ def stable_softmax(logits):
54
+ logits = logits - np.max(logits, axis=-1, keepdims=True)
55
+ exp_logits = np.exp(logits)
56
+ probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
57
+ return probs
58
+
59
+ def generate_translation_onnx(text, session=session, tokenizer=tokenizer, max_length=512, temperature=0.1):
60
+ inputs, decoder_input, src_len = prepare_input(text, tokenizer, max_length)
61
+ generated_tokens = [tokenizer.bos_token_id]
62
+
63
+ start_time = time.time()
64
+ for step in range(max_length):
65
+ outputs = session.run(["logits"], inputs)
66
+ logits = outputs[0]
67
+
68
+ # temperature sampling
69
+ next_token_logits = logits[:, -1, :] / temperature
70
+ probs = stable_softmax(next_token_logits)
71
+ next_token = np.random.choice(len(probs[0]), p=probs[0])
72
+
73
+ # greedy search
74
+ # next_token = logits[:, -1, :].argmax(-1)
75
+ # next_token = next_token.item()
76
+
77
+ generated_tokens.append(next_token)
78
+
79
+ if next_token == tokenizer.eos_token_id:
80
+ break
81
+
82
+ decoder_input = torch.tensor([generated_tokens])
83
+ tgt_len = decoder_input.size(1)
84
+
85
+ decoder_attention_mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1).bool()
86
+ decoder_attention_mask = decoder_attention_mask.unsqueeze(0).unsqueeze(0)
87
+ decoder_attention_mask = decoder_attention_mask.float() * -10000.0
88
+
89
+ cross_attention_mask = torch.ones((1, 1, tgt_len, src_len))
90
+ cross_attention_mask = (1.0 - cross_attention_mask) * -10000.0
91
+
92
+ inputs["decoder_input_ids"] = decoder_input.numpy().astype(np.int64)
93
+ inputs["decoder_attention_mask"] = decoder_attention_mask.numpy().astype(np.float32)
94
+ inputs["cross_attention_mask"] = cross_attention_mask.numpy().astype(np.float32)
95
+
96
+ duration = time.time() - start_time
97
+ speed = len(text) / duration if duration > 0 else 0
98
+
99
+ performance_info = f"Time: {duration:.2f}s | Speed: {speed:.2f} chars/s"
100
+
101
+ return tokenizer.decode(generated_tokens, skip_special_tokens=True), performance_info
102
+
103
+ def interactive_translation():
104
+ while True:
105
+ text = input("Text input (press q to exit): ")
106
+
107
+ if text.lower() == 'q':
108
+ break
109
+
110
+ translation, performance = generate_translation_onnx(text, session, tokenizer)
111
+ print(f"Translation: {translation}")
112
+ print(f"{performance}")
113
+
114
+ if __name__ == "__main__":
115
+ interactive_translation()