lookbzz commited on
Commit
c99dcd5
·
verified ·
1 Parent(s): f55a4f2

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ axmodel/backbone_encoder.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ axmodel/decoder.axmodel filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,88 @@
1
- ---
2
- license: bsd-3-clause-clear
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # satrn
2
+
3
+ [original repo](https://github.com/open-mmlab/mmocr/blob/main/configs/textrecog/satrn/README.md)
4
+
5
+ ## Convert tools links:
6
+
7
+ For those who are interested in model conversion, you can try to export onnx or axmodel through
8
+
9
+ [satrn.axera](https://github.com/AXERA-TECH/satrn.axera)
10
+
11
+
12
+ ## Support Platform
13
+
14
+ - AX650
15
+ - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
16
+ - [M.2 Accelerator card](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
17
+
18
+
19
+ The speed measurements(under different NPU configurations ) of the two parts of SATRN:
20
+
21
+ (1) backbone+encoder
22
+
23
+ (2) decoder
24
+
25
+
26
+ ||backbone+encoder(ms)|decoder(ms)|
27
+ |--|--|--|
28
+ |NPU1|20.494|2.648|
29
+ |NPU2|9.785|1.504|
30
+ |NPU3|6.085|1.384|
31
+
32
+ ## How to use
33
+
34
+ Download all files from this repository to the device
35
+
36
+ ```
37
+ .
38
+ ├── axmodel
39
+ │ ├── backbone_encoder.axmodel
40
+ │ └── decoder.axmodel
41
+ ├── demo_text_recog.jpg
42
+ ├── onnx
43
+ │ ├── satrn_backbone_encoder.onnx
44
+ │ └── satrn_decoder_sim.onnx
45
+ ├── README.md
46
+ ├── run_axmodel.py
47
+ ├── run_model.py
48
+ └── run_onnx.py
49
+ ```
50
+
51
+ ### python env requirement
52
+
53
+ #### 1. pyaxengine
54
+
55
+ https://github.com/AXERA-TECH/pyaxengine
56
+
57
+ ```
58
+ wget https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.1rc0/axengine-0.1.1-py3-none-any.whl
59
+ pip install axengine-0.1.1-py3-none-any.whl
60
+ ```
61
+
62
+ #### 2. satrn
63
+
64
+ [satrn installation](https://github.com/open-mmlab/mmocr/tree/main?tab=readme-ov-file#installation)
65
+
66
+ #### Inference onnxmodel
67
+ ```
68
+ python run_onnx.py
69
+ ```
70
+ input:
71
+
72
+ ![](demo_text_recog.jpg)
73
+
74
+ output:
75
+
76
+ ```
77
+ pred_text: STAR
78
+ score: [0.9384028315544128, 0.9574984908103943, 0.9993689656257629, 0.9994958639144897]
79
+ ```
80
+
81
+ #### Inference with AX650 Host
82
+
83
+ ```
84
+ python run_axmodel.py
85
+ ```
86
+
87
+
88
+
axmodel/backbone_encoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca1bc3964ad5b7d57a2c5b08b0ca53619127501aed402f02829a53c26b021756
3
+ size 47589096
axmodel/decoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1914ff3a36e5e2d9d2e6174bdbb8e5c369374e3c8420e22c445771ab1406347e
3
+ size 27697793
demo_text_recog.jpg ADDED
onnx/satrn_backbone_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66870bd86006213fcb0e5db1d5b0e376d6a4f30c0e20e9f34cda66d3c259f39c
3
+ size 161383339
onnx/satrn_decoder_sim.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d417bc9618b68e4f691d3f9571c93f9d101b30da92e47924d4e25ad3e37f8198
3
+ size 101341850
run_axmodel.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmocr.apis import MMOCRInferencer
2
+ from mmocr.apis.inferencers.base_mmocr_inferencer import BaseMMOCRInferencer
3
+ import torch
4
+ from rich.progress import track
5
+ import torch.nn as nn
6
+ import axengine as axe
7
+ import numpy as np
8
+
9
+ onnx_bb_encoder = axe.InferenceSession("satrn_backbone_encoder.axmodel")
10
+ onnx_decoder = axe.InferenceSession("satrn_decoder.axmodel")
11
+
12
+
13
+ class BackboneEncoderOnly(nn.Module):
14
+ def __init__(self, original_model):
15
+ super().__init__()
16
+ # 保留 backbone 和 encoder
17
+ self.backbone = original_model.backbone
18
+ self.encoder = original_model.encoder
19
+
20
+ def forward(self, x):
21
+ x = self.backbone(x)
22
+ return self.encoder(x)
23
+
24
+
25
+ class DecoderOnly(nn.Module):
26
+ def __init__(self, original_model):
27
+ super().__init__()
28
+ # 保留 backbone 和 encoder
29
+ original_decoder = original_model.decoder
30
+ # self._attention = original_decoder._attention
31
+ self.classifier = original_decoder.classifier
32
+ self.trg_word_emb = original_decoder.trg_word_emb
33
+ self.position_enc = original_decoder.position_enc
34
+ self._get_target_mask = original_decoder._get_target_mask
35
+ self.dropout = original_decoder.dropout
36
+ self.layer_stack = original_decoder.layer_stack
37
+ self.layer_norm = original_decoder.layer_norm
38
+ self._get_source_mask = original_decoder._get_source_mask
39
+ self.postprocessor = original_decoder.postprocessor
40
+ self.start_idx = 90
41
+ self.padding_idx = 91
42
+ self.max_seq_len = 25
43
+ self.softmax = nn.Softmax(dim=-1)
44
+
45
+
46
+
47
+ def forward(self, trg_seq,src,src_mask,step):
48
+ # decoder_output = self._attention(init_target_seq, out_enc, src_mask=src_mask)
49
+ trg_embedding = self.trg_word_emb(trg_seq)
50
+ trg_pos_encoded = self.position_enc(trg_embedding)
51
+ trg_mask = self._get_target_mask(trg_seq)
52
+ tgt_seq = self.dropout(trg_pos_encoded)
53
+
54
+ output = tgt_seq
55
+ for dec_layer in self.layer_stack:
56
+ output = dec_layer(
57
+ output,
58
+ src,
59
+ self_attn_mask=trg_mask,
60
+ dec_enc_attn_mask=src_mask)
61
+ output = self.layer_norm(output)
62
+ # bsz * seq_len * C
63
+ step_result = self.classifier(output[:, step, :])
64
+ return step_result
65
+
66
+
67
+
68
+ def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ 对 uint8 张量进行标准化处理
71
+ 参数:
72
+ tensor: 输入张量,形状为 [3, 32, 100],数据类型为 uint8
73
+ 返回:
74
+ 标准化后的张量,形状不变,数据类型为 float32
75
+ """
76
+ # 检查输入张量的形状和数据类型
77
+ assert tensor.shape == (3, 32, 100), "输入张量形状必须为 [3, 32, 100]"
78
+ assert tensor.dtype == torch.uint8, "输入张量数据类型必须为 uint8"
79
+
80
+ # 转换为 float32 类型
81
+ tensor = tensor.float()
82
+
83
+ # 定义标准化参数(RGB 通道顺序)
84
+ mean = torch.tensor([123.675, 116.28, 103.53], dtype=torch.float32).view(3, 1, 1)
85
+ std = torch.tensor([58.395, 57.12, 57.375], dtype=torch.float32).view(3, 1, 1)
86
+
87
+ # 执行标准化:(x - mean) / std
88
+ normalized_tensor = (tensor - mean) / std
89
+
90
+ return normalized_tensor
91
+
92
+
93
+ infer = MMOCRInferencer(rec='satrn')
94
+ model = infer.textrec_inferencer.model
95
+ model.eval()
96
+ model.cpu()
97
+ input_path = 'mmor_demo/demo/demo_text_recog.jpg'
98
+ ori_inputs = infer._inputs_to_list([input_path])
99
+ base = BaseMMOCRInferencer(model='satrn')
100
+ chunked_inputs = base._get_chunk_data(ori_inputs, 1)
101
+ for ori_inputs in track(chunked_inputs, description='Inference'):
102
+ input = ori_inputs[0][1]
103
+ input_img = input['inputs']
104
+ input_image = normalize_tensor(input_img).unsqueeze(0)
105
+ input_sample = input['data_samples']
106
+
107
+ # backbone+encoder
108
+ model_backbone_encoder = BackboneEncoderOnly(model)
109
+ model_decoder = DecoderOnly(model)
110
+
111
+ # out_enc = model_backbone_encoder(input_image)
112
+ out_enc = onnx_bb_encoder.run(None, {"input": np.array(input_image.cpu())})[0]
113
+ out_enc = torch.tensor(out_enc)
114
+ data_samples = None
115
+
116
+ N = out_enc.size(0)
117
+ init_target_seq = torch.full((N, model_decoder.max_seq_len + 1),
118
+ model_decoder.padding_idx,
119
+ device=out_enc.device,
120
+ dtype=torch.long)
121
+ # bsz * seq_len
122
+ init_target_seq[:, 0] = model_decoder.start_idx
123
+
124
+ outputs = []
125
+ for step in range(0, model_decoder.max_seq_len):
126
+ valid_ratios = [1.0 for _ in range(out_enc.size(0))]
127
+ if data_samples is not None:
128
+ valid_ratios = []
129
+ for data_sample in data_samples:
130
+ valid_ratios.append(data_sample.get('valid_ratio'))
131
+
132
+ src_mask = model_decoder._get_source_mask(out_enc, valid_ratios)
133
+ # step_result = model_decoder(init_target_seq,out_enc,src_mask,step)
134
+ step_result = onnx_decoder.run(None,{'init_target_seq':np.array(init_target_seq),
135
+ 'out_enc':np.array(out_enc),
136
+ 'src_mask':np.array(src_mask),
137
+ 'step':np.array([step])})[0][0]
138
+ step_result = torch.tensor(step_result)
139
+ outputs.append(step_result)
140
+ _, step_max_index = torch.max(step_result, dim=-1)
141
+ init_target_seq[:, step + 1] = step_max_index
142
+ outputs = torch.stack(outputs, dim=1)
143
+ out_dec = model_decoder.softmax(outputs)
144
+ output = model_decoder.postprocessor(out_dec, [input_sample])
145
+ outstr = output[0].pred_text.item
146
+ outscore = output[0].pred_text.score
147
+
148
+ print('pred_text:',outstr)
149
+ print('score:',outscore)
150
+
151
+
152
+
153
+
154
+
run_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmocr.apis import MMOCRInferencer
2
+ from mmocr.apis.inferencers.base_mmocr_inferencer import BaseMMOCRInferencer
3
+ import torch
4
+ from rich.progress import track
5
+ import torch.nn as nn
6
+
7
+ class BackboneEncoderOnly(nn.Module):
8
+ def __init__(self, original_model):
9
+ super().__init__()
10
+ # 保留 backbone 和 encoder
11
+ self.backbone = original_model.backbone
12
+ self.encoder = original_model.encoder
13
+
14
+ def forward(self, x):
15
+ x = self.backbone(x)
16
+ return self.encoder(x)
17
+
18
+
19
+ class DecoderOnly(nn.Module):
20
+ def __init__(self, original_model):
21
+ super().__init__()
22
+ # 保留 backbone 和 encoder
23
+ original_decoder = original_model.decoder
24
+ # self._attention = original_decoder._attention
25
+ self.classifier = original_decoder.classifier
26
+ self.trg_word_emb = original_decoder.trg_word_emb
27
+ self.position_enc = original_decoder.position_enc
28
+ self._get_target_mask = original_decoder._get_target_mask
29
+ self.dropout = original_decoder.dropout
30
+ self.layer_stack = original_decoder.layer_stack
31
+ self.layer_norm = original_decoder.layer_norm
32
+ self._get_source_mask = original_decoder._get_source_mask
33
+ self.postprocessor = original_decoder.postprocessor
34
+ self.start_idx = 90
35
+ self.padding_idx = 91
36
+ self.max_seq_len = 25
37
+ self.softmax = nn.Softmax(dim=-1)
38
+
39
+
40
+
41
+ def forward(self, trg_seq,src,src_mask,step):
42
+ # decoder_output = self._attention(init_target_seq, out_enc, src_mask=src_mask)
43
+ trg_embedding = self.trg_word_emb(trg_seq)
44
+ trg_pos_encoded = self.position_enc(trg_embedding)
45
+ trg_mask = self._get_target_mask(trg_seq)
46
+ tgt_seq = self.dropout(trg_pos_encoded)
47
+
48
+ output = tgt_seq
49
+ for dec_layer in self.layer_stack:
50
+ output = dec_layer(
51
+ output,
52
+ src,
53
+ self_attn_mask=trg_mask,
54
+ dec_enc_attn_mask=src_mask)
55
+ output = self.layer_norm(output)
56
+ # bsz * seq_len * C
57
+ step_result = self.classifier(output[:, step, :])
58
+ return step_result
59
+
60
+
61
+ def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ 对 uint8 张量进行标准化处理
64
+ 参数:
65
+ tensor: 输入张量,形状为 [3, 32, 100],数据类型为 uint8
66
+ 返回:
67
+ 标准化后的张量,形状不变,数据类型为 float32
68
+ """
69
+ # 检查输入张量的形状和数据类型
70
+ assert tensor.shape == (3, 32, 100), "输入张量形状必须为 [3, 32, 100]"
71
+ assert tensor.dtype == torch.uint8, "输入张量数据类型必须为 uint8"
72
+
73
+ # 转换为 float32 类型
74
+ tensor = tensor.float()
75
+
76
+ # 定义标准化参数(RGB 通道顺序)
77
+ mean = torch.tensor([123.675, 116.28, 103.53], dtype=torch.float32).view(3, 1, 1)
78
+ std = torch.tensor([58.395, 57.12, 57.375], dtype=torch.float32).view(3, 1, 1)
79
+
80
+ # 执行标准化:(x - mean) / std
81
+ normalized_tensor = (tensor - mean) / std
82
+
83
+ return normalized_tensor
84
+
85
+
86
+ infer = MMOCRInferencer(rec='satrn')
87
+ model = infer.textrec_inferencer.model
88
+ model.eval()
89
+ model.cpu()
90
+ input_path = 'mmor_demo/demo/demo_text_recog.jpg'
91
+ ori_inputs = infer._inputs_to_list([input_path])
92
+ base = BaseMMOCRInferencer(model='satrn')
93
+ chunked_inputs = base._get_chunk_data(ori_inputs, 1)
94
+ for ori_inputs in track(chunked_inputs, description='Inference'):
95
+ input = ori_inputs[0][1]
96
+ input_img = input['inputs']
97
+ input_image = normalize_tensor(input_img).unsqueeze(0)
98
+ input_sample = input['data_samples']
99
+
100
+ # backbone+encoder
101
+ model_backbone_encoder = BackboneEncoderOnly(model)
102
+ model_decoder = DecoderOnly(model)
103
+
104
+ out_enc = model_backbone_encoder(input_image)
105
+ data_samples = None
106
+
107
+ N = out_enc.size(0)
108
+ init_target_seq = torch.full((N, model_decoder.max_seq_len + 1),
109
+ model_decoder.padding_idx,
110
+ device=out_enc.device,
111
+ dtype=torch.long)
112
+ # bsz * seq_len
113
+ init_target_seq[:, 0] = model_decoder.start_idx
114
+
115
+ outputs = []
116
+ for step in range(0, model_decoder.max_seq_len):
117
+ valid_ratios = [1.0 for _ in range(out_enc.size(0))]
118
+ if data_samples is not None:
119
+ valid_ratios = []
120
+ for data_sample in data_samples:
121
+ valid_ratios.append(data_sample.get('valid_ratio'))
122
+
123
+ src_mask = model_decoder._get_source_mask(out_enc, valid_ratios)
124
+ step_result = model_decoder(init_target_seq,out_enc,src_mask,step)
125
+ outputs.append(step_result)
126
+ _, step_max_index = torch.max(step_result, dim=-1)
127
+ init_target_seq[:, step + 1] = step_max_index
128
+ outputs = torch.stack(outputs, dim=1)
129
+ out_dec = model_decoder.softmax(outputs)
130
+ output = model_decoder.postprocessor(out_dec, [input_sample])
131
+ outstr = output[0].pred_text.item
132
+ outscore = output[0].pred_text.score
133
+
134
+ print('pred_text:',outstr)
135
+ print('score:',outscore)
136
+
137
+
138
+
139
+
140
+
run_onnx.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmocr.apis import MMOCRInferencer
2
+ from mmocr.apis.inferencers.base_mmocr_inferencer import BaseMMOCRInferencer
3
+ import torch
4
+ from rich.progress import track
5
+ import torch.nn as nn
6
+ import onnxruntime as ort
7
+ import numpy as np
8
+
9
+ onnx_bb_encoder = ort.InferenceSession("onnx/satrn_backbone_encoder.onnx")
10
+ onnx_decoder = ort.InferenceSession("onnx/satrn_decoder_sim.onnx")
11
+
12
+
13
+ class BackboneEncoderOnly(nn.Module):
14
+ def __init__(self, original_model):
15
+ super().__init__()
16
+ # 保留 backbone 和 encoder
17
+ self.backbone = original_model.backbone
18
+ self.encoder = original_model.encoder
19
+
20
+ def forward(self, x):
21
+ x = self.backbone(x)
22
+ return self.encoder(x)
23
+
24
+
25
+ class DecoderOnly(nn.Module):
26
+ def __init__(self, original_model):
27
+ super().__init__()
28
+ # 保留 backbone 和 encoder
29
+ original_decoder = original_model.decoder
30
+ # self._attention = original_decoder._attention
31
+ self.classifier = original_decoder.classifier
32
+ self.trg_word_emb = original_decoder.trg_word_emb
33
+ self.position_enc = original_decoder.position_enc
34
+ self._get_target_mask = original_decoder._get_target_mask
35
+ self.dropout = original_decoder.dropout
36
+ self.layer_stack = original_decoder.layer_stack
37
+ self.layer_norm = original_decoder.layer_norm
38
+ self._get_source_mask = original_decoder._get_source_mask
39
+ self.postprocessor = original_decoder.postprocessor
40
+ self.start_idx = 90
41
+ self.padding_idx = 91
42
+ self.max_seq_len = 25
43
+ self.softmax = nn.Softmax(dim=-1)
44
+
45
+
46
+
47
+ def forward(self, trg_seq,src,src_mask,step):
48
+ # decoder_output = self._attention(init_target_seq, out_enc, src_mask=src_mask)
49
+ trg_embedding = self.trg_word_emb(trg_seq)
50
+ trg_pos_encoded = self.position_enc(trg_embedding)
51
+ trg_mask = self._get_target_mask(trg_seq)
52
+ tgt_seq = self.dropout(trg_pos_encoded)
53
+
54
+ output = tgt_seq
55
+ for dec_layer in self.layer_stack:
56
+ output = dec_layer(
57
+ output,
58
+ src,
59
+ self_attn_mask=trg_mask,
60
+ dec_enc_attn_mask=src_mask)
61
+ output = self.layer_norm(output)
62
+ # bsz * seq_len * C
63
+ step_result = self.classifier(output[:, step, :])
64
+ return step_result
65
+
66
+
67
+
68
+ def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ 对 uint8 张量进行标准化处理
71
+ 参数:
72
+ tensor: 输入张量,形状为 [3, 32, 100],数据类型为 uint8
73
+ 返回:
74
+ 标准化后的张量,形状不变,数据类型为 float32
75
+ """
76
+ # 检查输入张量的形状和数据类型
77
+ assert tensor.shape == (3, 32, 100), "输入张量形状必须为 [3, 32, 100]"
78
+ assert tensor.dtype == torch.uint8, "输入张量数据类型必须为 uint8"
79
+
80
+ # 转换为 float32 类型
81
+ tensor = tensor.float()
82
+
83
+ # 定义标准化参数(RGB 通道顺序)
84
+ mean = torch.tensor([123.675, 116.28, 103.53], dtype=torch.float32).view(3, 1, 1)
85
+ std = torch.tensor([58.395, 57.12, 57.375], dtype=torch.float32).view(3, 1, 1)
86
+
87
+ # 执行标准化:(x - mean) / std
88
+ normalized_tensor = (tensor - mean) / std
89
+
90
+ return normalized_tensor
91
+
92
+
93
+ infer = MMOCRInferencer(rec='satrn')
94
+ model = infer.textrec_inferencer.model
95
+ model.eval()
96
+ model.cpu()
97
+ input_path = 'demo_text_recog.jpg'
98
+ ori_inputs = infer._inputs_to_list([input_path])
99
+ base = BaseMMOCRInferencer(model='satrn')
100
+ chunked_inputs = base._get_chunk_data(ori_inputs, 1)
101
+ for ori_inputs in track(chunked_inputs, description='Inference'):
102
+ input = ori_inputs[0][1]
103
+ input_img = input['inputs']
104
+ input_image = normalize_tensor(input_img).unsqueeze(0)
105
+ input_sample = input['data_samples']
106
+
107
+ # backbone+encoder
108
+ model_backbone_encoder = BackboneEncoderOnly(model)
109
+ model_decoder = DecoderOnly(model)
110
+
111
+ # out_enc = model_backbone_encoder(input_image)
112
+ out_enc = onnx_bb_encoder.run(None, {"input": np.array(input_image.cpu())})[0]
113
+ out_enc = torch.tensor(out_enc)
114
+ data_samples = None
115
+
116
+ N = out_enc.size(0)
117
+ init_target_seq = torch.full((N, model_decoder.max_seq_len + 1),
118
+ model_decoder.padding_idx,
119
+ device=out_enc.device,
120
+ dtype=torch.long)
121
+ # bsz * seq_len
122
+ init_target_seq[:, 0] = model_decoder.start_idx
123
+
124
+ outputs = []
125
+ for step in range(0, model_decoder.max_seq_len):
126
+ valid_ratios = [1.0 for _ in range(out_enc.size(0))]
127
+ if data_samples is not None:
128
+ valid_ratios = []
129
+ for data_sample in data_samples:
130
+ valid_ratios.append(data_sample.get('valid_ratio'))
131
+
132
+ src_mask = model_decoder._get_source_mask(out_enc, valid_ratios)
133
+ # step_result = model_decoder(init_target_seq,out_enc,src_mask,step)
134
+ step_result = onnx_decoder.run(None,{'init_target_seq':np.array(init_target_seq),
135
+ 'out_enc':np.array(out_enc),
136
+ 'src_mask':np.array(src_mask),
137
+ 'step':np.array([step])})[0][0]
138
+ step_result = torch.tensor(step_result)
139
+ outputs.append(step_result)
140
+ _, step_max_index = torch.max(step_result, dim=-1)
141
+ init_target_seq[:, step + 1] = step_max_index
142
+ outputs = torch.stack(outputs, dim=1)
143
+ out_dec = model_decoder.softmax(outputs)
144
+ output = model_decoder.postprocessor(out_dec, [input_sample])
145
+ outstr = output[0].pred_text.item
146
+ outscore = output[0].pred_text.score
147
+
148
+ print('pred_text:',outstr)
149
+ print('score:',outscore)
150
+
151
+
152
+
153
+
154
+