lookbzz commited on
Commit
9b38255
·
verified ·
1 Parent(s): 3b48904

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ rtdetr_msda.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ ssd_horse.jpg filter=lfs diff=lfs merge=lfs -text
axmodel_inference.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import onnxruntime
2
+ import axengine as axe
3
+
4
+ CLASS_NAMES = [
5
+ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
6
+ "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
7
+ "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
8
+ "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
9
+ "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
10
+ "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
11
+ "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
12
+ "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
13
+ "hair drier", "toothbrush"]
14
+
15
+
16
+ class axmodel_inferencer:
17
+
18
+ def __init__(self, model_path) -> None:
19
+ # self.onnx_model_sess = onnxruntime.InferenceSession(model_path)
20
+ self.onnx_model_sess = axe.InferenceSession(model_path)
21
+ self.output_names = []
22
+ self.input_names = []
23
+ print(model_path)
24
+ for i in range(len(self.onnx_model_sess.get_inputs())):
25
+ self.input_names.append(self.onnx_model_sess.get_inputs()[i].name)
26
+ print(" input:", i,
27
+ self.onnx_model_sess.get_inputs()[i].name,
28
+ self.onnx_model_sess.get_inputs()[i].shape)
29
+
30
+ for i in range(len(self.onnx_model_sess.get_outputs())):
31
+ self.output_names.append(
32
+ self.onnx_model_sess.get_outputs()[i].name)
33
+ print(" output:", i,
34
+ self.onnx_model_sess.get_outputs()[i].name,
35
+ self.onnx_model_sess.get_outputs()[i].shape)
36
+ print("")
37
+
38
+ def get_input_count(self):
39
+ return len(self.input_names)
40
+
41
+ def get_input_shape(self, idx: int):
42
+ return self.onnx_model_sess.get_inputs()[idx].shape
43
+
44
+ def get_input_names(self):
45
+ return self.input_names
46
+
47
+ def get_output_count(self):
48
+ return len(self.output_names)
49
+
50
+ def get_output_shape(self, idx: int):
51
+ return self.onnx_model_sess.get_outputs()[idx].shape
52
+
53
+ def get_output_names(self):
54
+ return self.output_names
55
+
56
+ def inference(self, tensor):
57
+ return self.onnx_model_sess.run(
58
+ self.output_names, input_feed={self.input_names[0]: tensor})
59
+
60
+ def inference_multi_input(self, tensors: list):
61
+ inputs = dict()
62
+ for idx, tensor in enumerate(tensors):
63
+ inputs[self.input_names[idx]] = tensor
64
+ return self.onnx_model_sess.run(input_feed=inputs)
65
+
66
+ def numpy_sigmoid(self,x):
67
+ """
68
+ 用NumPy实现的sigmoid函数
69
+
70
+ 参数:
71
+ x (np.ndarray): 输入数组
72
+
73
+ 返回:
74
+ np.ndarray: 经过sigmoid处理后的数组
75
+ """
76
+ return 1 / (1 + np.exp(-x))
77
+
78
+
79
+
80
+
81
+ if __name__ == "__main__":
82
+ axmodel_model_path = "rtdetr_msda.axmodel"
83
+ test_model = axmodel_inferencer(axmodel_model_path)
84
+
85
+ # import onnxruntime as ort
86
+ from PIL import Image, ImageDraw
87
+ # from torchvision.transforms import ToTensor
88
+ import numpy as np
89
+ # import torch
90
+
91
+ # # print(onnx.helper.printable_graph(mm.graph))
92
+
93
+
94
+ image = Image.open('ssd_horse.jpg').convert('RGB')
95
+ im = image.resize((640, 640))
96
+ im_data = np.array([im])
97
+ print(im_data.shape)
98
+
99
+ pred_logits,pred_boxes = test_model.inference(im_data)
100
+
101
+ pred_logits = np.array(pred_logits)
102
+ pred_boxes = np.array(pred_boxes)
103
+ print(pred_boxes.shape,pred_logits.shape)
104
+
105
+
106
+ # pred_logits = 1/(1+np.exp(-pred_logits))
107
+
108
+ pred_logits = test_model.numpy_sigmoid(pred_logits)
109
+
110
+
111
+ # print(pred["pred_logits"].shape,pred["pred_boxes"].shape)
112
+ # argmax = torch.argmax(pred_logits,2).reshape(-1)
113
+ argmax = np.argmax(pred_logits, axis=2).reshape(-1)
114
+ print(argmax.shape)
115
+
116
+ # pred_logits = pred["pred_logits"]
117
+ # pred_boxes = pred["pred_boxes"]
118
+ draw = ImageDraw.Draw(image)
119
+
120
+ for i,idx in enumerate(argmax):
121
+ score = pred_logits[0,i,idx]
122
+ if score > 0.6:
123
+ print(score,idx)
124
+ bbox = pred_boxes[0,i]
125
+ print(bbox)
126
+ cx,cy,w,h = bbox
127
+ x0 = (cx-0.5*w)*image.width
128
+ y0 = (cy-0.5*h)*image.height
129
+ x1 = (cx+0.5*w)*image.width
130
+ y1 = (cy+0.5*h)*image.height
131
+ draw.rectangle([x0,y0,x1,y1],outline="red")
132
+ draw.text([x0,y0],CLASS_NAMES[idx]+" %.2f"%score)
133
+ image.save("output.jpg")
134
+
135
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ onnx==1.14.0
4
+ onnxruntime==1.15.1
5
+ pycocotools
6
+ PyYAML
7
+ scipy
8
+ transformers
rtdetr_msda.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aa8573d79dff26d54eba74c6ac835296a39ef5439b365beb58578d7275b07c3
3
+ size 22428394
rtdetr_r18vd_5x_coco_objects365_from_paddle_opt.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5ffbfa35923b2d28b7764f2f8c559e4bf32ba5cbb6826c777fe63dfac632565
3
+ size 81191543
ssd_horse.jpg ADDED

Git LFS Details

  • SHA256: ed22f6b4c8c33e50e391e089ede14e8fa9402c623b09dbcf010e804770698fbb
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB