dgarrett-synaptics commited on
Commit
65118ce
·
verified ·
1 Parent(s): bc1d16d

Delete tests/test_keras.py

Browse files
Files changed (1) hide show
  1. tests/test_keras.py +0 -210
tests/test_keras.py DELETED
@@ -1,210 +0,0 @@
1
- from numpy import absolute
2
- from torch import rand
3
- from torch.nn.init import uniform_
4
-
5
- from synet.base import askeras
6
-
7
-
8
- BATCH_SIZE = 2
9
- IN_CHANNELS = 5
10
- OUT_CHANNELS = 7
11
- SHAPES = [(i, i) for i in range(4, 8)]
12
- MAX_DIFF = -1
13
- TOLERANCE = 2e-4
14
-
15
-
16
- def diff_arr(out1, out2):
17
- """compare two arrays. Return the max difference."""
18
- if isinstance(out1, (list, tuple)):
19
- assert isinstance(out2, (list, tuple))
20
- return max(diff_arr(o1, o2) for o1, o2 in zip(out1, out2))
21
- assert all(s1 == s2 for s1, s2 in zip(out1.shape, out2.shape)), \
22
- (out1.shape, out2.shape)
23
- return absolute(out1 - out2).max()
24
-
25
-
26
- def t_actv_to_k(actv):
27
- if isinstance(actv, (tuple, list)):
28
- return [t_actv_to_k(a) for a in actv]
29
- if len(actv.shape) == 4:
30
- tp = 0, 2, 3, 1
31
- elif len(actv.shape) == 3:
32
- tp = 0, 2, 1
33
- elif len(actv.shape) == 2:
34
- tp = 0, 1
35
- return actv.detach().numpy().transpose(*tp)
36
-
37
-
38
- def k_to_numpy(actv):
39
- if isinstance(actv, (list, tuple)):
40
- return [k_to_numpy(k) for k in actv]
41
- if hasattr(actv, "numpy"):
42
- return actv.numpy()
43
- return actv
44
-
45
-
46
- def validate_layer(layer, torch_inp, **akwds):
47
- """Given synet layer, test on some torch input activations and
48
- return max error between two output activations
49
-
50
- """
51
- tout = layer(torch_inp[:])
52
- with askeras(imgsz=torch_inp[0].shape[-2:], **akwds):
53
- kout = k_to_numpy(layer(t_actv_to_k(torch_inp)))
54
- if isinstance(tout, dict):
55
- assert len(tout) == len(kout)
56
- return max(diff_arr(t_actv_to_k(tout[key]), kout[key])
57
- for key in tout)
58
- elif isinstance(tout, list):
59
- assert len(tout) == len(kout)
60
- return max(diff_arr(t_actv_to_k(t), k)
61
- for t, k in zip(tout, kout))
62
- return diff_arr(t_actv_to_k(tout), kout)
63
-
64
-
65
- def validate(layer, batch_size=BATCH_SIZE,
66
- in_channels=IN_CHANNELS, shapes=SHAPES, **akwds):
67
- """Run validate_layer on a set of random input shapes. Prints the max
68
- difference between all configurations.
69
-
70
- """
71
- for param in layer.parameters():
72
- uniform_(param, -1)
73
- max_diff = max(validate_layer(layer,
74
- [rand(batch_size, in_channels, *s)*2-1
75
- for s in shape]
76
- if len(shape) and isinstance(shape[0], tuple)
77
- else rand(batch_size, in_channels, *shape
78
- )*2-1,
79
- **akwds)
80
- for shape in shapes)
81
- print("max_diff:", max_diff)
82
- assert max_diff < TOLERANCE
83
-
84
-
85
- def test_conv2d():
86
- from synet.base import Conv2d
87
- print("testing Conv2d")
88
- in_channels = 12
89
- out_channels = 24
90
- for bias in True, False:
91
- for kernel, stride in ((1, 1), (2, 1), (3, 1), (3, 2), (4, 1),
92
- (4, 2), (4, 3)):
93
- for padding in True, False:
94
- for groups in 1, 2, 3:
95
- validate(Conv2d(in_channels, out_channels, kernel,
96
- stride, bias, padding),
97
- in_channels=in_channels)
98
-
99
- def test_dw_conv2d():
100
- from synet.layers import DepthwiseConv2d
101
- print("testing dw Conv2d")
102
- channels = 32
103
- for bias in True, False:
104
- for kernel, stride in ((1, 1), (2, 1), (3, 1), (3, 2), (4, 1),
105
- (4, 2), (4, 3)):
106
- for padding in True, False:
107
- validate(DepthwiseConv2d(channels, kernel,
108
- stride, bias, padding),
109
- in_channels=channels)
110
-
111
-
112
- def test_convtranspose():
113
- from synet.base import ConvTranspose2d
114
- validate(ConvTranspose2d(IN_CHANNELS, OUT_CHANNELS, 2, 2, 0, bias=True))
115
-
116
-
117
- def test_relu():
118
- from synet.base import ReLU
119
- validate(ReLU(.6))
120
-
121
-
122
- def test_upsample():
123
- from synet.base import Upsample
124
- for scale_factor in 1, 2, 3:
125
- for mode in Upsample.allowed_modes:
126
- validate(Upsample(scale_factor, mode))
127
-
128
-
129
- def test_globavgpool():
130
- from synet.base import GlobalAvgPool
131
- validate(GlobalAvgPool())
132
-
133
-
134
- def test_dropout():
135
- from synet.base import Dropout
136
- for p in 0.0, 0.5, 1.0:
137
- for inplace in True, False:
138
- layer = Dropout(p, inplace=inplace)
139
- layer.eval()
140
- validate(layer)
141
-
142
-
143
- def test_linear():
144
- from synet.base import Linear
145
- for bias in True, False:
146
- validate(Linear(IN_CHANNELS, OUT_CHANNELS, bias), shapes=[()])
147
-
148
-
149
- def test_batchnorm():
150
- from synet.base import BatchNorm
151
- validate(BatchNorm(IN_CHANNELS), train=True)
152
-
153
-
154
- def test_ultralytics_detect():
155
- from synet.backends.ultralytics import Detect
156
- for sm_split in ((True, None), (2, True)):
157
- layer = Detect(80, (IN_CHANNELS, IN_CHANNELS), *sm_split)
158
- layer.eval()
159
- layer.export = True
160
- layer.format = "tflite"
161
- layer.stride[0], layer.stride[1] = 1, 2
162
- validate(layer,
163
- shapes=[((4, 6), (2, 3)),
164
- ((5, 7), (3, 4)),
165
- ((6, 8), (3, 4))],
166
- xywh=True)
167
-
168
-
169
- def test_ultralytics_pose():
170
- from synet.backends.ultralytics import Pose
171
- for sm_split in ((True, None), (2, True)):
172
- for kpt_shape in ([17, 2], [17, 3]):
173
- layer = Pose(80, kpt_shape, (IN_CHANNELS, IN_CHANNELS), *sm_split)
174
- layer.eval()
175
- layer.export = True
176
- layer.format = "tflite"
177
- layer.stride[0], layer.stride[1] = 1, 2
178
- validate(layer,
179
- shapes=[((4, 6), (2, 3)),
180
- ((5, 7), (3, 4)),
181
- ((6, 8), (3, 4))],
182
- xywh=True)
183
-
184
-
185
- def test_ultralytics_segment():
186
- from synet.backends.ultralytics import Segment
187
- layer = Segment(nc=80, nm=32, npr=256, ch=(IN_CHANNELS, IN_CHANNELS))
188
- layer.eval()
189
- layer.export = True
190
- layer.format = "tflite"
191
- layer.stride[0], layer.stride[1] = 1, 2
192
- validate(layer,
193
- shapes=[(( 4, 4), (2, 2)),
194
- (( 8, 8), (4, 4)),
195
- ((12, 12), (6, 6))],
196
- xywh=True)
197
-
198
-
199
- def test_ultralytics_classify():
200
- from synet.backends.ultralytics import Classify
201
- layer = Classify(None, c1=IN_CHANNELS, c2=OUT_CHANNELS)
202
- layer.eval()
203
- layer.export = True
204
- layer.format = 'tflite'
205
- validate(layer)
206
-
207
-
208
- def test_channelslice():
209
- from synet.base import ChannelSlice
210
- validate(ChannelSlice(slice(4, 8)), in_channels=12)