Commit
·
8491e38
1
Parent(s):
a656d3e
Adding the main SyNET branches
Browse files- requirements.txt +9 -0
- synet/__init__.py +15 -0
- synet/__main__.py +14 -0
- synet/asymmetric.py +113 -0
- synet/backends/__init__.py +68 -0
- synet/backends/custom.py +82 -0
- synet/backends/ultralytics.py +404 -0
- synet/backends/yolov5.py +436 -0
- synet/base.py +1104 -0
- synet/data_subset.py +54 -0
- synet/demosaic.py +290 -0
- synet/katana.py +4 -0
- synet/layers.py +439 -0
- synet/legacy.py +151 -0
- synet/metrics.py +328 -0
- synet/quantize.py +143 -0
- synet/sabre.py +7 -0
- synet/tflite_utils.py +349 -0
- synet/ultralytics_patches.py +3 -0
- synet/zoo/__init__.py +35 -0
- synet/zoo/ultralytics/sabre-detect-vga.yaml +29 -0
- synet/zoo/ultralytics/sabre-keypoint-vga.yaml +29 -0
- synet/zoo/ultralytics/sabre-segment-vga.yaml +29 -0
- tests/test_demosaic.py +17 -0
- tests/test_keras.py +210 -0
- tests/test_ultralytics.py +18 -0
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
ultralytics
|
| 3 |
+
transformers
|
| 4 |
+
datasets
|
| 5 |
+
trl
|
| 6 |
+
peft
|
| 7 |
+
tf-keras
|
| 8 |
+
bitsandbytes
|
| 9 |
+
pytest
|
synet/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backends import get_backend
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = "backends", "base", "katana", "sabre", "quantize", "test", \
|
| 5 |
+
"metrics", "tflite_utils"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_model(model_path, backend, *args, **kwds):
|
| 9 |
+
"""Method to get the model. For now, only the katananet model is
|
| 10 |
+
supported in ultralytics format."""
|
| 11 |
+
|
| 12 |
+
print("loading", model_path)
|
| 13 |
+
|
| 14 |
+
backend = get_backend(backend)
|
| 15 |
+
return backend.get_model(model_path, *args, **kwds)
|
synet/__main__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module
|
| 2 |
+
from sys import argv, exit
|
| 3 |
+
|
| 4 |
+
import synet
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
if argv[1] in synet.__all__:
|
| 9 |
+
return import_module(f"synet.{argv.pop(1)}").main()
|
| 10 |
+
return import_module(f"synet.backends.{argv.pop(1)}").main()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
exit(main())
|
synet/asymmetric.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""asymetric.py diverges from base.py and layers.py in that its core
|
| 2 |
+
assumption is switched. In base.py/layers.py, the output of a module
|
| 3 |
+
in keras vs torch is identical, while asymetric modules act as the
|
| 4 |
+
identity function in keras. To get non-identity behavior in keras
|
| 5 |
+
mode, you should call module.clf(). 'clf' should be read as 'channels
|
| 6 |
+
last forward'; such methods take in and return a channels-last numpy
|
| 7 |
+
array.
|
| 8 |
+
|
| 9 |
+
The main use case for these modules is for uniform preprocessing to
|
| 10 |
+
bridge the gap between 'standard' training scenarios and actual
|
| 11 |
+
execution environments. So far, the main examples implemented are
|
| 12 |
+
conversions to grayscale, bayer, and camera augmented images. This
|
| 13 |
+
way, you can train your model on a standard RGB pipeline. The
|
| 14 |
+
resulting tflite model will not have these extra layers, and is ready
|
| 15 |
+
to operate on the raw input at deployment.
|
| 16 |
+
|
| 17 |
+
The cfl methods are mainly used for python demos where the sensor
|
| 18 |
+
still needs to be simulated, but not included in the model.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from os.path import join, dirname
|
| 23 |
+
from cv2 import GaussianBlur as cv2GaussianBlur
|
| 24 |
+
from numpy import array, interp, ndarray
|
| 25 |
+
from numpy.random import normal
|
| 26 |
+
from torch import empty, tensor, no_grad, rand
|
| 27 |
+
from torchvision.transforms import GaussianBlur
|
| 28 |
+
|
| 29 |
+
from .demosaic import Demosaic, UnfoldedDemosaic, Mosaic
|
| 30 |
+
from .base import askeras, Module
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Grayscale(Module):
|
| 34 |
+
"""Training frameworks often fix input channels to 3. This
|
| 35 |
+
grayscale layer can be added to the beginning of a model to convert to
|
| 36 |
+
grayscale. This layer is ignored when converting to tflite. The end
|
| 37 |
+
result is that the pytorch model can take any number of input
|
| 38 |
+
channels, but the tensorflow (tflite) model expects exactly one input
|
| 39 |
+
channel.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
if askeras.use_keras:
|
| 45 |
+
return x
|
| 46 |
+
return x.mean(1, keepdims=True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Camera(Module):
|
| 50 |
+
def __init__(self,
|
| 51 |
+
gamma,
|
| 52 |
+
bayer_pattern='gbrg',
|
| 53 |
+
from_bayer=False,
|
| 54 |
+
to_bayer=False,
|
| 55 |
+
ratio=(1, 1, 1),
|
| 56 |
+
blur_sigma=0.4,
|
| 57 |
+
noise_sigma=10/255):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.mosaic = Mosaic(bayer_pattern)
|
| 60 |
+
self.demosaic = UnfoldedDemosaic('malvar', bayer_pattern
|
| 61 |
+
).requires_grad_(False)
|
| 62 |
+
self.blur_sigma = blur_sigma
|
| 63 |
+
self.noise_sigma = noise_sigma
|
| 64 |
+
self.from_bayer = from_bayer
|
| 65 |
+
self.to_bayer = to_bayer
|
| 66 |
+
self.gamma = gamma
|
| 67 |
+
self.blur = GaussianBlur(3, blur_sigma)
|
| 68 |
+
self.ratio = ratio
|
| 69 |
+
|
| 70 |
+
def gamma_correction(self, image):
|
| 71 |
+
|
| 72 |
+
for yoff, xoff, chan in zip(self.mosaic.rows,
|
| 73 |
+
self.mosaic.cols,
|
| 74 |
+
self.mosaic.bayer_pattern):
|
| 75 |
+
# the gamma correction (from experiments) is channel dependent
|
| 76 |
+
image[yoff::2, xoff::2] = ((image[yoff::2, xoff::2]
|
| 77 |
+
) ** self.gamma[chan])
|
| 78 |
+
return image
|
| 79 |
+
|
| 80 |
+
#@no_grad
|
| 81 |
+
def forward(self, im):
|
| 82 |
+
if askeras.use_keras:
|
| 83 |
+
return im
|
| 84 |
+
if not self.from_bayer:
|
| 85 |
+
im = self.mosaic(im)
|
| 86 |
+
if rand(1) < self.ratio[0]:
|
| 87 |
+
im = self.blur(im)
|
| 88 |
+
if rand(1) < self.ratio[1]:
|
| 89 |
+
im = self.gamma_correction(im)
|
| 90 |
+
if rand(1) < self.ratio[2]:
|
| 91 |
+
this_noise_sigma, = empty(1).normal_(self.noise_sigma, 2/255)
|
| 92 |
+
im += empty(im.shape, device=im.device
|
| 93 |
+
).normal_(0.0, max(0, this_noise_sigma))
|
| 94 |
+
if not self.to_bayer:
|
| 95 |
+
im = self.demosaic(im)
|
| 96 |
+
return im.clip(0, 1)
|
| 97 |
+
|
| 98 |
+
def clf(self, im):
|
| 99 |
+
assert False, "didn't update this function after refactor"
|
| 100 |
+
# augmentation should always be done on bayer image.
|
| 101 |
+
if not self.from_bayer:
|
| 102 |
+
im = self.mosaic.clf(im)
|
| 103 |
+
# let the noise level vary
|
| 104 |
+
this_noise_sigma = normal(self.noise_sigma, 2)
|
| 105 |
+
# if you blur too much, the image becomes grayscale
|
| 106 |
+
im = cv2GaussianBlur(im, [3, 3], self.blur_sigma)
|
| 107 |
+
im = self.map_to_linear(im)
|
| 108 |
+
# GaussianBlur likes to remove singleton channel dimension
|
| 109 |
+
im = im[..., None] + normal(0.0, this_noise_sigma, im.shape + (1,))
|
| 110 |
+
# depending on scenario, you may not want to return an RGB image.
|
| 111 |
+
if not self.to_bayer:
|
| 112 |
+
im = self.demosaic.clf(im)
|
| 113 |
+
return im.clip(0, 255)
|
synet/backends/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module
|
| 2 |
+
from shutil import copy
|
| 3 |
+
|
| 4 |
+
from ..zoo import in_zoo, get_config, get_configs, get_weights
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Backend:
|
| 8 |
+
def get_model(self, model_path):
|
| 9 |
+
"""Load model from config, or pretrained save."""
|
| 10 |
+
raise NotImplementedError("Please subclass and implement")
|
| 11 |
+
|
| 12 |
+
def get_shape(self, model):
|
| 13 |
+
"""Get shape of model."""
|
| 14 |
+
raise NotImplementedError("Please subclass and implement")
|
| 15 |
+
|
| 16 |
+
def patch(self):
|
| 17 |
+
"""Initialize backend to utilize Synet Modules"""
|
| 18 |
+
raise NotImplementedError("Please subclass and implement")
|
| 19 |
+
|
| 20 |
+
def val_post(self, weights, tflite, val_post, conf_thresh=.25,
|
| 21 |
+
iou_thresh=.7):
|
| 22 |
+
"""Default conf_thresh and iou_thresh (.25 and .75 resp.)
|
| 23 |
+
taken from ultralytics/cfg/default.yaml.
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
raise NotImplementedError("Please subclass and implement")
|
| 27 |
+
|
| 28 |
+
def tf_post(self, tflite, val_post, conf_thresh, iou_thresh):
|
| 29 |
+
"""Loads the tflite, loads the image, preprocesses the image,
|
| 30 |
+
evaluates the tflite on the pre-processed image, and performs
|
| 31 |
+
post-processing on the tflite output with a given confidence
|
| 32 |
+
and iou threshold.
|
| 33 |
+
|
| 34 |
+
:param tflite: Path to tflite file, or a raw tflite buffer
|
| 35 |
+
:param val_post: Path to image to evaluate on.
|
| 36 |
+
:param conf_thresh: Confidence threshould. See val_post docstring
|
| 37 |
+
above for default value details.
|
| 38 |
+
:param iou_thresh: IoU threshold for NMS. See val_post docstring
|
| 39 |
+
above for default value details.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
raise NotImplementedError("Please subclass and implement")
|
| 43 |
+
|
| 44 |
+
def get_chip(self, model):
|
| 45 |
+
"""Get chip of model."""
|
| 46 |
+
raise NotImplementedError("Please subclass and implement")
|
| 47 |
+
|
| 48 |
+
def maybe_grab_from_zoo(self, model_path):
|
| 49 |
+
if in_zoo(model_path, self.name):
|
| 50 |
+
copy(get_config(model_path, self.name), model_path)
|
| 51 |
+
elif model_path.endswith(".pt") or model_path.endswith(".tflite"):
|
| 52 |
+
get_weights(model_path, self.name)
|
| 53 |
+
return model_path
|
| 54 |
+
|
| 55 |
+
def get_configs(self):
|
| 56 |
+
return get_configs(self.name)
|
| 57 |
+
|
| 58 |
+
def get_data(self, data):
|
| 59 |
+
"""return a {split:files} where files is either a string or
|
| 60 |
+
list of strings denoting path(s) to file(s) or
|
| 61 |
+
directory(ies). Files should be newline-seperated lists of
|
| 62 |
+
image paths, and directories should (recursively) contain only
|
| 63 |
+
images or directories."""
|
| 64 |
+
raise NotImplementedError("Please subclass and implement")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_backend(name):
|
| 68 |
+
return import_module(f".{name}", __name__).Backend()
|
synet/backends/custom.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import askeras
|
| 2 |
+
|
| 3 |
+
from object_detection.models.tf import TFDetect as PC_TFDetect
|
| 4 |
+
from tensorflow.math import ceil
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
class TFDetect(PC_TFDetect):
|
| 7 |
+
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None):
|
| 8 |
+
super().__init__(nc, anchors, ch, imgsz, w)
|
| 9 |
+
for i in range(self.nl):
|
| 10 |
+
ny, nx = (ceil(self.imgsz[0] / self.stride[i]),
|
| 11 |
+
ceil(self.imgsz[1] / self.stride[i]))
|
| 12 |
+
self.grid[i] = self._make_grid(nx, ny)
|
| 13 |
+
|
| 14 |
+
# copy call method, but replace // with ceil div
|
| 15 |
+
def call(self, inputs):
|
| 16 |
+
if askeras.kwds.get('deploy'):
|
| 17 |
+
return self.deploy(inputs)
|
| 18 |
+
z = [] # inference output
|
| 19 |
+
x = []
|
| 20 |
+
for i in range(self.nl):
|
| 21 |
+
x.append(self.m[i](inputs[i]))
|
| 22 |
+
# x(bs,20,20,255) to x(bs,3,20,20,85)
|
| 23 |
+
ny, nx = (ceil(self.imgsz[0] / self.stride[i]),
|
| 24 |
+
ceil(self.imgsz[1] / self.stride[i]))
|
| 25 |
+
x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
|
| 26 |
+
|
| 27 |
+
if not self.training: # inference
|
| 28 |
+
y = tf.sigmoid(x[i])
|
| 29 |
+
xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
|
| 30 |
+
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
|
| 31 |
+
# Normalize xywh to 0-1 to reduce calibration error
|
| 32 |
+
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
| 33 |
+
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
| 34 |
+
y = tf.concat([xy, wh, y[..., 4:]], -1)
|
| 35 |
+
# y = tf.concat([xy, wh, y[..., 4:]], 3)
|
| 36 |
+
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
|
| 37 |
+
|
| 38 |
+
return x if self.training else (tf.concat(z, 1), x)
|
| 39 |
+
|
| 40 |
+
def deploy(self, inputs):
|
| 41 |
+
assert inputs[0].shape[0] == 1, 'requires batch_size == 1'
|
| 42 |
+
box1, box2, cls = [], [], []
|
| 43 |
+
for mi, xi, gi, ai, si in zip(self.m, inputs, self.grid, self.anchor_grid, self.stride):
|
| 44 |
+
x = tf.reshape(tf.sigmoid(mi(xi)), (1, -1, self.na, self.no))
|
| 45 |
+
xy = (x[..., 0:2] * 2 + (tf.transpose(gi, (0, 2, 1, 3)) - .5)) * si
|
| 46 |
+
wh = (x[..., 2:4] * 2) ** 2 * tf.transpose(ai, (0, 2, 1, 3))
|
| 47 |
+
box1.append(tf.reshape(xy - wh/2, (1, -1, 2)))
|
| 48 |
+
box2.append(tf.reshape(xy + wh/2, (1, -1, 2)))
|
| 49 |
+
cls.append(tf.reshape(x[..., 4:5]*x[..., 5:], (1, -1, x.shape[-1]-5)))
|
| 50 |
+
return (tf.concat(box1, 1, name='box1'),
|
| 51 |
+
tf.concat(box2, 1, name='box2'),
|
| 52 |
+
tf.concat(cls, 1, name='cls'))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
from object_detection.models.yolo import Detect as PC_PTDetect
|
| 56 |
+
class Detect(PC_PTDetect):
|
| 57 |
+
def __init__(self, *args, **kwds):
|
| 58 |
+
if len(args) == 4:
|
| 59 |
+
args = args[:3]
|
| 60 |
+
# construct normally
|
| 61 |
+
super().__init__(*args, **kwds)
|
| 62 |
+
# save args/kwargs for later construction of TF model
|
| 63 |
+
self.args = args
|
| 64 |
+
self.kwds = kwds
|
| 65 |
+
def forward(self, x, theta=None):
|
| 66 |
+
if askeras.use_keras:
|
| 67 |
+
assert theta is None
|
| 68 |
+
return self.as_keras(x)
|
| 69 |
+
return super().forward(x, theta=theta)
|
| 70 |
+
def as_keras(self, x):
|
| 71 |
+
return TFDetect(*self.args, imgsz=askeras.kwds["imgsz"],
|
| 72 |
+
w=self, **self.kwds
|
| 73 |
+
)(x)
|
| 74 |
+
|
| 75 |
+
from object_detection.models import yolo
|
| 76 |
+
from importlib import import_module
|
| 77 |
+
def patch_custom(chip):
|
| 78 |
+
# patch custom.models.yolo
|
| 79 |
+
module = import_module(f'..{chip}', __name__)
|
| 80 |
+
setattr(yolo, chip, module)
|
| 81 |
+
yolo.Concat = module.Cat
|
| 82 |
+
yolo.Detect = module.Detect = Detect
|
synet/backends/ultralytics.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from importlib import import_module
|
| 3 |
+
from sys import argv
|
| 4 |
+
|
| 5 |
+
from cv2 import imread, imwrite, resize
|
| 6 |
+
from numpy import array
|
| 7 |
+
from torch import tensor
|
| 8 |
+
from torch.nn import ModuleList
|
| 9 |
+
from ultralytics import YOLO
|
| 10 |
+
from ultralytics.data.utils import check_det_dataset, check_cls_dataset
|
| 11 |
+
from ultralytics.engine import validator, predictor, trainer
|
| 12 |
+
from ultralytics.engine.results import Results
|
| 13 |
+
from ultralytics.models.yolo import model as yolo_model
|
| 14 |
+
from ultralytics.nn import tasks
|
| 15 |
+
from ultralytics.nn.autobackend import AutoBackend
|
| 16 |
+
from ultralytics.nn.modules.block import DFL as Torch_DFL, Proto as Torch_Proto
|
| 17 |
+
from ultralytics.nn.modules.head import (Pose as Torch_Pose,
|
| 18 |
+
Detect as Torch_Detect,
|
| 19 |
+
Segment as Torch_Segment,
|
| 20 |
+
Classify as Torch_Classify)
|
| 21 |
+
from ultralytics.utils import dist
|
| 22 |
+
from ultralytics.utils.ops import non_max_suppression, process_mask
|
| 23 |
+
from ultralytics.utils.checks import check_imgsz
|
| 24 |
+
|
| 25 |
+
from . import Backend as BaseBackend
|
| 26 |
+
from ..base import (askeras, Conv2d, ReLU, Upsample, GlobalAvgPool,
|
| 27 |
+
Dropout, Linear)
|
| 28 |
+
from .. import layers
|
| 29 |
+
from .. import asymmetric
|
| 30 |
+
from ..layers import Sequential, CoBNRLU
|
| 31 |
+
from ..tflite_utils import tf_run, concat_reshape
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DFL(Torch_DFL):
|
| 35 |
+
def __init__(self, c1=16, sm_split=None):
|
| 36 |
+
super().__init__(c1)
|
| 37 |
+
weight = self.conv.weight
|
| 38 |
+
self.conv = Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
| 39 |
+
self.conv.conv.weight.data[:] = weight.data
|
| 40 |
+
self.sm_split = sm_split
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
if askeras.use_keras:
|
| 44 |
+
return self.as_keras(x)
|
| 45 |
+
return super().forward(x)
|
| 46 |
+
|
| 47 |
+
def as_keras(self, x):
|
| 48 |
+
# b, ay, ax, c = x.shape
|
| 49 |
+
from tensorflow.keras.layers import Reshape, Softmax
|
| 50 |
+
if hasattr(self, "sm_split") and self.sm_split is not None:
|
| 51 |
+
from tensorflow.keras.layers import Concatenate
|
| 52 |
+
assert not (x.shape[0]*x.shape[1]*x.shape[2]*4) % self.sm_split
|
| 53 |
+
x = Reshape((self.sm_split, -1, self.c1))(x)
|
| 54 |
+
# tensorflow really wants to be indented like this. I relent...
|
| 55 |
+
return Reshape((-1, 4))(
|
| 56 |
+
self.conv(
|
| 57 |
+
Concatenate(1)([
|
| 58 |
+
Softmax(-1)(x[:, i:i+1])
|
| 59 |
+
for i in range(x.shape[1])
|
| 60 |
+
])
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return Reshape((-1, 4)
|
| 65 |
+
)(self.conv(Softmax(-1)(Reshape((-1, 4, self.c1))(x))))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Proto(Torch_Proto):
|
| 69 |
+
def __init__(self, c1, c_=256, c2=32):
|
| 70 |
+
"""arguments understood as in_channels, number of protos, and
|
| 71 |
+
number of masks"""
|
| 72 |
+
super().__init__(c1, c_, c2)
|
| 73 |
+
self.cv1 = CoBNRLU(c1, c_, 3)
|
| 74 |
+
self.upsample = Upsample(scale_factor=2, mode='bilinear')
|
| 75 |
+
self.cv2 = CoBNRLU(c_, c_, 3)
|
| 76 |
+
self.cv3 = CoBNRLU(c_, c2, 1, name='proto')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def generate_anchors(H, W, stride, offset):
|
| 80 |
+
from tensorflow import meshgrid, range, stack, reshape, concat
|
| 81 |
+
from tensorflow.math import ceil
|
| 82 |
+
return concat([stack((reshape((sx + offset) * s, (-1,)),
|
| 83 |
+
reshape((sy + offset) * s, (-1,))),
|
| 84 |
+
-1)
|
| 85 |
+
for s, (sy, sx) in ((s.item(),
|
| 86 |
+
meshgrid(range(ceil(H/s)),
|
| 87 |
+
range(ceil(W/s)),
|
| 88 |
+
indexing="ij"))
|
| 89 |
+
for s in stride)],
|
| 90 |
+
-2)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class Detect(Torch_Detect):
|
| 94 |
+
def __init__(self, nc=80, ch=(), sm_split=None, junk=None):
|
| 95 |
+
super().__init__(nc, ch)
|
| 96 |
+
c2 = max((16, ch[0] // 4, self.reg_max * 4))
|
| 97 |
+
self.cv2 = ModuleList(Sequential(Conv2d(x, c2, 3, bias=True),
|
| 98 |
+
ReLU(6),
|
| 99 |
+
Conv2d(c2, c2, 3, bias=True),
|
| 100 |
+
ReLU(6),
|
| 101 |
+
Conv2d(c2, 4 * self.reg_max, 1,
|
| 102 |
+
bias=True))
|
| 103 |
+
for x in ch)
|
| 104 |
+
self.cv3 = ModuleList(Sequential(Conv2d(x, x, 3, bias=True),
|
| 105 |
+
ReLU(6),
|
| 106 |
+
Conv2d(x, x, 3, bias=True),
|
| 107 |
+
ReLU(6),
|
| 108 |
+
Conv2d(x, self.nc, 1, bias=True))
|
| 109 |
+
for x in ch)
|
| 110 |
+
if junk is None:
|
| 111 |
+
sm_split = None
|
| 112 |
+
self.dfl = DFL(sm_split=sm_split)
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
if askeras.use_keras:
|
| 116 |
+
return Detect.as_keras(self, x)
|
| 117 |
+
return super().forward(x)
|
| 118 |
+
|
| 119 |
+
def as_keras(self, x):
|
| 120 |
+
from tensorflow.keras.layers import Reshape
|
| 121 |
+
from tensorflow import stack
|
| 122 |
+
from tensorflow.keras.layers import (Concatenate, Subtract,
|
| 123 |
+
Add, Activation)
|
| 124 |
+
from tensorflow.keras.activations import sigmoid
|
| 125 |
+
ltrb = Concatenate(-2)([self.dfl(cv2(xi)) * s.item()
|
| 126 |
+
for cv2, xi, s in
|
| 127 |
+
zip(self.cv2, x, self.stride)])
|
| 128 |
+
H, W = askeras.kwds['imgsz']
|
| 129 |
+
anchors = generate_anchors(H, W, self.stride, .5) # Nx2
|
| 130 |
+
anchors = stack([anchors for batch in range(x[0].shape[0])]) # BxNx2
|
| 131 |
+
box1 = Subtract(name="box1")((anchors, ltrb[:, :, :2]))
|
| 132 |
+
box2 = Add(name="box2")((anchors, ltrb[:, :, 2:]))
|
| 133 |
+
if askeras.kwds.get("xywh"):
|
| 134 |
+
box1, box2 = (box1 + box2) / 2, box2 - box1
|
| 135 |
+
|
| 136 |
+
cls = Activation(sigmoid, name='cls')(
|
| 137 |
+
Concatenate(-2)([
|
| 138 |
+
Reshape((-1, self.nc))(cv3(xi))
|
| 139 |
+
for cv3, xi in zip(self.cv3, x)
|
| 140 |
+
])
|
| 141 |
+
)
|
| 142 |
+
out = [box1, box2, cls]
|
| 143 |
+
if askeras.kwds.get("quant_export"):
|
| 144 |
+
return out
|
| 145 |
+
# everything after here needs to be implemented by post-processing
|
| 146 |
+
out[:2] = (box/array((W, H)) for box in out[:2])
|
| 147 |
+
return Concatenate(-1)(out)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Pose(Torch_Pose, Detect):
|
| 151 |
+
def __init__(self, nc, kpt_shape, ch, sm_split=None, junk=None):
|
| 152 |
+
super().__init__(nc, kpt_shape, ch)
|
| 153 |
+
Detect.__init__(self, nc, ch, sm_split, junk=junk)
|
| 154 |
+
self.detect = Detect.forward
|
| 155 |
+
c4 = max(ch[0] // 4, self.nk)
|
| 156 |
+
self.cv4 = ModuleList(Sequential(Conv2d(x, c4, 3),
|
| 157 |
+
ReLU(6),
|
| 158 |
+
Conv2d(c4, c4, 3),
|
| 159 |
+
ReLU(6),
|
| 160 |
+
Conv2d(c4, self.nk, 1))
|
| 161 |
+
for x in ch)
|
| 162 |
+
|
| 163 |
+
def forward(self, *args, **kwds):
|
| 164 |
+
if askeras.use_keras:
|
| 165 |
+
return self.as_keras(*args, **kwds)
|
| 166 |
+
return super().forward(*args, **kwds)
|
| 167 |
+
|
| 168 |
+
def s(self, stride):
|
| 169 |
+
if self.kpt_shape[1] == 3:
|
| 170 |
+
from tensorflow import constant
|
| 171 |
+
return constant([stride, stride, 1]*self.kpt_shape[0])
|
| 172 |
+
return stride
|
| 173 |
+
|
| 174 |
+
def as_keras(self, x):
|
| 175 |
+
|
| 176 |
+
from tensorflow.keras.layers import Reshape, Concatenate, Add
|
| 177 |
+
from tensorflow import stack, reshape
|
| 178 |
+
from tensorflow.keras.activations import sigmoid
|
| 179 |
+
|
| 180 |
+
if self.kpt_shape[1] == 3:
|
| 181 |
+
presence_chans = [i*3+2 for i in range(17)]
|
| 182 |
+
pres, kpts = zip(*((Reshape((-1, self.kpt_shape[0], 1)
|
| 183 |
+
)(presence(xi)),
|
| 184 |
+
Reshape((-1, self.kpt_shape[0], 2)
|
| 185 |
+
)(keypoint(xi)*s*2))
|
| 186 |
+
for presence, keypoint, xi, s in
|
| 187 |
+
((*cv[-1].split_channels(presence_chans),
|
| 188 |
+
cv[:-1](xi), s.item())
|
| 189 |
+
for cv, xi, s in
|
| 190 |
+
zip(self.cv4, x, self.stride))))
|
| 191 |
+
pres = Concatenate(-3, name="pres")([sigmoid(p) for p in pres])
|
| 192 |
+
else:
|
| 193 |
+
kpts = [Reshape((-1, self.kpt_shape[0], 2))(cv(xi)*s*2)
|
| 194 |
+
for cv, xi, s in
|
| 195 |
+
zip(self.cv4, x, self.stride)]
|
| 196 |
+
|
| 197 |
+
H, W = askeras.kwds['imgsz']
|
| 198 |
+
anchors = generate_anchors(H, W, self.stride, offset=0) # Nx2
|
| 199 |
+
anchors = reshape(anchors, (-1, 1, 2)) # Nx1x2
|
| 200 |
+
anchors = stack([anchors for batch in range(x[0].shape[0])]) # BxNx1x2
|
| 201 |
+
kpts = Add(name='kpts')((Concatenate(-3)(kpts), anchors))
|
| 202 |
+
|
| 203 |
+
x = self.detect(self, x)
|
| 204 |
+
|
| 205 |
+
if askeras.kwds.get("quant_export"):
|
| 206 |
+
if self.kpt_shape[1] == 3:
|
| 207 |
+
return *x, kpts, pres
|
| 208 |
+
return *x, kpts
|
| 209 |
+
|
| 210 |
+
# everything after here needs to be implemented by post-processing
|
| 211 |
+
if self.kpt_shape[1] == 3:
|
| 212 |
+
kpts = Concatenate(-1)((kpts, pres))
|
| 213 |
+
|
| 214 |
+
return Concatenate(-1)((x, Reshape((-1, self.nk))(kpts)))
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class Segment(Torch_Segment, Detect):
|
| 218 |
+
"""YOLOv8 Segment head for segmentation models."""
|
| 219 |
+
|
| 220 |
+
def __init__(self, nc=80, nm=32, npr=256, ch=(), sm_split=None, junk=None):
|
| 221 |
+
super().__init__(nc, nm, npr, ch)
|
| 222 |
+
Detect.__init__(self, nc, ch, sm_split, junk=junk)
|
| 223 |
+
self.detect = Detect.forward
|
| 224 |
+
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
| 225 |
+
c4 = max(ch[0] // 4, self.nm)
|
| 226 |
+
self.cv4 = ModuleList(Sequential(CoBNRLU(x, c4, 3),
|
| 227 |
+
CoBNRLU(c4, c4, 3),
|
| 228 |
+
Conv2d(c4, self.nm, 1))
|
| 229 |
+
for x in ch)
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
if askeras.use_keras:
|
| 233 |
+
return self.as_keras(x)
|
| 234 |
+
return super().forward(x)
|
| 235 |
+
|
| 236 |
+
def as_keras(self, x):
|
| 237 |
+
from tensorflow.keras.layers import Reshape, Concatenate
|
| 238 |
+
p = self.proto(x[0])
|
| 239 |
+
mc = Concatenate(-2, name='seg')([Reshape((-1, self.nm))(cv4(xi))
|
| 240 |
+
for cv4, xi in zip(self.cv4, x)])
|
| 241 |
+
x = self.detect(self, x)
|
| 242 |
+
if askeras.kwds.get("quant_export"):
|
| 243 |
+
return *x, mc, p
|
| 244 |
+
# everything after here needs to be implemented by post-processing
|
| 245 |
+
return Concatenate(-1)((x, mc)), p
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class Classify(Torch_Classify):
|
| 249 |
+
def __init__(self, junk, c1, c2, k=1, s=1, p=None, g=1):
|
| 250 |
+
super().__init__(c1, c2, k=k, s=s, p=p, g=g)
|
| 251 |
+
c_ = 1280
|
| 252 |
+
assert p is None
|
| 253 |
+
self.conv = CoBNRLU(c1, c_, k, s, groups=g)
|
| 254 |
+
self.pool = GlobalAvgPool()
|
| 255 |
+
self.drop = Dropout(p=0.0, inplace=True)
|
| 256 |
+
self.linear = Linear(c_, c2)
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
if askeras.use_keras:
|
| 260 |
+
return self.as_keras(x)
|
| 261 |
+
return super().forward(x)
|
| 262 |
+
|
| 263 |
+
def as_keras(self, x):
|
| 264 |
+
from keras.layers import Concatenate, Flatten, Softmax
|
| 265 |
+
if isinstance(x, list):
|
| 266 |
+
x = Concatenate(-1)(x)
|
| 267 |
+
x = self.linear(self.drop(Flatten()(self.pool(self.conv(x)))))
|
| 268 |
+
return x if self.training else Softmax()(x)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Backend(BaseBackend):
|
| 272 |
+
|
| 273 |
+
models = {}
|
| 274 |
+
name = "ultralytics"
|
| 275 |
+
|
| 276 |
+
def get_model(self, model_path, full=False):
|
| 277 |
+
|
| 278 |
+
model_path = self.maybe_grab_from_zoo(model_path)
|
| 279 |
+
|
| 280 |
+
if model_path in self.models:
|
| 281 |
+
model = self.models[model_path]
|
| 282 |
+
else:
|
| 283 |
+
model = self.models[model_path] = YOLO(model_path)
|
| 284 |
+
|
| 285 |
+
if full:
|
| 286 |
+
return model
|
| 287 |
+
return model.model
|
| 288 |
+
|
| 289 |
+
def get_shape(self, model):
|
| 290 |
+
if isinstance(model, str):
|
| 291 |
+
model = self.get_model(model)
|
| 292 |
+
return model.yaml["image_shape"]
|
| 293 |
+
|
| 294 |
+
def patch(self, model_path=None):
|
| 295 |
+
for module in layers, asymmetric:
|
| 296 |
+
for name in dir(module):
|
| 297 |
+
if name[0] != "_":
|
| 298 |
+
setattr(tasks, name, getattr(module, name))
|
| 299 |
+
tasks.Concat = layers.Cat
|
| 300 |
+
tasks.Pose = Pose
|
| 301 |
+
tasks.Detect = Detect
|
| 302 |
+
tasks.Segment = Segment
|
| 303 |
+
tasks.Classify = Classify
|
| 304 |
+
orig_ddp_file = dist.generate_ddp_file
|
| 305 |
+
|
| 306 |
+
def generate_ddp_file(trainer):
|
| 307 |
+
fname = orig_ddp_file(trainer)
|
| 308 |
+
fstr = open(fname).read()
|
| 309 |
+
open(fname, 'w').write(f"""\
|
| 310 |
+
from synet.backends import get_backend
|
| 311 |
+
get_backend('ultralytics').patch()
|
| 312 |
+
{fstr}""")
|
| 313 |
+
return fname
|
| 314 |
+
dist.generate_ddp_file = generate_ddp_file
|
| 315 |
+
|
| 316 |
+
def tflite_check_imgsz(*args, **kwds):
|
| 317 |
+
kwds['stride'] = 1
|
| 318 |
+
return check_imgsz(*args, **kwds)
|
| 319 |
+
trainer.check_imgsz = tflite_check_imgsz
|
| 320 |
+
if model_path is not None and model_path.endswith('tflite'):
|
| 321 |
+
print('SyNet: model provided is tflite. Modifying validators'
|
| 322 |
+
' to anticipate tflite output')
|
| 323 |
+
task_map = yolo_model.YOLO(model_path).task_map
|
| 324 |
+
for task in task_map:
|
| 325 |
+
for mode in 'predictor', 'validator':
|
| 326 |
+
class Wrap(task_map[task][mode]):
|
| 327 |
+
def postprocess(self, preds, *args, **kwds):
|
| 328 |
+
# concate_reshape currently expect ndarry
|
| 329 |
+
# with batch size of 1, so remove and
|
| 330 |
+
# re-add batch and tensorship.
|
| 331 |
+
preds = concat_reshape([p[0].numpy()
|
| 332 |
+
for p in preds],
|
| 333 |
+
self.args.task,
|
| 334 |
+
classes_to_index=False,
|
| 335 |
+
xywh=True)
|
| 336 |
+
if isinstance(preds, tuple):
|
| 337 |
+
preds = (tensor(preds[0][None])
|
| 338 |
+
.permute(0, 2, 1),
|
| 339 |
+
tensor(preds[1][None])
|
| 340 |
+
.permute(0, 2, 3, 1))
|
| 341 |
+
else:
|
| 342 |
+
preds = tensor(preds[None]).permute(0, 2, 1)
|
| 343 |
+
return super().postprocess(preds, *args, **kwds)
|
| 344 |
+
if task != 'classify':
|
| 345 |
+
task_map[task][mode] = Wrap
|
| 346 |
+
yolo_model.YOLO.task_map = task_map
|
| 347 |
+
|
| 348 |
+
class TfliteAutoBackend(AutoBackend):
|
| 349 |
+
def __init__(self, *args, **kwds):
|
| 350 |
+
super().__init__(*args, **kwds)
|
| 351 |
+
self.output_details.sort(key=lambda x: x['name'])
|
| 352 |
+
if len(self.output_details) == 1: # classify
|
| 353 |
+
num_classes = self.output_details[0]['shape'][-1]
|
| 354 |
+
else:
|
| 355 |
+
num_classes = self.output_details[2]['shape'][2]
|
| 356 |
+
self.kpt_shape = (self.output_details[-1]['shape'][-2], 3)
|
| 357 |
+
self.names = {k: self.names[k] for k in range(num_classes)}
|
| 358 |
+
|
| 359 |
+
validator.check_imgsz = tflite_check_imgsz
|
| 360 |
+
predictor.check_imgsz = tflite_check_imgsz
|
| 361 |
+
validator.AutoBackend = TfliteAutoBackend
|
| 362 |
+
predictor.AutoBackend = TfliteAutoBackend
|
| 363 |
+
|
| 364 |
+
def get_data(self, data):
|
| 365 |
+
try:
|
| 366 |
+
return check_det_dataset(data)
|
| 367 |
+
except Exception as e:
|
| 368 |
+
try:
|
| 369 |
+
return check_cls_dataset(data)
|
| 370 |
+
except Exception as e2:
|
| 371 |
+
print("unable to load data as classification or detection dataset")
|
| 372 |
+
print(e2)
|
| 373 |
+
raise e
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def main():
|
| 377 |
+
|
| 378 |
+
backend = Backend()
|
| 379 |
+
|
| 380 |
+
# copy model from zoo if necessary
|
| 381 |
+
for ind, val in enumerate(argv):
|
| 382 |
+
if val.startswith("model="):
|
| 383 |
+
model = backend.maybe_grab_from_zoo(val.split("=")[1])
|
| 384 |
+
argv[ind] = "model="+model
|
| 385 |
+
|
| 386 |
+
# add synet ml modules to ultralytics
|
| 387 |
+
backend.patch(model_path=model)
|
| 388 |
+
|
| 389 |
+
# add imgsz if not explicitly given
|
| 390 |
+
for val in argv:
|
| 391 |
+
if val.startswith("imgsz="):
|
| 392 |
+
break
|
| 393 |
+
else:
|
| 394 |
+
argv.append(f"imgsz={max(backend.get_shape(model))}")
|
| 395 |
+
|
| 396 |
+
break
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# launch ultralytics
|
| 400 |
+
try:
|
| 401 |
+
from ultralytics.cfg import entrypoint
|
| 402 |
+
except:
|
| 403 |
+
from ultralytics.yolo.cfg import entrypoint
|
| 404 |
+
entrypoint()
|
synet/backends/yolov5.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import SimpleNamespace
|
| 2 |
+
from importlib import import_module
|
| 3 |
+
|
| 4 |
+
import numpy
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
from tensorflow.math import ceil
|
| 7 |
+
from torch import load, no_grad, tensor
|
| 8 |
+
from yolov5 import val
|
| 9 |
+
from yolov5.models import yolo, common
|
| 10 |
+
from yolov5.models.yolo import Detect as Yolo_PTDetect, Model
|
| 11 |
+
from yolov5.models.tf import TFDetect as Yolo_TFDetect
|
| 12 |
+
from yolov5.utils.general import non_max_suppression
|
| 13 |
+
from yolov5.val import (Path, Callbacks, create_dataloader,
|
| 14 |
+
select_device, DetectMultiBackend,
|
| 15 |
+
check_img_size, LOGGER, check_dataset, torch,
|
| 16 |
+
np, ConfusionMatrix, coco80_to_coco91_class,
|
| 17 |
+
Profile, tqdm, scale_boxes, xywh2xyxy,
|
| 18 |
+
output_to_target, ap_per_class, pd,
|
| 19 |
+
increment_path, os, colorstr, TQDM_BAR_FORMAT,
|
| 20 |
+
process_batch, plot_images, save_one_txt)
|
| 21 |
+
|
| 22 |
+
from .base import askeras
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_yolov5_model(model_path, low_thld=0, raw=False, **kwds):
|
| 26 |
+
"""Convenience function to load yolov5 model"""
|
| 27 |
+
if model_path.endswith(".yml") or model_path.endswith(".yaml"):
|
| 28 |
+
assert raw
|
| 29 |
+
return Model(model_path)
|
| 30 |
+
ckpt = load(model_path)
|
| 31 |
+
ckpt = ckpt['model'] if isinstance(ckpt, dict) else ckpt
|
| 32 |
+
raw_model = Model(ckpt.yaml)
|
| 33 |
+
raw_model.load_state_dict(ckpt.state_dict())
|
| 34 |
+
if raw:
|
| 35 |
+
return raw_model
|
| 36 |
+
raw_model.eval()
|
| 37 |
+
|
| 38 |
+
def model(x):
|
| 39 |
+
with no_grad():
|
| 40 |
+
xyxyoc = non_max_suppression(raw_model(tensor(x)
|
| 41 |
+
.unsqueeze(0).float()),
|
| 42 |
+
conf_thres=low_thld,
|
| 43 |
+
iou_thres=.3,
|
| 44 |
+
multi_label=True
|
| 45 |
+
)[0].numpy()
|
| 46 |
+
return xyxyoc[:, :4], xyxyoc[:, 4:].prod(1)
|
| 47 |
+
return model
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TFDetect(Yolo_TFDetect):
|
| 51 |
+
"""Modify Tensorflow Detect head to allow for arbitrary input
|
| 52 |
+
shape (need not be multiple of 32).
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# use orig __init__, but make nx, ny calculated via ceil div
|
| 57 |
+
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None):
|
| 58 |
+
super().__init__(nc, anchors, ch, imgsz, w)
|
| 59 |
+
for i in range(self.nl):
|
| 60 |
+
ny, nx = (ceil(self.imgsz[0] / self.stride[i]),
|
| 61 |
+
ceil(self.imgsz[1] / self.stride[i]))
|
| 62 |
+
self.grid[i] = self._make_grid(nx, ny)
|
| 63 |
+
|
| 64 |
+
# copy call method, but replace // with ceil div
|
| 65 |
+
def call(self, inputs):
|
| 66 |
+
z = [] # inference output
|
| 67 |
+
x = []
|
| 68 |
+
for i in range(self.nl):
|
| 69 |
+
x.append(self.m[i](inputs[i]))
|
| 70 |
+
# x(bs,20,20,255) to x(bs,3,20,20,85)
|
| 71 |
+
ny, nx = (ceil(self.imgsz[0] / self.stride[i]),
|
| 72 |
+
ceil(self.imgsz[1] / self.stride[i]))
|
| 73 |
+
x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
|
| 74 |
+
|
| 75 |
+
if not self.training: # inference
|
| 76 |
+
y = x[i]
|
| 77 |
+
grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
|
| 78 |
+
anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3])*4
|
| 79 |
+
xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i]
|
| 80 |
+
wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
|
| 81 |
+
# Normalize xywh to 0-1 to reduce calibration error
|
| 82 |
+
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]],
|
| 83 |
+
dtype=tf.float32)
|
| 84 |
+
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]],
|
| 85 |
+
dtype=tf.float32)
|
| 86 |
+
y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]),
|
| 87 |
+
y[..., 5 + self.nc:]], -1)
|
| 88 |
+
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
|
| 89 |
+
|
| 90 |
+
return tf.transpose(x, [0, 2, 1, 3]) \
|
| 91 |
+
if self.training else (tf.concat(z, 1),)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Detect(Yolo_PTDetect):
|
| 95 |
+
"""Make YOLOv5 Detect head compatible with synet tflite export"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, *args, **kwds):
|
| 98 |
+
# to account for args hack.
|
| 99 |
+
if len(args) == 4:
|
| 100 |
+
args = args[:3]
|
| 101 |
+
# construct normally
|
| 102 |
+
super().__init__(*args, **kwds)
|
| 103 |
+
# save args/kwargs for later construction of TF model
|
| 104 |
+
self.args = args
|
| 105 |
+
self.kwds = kwds
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
if askeras.use_keras:
|
| 109 |
+
return self.as_keras(x)
|
| 110 |
+
return super().forward(x)
|
| 111 |
+
|
| 112 |
+
def as_keras(self, x):
|
| 113 |
+
return TFDetect(*self.args, imgsz=askeras.kwds["imgsz"],
|
| 114 |
+
w=self, **self.kwds
|
| 115 |
+
)(x)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def val_run_tflite(
|
| 119 |
+
data,
|
| 120 |
+
weights=None, # model.pt path(s)
|
| 121 |
+
batch_size=None, # batch size
|
| 122 |
+
batch=None, # batch size
|
| 123 |
+
imgsz=None, # inference size (pixels)
|
| 124 |
+
img=None, # inference size (pixels)
|
| 125 |
+
conf_thres=0.001, # confidence threshold
|
| 126 |
+
iou_thres=0.6, # NMS IoU threshold
|
| 127 |
+
max_det=300, # maximum detections per image
|
| 128 |
+
task='val', # train, val, test, speed or study
|
| 129 |
+
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
| 130 |
+
workers=8, # max dataloader workers (per RANK in DDP mode)
|
| 131 |
+
single_cls=False, # treat as single-class dataset
|
| 132 |
+
augment=False, # augmented inference
|
| 133 |
+
verbose=False, # verbose output
|
| 134 |
+
save_txt=False, # save results to *.txt
|
| 135 |
+
save_hybrid=False, # save label+prediction hybrid results to *.txt
|
| 136 |
+
save_conf=False, # save confidences in --save-txt labels
|
| 137 |
+
save_json=False, # save a COCO-JSON results file
|
| 138 |
+
project='runs/val', # save to project/name
|
| 139 |
+
name='exp', # save to project/name
|
| 140 |
+
exist_ok=False, # existing project/name ok, do not increment
|
| 141 |
+
half=True, # use FP16 half-precision inference
|
| 142 |
+
dnn=False, # use OpenCV DNN for ONNX inference
|
| 143 |
+
model=None,
|
| 144 |
+
dataloader=None,
|
| 145 |
+
save_dir=Path(''),
|
| 146 |
+
plots=True,
|
| 147 |
+
callbacks=Callbacks(),
|
| 148 |
+
compute_loss=None,
|
| 149 |
+
):
|
| 150 |
+
|
| 151 |
+
if imgsz is None and img is None:
|
| 152 |
+
imgsz = 640
|
| 153 |
+
elif img is not None:
|
| 154 |
+
imgsz = img
|
| 155 |
+
if batch_size is None and batch is None:
|
| 156 |
+
batch_size = 32
|
| 157 |
+
elif batch is not None:
|
| 158 |
+
batch_size = batch
|
| 159 |
+
|
| 160 |
+
# Initialize/load model and set device
|
| 161 |
+
training = model is not None
|
| 162 |
+
if training: # called by train.py
|
| 163 |
+
device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
|
| 164 |
+
half &= device.type != 'cpu' # half precision only supported on CUDA
|
| 165 |
+
model.half() if half else model.float()
|
| 166 |
+
|
| 167 |
+
# SYNET MODIFICATION: never train tflite
|
| 168 |
+
tflite = False
|
| 169 |
+
|
| 170 |
+
else: # called directly
|
| 171 |
+
device = select_device(device, batch_size=batch_size)
|
| 172 |
+
half &= device.type != 'cpu' # half precision only supported on CUDA, dont remove!
|
| 173 |
+
|
| 174 |
+
# Directories
|
| 175 |
+
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
|
| 176 |
+
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
| 177 |
+
|
| 178 |
+
# Load model
|
| 179 |
+
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
|
| 180 |
+
|
| 181 |
+
# SYNET MODIFICATION: check for tflite
|
| 182 |
+
tflite = hasattr(model, "interpreter")
|
| 183 |
+
|
| 184 |
+
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
| 185 |
+
|
| 186 |
+
# SYNET MODIFICATION: if tflite, use that shape
|
| 187 |
+
if tflite:
|
| 188 |
+
sn = model.input_details[0]['shape']
|
| 189 |
+
imgsz = int(max(sn[2], sn[1]))
|
| 190 |
+
|
| 191 |
+
if not isinstance(imgsz, (list, tuple)):
|
| 192 |
+
imgsz = check_img_size(imgsz, s=stride) # check image size
|
| 193 |
+
half = model.fp16 # FP16 supported on limited backends with CUDA
|
| 194 |
+
if engine:
|
| 195 |
+
batch_size = model.batch_size
|
| 196 |
+
else:
|
| 197 |
+
device = model.device
|
| 198 |
+
if not (pt or jit):
|
| 199 |
+
batch_size = 1 # export.py models default to batch-size 1
|
| 200 |
+
LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
| 201 |
+
|
| 202 |
+
# Data
|
| 203 |
+
data = check_dataset(data) # check
|
| 204 |
+
|
| 205 |
+
# Configure
|
| 206 |
+
model.eval()
|
| 207 |
+
cuda = device.type != 'cpu' # half precision only supported on CUDA, dont remove!
|
| 208 |
+
is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'coco{os.sep}val2017.txt') # COCO dataset
|
| 209 |
+
nc = 1 if single_cls else int(data['nc']) # number of classes
|
| 210 |
+
iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95
|
| 211 |
+
niou = iouv.numel()
|
| 212 |
+
|
| 213 |
+
# Dataloader
|
| 214 |
+
if not training:
|
| 215 |
+
if pt and not single_cls: # check --weights are trained on --data
|
| 216 |
+
ncm = model.model.nc
|
| 217 |
+
assert ncm == nc, f'{weights} ({ncm} classes) trained on different --data than what you passed ({nc} ' \
|
| 218 |
+
f'classes). Pass correct combination of --weights and --data that are trained together.'
|
| 219 |
+
if not isinstance(imgsz, (list, tuple)):
|
| 220 |
+
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup
|
| 221 |
+
|
| 222 |
+
pad, rect = (0.0, False) if task == 'speed' else (0.5, pt) # square inference for benchmarks
|
| 223 |
+
|
| 224 |
+
# SYNET MODIFICATION: if tflite, use rect with no padding
|
| 225 |
+
if tflite:
|
| 226 |
+
pad, rect = 0.0, True
|
| 227 |
+
stride = np.gcd(sn[2], sn[1])
|
| 228 |
+
|
| 229 |
+
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
| 230 |
+
dataloader = create_dataloader(data[task],
|
| 231 |
+
imgsz,
|
| 232 |
+
batch_size,
|
| 233 |
+
stride,
|
| 234 |
+
single_cls,
|
| 235 |
+
pad=pad,
|
| 236 |
+
rect=rect,
|
| 237 |
+
workers=workers,
|
| 238 |
+
prefix=colorstr(f'{task}: '))[0]
|
| 239 |
+
|
| 240 |
+
seen = 0
|
| 241 |
+
confusion_matrix = ConfusionMatrix(nc=nc)
|
| 242 |
+
names = model.names if hasattr(model, 'names') else model.module.names # get class names
|
| 243 |
+
if isinstance(names, (list, tuple)): # old format
|
| 244 |
+
names = dict(enumerate(names))
|
| 245 |
+
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
|
| 246 |
+
s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95')
|
| 247 |
+
tp, fp, p, r, f1, mp, mr, map50, ap50, map = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
| 248 |
+
dt = Profile(), Profile(), Profile() # profiling times
|
| 249 |
+
loss = torch.zeros(3, device=device)
|
| 250 |
+
jdict, stats, ap, ap_class = [], [], [], []
|
| 251 |
+
callbacks.run('on_val_start')
|
| 252 |
+
pbar = tqdm(dataloader, desc=s, bar_format=TQDM_BAR_FORMAT) # progress bar
|
| 253 |
+
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
|
| 254 |
+
callbacks.run('on_val_batch_start')
|
| 255 |
+
with dt[0]:
|
| 256 |
+
if cuda:
|
| 257 |
+
im = im.to(device, non_blocking=True)
|
| 258 |
+
targets = targets.to(device)
|
| 259 |
+
im = im.half() if half else im.float() # uint8 to fp16/32
|
| 260 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
| 261 |
+
nb, _, height, width = im.shape # batch size, channels, height, width
|
| 262 |
+
|
| 263 |
+
# SYNET MODIFICATION: if tflite, make grayscale
|
| 264 |
+
if tflite:
|
| 265 |
+
im = im.mean(1, keepdims=True)
|
| 266 |
+
|
| 267 |
+
# Inference
|
| 268 |
+
with dt[1]:
|
| 269 |
+
preds, train_out = model(im) if compute_loss else (model(im, augment=augment), None)
|
| 270 |
+
|
| 271 |
+
# Loss
|
| 272 |
+
if compute_loss:
|
| 273 |
+
loss += compute_loss(train_out, targets)[1] # box, obj, cls
|
| 274 |
+
|
| 275 |
+
# NMS
|
| 276 |
+
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
|
| 277 |
+
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
|
| 278 |
+
with dt[2]:
|
| 279 |
+
preds = non_max_suppression(preds,
|
| 280 |
+
conf_thres,
|
| 281 |
+
iou_thres,
|
| 282 |
+
labels=lb,
|
| 283 |
+
multi_label=True,
|
| 284 |
+
agnostic=single_cls,
|
| 285 |
+
max_det=max_det)
|
| 286 |
+
|
| 287 |
+
# Metrics
|
| 288 |
+
for si, pred in enumerate(preds):
|
| 289 |
+
labels = targets[targets[:, 0] == si, 1:]
|
| 290 |
+
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
| 291 |
+
path, shape = Path(paths[si]), shapes[si][0]
|
| 292 |
+
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
|
| 293 |
+
seen += 1
|
| 294 |
+
|
| 295 |
+
if npr == 0:
|
| 296 |
+
if nl:
|
| 297 |
+
stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
|
| 298 |
+
if plots:
|
| 299 |
+
confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
# Predictions
|
| 303 |
+
if single_cls:
|
| 304 |
+
pred[:, 5] = 0
|
| 305 |
+
predn = pred.clone()
|
| 306 |
+
scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
|
| 307 |
+
|
| 308 |
+
# Evaluate
|
| 309 |
+
if nl:
|
| 310 |
+
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
|
| 311 |
+
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
|
| 312 |
+
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
| 313 |
+
correct = process_batch(predn, labelsn, iouv)
|
| 314 |
+
if plots:
|
| 315 |
+
confusion_matrix.process_batch(predn, labelsn)
|
| 316 |
+
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
|
| 317 |
+
|
| 318 |
+
# Save/log
|
| 319 |
+
if save_txt:
|
| 320 |
+
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
| 321 |
+
if save_json:
|
| 322 |
+
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
| 323 |
+
callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
|
| 324 |
+
|
| 325 |
+
# Plot images
|
| 326 |
+
if plots and batch_i < 3:
|
| 327 |
+
plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
|
| 328 |
+
plot_images(im, output_to_target(preds), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
|
| 329 |
+
|
| 330 |
+
callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, preds)
|
| 331 |
+
|
| 332 |
+
# Compute metrics
|
| 333 |
+
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
| 334 |
+
if len(stats) and stats[0].any():
|
| 335 |
+
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
|
| 336 |
+
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
| 337 |
+
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
| 338 |
+
nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class
|
| 339 |
+
|
| 340 |
+
# Print results
|
| 341 |
+
pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format
|
| 342 |
+
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
|
| 343 |
+
if nt.sum() == 0:
|
| 344 |
+
LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')
|
| 345 |
+
|
| 346 |
+
# Print results per class
|
| 347 |
+
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
|
| 348 |
+
for i, c in enumerate(ap_class):
|
| 349 |
+
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
|
| 350 |
+
|
| 351 |
+
# Export results as html
|
| 352 |
+
header = "Class Images Labels P R mAP@.5 mAP@.5:.95"
|
| 353 |
+
headers = header.split()
|
| 354 |
+
data = []
|
| 355 |
+
data.append(['all', seen, nt.sum(), f"{float(mp):0.3f}", f"{float(mr):0.3f}", f"{float(map50):0.3f}", f"{float(map):0.3f}"])
|
| 356 |
+
for i, c in enumerate(ap_class):
|
| 357 |
+
data.append([names[c], seen, nt[c], f"{float(p[i]):0.3f}", f"{float(r[i]):0.3f}", f"{float(ap50[i]):0.3f}", f"{float(ap[i]):0.3f}"])
|
| 358 |
+
results_df = pd.DataFrame(data,columns=headers)
|
| 359 |
+
results_html = results_df.to_html()
|
| 360 |
+
text_file = open(save_dir / "results.html", "w")
|
| 361 |
+
text_file.write(results_html)
|
| 362 |
+
text_file.close()
|
| 363 |
+
|
| 364 |
+
# Print speeds
|
| 365 |
+
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
|
| 366 |
+
if not training:
|
| 367 |
+
if isinstance(imgsz, (list, tuple)):
|
| 368 |
+
shape = (batch_size, 3, *imgsz)
|
| 369 |
+
else:
|
| 370 |
+
shape = (batch_size, 3, imgsz, imgsz)
|
| 371 |
+
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
|
| 372 |
+
|
| 373 |
+
# Plots
|
| 374 |
+
if plots:
|
| 375 |
+
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
| 376 |
+
callbacks.run('on_val_end', nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix)
|
| 377 |
+
|
| 378 |
+
# Save JSON
|
| 379 |
+
if save_json and len(jdict):
|
| 380 |
+
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
|
| 381 |
+
anno_json = str(Path('../datasets/coco/annotations/instances_val2017.json')) # annotations
|
| 382 |
+
pred_json = str(save_dir / f"{w}_predictions.json") # predictions
|
| 383 |
+
LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
|
| 384 |
+
with open(pred_json, 'w') as f:
|
| 385 |
+
json.dump(jdict, f)
|
| 386 |
+
|
| 387 |
+
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
| 388 |
+
check_requirements('pycocotools>=2.0.6')
|
| 389 |
+
from pycocotools.coco import COCO
|
| 390 |
+
from pycocotools.cocoeval import COCOeval
|
| 391 |
+
|
| 392 |
+
anno = COCO(anno_json) # init annotations api
|
| 393 |
+
pred = anno.loadRes(pred_json) # init predictions api
|
| 394 |
+
eval = COCOeval(anno, pred, 'bbox')
|
| 395 |
+
if is_coco:
|
| 396 |
+
eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # image IDs to evaluate
|
| 397 |
+
eval.evaluate()
|
| 398 |
+
eval.accumulate()
|
| 399 |
+
eval.summarize()
|
| 400 |
+
map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
|
| 401 |
+
except Exception as e:
|
| 402 |
+
LOGGER.info(f'pycocotools unable to run: {e}')
|
| 403 |
+
|
| 404 |
+
# Return results
|
| 405 |
+
model.float() # for training
|
| 406 |
+
if not training:
|
| 407 |
+
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
| 408 |
+
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
|
| 409 |
+
maps = np.zeros(nc) + map
|
| 410 |
+
for i, c in enumerate(ap_class):
|
| 411 |
+
maps[c] = ap[i]
|
| 412 |
+
map50s = np.zeros(nc) + map50
|
| 413 |
+
for i, c in enumerate(ap_class):
|
| 414 |
+
map50s[c] = ap50[i]
|
| 415 |
+
return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, map50s, t
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def patch_yolov5(chip=None):
|
| 419 |
+
"""Apply modifications to YOLOv5 for synet"""
|
| 420 |
+
|
| 421 |
+
# enable the chip if given
|
| 422 |
+
if chip is not None:
|
| 423 |
+
module = import_module(f"..{chip}", __name__)
|
| 424 |
+
setattr(yolo, chip, module)
|
| 425 |
+
yolo.Concat = module.Cat
|
| 426 |
+
yolo.Detect = module.Detect = Detect
|
| 427 |
+
|
| 428 |
+
# use modified val run function for tflites
|
| 429 |
+
val.run = val_run_tflite
|
| 430 |
+
|
| 431 |
+
# yolo uses uint8. Change to int8
|
| 432 |
+
common.np = SimpleNamespace(**vars(numpy))
|
| 433 |
+
common.np.uint8 = common.np.int8
|
| 434 |
+
|
| 435 |
+
import synet
|
| 436 |
+
synet.get_model_backend = get_yolov5_model
|
synet/base.py
ADDED
|
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""base.py is the "export" layer of synet. As such, it includes the
|
| 2 |
+
logic of how to run as a keras model. This is handled by cheking the
|
| 3 |
+
'askeras' context manager, and running in "keras mode" if that context
|
| 4 |
+
is enabled. As a rule of thumb to differentiate between base.py,
|
| 5 |
+
layers.py:
|
| 6 |
+
|
| 7 |
+
- base.py should only import from torch, keras, and tensorflow.
|
| 8 |
+
- layers.py should only import from base.py.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Tuple, Union, Optional, List
|
| 12 |
+
from torch import cat as torch_cat, minimum, tensor, no_grad, empty
|
| 13 |
+
from torch.nn import (Module as Torch_Module,
|
| 14 |
+
Conv2d as Torch_Conv2d,
|
| 15 |
+
BatchNorm2d as Torch_Batchnorm,
|
| 16 |
+
ModuleList,
|
| 17 |
+
ReLU as Torch_ReLU,
|
| 18 |
+
ConvTranspose2d as Torch_ConvTranspose2d,
|
| 19 |
+
Upsample as Torch_Upsample,
|
| 20 |
+
AdaptiveAvgPool2d as Torch_AdaptiveAvgPool,
|
| 21 |
+
Dropout as Torch_Dropout,
|
| 22 |
+
Linear as Torch_Linear)
|
| 23 |
+
from torch.nn.functional import pad
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AsKeras:
|
| 29 |
+
"""AsKeras is a context manager used to export from pytorch to
|
| 30 |
+
keras. See test.py and quantize.py for examples.
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.use_keras = False
|
| 36 |
+
self.kwds = dict(train=False)
|
| 37 |
+
|
| 38 |
+
def __call__(self, **kwds):
|
| 39 |
+
self.kwds.update(kwds)
|
| 40 |
+
return self
|
| 41 |
+
|
| 42 |
+
def __enter__(self):
|
| 43 |
+
self.use_keras = True
|
| 44 |
+
|
| 45 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
| 46 |
+
self.__init__()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
askeras = AsKeras()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Module(Torch_Module):
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
if askeras.use_keras and hasattr(self, 'as_keras'):
|
| 55 |
+
return self.as_keras(x)
|
| 56 |
+
return self.module(x)
|
| 57 |
+
|
| 58 |
+
def __getattr__(self, name):
|
| 59 |
+
try:
|
| 60 |
+
return super().__getattr__(name)
|
| 61 |
+
except AttributeError as e:
|
| 62 |
+
if name == 'module':
|
| 63 |
+
raise e
|
| 64 |
+
return getattr(self.module, name)
|
| 65 |
+
|
| 66 |
+
def to_keras(self, imgsz, in_channels=1, batch_size=1, **kwds):
|
| 67 |
+
from keras import Input, Model
|
| 68 |
+
inp = Input(list(imgsz) + [in_channels], batch_size=batch_size)
|
| 69 |
+
with askeras(imgsz=imgsz, **kwds):
|
| 70 |
+
return Model(inp, self(inp))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Conv2d(Module):
|
| 74 |
+
"""Convolution operator which ensures padding is done equivalently
|
| 75 |
+
between PyTorch and TensorFlow.
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self,
|
| 80 |
+
in_channels: int,
|
| 81 |
+
out_channels: int,
|
| 82 |
+
kernel_size: Union[int, Tuple[int, int]],
|
| 83 |
+
stride: int = 1,
|
| 84 |
+
bias: bool = False,
|
| 85 |
+
padding: Optional[bool] = True,
|
| 86 |
+
groups: Optional[int] = 1):
|
| 87 |
+
"""
|
| 88 |
+
Implementation of torch Conv2D with option fot supporting keras
|
| 89 |
+
inference
|
| 90 |
+
:param in_channels: Number of channels in the input
|
| 91 |
+
:param out_channels: Number of channels produced by the convolution
|
| 92 |
+
:param kernel_size: Size of the kernel
|
| 93 |
+
:param stride:
|
| 94 |
+
:param bias:
|
| 95 |
+
:param groups: using for pointwise/depthwise
|
| 96 |
+
"""
|
| 97 |
+
super().__init__()
|
| 98 |
+
if isinstance(kernel_size, int):
|
| 99 |
+
kernel_size = (kernel_size, kernel_size)
|
| 100 |
+
self.in_channels = in_channels
|
| 101 |
+
self.out_channels = out_channels
|
| 102 |
+
self.kernel_size = kernel_size
|
| 103 |
+
self.stride = stride
|
| 104 |
+
self.padding = "same" if padding else 'valid'
|
| 105 |
+
self.groups = groups
|
| 106 |
+
self.conv = Torch_Conv2d(in_channels=in_channels,
|
| 107 |
+
out_channels=out_channels,
|
| 108 |
+
kernel_size=kernel_size,
|
| 109 |
+
bias=bias,
|
| 110 |
+
stride=stride,
|
| 111 |
+
groups=self.groups)
|
| 112 |
+
self.use_bias = bias
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
|
| 116 |
+
# temporary code for backwards compatibility
|
| 117 |
+
if not hasattr(self, 'padding'):
|
| 118 |
+
self.padding = 'same'
|
| 119 |
+
if not hasattr(self, 'groups'):
|
| 120 |
+
self.groups = 1
|
| 121 |
+
if not isinstance(self.padding, str):
|
| 122 |
+
self.padding = "same" if self.padding else 'valid'
|
| 123 |
+
|
| 124 |
+
if askeras.use_keras:
|
| 125 |
+
return self.as_keras(x)
|
| 126 |
+
|
| 127 |
+
if self.padding == "valid":
|
| 128 |
+
return self.conv(x)
|
| 129 |
+
|
| 130 |
+
# make padding like in tensorflow, which right aligns convolutionn.
|
| 131 |
+
H, W = (s if isinstance(s, int) else s.item() for s in x.shape[-2:])
|
| 132 |
+
# radius of the kernel and carry. Border size + carry. All in y
|
| 133 |
+
ry, rcy = divmod(self.kernel_size[0] - 1, 2)
|
| 134 |
+
by, bcy = divmod((H - 1) % self.stride - rcy, 2)
|
| 135 |
+
# radius of the kernel and carry. Border size + carry. All in x
|
| 136 |
+
rx, rcx = divmod(self.kernel_size[1] - 1, 2)
|
| 137 |
+
bx, bcx = divmod((W - 1) % self.stride - rcx, 2)
|
| 138 |
+
# apply pad
|
| 139 |
+
return self.conv(
|
| 140 |
+
pad(x, (rx - bx - bcx, rx - bx, ry - by - bcy, ry - by)))
|
| 141 |
+
|
| 142 |
+
def as_keras(self, x):
|
| 143 |
+
if askeras.kwds.get('demosaic'):
|
| 144 |
+
from .demosaic import Demosaic, reshape_conv
|
| 145 |
+
demosaic = Demosaic(*askeras.kwds['demosaic'].split('-'))
|
| 146 |
+
del askeras.kwds['demosaic']
|
| 147 |
+
return reshape_conv(self)(demosaic(x))
|
| 148 |
+
from keras.layers import Conv2D as Keras_Conv2d
|
| 149 |
+
assert x.shape[-1] == self.in_channels, (x.shape, self.in_channels)
|
| 150 |
+
conv = Keras_Conv2d(filters=self.out_channels,
|
| 151 |
+
kernel_size=self.kernel_size,
|
| 152 |
+
strides=self.stride,
|
| 153 |
+
padding=self.padding,
|
| 154 |
+
use_bias=self.use_bias,
|
| 155 |
+
groups=self.groups)
|
| 156 |
+
conv.build(x.shape)
|
| 157 |
+
if isinstance(self.conv, Torch_Conv2d):
|
| 158 |
+
tconv = self.conv
|
| 159 |
+
else:
|
| 160 |
+
# for NNI compatibility
|
| 161 |
+
tconv = self.conv.module
|
| 162 |
+
weight = tconv.weight.detach().numpy().transpose(2, 3, 1, 0)
|
| 163 |
+
conv.set_weights([weight, tconv.bias.detach().numpy()]
|
| 164 |
+
if self.use_bias else
|
| 165 |
+
[weight])
|
| 166 |
+
return conv(x)
|
| 167 |
+
|
| 168 |
+
def requires_grad_(self, val):
|
| 169 |
+
self.conv = self.conv.requires_grad_(val)
|
| 170 |
+
return self
|
| 171 |
+
|
| 172 |
+
def __getattr__(self, name):
|
| 173 |
+
if name in ("bias", "weight"):
|
| 174 |
+
return getattr(self.conv, name)
|
| 175 |
+
return super().__getattr__(name)
|
| 176 |
+
|
| 177 |
+
def __setattr__(self, name, value):
|
| 178 |
+
if name in ("bias", "weight"):
|
| 179 |
+
return setattr(self.conv, name, value)
|
| 180 |
+
return super().__setattr__(name, value)
|
| 181 |
+
|
| 182 |
+
def split_channels(self, chans):
|
| 183 |
+
|
| 184 |
+
with no_grad():
|
| 185 |
+
split = Conv2d(self.in_channels, len(chans),
|
| 186 |
+
self.kernel_size, self.stride, self.use_bias)
|
| 187 |
+
split.weight[:] = self.weight[chans]
|
| 188 |
+
|
| 189 |
+
rest_chans = [i for i in range(self.out_channels)
|
| 190 |
+
if i not in chans]
|
| 191 |
+
rest = Conv2d(self.in_channels, self.out_channels - len(chans),
|
| 192 |
+
self.kernel_size, self.stride, self.use_bias)
|
| 193 |
+
rest.weight[:] = self.weight[rest_chans]
|
| 194 |
+
|
| 195 |
+
if self.use_bias:
|
| 196 |
+
split.bias[:] = self.bias[chans]
|
| 197 |
+
rest.bias[:] = self.bias[rest_chans]
|
| 198 |
+
|
| 199 |
+
return split, rest
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# don't try to move this assignment into class def. It won't work.
|
| 203 |
+
# This is for compatibility with NNI so it does not treat this like a
|
| 204 |
+
# pytorch conv2d, and instead finds the nested conv2d.
|
| 205 |
+
Conv2d.__name__ = "Synet_Conv2d"
|
| 206 |
+
|
| 207 |
+
class DepthwiseConv2d(Conv2d):
|
| 208 |
+
"""DepthwiseConv2d operator implemented as a group convolution for
|
| 209 |
+
pytorch and DepthwiseConv2d operator for keras
|
| 210 |
+
"""
|
| 211 |
+
def as_keras(self, x):
|
| 212 |
+
if askeras.kwds.get('demosaic'):
|
| 213 |
+
from .demosaic import Demosaic, reshape_conv
|
| 214 |
+
demosaic = Demosaic(*askeras.kwds['demosaic'].split('-'))
|
| 215 |
+
del askeras.kwds['demosaic']
|
| 216 |
+
return reshape_conv(self)(demosaic(x))
|
| 217 |
+
from keras.layers import DepthwiseConv2D as Keras_DWConv2d
|
| 218 |
+
assert x.shape[-1] == self.in_channels, (x.shape, self.in_channels)
|
| 219 |
+
conv = Keras_DWConv2d(kernel_size=self.kernel_size,
|
| 220 |
+
strides=self.stride,
|
| 221 |
+
padding=self.padding,
|
| 222 |
+
use_bias=self.use_bias)
|
| 223 |
+
conv.build(x.shape)
|
| 224 |
+
if isinstance(self.conv, Torch_Conv2d):
|
| 225 |
+
tconv = self.conv
|
| 226 |
+
else:
|
| 227 |
+
# for NNI compatibility
|
| 228 |
+
tconv = self.conv.module
|
| 229 |
+
weight = tconv.weight.detach().numpy().transpose(2, 3, 0, 1)
|
| 230 |
+
conv.set_weights([weight, tconv.bias.detach().numpy()]
|
| 231 |
+
if self.use_bias else
|
| 232 |
+
[weight])
|
| 233 |
+
return conv(x)
|
| 234 |
+
|
| 235 |
+
DepthwiseConv2d.__name__ = "Synet_DepthwiseConv2d"
|
| 236 |
+
|
| 237 |
+
class ConvTranspose2d(Module):
|
| 238 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
| 239 |
+
bias=False):
|
| 240 |
+
print("WARNING: synet ConvTranspose2d mostly untested")
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.in_channels = in_channels
|
| 243 |
+
self.out_channels = out_channels
|
| 244 |
+
self.kernel_size = kernel_size
|
| 245 |
+
self.stride = stride
|
| 246 |
+
self.padding = "valid" if padding == 0 else "same"
|
| 247 |
+
self.use_bias = bias
|
| 248 |
+
self.module = Torch_ConvTranspose2d(in_channels, out_channels,
|
| 249 |
+
kernel_size, stride,
|
| 250 |
+
padding, bias=bias)
|
| 251 |
+
|
| 252 |
+
def as_keras(self, x):
|
| 253 |
+
from keras.layers import Conv2DTranspose as Keras_ConvTrans
|
| 254 |
+
conv = Keras_ConvTrans(self.out_channels, self.kernel_size, self.stride,
|
| 255 |
+
self.padding, use_bias=self.use_bias)
|
| 256 |
+
conv.build(x.shape)
|
| 257 |
+
if isinstance(self.module, Torch_ConvTranspose2d):
|
| 258 |
+
tconv = self.module
|
| 259 |
+
else:
|
| 260 |
+
# for NNI compatibility
|
| 261 |
+
tconv = self.module.module
|
| 262 |
+
weight = tconv.weight.detach().numpy().transpose(2, 3, 1, 0)
|
| 263 |
+
conv.set_weights([weight, tconv.bias.detach().numpy()]
|
| 264 |
+
if self.use_bias else
|
| 265 |
+
[weight])
|
| 266 |
+
return conv(x)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class Cat(Module):
|
| 270 |
+
"""Concatenate along feature dimension."""
|
| 271 |
+
|
| 272 |
+
def __init__(self, *args):
|
| 273 |
+
super().__init__()
|
| 274 |
+
|
| 275 |
+
def forward(self, xs):
|
| 276 |
+
if askeras.use_keras:
|
| 277 |
+
return self.as_keras(xs)
|
| 278 |
+
return torch_cat(xs, dim=1)
|
| 279 |
+
|
| 280 |
+
def as_keras(self, xs):
|
| 281 |
+
assert all(len(x.shape) == 4 for x in xs)
|
| 282 |
+
from keras.layers import Concatenate as Keras_Concatenate
|
| 283 |
+
return Keras_Concatenate(-1)(xs)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class ReLU(Module):
|
| 287 |
+
def __init__(self, max_val=None, name=None):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.max_val = None if max_val is None else tensor(max_val,
|
| 290 |
+
dtype=float)
|
| 291 |
+
self.name = name
|
| 292 |
+
self.relu = Torch_ReLU()
|
| 293 |
+
|
| 294 |
+
def forward(self, x):
|
| 295 |
+
if askeras.use_keras:
|
| 296 |
+
return self.as_keras(x)
|
| 297 |
+
if self.max_val is None:
|
| 298 |
+
return self.relu(x)
|
| 299 |
+
return minimum(self.relu(x), self.max_val)
|
| 300 |
+
|
| 301 |
+
def as_keras(self, x):
|
| 302 |
+
# temporary code for backwards compatibility
|
| 303 |
+
if not hasattr(self, 'name'):
|
| 304 |
+
self.name = None
|
| 305 |
+
from keras.layers import ReLU as Keras_ReLU
|
| 306 |
+
return Keras_ReLU(self.max_val, name=self.name)(x)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class BatchNorm(Module):
|
| 310 |
+
def __init__(self, features, epsilon=1e-3, momentum=0.999):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.epsilon = epsilon
|
| 313 |
+
self.momentum = momentum
|
| 314 |
+
self.module = Torch_Batchnorm(features, epsilon, momentum)
|
| 315 |
+
|
| 316 |
+
def forward(self, x):
|
| 317 |
+
# temporary code for backwards compatibility
|
| 318 |
+
if hasattr(self, 'batchnorm'):
|
| 319 |
+
self.module = self.batchnorm
|
| 320 |
+
return super().forward(x)
|
| 321 |
+
|
| 322 |
+
def as_keras(self, x):
|
| 323 |
+
|
| 324 |
+
from keras.layers import BatchNormalization as Keras_Batchnorm
|
| 325 |
+
batchnorm = Keras_Batchnorm(momentum=self.momentum,
|
| 326 |
+
epsilon=self.epsilon)
|
| 327 |
+
batchnorm.build(x.shape)
|
| 328 |
+
if isinstance(self.module, Torch_Batchnorm):
|
| 329 |
+
bn = self.module
|
| 330 |
+
else:
|
| 331 |
+
bn = self.module.module
|
| 332 |
+
weights = bn.weight.detach().numpy()
|
| 333 |
+
bias = bn.bias.detach().numpy()
|
| 334 |
+
running_mean = bn.running_mean.detach().numpy()
|
| 335 |
+
running_var = bn.running_var.detach().numpy()
|
| 336 |
+
batchnorm.set_weights([weights, bias, running_mean, running_var])
|
| 337 |
+
return batchnorm(x, training=askeras.kwds["train"])
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class Upsample(Module):
|
| 341 |
+
allowed_modes = "bilinear", "nearest"
|
| 342 |
+
|
| 343 |
+
def __init__(self, scale_factor, mode="nearest"):
|
| 344 |
+
assert mode in self.allowed_modes
|
| 345 |
+
if not isinstance(scale_factor, int):
|
| 346 |
+
for sf in scale_factor:
|
| 347 |
+
assert isinstance(sf, int)
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.scale_factor = scale_factor
|
| 350 |
+
self.mode = mode
|
| 351 |
+
self.module = Torch_Upsample(scale_factor=scale_factor, mode=mode)
|
| 352 |
+
|
| 353 |
+
def forward(self, x):
|
| 354 |
+
# temporary code for backwards compatibility
|
| 355 |
+
if not hasattr(self, 'module'):
|
| 356 |
+
self.module = self.upsample
|
| 357 |
+
return super().forward(x)
|
| 358 |
+
|
| 359 |
+
def as_keras(self, x):
|
| 360 |
+
from keras.layers import UpSampling2D
|
| 361 |
+
return UpSampling2D(size=self.scale_factor,
|
| 362 |
+
interpolation=self.mode,
|
| 363 |
+
)(x)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class Sequential(Module):
|
| 367 |
+
def __init__(self, *sequence):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.ml = ModuleList(sequence)
|
| 370 |
+
|
| 371 |
+
def forward(self, x):
|
| 372 |
+
for layer in self.ml:
|
| 373 |
+
x = layer(x)
|
| 374 |
+
return x
|
| 375 |
+
|
| 376 |
+
def __getitem__(self, i):
|
| 377 |
+
if isinstance(i, int):
|
| 378 |
+
return self.ml[i]
|
| 379 |
+
return Sequential(*self.ml[i])
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class GlobalAvgPool(Module):
|
| 383 |
+
def __init__(self):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.module = Torch_AdaptiveAvgPool(1)
|
| 386 |
+
|
| 387 |
+
def as_keras(self, x):
|
| 388 |
+
from keras.layers import GlobalAveragePooling2D
|
| 389 |
+
return GlobalAveragePooling2D(keepdims=True)(x)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class Dropout(Module):
|
| 393 |
+
def __init__(self, p=0, inplace=False):
|
| 394 |
+
super().__init__()
|
| 395 |
+
self.p = p
|
| 396 |
+
self.module = Torch_Dropout(p, inplace=inplace)
|
| 397 |
+
|
| 398 |
+
def as_keras(self, x):
|
| 399 |
+
from keras.layers import Dropout
|
| 400 |
+
return Dropout(self.p)(x, training=askeras.kwds["train"])
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class Linear(Module):
|
| 404 |
+
def __init__(self, in_c, out_c, bias=True):
|
| 405 |
+
super().__init__()
|
| 406 |
+
self.use_bias = bias
|
| 407 |
+
self.module = Torch_Linear(in_c, out_c, bias)
|
| 408 |
+
|
| 409 |
+
def as_keras(self, x):
|
| 410 |
+
from keras.layers import Dense
|
| 411 |
+
out_c, in_c = self.module.weight.shape
|
| 412 |
+
params = [self.module.weight.detach().numpy().transpose(1, 0)]
|
| 413 |
+
if self.use_bias:
|
| 414 |
+
params.append(self.module.bias.detach().numpy())
|
| 415 |
+
dense = Dense(out_c, use_bias=self.use_bias)
|
| 416 |
+
dense.build(x.shape[1:])
|
| 417 |
+
dense.set_weights(params)
|
| 418 |
+
return dense(x)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class Transpose(Module):
|
| 422 |
+
"""
|
| 423 |
+
A class designed to transpose tensors according to specified dimension permutations, compatible
|
| 424 |
+
with both PyTorch and TensorFlow (Keras). It allows for flexible tensor manipulation, enabling
|
| 425 |
+
dimension reordering to accommodate the requirements of different neural network architectures
|
| 426 |
+
or operations.
|
| 427 |
+
|
| 428 |
+
The class supports optional channel retention during transposition in TensorFlow to ensure
|
| 429 |
+
compatibility with Keras' channel ordering conventions.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
def forward(self, x, perm: Union[Tuple[int], List[int]],
|
| 433 |
+
keep_channel_last: bool = False):
|
| 434 |
+
"""
|
| 435 |
+
Transposes the input tensor according to the specified dimension permutation. If integrated
|
| 436 |
+
with Keras, it converts PyTorch tensors to TensorFlow tensors before transposing, with an
|
| 437 |
+
option to retain channel ordering as per Keras convention.
|
| 438 |
+
|
| 439 |
+
Parameters:
|
| 440 |
+
x (Tensor): The input tensor to be transposed.
|
| 441 |
+
perm (tuple or list): The permutation of dimensions to apply to the tensor.
|
| 442 |
+
keep_channel_last (bool, optional): Specifies whether to adjust the permutation to
|
| 443 |
+
retain Keras' channel ordering convention. Default
|
| 444 |
+
is False.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Tensor: The transposed tensor.
|
| 448 |
+
"""
|
| 449 |
+
if askeras.use_keras:
|
| 450 |
+
return self.as_keras(x, perm, keep_channel_last)
|
| 451 |
+
# Use PyTorch's permute method for the operation
|
| 452 |
+
return x.permute(*perm)
|
| 453 |
+
|
| 454 |
+
def as_keras(self, x, perm: Union[Tuple[int], List[int]],
|
| 455 |
+
keep_channel_last: bool):
|
| 456 |
+
"""
|
| 457 |
+
Handles tensor transposition in a TensorFlow/Keras environment, converting PyTorch tensors
|
| 458 |
+
to TensorFlow tensors if necessary, and applying the specified permutation. Supports an
|
| 459 |
+
option for channel retention according to Keras conventions.
|
| 460 |
+
|
| 461 |
+
Parameters:
|
| 462 |
+
x (Tensor): The input tensor, possibly a PyTorch tensor.
|
| 463 |
+
perm (tuple or list): The permutation of dimensions to apply.
|
| 464 |
+
keep_channel_last (bool): If True, adjusts the permutation to retain Keras' channel
|
| 465 |
+
ordering convention.
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
Tensor: The transposed tensor in TensorFlow format.
|
| 469 |
+
"""
|
| 470 |
+
import tensorflow as tf
|
| 471 |
+
|
| 472 |
+
# Adjust for TensorFlow's default channel ordering if necessary
|
| 473 |
+
tf_format = [0, 3, 1, 2]
|
| 474 |
+
|
| 475 |
+
# Map PyTorch indices to TensorFlow indices if channel retention is enabled
|
| 476 |
+
mapped_indices = [tf_format[index] for index in
|
| 477 |
+
perm] if keep_channel_last else perm
|
| 478 |
+
|
| 479 |
+
# Convert PyTorch tensors to TensorFlow tensors if necessary
|
| 480 |
+
x_tf = tf.convert_to_tensor(x.detach().numpy(),
|
| 481 |
+
dtype=tf.float32) if isinstance(x,
|
| 482 |
+
torch.Tensor) else x
|
| 483 |
+
|
| 484 |
+
# Apply the transposition with TensorFlow's transpose method
|
| 485 |
+
x_tf_transposed = tf.transpose(x_tf, perm=mapped_indices)
|
| 486 |
+
|
| 487 |
+
return x_tf_transposed
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class Reshape(Module):
|
| 491 |
+
"""
|
| 492 |
+
A class designed to reshape tensors to a specified shape, compatible with both PyTorch and
|
| 493 |
+
TensorFlow (Keras). This class facilitates tensor manipulation across different deep learning
|
| 494 |
+
frameworks, enabling the adjustment of tensor dimensions to meet the requirements of different
|
| 495 |
+
neural network layers or operations.
|
| 496 |
+
|
| 497 |
+
It supports dynamic reshaping capabilities, automatically handling the conversion between
|
| 498 |
+
PyTorch and TensorFlow tensors and applying the appropriate reshaping operation based on the
|
| 499 |
+
runtime context.
|
| 500 |
+
"""
|
| 501 |
+
|
| 502 |
+
def forward(self, x, shape: Union[Tuple[int], List[int]]):
|
| 503 |
+
"""
|
| 504 |
+
Reshapes the input tensor to the specified shape. If integrated with Keras, it converts
|
| 505 |
+
PyTorch tensors to TensorFlow tensors before reshaping.
|
| 506 |
+
|
| 507 |
+
Parameters:
|
| 508 |
+
x (Tensor): The input tensor to be reshaped.
|
| 509 |
+
shape (tuple or list): The new shape for the tensor. The specified shape can include
|
| 510 |
+
a `-1` to automatically infer the dimension that ensures the
|
| 511 |
+
total size remains constant.
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
Tensor: The reshaped tensor.
|
| 515 |
+
"""
|
| 516 |
+
if askeras.use_keras:
|
| 517 |
+
return self.as_keras(x, shape)
|
| 518 |
+
# Use PyTorch's reshape method for the operation
|
| 519 |
+
return x.reshape(*shape)
|
| 520 |
+
|
| 521 |
+
def as_keras(self, x, shape: Union[Tuple[int], List[int]]):
|
| 522 |
+
"""
|
| 523 |
+
Converts PyTorch tensors to TensorFlow tensors, if necessary, and performs the reshape
|
| 524 |
+
operation using TensorFlow's reshape function. This method ensures compatibility and
|
| 525 |
+
functionality within a TensorFlow/Keras environment.
|
| 526 |
+
|
| 527 |
+
Parameters:
|
| 528 |
+
x (Tensor): The input tensor, possibly a PyTorch tensor.
|
| 529 |
+
shape (tuple or list): The new shape for the tensor, including the possibility
|
| 530 |
+
of using `-1` to infer a dimension automatically.
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
Tensor: The reshaped tensor in TensorFlow format.
|
| 534 |
+
"""
|
| 535 |
+
import tensorflow as tf
|
| 536 |
+
# Convert PyTorch tensors to TensorFlow tensors if necessary
|
| 537 |
+
x_tf = tf.convert_to_tensor(x.detach().numpy(),
|
| 538 |
+
dtype=tf.float32) if isinstance(x,
|
| 539 |
+
torch.Tensor) else x
|
| 540 |
+
|
| 541 |
+
# Use TensorFlow's reshape function to adjust the tensor's dimensions
|
| 542 |
+
x_tf_reshaped = tf.reshape(x_tf, shape)
|
| 543 |
+
|
| 544 |
+
# TensorFlow's reshape might introduce an additional dimension if shape is fully defined,
|
| 545 |
+
# use tf.squeeze to adjust dimensions if necessary
|
| 546 |
+
return x_tf_reshaped
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class Flip(Module):
|
| 550 |
+
"""
|
| 551 |
+
A class to flip tensors along specified dimensions, supporting both PyTorch and TensorFlow
|
| 552 |
+
(Keras). This class enables consistent tensor manipulation across different deep learning
|
| 553 |
+
frameworks, facilitating operations like data augmentation or image processing where flipping
|
| 554 |
+
is required.
|
| 555 |
+
|
| 556 |
+
The class automatically detects the runtime environment to apply the appropriate flipping
|
| 557 |
+
operation, handling tensor conversions between PyTorch and TensorFlow as needed.
|
| 558 |
+
"""
|
| 559 |
+
|
| 560 |
+
def forward(self, x, dims: Union[List[int], Tuple[int]]):
|
| 561 |
+
"""
|
| 562 |
+
Flips the input tensor along specified dimensions. If integrated with Keras, it
|
| 563 |
+
converts PyTorch tensors to TensorFlow tensors before flipping.
|
| 564 |
+
|
| 565 |
+
Parameters:
|
| 566 |
+
x (Tensor): The input tensor to be flipped.
|
| 567 |
+
dims (list or tuple): The dimensions along which to flip the tensor.
|
| 568 |
+
|
| 569 |
+
Returns:
|
| 570 |
+
Tensor: The flipped tensor.
|
| 571 |
+
"""
|
| 572 |
+
# Check if Keras usage is flagged and handle accordingly
|
| 573 |
+
if askeras.use_keras:
|
| 574 |
+
return self.as_keras(x, dims)
|
| 575 |
+
# Use PyTorch's flip function for the operation
|
| 576 |
+
return torch.flip(x, dims)
|
| 577 |
+
|
| 578 |
+
def as_keras(self, x, dims: Union[List[int], Tuple[int]]):
|
| 579 |
+
"""
|
| 580 |
+
Converts PyTorch tensors to TensorFlow tensors, if necessary, and performs the flip
|
| 581 |
+
operation using TensorFlow's reverse function. This method ensures compatibility and
|
| 582 |
+
functionality within a TensorFlow/Keras environment.
|
| 583 |
+
|
| 584 |
+
Parameters:
|
| 585 |
+
x (Tensor): The input tensor, possibly a PyTorch tensor.
|
| 586 |
+
dims (list or tuple): The dimensions along which to flip the tensor.
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Tensor: The flipped tensor in TensorFlow format.
|
| 590 |
+
"""
|
| 591 |
+
import tensorflow as tf
|
| 592 |
+
# Convert PyTorch tensors to TensorFlow tensors if necessary
|
| 593 |
+
x_tf = tf.convert_to_tensor(x.detach().numpy(),
|
| 594 |
+
dtype=tf.float32) if isinstance(x,
|
| 595 |
+
torch.Tensor) else x
|
| 596 |
+
|
| 597 |
+
# Use TensorFlow's reverse function for flipping along specified dimensions
|
| 598 |
+
return tf.reverse(x_tf, axis=dims)
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class Add(Module):
|
| 602 |
+
"""
|
| 603 |
+
A class designed to perform element-wise addition on tensors, compatible with both
|
| 604 |
+
PyTorch and TensorFlow (Keras). This enables seamless operation across different deep
|
| 605 |
+
learning frameworks, supporting the addition of tensors regardless of their originating
|
| 606 |
+
framework.
|
| 607 |
+
|
| 608 |
+
The class automatically handles framework-specific tensor conversions and uses the
|
| 609 |
+
appropriate addition operation based on the runtime context, determined by whether
|
| 610 |
+
TensorFlow/Keras or PyTorch is being used.
|
| 611 |
+
"""
|
| 612 |
+
|
| 613 |
+
def forward(self, x, y):
|
| 614 |
+
"""
|
| 615 |
+
Performs element-wise addition of two tensors. If integrated with Keras, converts
|
| 616 |
+
PyTorch tensors to TensorFlow tensors before addition.
|
| 617 |
+
|
| 618 |
+
Parameters:
|
| 619 |
+
x (Tensor): The first input tensor.
|
| 620 |
+
y (Tensor): The second input tensor to be added to the first.
|
| 621 |
+
|
| 622 |
+
Returns:
|
| 623 |
+
Tensor: The result of element-wise addition of `x` and `y`.
|
| 624 |
+
"""
|
| 625 |
+
if askeras.use_keras:
|
| 626 |
+
return self.as_keras(x, y)
|
| 627 |
+
# Use PyTorch's add function for element-wise addition
|
| 628 |
+
return torch.add(x, y)
|
| 629 |
+
|
| 630 |
+
def as_keras(self, x, y):
|
| 631 |
+
"""
|
| 632 |
+
Converts PyTorch tensors to TensorFlow tensors, if necessary, and performs
|
| 633 |
+
element-wise addition using TensorFlow's add function. This method ensures
|
| 634 |
+
compatibility and functionality within a TensorFlow/Keras environment.
|
| 635 |
+
|
| 636 |
+
Parameters:
|
| 637 |
+
x (Tensor): The first input tensor, possibly a PyTorch tensor.
|
| 638 |
+
y (Tensor): The second input tensor, possibly a PyTorch tensor.
|
| 639 |
+
|
| 640 |
+
Returns:
|
| 641 |
+
Tensor: The result of element-wise addition of `x` and `y` in TensorFlow format.
|
| 642 |
+
"""
|
| 643 |
+
import tensorflow as tf
|
| 644 |
+
# Convert PyTorch tensors to TensorFlow tensors if necessary
|
| 645 |
+
x_tf = tf.convert_to_tensor(x.detach().numpy(),
|
| 646 |
+
dtype=tf.float32) if isinstance(x,
|
| 647 |
+
torch.Tensor) else x
|
| 648 |
+
y_tf = tf.convert_to_tensor(y.detach().numpy(),
|
| 649 |
+
dtype=tf.float32) if isinstance(y,
|
| 650 |
+
torch.Tensor) else y
|
| 651 |
+
|
| 652 |
+
# Use TensorFlow's add function for element-wise addition
|
| 653 |
+
return tf.add(x_tf, y_tf)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
class Shape(Module):
|
| 657 |
+
"""
|
| 658 |
+
A utility class for obtaining the shape of a tensor in a format compatible
|
| 659 |
+
with either PyTorch or Keras. This class facilitates the transformation of
|
| 660 |
+
tensor shapes, particularly useful for adapting model input or output
|
| 661 |
+
dimensions across different deep learning frameworks.
|
| 662 |
+
|
| 663 |
+
The class provides a method to directly return the shape of a tensor for
|
| 664 |
+
PyTorch use cases and an additional method for transforming the shape to a
|
| 665 |
+
Keras-compatible format, focusing on the common difference in dimension
|
| 666 |
+
ordering between the two frameworks.
|
| 667 |
+
"""
|
| 668 |
+
|
| 669 |
+
def forward(self, x):
|
| 670 |
+
"""
|
| 671 |
+
Returns the shape of the tensor. If integrated with Keras, it transforms the tensor shape
|
| 672 |
+
to be compatible with Keras dimension ordering.
|
| 673 |
+
|
| 674 |
+
Parameters:
|
| 675 |
+
x (Tensor): The input tensor whose shape is to be obtained or transformed.
|
| 676 |
+
|
| 677 |
+
Returns:
|
| 678 |
+
Tuple: The shape of the tensor, directly returned for PyTorch or transformed for Keras.
|
| 679 |
+
"""
|
| 680 |
+
if askeras.use_keras:
|
| 681 |
+
return self.as_keras(x)
|
| 682 |
+
# Directly return the shape for PyTorch tensors
|
| 683 |
+
return x.shape
|
| 684 |
+
|
| 685 |
+
def as_keras(self, x):
|
| 686 |
+
"""
|
| 687 |
+
Transforms the tensor shape to be compatible with Keras' expected dimension ordering.
|
| 688 |
+
This method is designed to switch between CHW and HWC formats based on the tensor's
|
| 689 |
+
dimensionality, handling common cases for 2D, 3D, and 4D tensors.
|
| 690 |
+
|
| 691 |
+
Parameters:
|
| 692 |
+
x (Tensor): The input tensor whose shape is to be transformed for Keras.
|
| 693 |
+
|
| 694 |
+
Returns:
|
| 695 |
+
Tuple: The transformed shape of the tensor, suitable for Keras models.
|
| 696 |
+
"""
|
| 697 |
+
# Handle different tensor dimensionality with appropriate
|
| 698 |
+
# transformations
|
| 699 |
+
if len(x.shape) == 4: # Assuming NCHW format, convert to NHWC
|
| 700 |
+
N, W, H, C = x.shape
|
| 701 |
+
x_shape = (N, C, H, W)
|
| 702 |
+
elif len(x.shape) == 3: # Assuming CHW format, convert to HWC
|
| 703 |
+
H, W, C = x.shape
|
| 704 |
+
x_shape = (C, H, W)
|
| 705 |
+
else: # Assuming 2D tensor, no channel dimension involved
|
| 706 |
+
H, W = x.shape
|
| 707 |
+
x_shape = (H, W)
|
| 708 |
+
|
| 709 |
+
return x_shape
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class GenericRNN(nn.Module):
|
| 713 |
+
"""
|
| 714 |
+
A base class for customizable RNN models supporting RNN, GRU, and LSTM networks.
|
| 715 |
+
"""
|
| 716 |
+
|
| 717 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
| 718 |
+
bidirectional: bool = False, bias: bool = True,
|
| 719 |
+
batch_first: bool = True, dropout: float = 0) -> None:
|
| 720 |
+
super(GenericRNN, self).__init__()
|
| 721 |
+
self.bidirectional = 2 if bidirectional else 1
|
| 722 |
+
self.hidden_size = hidden_size
|
| 723 |
+
self.num_layers = num_layers
|
| 724 |
+
self.bias = bias
|
| 725 |
+
self.dropout = dropout
|
| 726 |
+
self.input_size = input_size
|
| 727 |
+
self.batch_first = batch_first
|
| 728 |
+
|
| 729 |
+
def init_rnn(self, input_size: int, hidden_size: int, num_layers: int,
|
| 730 |
+
bidirectional: bool, bias: bool, batch_first: bool,
|
| 731 |
+
dropout: float) -> None:
|
| 732 |
+
|
| 733 |
+
raise NotImplementedError("Must be implemented by subclass.")
|
| 734 |
+
|
| 735 |
+
def forward(self, x, h0=None, c0=None):
|
| 736 |
+
|
| 737 |
+
raise NotImplementedError("Must be implemented by subclass.")
|
| 738 |
+
|
| 739 |
+
def as_keras(self, x):
|
| 740 |
+
raise NotImplementedError("Must be implemented by subclass.")
|
| 741 |
+
|
| 742 |
+
def generic_as_keras(self, x, RNNBase):
|
| 743 |
+
"""
|
| 744 |
+
Converts the model architecture and weights to a Keras-compatible format and applies
|
| 745 |
+
the model to the provided input.
|
| 746 |
+
|
| 747 |
+
This method enables the use of PyTorch-trained models within the Keras framework by
|
| 748 |
+
converting the input tensor to a TensorFlow tensor, recreating the model architecture
|
| 749 |
+
in Keras, and setting the weights accordingly.
|
| 750 |
+
|
| 751 |
+
Parameters:
|
| 752 |
+
x (Tensor): The input tensor for the model, which can be a PyTorch tensor.
|
| 753 |
+
RNNBase: The base class for the RNN model to be used in the Keras model.
|
| 754 |
+
|
| 755 |
+
Returns:
|
| 756 |
+
Tuple[Tensor, None]: A tuple containing the output of the Keras model applied to the
|
| 757 |
+
converted input tensor and None (since Keras models do not
|
| 758 |
+
necessarily return the final hidden state as PyTorch models do).
|
| 759 |
+
|
| 760 |
+
Raises:
|
| 761 |
+
ImportError: If the required TensorFlow or Keras modules are not available.
|
| 762 |
+
"""
|
| 763 |
+
|
| 764 |
+
# Import necessary modules from Keras and TensorFlow
|
| 765 |
+
from keras.layers import Bidirectional
|
| 766 |
+
from keras.models import Sequential as KerasSequential
|
| 767 |
+
import tensorflow as tf
|
| 768 |
+
|
| 769 |
+
# Convert PyTorch tensor to TensorFlow tensor if necessary
|
| 770 |
+
if isinstance(x, torch.Tensor):
|
| 771 |
+
x_tf = tf.convert_to_tensor(x.detach().numpy(), dtype=tf.float32)
|
| 772 |
+
else:
|
| 773 |
+
x_tf = x
|
| 774 |
+
|
| 775 |
+
# Create a Keras Sequential model for stacking layers
|
| 776 |
+
model = KerasSequential()
|
| 777 |
+
|
| 778 |
+
# Add RNN layers to the Keras model
|
| 779 |
+
for i in range(self.num_layers):
|
| 780 |
+
# Determine if input shape needs to be specified (only for the first layer)
|
| 781 |
+
if i == 0:
|
| 782 |
+
layer = RNNBase(units=self.hidden_size, return_sequences=True,
|
| 783 |
+
input_shape=list(x.shape[1:]),
|
| 784 |
+
use_bias=self.bias,
|
| 785 |
+
dropout=self.dropout if i < self.num_layers - 1 else 0)
|
| 786 |
+
else:
|
| 787 |
+
layer = RNNBase(units=self.hidden_size, return_sequences=True,
|
| 788 |
+
use_bias=self.bias,
|
| 789 |
+
dropout=self.dropout if i < self.num_layers - 1 else 0)
|
| 790 |
+
|
| 791 |
+
# Wrap the layer with Bidirectional if needed
|
| 792 |
+
if self.bidirectional == 2:
|
| 793 |
+
layer = Bidirectional(layer)
|
| 794 |
+
|
| 795 |
+
model.add(layer)
|
| 796 |
+
|
| 797 |
+
# Apply previously extracted PyTorch weights to the Keras model
|
| 798 |
+
self.set_keras_weights(model)
|
| 799 |
+
|
| 800 |
+
# Process the input through the Keras model
|
| 801 |
+
output = model(x_tf)
|
| 802 |
+
|
| 803 |
+
# Return the output and None for compatibility with PyTorch output format
|
| 804 |
+
return output, None
|
| 805 |
+
|
| 806 |
+
def extract_pytorch_rnn_weights(self):
|
| 807 |
+
"""
|
| 808 |
+
Extracts weights from a PyTorch model's RNN layers and prepares them for
|
| 809 |
+
transfer to a Keras model.
|
| 810 |
+
|
| 811 |
+
This function iterates through the named parameters of a PyTorch model,
|
| 812 |
+
detaching them from the GPU (if applicable),
|
| 813 |
+
moving them to CPU memory, and converting them to NumPy arrays.
|
| 814 |
+
It organizes these weights in a dictionary,
|
| 815 |
+
using the parameter names as keys, which facilitates their later use in
|
| 816 |
+
setting weights for a Keras model.
|
| 817 |
+
|
| 818 |
+
Returns:
|
| 819 |
+
A dictionary containing the weights of the PyTorch model, with parameter
|
| 820 |
+
names as keys and their corresponding NumPy array representations as values.
|
| 821 |
+
"""
|
| 822 |
+
|
| 823 |
+
weights = {} # Initialize a dictionary to store weights
|
| 824 |
+
|
| 825 |
+
# Iterate through the model's named parameters
|
| 826 |
+
for name, param in self.named_parameters():
|
| 827 |
+
# Process the parameter name to extract the relevant part
|
| 828 |
+
# and use it as the key in the weights dictionary
|
| 829 |
+
key = name.split('.')[
|
| 830 |
+
-1] # Extract the last part of the parameter name
|
| 831 |
+
|
| 832 |
+
# Detach the parameter from the computation graph, move it to CPU,
|
| 833 |
+
# and convert to NumPy array
|
| 834 |
+
weights[key] = param.detach().cpu().numpy()
|
| 835 |
+
|
| 836 |
+
return weights # Return the dictionary of weights
|
| 837 |
+
|
| 838 |
+
def set_keras_weights(self, keras_model):
|
| 839 |
+
raise NotImplementedError("Must be implemented by subclass.")
|
| 840 |
+
|
| 841 |
+
def generic_set_keras_weights(self, keras_model, RNNBase: str):
|
| 842 |
+
"""
|
| 843 |
+
Sets the weights of a Keras model based on the weights from a PyTorch model.
|
| 844 |
+
|
| 845 |
+
This function is designed to transfer weights from PyTorch RNN layers (SimpleRNN, GRU, LSTM)
|
| 846 |
+
to their Keras counterparts, including handling for bidirectional layers. It ensures that the
|
| 847 |
+
weights are correctly transposed and combined to match Keras's expectations.
|
| 848 |
+
|
| 849 |
+
Parameters:
|
| 850 |
+
- keras_model: The Keras model to update the weights for.
|
| 851 |
+
"""
|
| 852 |
+
|
| 853 |
+
# Import necessary modules
|
| 854 |
+
from keras.layers import Bidirectional
|
| 855 |
+
import numpy as np
|
| 856 |
+
|
| 857 |
+
# Extract weights from PyTorch model
|
| 858 |
+
pytorch_weights = self.extract_pytorch_rnn_weights()
|
| 859 |
+
|
| 860 |
+
# Iterate over each layer in the Keras model
|
| 861 |
+
for layer in keras_model.layers:
|
| 862 |
+
# Check if layer is bidirectional and set layers to update
|
| 863 |
+
# accordingly
|
| 864 |
+
if isinstance(layer, Bidirectional):
|
| 865 |
+
layers_to_update = [layer.layer, layer.backward_layer]
|
| 866 |
+
else:
|
| 867 |
+
layers_to_update = [layer]
|
| 868 |
+
|
| 869 |
+
# Update weights for each RNN layer in layers_to_update
|
| 870 |
+
for rnn_layer in layers_to_update:
|
| 871 |
+
|
| 872 |
+
num_gates = {'SimpleRNN': 1, 'GRU': 3, 'LSTM': 4}.get(
|
| 873 |
+
RNNBase, 0)
|
| 874 |
+
|
| 875 |
+
# Initialize lists for input-hidden, hidden-hidden weights,
|
| 876 |
+
# and biases
|
| 877 |
+
ih_weights, hh_weights, biases = [], [], []
|
| 878 |
+
|
| 879 |
+
# Process weights and biases for each gate
|
| 880 |
+
for i in range(num_gates):
|
| 881 |
+
gate_suffix = f'_l{i}'
|
| 882 |
+
for prefix in ('weight_ih', 'weight_hh'):
|
| 883 |
+
key = f'{prefix}{gate_suffix}'
|
| 884 |
+
if key in pytorch_weights:
|
| 885 |
+
weights = \
|
| 886 |
+
pytorch_weights[key].T # Transpose to match Keras shape
|
| 887 |
+
|
| 888 |
+
(ih_weights if prefix == 'weight_ih' else hh_weights) \
|
| 889 |
+
.append(weights)
|
| 890 |
+
|
| 891 |
+
bias_keys = (
|
| 892 |
+
f'bias_ih{gate_suffix}', f'bias_hh{gate_suffix}')
|
| 893 |
+
if all(key in pytorch_weights for key in bias_keys):
|
| 894 |
+
# Sum biases from input-hidden and hidden-hidden
|
| 895 |
+
biases.append(
|
| 896 |
+
sum(pytorch_weights[key] for key in bias_keys))
|
| 897 |
+
|
| 898 |
+
# Combine weights and biases into a format suitable for Keras
|
| 899 |
+
keras_weights = [np.vstack(ih_weights),
|
| 900 |
+
np.vstack(hh_weights), np.hstack(biases)]
|
| 901 |
+
|
| 902 |
+
# Set the weights for the Keras layer
|
| 903 |
+
if not isinstance(layer, Bidirectional):
|
| 904 |
+
rnn_layer.set_weights(keras_weights)
|
| 905 |
+
else:
|
| 906 |
+
rnn_layer.cell.set_weights(keras_weights)
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
class RNN(GenericRNN):
|
| 910 |
+
def __init__(self, *args, **kwargs):
|
| 911 |
+
super(RNN, self).__init__(*args, **kwargs)
|
| 912 |
+
self.rnn = nn.RNN(input_size=kwargs['input_size'],
|
| 913 |
+
hidden_size=kwargs['hidden_size'],
|
| 914 |
+
num_layers=kwargs['num_layers'],
|
| 915 |
+
bias=kwargs['bias'],
|
| 916 |
+
batch_first=kwargs['batch_first'],
|
| 917 |
+
dropout=kwargs['dropout'],
|
| 918 |
+
bidirectional=kwargs['bidirectional'])
|
| 919 |
+
|
| 920 |
+
def forward(self, x, h0=None):
|
| 921 |
+
if askeras.use_keras:
|
| 922 |
+
return self.as_keras(x)
|
| 923 |
+
|
| 924 |
+
out, h = self.rnn(x, h0)
|
| 925 |
+
return out, h
|
| 926 |
+
|
| 927 |
+
def as_keras(self, x):
|
| 928 |
+
"""
|
| 929 |
+
Converts the model architecture and weights to a Keras-compatible format and applies
|
| 930 |
+
the model to the provided input.
|
| 931 |
+
|
| 932 |
+
This method enables the use of PyTorch-trained models within the Keras framework by
|
| 933 |
+
converting the input tensor to a TensorFlow tensor, recreating the model architecture
|
| 934 |
+
in Keras, and setting the weights accordingly.
|
| 935 |
+
|
| 936 |
+
Parameters:
|
| 937 |
+
x (Tensor): The input tensor for the model, which can be a PyTorch tensor.
|
| 938 |
+
|
| 939 |
+
Returns:
|
| 940 |
+
Tuple[Tensor, None]: A tuple containing the output of the Keras model applied to the
|
| 941 |
+
converted input tensor and None (since Keras models do not
|
| 942 |
+
necessarily return the final hidden state as PyTorch models do).
|
| 943 |
+
|
| 944 |
+
Raises:
|
| 945 |
+
ImportError: If the required TensorFlow or Keras modules are not available.
|
| 946 |
+
"""
|
| 947 |
+
|
| 948 |
+
# Import necessary modules from Keras and TensorFlow
|
| 949 |
+
from keras.layers import SimpleRNN
|
| 950 |
+
|
| 951 |
+
output, _ = super().generic_as_keras(x, SimpleRNN)
|
| 952 |
+
|
| 953 |
+
return output, None
|
| 954 |
+
|
| 955 |
+
def set_keras_weights(self, keras_model):
|
| 956 |
+
"""
|
| 957 |
+
Sets the weights of a Keras model based on the weights from a PyTorch model.
|
| 958 |
+
|
| 959 |
+
This function is designed to transfer weights from PyTorch RNN layers
|
| 960 |
+
to their Keras counterparts, including handling for bidirectional layers. It ensures that the
|
| 961 |
+
weights are correctly transposed and combined to match Keras's expectations.
|
| 962 |
+
|
| 963 |
+
Parameters:
|
| 964 |
+
- keras_model: The Keras model to update the weights for.
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
self.generic_set_keras_weights(keras_model, 'SimpleRNN')
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class GRU(GenericRNN):
|
| 971 |
+
def __init__(self, *args, **kwargs):
|
| 972 |
+
super(GRU, self).__init__(*args, **kwargs)
|
| 973 |
+
self.rnn = nn.GRU(input_size=kwargs['input_size'],
|
| 974 |
+
hidden_size=kwargs['hidden_size'],
|
| 975 |
+
num_layers=kwargs['num_layers'],
|
| 976 |
+
bias=kwargs['bias'],
|
| 977 |
+
batch_first=kwargs['batch_first'],
|
| 978 |
+
dropout=kwargs['dropout'],
|
| 979 |
+
bidirectional=kwargs['bidirectional'])
|
| 980 |
+
|
| 981 |
+
def forward(self, x, h0=None):
|
| 982 |
+
if askeras.use_keras:
|
| 983 |
+
return self.as_keras(x)
|
| 984 |
+
|
| 985 |
+
out, h = self.rnn(x, h0)
|
| 986 |
+
return out, h
|
| 987 |
+
|
| 988 |
+
def as_keras(self, x):
|
| 989 |
+
"""
|
| 990 |
+
Converts the model architecture and weights to a Keras-compatible format and applies
|
| 991 |
+
the model to the provided input.
|
| 992 |
+
|
| 993 |
+
This method enables the use of PyTorch-trained models within the Keras framework by
|
| 994 |
+
converting the input tensor to a TensorFlow tensor, recreating the model architecture
|
| 995 |
+
in Keras, and setting the weights accordingly.
|
| 996 |
+
|
| 997 |
+
Parameters:
|
| 998 |
+
x (Tensor): The input tensor for the model, which can be a PyTorch tensor.
|
| 999 |
+
|
| 1000 |
+
Returns:
|
| 1001 |
+
Tuple[Tensor, None]: A tuple containing the output of the Keras model applied to the
|
| 1002 |
+
converted input tensor and None (since Keras models do not
|
| 1003 |
+
necessarily return the final hidden state as PyTorch models do).
|
| 1004 |
+
|
| 1005 |
+
Raises:
|
| 1006 |
+
ImportError: If the required TensorFlow or Keras modules are not available.
|
| 1007 |
+
"""
|
| 1008 |
+
|
| 1009 |
+
# Import necessary modules from Keras and TensorFlow
|
| 1010 |
+
from keras.layers import GRU
|
| 1011 |
+
|
| 1012 |
+
output, _ = super().generic_as_keras(x, GRU)
|
| 1013 |
+
|
| 1014 |
+
return output, None
|
| 1015 |
+
|
| 1016 |
+
def set_keras_weights(self, keras_model):
|
| 1017 |
+
"""
|
| 1018 |
+
Sets the weights of a Keras model based on the weights from a PyTorch model.
|
| 1019 |
+
|
| 1020 |
+
This function is designed to transfer weights from PyTorch RNN layers
|
| 1021 |
+
to their Keras counterparts, including handling for bidirectional layers. It ensures that the
|
| 1022 |
+
weights are correctly transposed and combined to match Keras's expectations.
|
| 1023 |
+
|
| 1024 |
+
Parameters:
|
| 1025 |
+
- keras_model: The Keras model to update the weights for.
|
| 1026 |
+
"""
|
| 1027 |
+
|
| 1028 |
+
self.generic_set_keras_weights(keras_model, 'GRU')
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
class LSTM(GenericRNN):
|
| 1032 |
+
def __init__(self, *args, **kwargs):
|
| 1033 |
+
super(LSTM, self).__init__(*args, **kwargs)
|
| 1034 |
+
self.rnn = nn.GRU(input_size=kwargs['input_size'],
|
| 1035 |
+
hidden_size=kwargs['hidden_size'],
|
| 1036 |
+
num_layers=kwargs['num_layers'],
|
| 1037 |
+
bias=kwargs['bias'],
|
| 1038 |
+
batch_first=kwargs['batch_first'],
|
| 1039 |
+
dropout=kwargs['dropout'],
|
| 1040 |
+
bidirectional=kwargs['bidirectional'])
|
| 1041 |
+
|
| 1042 |
+
def forward(self, x, h0=None, c0=None):
|
| 1043 |
+
if askeras.use_keras:
|
| 1044 |
+
return self.as_keras(x)
|
| 1045 |
+
|
| 1046 |
+
out, h = self.rnn(x, (h0, c0))
|
| 1047 |
+
|
| 1048 |
+
return out, h
|
| 1049 |
+
|
| 1050 |
+
def as_keras(self, x):
|
| 1051 |
+
"""
|
| 1052 |
+
Converts the model architecture and weights to a Keras-compatible format and applies
|
| 1053 |
+
the model to the provided input.
|
| 1054 |
+
|
| 1055 |
+
This method enables the use of PyTorch-trained models within the Keras framework by
|
| 1056 |
+
converting the input tensor to a TensorFlow tensor, recreating the model architecture
|
| 1057 |
+
in Keras, and setting the weights accordingly.
|
| 1058 |
+
|
| 1059 |
+
Parameters:
|
| 1060 |
+
x (Tensor): The input tensor for the model, which can be a PyTorch tensor.
|
| 1061 |
+
|
| 1062 |
+
Returns:
|
| 1063 |
+
Tuple[Tensor, None]: A tuple containing the output of the Keras model applied to the
|
| 1064 |
+
converted input tensor and None (since Keras models do not
|
| 1065 |
+
necessarily return the final hidden state as PyTorch models do).
|
| 1066 |
+
|
| 1067 |
+
Raises:
|
| 1068 |
+
ImportError: If the required TensorFlow or Keras modules are not available.
|
| 1069 |
+
"""
|
| 1070 |
+
|
| 1071 |
+
# Import necessary modules from Keras and TensorFlow
|
| 1072 |
+
from keras.layers import LSTM
|
| 1073 |
+
|
| 1074 |
+
output, _ = super().generic_as_keras(x, LSTM)
|
| 1075 |
+
|
| 1076 |
+
return output, None
|
| 1077 |
+
|
| 1078 |
+
def set_keras_weights(self, keras_model):
|
| 1079 |
+
"""
|
| 1080 |
+
Sets the weights of a Keras model based on the weights from a PyTorch model.
|
| 1081 |
+
|
| 1082 |
+
This function is designed to transfer weights from PyTorch RNN layers
|
| 1083 |
+
to their Keras counterparts, including handling for bidirectional layers. It ensures that the
|
| 1084 |
+
weights are correctly transposed and combined to match Keras's expectations.
|
| 1085 |
+
|
| 1086 |
+
Parameters:
|
| 1087 |
+
- keras_model: The Keras model to update the weights for.
|
| 1088 |
+
"""
|
| 1089 |
+
|
| 1090 |
+
self.generic_set_keras_weights(keras_model, 'LSTM')
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
class ChannelSlice(Module):
|
| 1094 |
+
def __init__(self, slice):
|
| 1095 |
+
super().__init__()
|
| 1096 |
+
self.slice = slice
|
| 1097 |
+
|
| 1098 |
+
def forward(self, x):
|
| 1099 |
+
if askeras.use_keras:
|
| 1100 |
+
return self.as_keras(x)
|
| 1101 |
+
return x[:, self.slice]
|
| 1102 |
+
|
| 1103 |
+
def as_keras(self, x):
|
| 1104 |
+
return x[(len(x.shape)-1)*(slice(None),)+(self.slice,)]
|
synet/data_subset.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os.path import splitext
|
| 2 |
+
def get_label(im):
|
| 3 |
+
return splitext(get_labels(im))[0] + '.txt'
|
| 4 |
+
|
| 5 |
+
def get_labels(ims):
|
| 6 |
+
return "/labels/".join(ims.rsplit("/images/", 1))
|
| 7 |
+
|
| 8 |
+
from argparse import ArgumentParser
|
| 9 |
+
def parse_opt():
|
| 10 |
+
parser = ArgumentParser()
|
| 11 |
+
parser.add_argument("--max-bg-ratio", type=float, default=.999)
|
| 12 |
+
parser.add_argument('old_yaml')
|
| 13 |
+
parser.add_argument('new_yaml')
|
| 14 |
+
return parser.parse_args()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from os import listdir, makedirs, symlink
|
| 18 |
+
from os.path import join, abspath, isfile
|
| 19 |
+
from random import shuffle
|
| 20 |
+
from yaml import safe_load as load
|
| 21 |
+
def run(old_yaml, new_yaml, max_bg_ratio):
|
| 22 |
+
old = load(open(old_yaml))
|
| 23 |
+
new = load(open(new_yaml))
|
| 24 |
+
l_n = {l:n for n, l in new['names'].items()}
|
| 25 |
+
old_cls_new = {str(o): str(l_n[l]) for o, l in old['names'].items()
|
| 26 |
+
if l in l_n}
|
| 27 |
+
splits = ['val', 'train']
|
| 28 |
+
if 'test' in new:
|
| 29 |
+
splits.append('test')
|
| 30 |
+
for split in splits:
|
| 31 |
+
fg = 0
|
| 32 |
+
background = []
|
| 33 |
+
for d in new, old:
|
| 34 |
+
d[split] = join(d.get('path', ''), d[split])
|
| 35 |
+
makedirs(new[split])
|
| 36 |
+
makedirs(get_labels(new[split]))
|
| 37 |
+
for imf in listdir(old[split]):
|
| 38 |
+
oldim = join(old[split], imf)
|
| 39 |
+
newim = join(new[split], imf)
|
| 40 |
+
labels = [" ".join([old_cls_new[parts[0]], parts[1]])
|
| 41 |
+
for label in open(oldlb).readlines()
|
| 42 |
+
if (parts := label.split(" ", 1))[0] in old_cls_new
|
| 43 |
+
] if isfile(oldlb := get_label(oldim)) else []
|
| 44 |
+
if not labels:
|
| 45 |
+
background.append((oldim, newim))
|
| 46 |
+
else:
|
| 47 |
+
fg += 1
|
| 48 |
+
symlink(abspath(oldim), newim)
|
| 49 |
+
open(get_label(newim), 'w').writelines(labels)
|
| 50 |
+
|
| 51 |
+
shuffle(background)
|
| 52 |
+
background = background[:int(max_bg_ratio * fg / (1 - max_bg_ratio))]
|
| 53 |
+
for oldim, newim in background:
|
| 54 |
+
symlink(abspath(oldim), newim)
|
synet/demosaic.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import zeros, tensor, arange, where, float32, empty
|
| 2 |
+
from numpy import empty as npempty
|
| 3 |
+
|
| 4 |
+
from .base import Module, Conv2d, askeras
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Mosaic(Module):
|
| 8 |
+
def __init__(self, bayer_pattern, real_keras=False):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.bayer_pattern = tensor(['rgb'.index(c)
|
| 11 |
+
for c in bayer_pattern.lower()])
|
| 12 |
+
self.rows = tensor([0, 0, 1, 1])
|
| 13 |
+
self.cols = tensor([0, 1, 0, 1])
|
| 14 |
+
self.real_keras = real_keras
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
if askeras.use_keras:
|
| 18 |
+
if self.real_keras:
|
| 19 |
+
return self.as_keras(x)
|
| 20 |
+
return x
|
| 21 |
+
*b, c, h, w = x.shape
|
| 22 |
+
y = empty((*b, 1, h, w), dtype=x.dtype, device=x.device)
|
| 23 |
+
for yoff, xoff, chan in zip(self.rows, self.cols, self.bayer_pattern):
|
| 24 |
+
y[..., 0, yoff::2, xoff::2] = x[..., chan, yoff::2, xoff::2]
|
| 25 |
+
return y
|
| 26 |
+
|
| 27 |
+
def clf(self, x):
|
| 28 |
+
y = npempty((*x.shape[:-1], 1), dtype=x.dtype)
|
| 29 |
+
for yoff, xoff, chan in zip(self.rows, self.cols, self.bayer_pattern):
|
| 30 |
+
y[..., yoff::2, xoff::2, 0] = x[..., yoff::2, xoff::2, chan]
|
| 31 |
+
return y
|
| 32 |
+
|
| 33 |
+
def as_keras(self, x):
|
| 34 |
+
B, H, W, C = x.shape
|
| 35 |
+
from keras.layers import Concatenate, Reshape
|
| 36 |
+
a, b, c, d = [x[..., int(yoff)::2, int(xoff)::2, int(chan):int(chan)+1]
|
| 37 |
+
for yoff, xoff, chan in
|
| 38 |
+
zip(self.rows, self.cols, self.bayer_pattern)]
|
| 39 |
+
return Reshape((H, W, 1))(
|
| 40 |
+
Concatenate(-2)((
|
| 41 |
+
Concatenate(-1)((a, b)),
|
| 42 |
+
Concatenate(-1)((c, d)))))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class MosaicGamma(Mosaic):
|
| 46 |
+
|
| 47 |
+
def __init__(self, *args, normalized=True, gammas=[], **kwds):
|
| 48 |
+
super().__init__(*args, **kwds)
|
| 49 |
+
self.gammas = gammas
|
| 50 |
+
if normalized:
|
| 51 |
+
self.gamma_func = self.normalized_gamma
|
| 52 |
+
else:
|
| 53 |
+
self.gamma_func = self.unnormalized_gamma
|
| 54 |
+
|
| 55 |
+
def normalized_gamma(self, x, gamma):
|
| 56 |
+
return x**gamma
|
| 57 |
+
|
| 58 |
+
def unnormalized_gamma(self, x, gamma):
|
| 59 |
+
return ((x / 255)**gamma) * 255
|
| 60 |
+
|
| 61 |
+
def as_keras(self, x):
|
| 62 |
+
from keras.layers import Concatenate
|
| 63 |
+
a, b, c, d = [self.gamma_func(x[..., int(yoff)::2, int(xoff)::2,
|
| 64 |
+
int(chan):int(chan) + 1],
|
| 65 |
+
self.gammas[chan])
|
| 66 |
+
for yoff, xoff, chan in
|
| 67 |
+
zip(self.rows, self.cols, self.bayer_pattern)]
|
| 68 |
+
return Concatenate(-2)((
|
| 69 |
+
Concatenate(-1)((a, b)),
|
| 70 |
+
Concatenate(-1)((c, d))))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class UnfoldedMosaicGamma(MosaicGamma):
|
| 74 |
+
def as_keras(self, x):
|
| 75 |
+
B, H, W, C = x.shape
|
| 76 |
+
from keras.layers import Reshape
|
| 77 |
+
return Reshape((H, W, 1))(super().as_keras(x))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Demosaic(Module):
|
| 81 |
+
|
| 82 |
+
def __init__(self, dfilter, bayer_pattern, *scales):
|
| 83 |
+
super().__init__()
|
| 84 |
+
assert bayer_pattern.lower() in ("rggb", "bggr", "grbg", "gbrg")
|
| 85 |
+
bayer_pattern = tensor(['rgb'.index(c) for c in bayer_pattern.lower()])
|
| 86 |
+
rows = tensor([0, 0, 1, 1])
|
| 87 |
+
cols = tensor([0, 1, 0, 1])
|
| 88 |
+
# assign kernels from specific filter method
|
| 89 |
+
getattr(self, dfilter+'_init')()
|
| 90 |
+
# The basic idea is to apply kxk kernels to two consecutive
|
| 91 |
+
# rows/columns similtaneously, so we will need a (k+1)x(k+1)
|
| 92 |
+
# kernel to give it the proper receptive field. For a given
|
| 93 |
+
# row or column in the 2x2 bayer grid, we need to either slice
|
| 94 |
+
# the kxk kernel into the first or last k rows/columns of the
|
| 95 |
+
# (k+1)x(k+1) generated kernel.
|
| 96 |
+
kslice = slice(None, -1), slice(1, None)
|
| 97 |
+
weight = zeros(4, 3, self.k+1, self.k+1, requires_grad=False)
|
| 98 |
+
|
| 99 |
+
# Set values for which the bayer image IS ground truth.
|
| 100 |
+
# +self.k//2 because 2x2 bayer is centered in the (k+1)x(k+1)
|
| 101 |
+
# kernel.
|
| 102 |
+
weight[arange(4), bayer_pattern, rows+self.k//2, cols+self.k//2] = 1
|
| 103 |
+
|
| 104 |
+
# Finishing off red bayer locations
|
| 105 |
+
r = bayer_pattern == 'rgb'.index('r')
|
| 106 |
+
slicey, slicex = kslice[rows[r]], kslice[cols[r]]
|
| 107 |
+
weight[r, 'rgb'.index('g'), slicey, slicex] = self.GatR
|
| 108 |
+
weight[r, 'rgb'.index('b'), slicey, slicex] = self.BatR
|
| 109 |
+
|
| 110 |
+
# Finishing off blue bayer locations
|
| 111 |
+
b = bayer_pattern == 'rgb'.index('b')
|
| 112 |
+
slicey, slicex = kslice[rows[b]], kslice[cols[b]]
|
| 113 |
+
weight[b, 'rgb'.index('g'), slicey, slicex] = self.GatB
|
| 114 |
+
weight[b, 'rgb'.index('r'), slicey, slicex] = self.RatB
|
| 115 |
+
|
| 116 |
+
# greens get a bit more interesting because there are two
|
| 117 |
+
# types: one in red rows, and one in blue rows.
|
| 118 |
+
g, = where(bayer_pattern == 'rgb'.index('g'))
|
| 119 |
+
# read "gbr" as green pixel in blue row, red column
|
| 120 |
+
if any(b[:2]): # if b is in the first row.
|
| 121 |
+
gbr, grb = g
|
| 122 |
+
else:
|
| 123 |
+
grb, gbr = g
|
| 124 |
+
slicey, slicex = kslice[rows[grb]], kslice[cols[grb]]
|
| 125 |
+
weight[grb, 'rgb'.index('r'), slicey, slicex] = self.RatGRB
|
| 126 |
+
weight[grb, 'rgb'.index('b'), slicey, slicex] = self.BatGRB
|
| 127 |
+
slicey, slicex = kslice[rows[gbr]], kslice[cols[gbr]]
|
| 128 |
+
weight[gbr, 'rgb'.index('r'), slicey, slicex] = self.RatGBR
|
| 129 |
+
weight[gbr, 'rgb'.index('b'), slicey, slicex] = self.BatGBR
|
| 130 |
+
|
| 131 |
+
# apply YUV to RGB transform if necessary. This is equivalent
|
| 132 |
+
# to scaling values AFTER applying filter.
|
| 133 |
+
for i, scale in enumerate(scales):
|
| 134 |
+
weight[:, i] *= float(scale)
|
| 135 |
+
|
| 136 |
+
# create the convulotion.
|
| 137 |
+
self.module = Conv2d(1, 12, (self.k+1, self.k+1), 2)
|
| 138 |
+
self.module.weight.data[:] = weight.reshape(12, 1, self.k+1, self.k+1)
|
| 139 |
+
|
| 140 |
+
def simple_init(self):
|
| 141 |
+
# generated by reading a 'demosaic.cpp' sent to me
|
| 142 |
+
self.GatR = tensor([[0, 1, 0],
|
| 143 |
+
[1, 0, 1],
|
| 144 |
+
[0, 1, 0]]
|
| 145 |
+
) / 4
|
| 146 |
+
# read "GRB" as green bayer location in red row, blue column.
|
| 147 |
+
self.RatGRB = tensor([[0, 0, 0],
|
| 148 |
+
[1, 0, 1],
|
| 149 |
+
[0, 0, 0]]
|
| 150 |
+
) / 2
|
| 151 |
+
self.RatB = tensor([[1, 0, 1],
|
| 152 |
+
[0, 0, 0],
|
| 153 |
+
[1, 0, 1]],
|
| 154 |
+
) / 4
|
| 155 |
+
self.k = 3
|
| 156 |
+
self.basic_init()
|
| 157 |
+
|
| 158 |
+
def malvar_init(self):
|
| 159 |
+
# kernels taken from https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/Demosaicing_ICASSP04.pdf
|
| 160 |
+
self.GatR = tensor([[ 0 , 0 ,-1 , 0 , 0 ],
|
| 161 |
+
[ 0 , 0 , 2 , 0 , 0 ],
|
| 162 |
+
[-1 , 2 , 4 , 2 ,-1 ],
|
| 163 |
+
[ 0 , 0 , 2 , 0 , 0 ],
|
| 164 |
+
[ 0 , 0 ,-1 , 0 , 0 ]],
|
| 165 |
+
dtype=float32) / 8
|
| 166 |
+
# read "GRB" as green bayer location in red row, blue column.
|
| 167 |
+
self.RatGRB = tensor([[ 0 , 0 , 0.5, 0 , 0 ],
|
| 168 |
+
[ 0 ,-1 , 0 ,-1 , 0 ],
|
| 169 |
+
[-1 , 4 , 5 , 4 ,-1 ],
|
| 170 |
+
[ 0 ,-1 , 0 ,-1 , 0 ],
|
| 171 |
+
[ 0 , 0 , 0.5, 0 , 0 ]],
|
| 172 |
+
dtype=float32) / 8
|
| 173 |
+
self.RatB = tensor([[ 0 , 0 ,-1.5, 0 , 0 ],
|
| 174 |
+
[ 0 , 2 , 0 , 2 , 0 ],
|
| 175 |
+
[-1.5, 0 , 6 , 0 ,-1.5],
|
| 176 |
+
[ 0 , 2 , 0 , 2 , 0 ],
|
| 177 |
+
[ 0 , 0 ,-1.5, 0 , 0 ]],
|
| 178 |
+
dtype=float32) / 8
|
| 179 |
+
self.k = 5
|
| 180 |
+
self.basic_init()
|
| 181 |
+
|
| 182 |
+
def basic_init(self):
|
| 183 |
+
self.GatB = self.GatR
|
| 184 |
+
# read "GRB" as green bayer location in red row, blue column.
|
| 185 |
+
self.BatGBR = self.RatGRB
|
| 186 |
+
self.RatGBR = self.RatGRB.T
|
| 187 |
+
self.BatGRB = self.RatGBR
|
| 188 |
+
self.BatR = self.RatB
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def reshape_conv(old_conv):
|
| 192 |
+
assert (all(k in (3, 4) for k in old_conv.kernel_size)
|
| 193 |
+
and old_conv.stride == 2
|
| 194 |
+
and old_conv.in_channels == 3
|
| 195 |
+
and not old_conv.use_bias)
|
| 196 |
+
# first 2x2 is demosaic output, second 2x2 is this new conv's
|
| 197 |
+
# input.
|
| 198 |
+
weight = zeros(old_conv.out_channels, 2, 2, 3, 2, 2)
|
| 199 |
+
old_weight = old_conv.weight.data
|
| 200 |
+
# This is the best image I can use to try and describe (for 3x3
|
| 201 |
+
# starting kernel):
|
| 202 |
+
#
|
| 203 |
+
# 0 1 2
|
| 204 |
+
# l l l
|
| 205 |
+
# +---+---+---+ - + -
|
| 206 |
+
# 0 |rgb rgblrgb|rgb
|
| 207 |
+
# + + + + + 0
|
| 208 |
+
# 1 |rgb rgblrgb|rgb
|
| 209 |
+
# + - + - + - + - + -
|
| 210 |
+
# 2 |rgb rgblrgb|rgb
|
| 211 |
+
# +---+---+---+ + 1
|
| 212 |
+
# lrgb rgblrgb rgb
|
| 213 |
+
# + - + - + - + - + -
|
| 214 |
+
# l l l
|
| 215 |
+
# 0 1
|
| 216 |
+
#
|
| 217 |
+
# The left/top coordinates and ('|', '---') are in terms of the
|
| 218 |
+
# original kernel, and the right/bottom coordinates and ('l', ' -
|
| 219 |
+
# ') are in terms of the new input coordinates. I use the
|
| 220 |
+
# coordinates above in the later comments. The 3x3 box is the
|
| 221 |
+
# orignal conv kernel. Each of the 4 2x2 blocks above have been
|
| 222 |
+
# transformed into one pixel of the demosaic output.
|
| 223 |
+
|
| 224 |
+
# (0, 0) in right/bottom coordinates
|
| 225 |
+
weight[:, 0, 0, :, 0, 0] = old_weight[:, :, 0, 0]
|
| 226 |
+
weight[:, 0, 1, :, 0, 0] = old_weight[:, :, 0, 1]
|
| 227 |
+
weight[:, 1, 0, :, 0, 0] = old_weight[:, :, 1, 0]
|
| 228 |
+
weight[:, 1, 1, :, 0, 0] = old_weight[:, :, 1, 1]
|
| 229 |
+
# (0, 1) in right/bottom coordinates
|
| 230 |
+
weight[:, 0, 0, :, 0, 1] = old_weight[:, :, 0, 2]
|
| 231 |
+
weight[:, 1, 0, :, 0, 1] = old_weight[:, :, 1, 2]
|
| 232 |
+
if old_conv.kernel_size[1] == 4:
|
| 233 |
+
weight[:, 0, 1, :, 0, 1] = old_weight[:, :, 0, 3]
|
| 234 |
+
weight[:, 1, 1, :, 0, 1] = old_weight[:, :, 1, 3]
|
| 235 |
+
# (1, 0) in right/bottom coordinates
|
| 236 |
+
weight[:, 0, 0, :, 1, 0] = old_weight[:, :, 2, 0]
|
| 237 |
+
weight[:, 0, 1, :, 1, 0] = old_weight[:, :, 2, 1]
|
| 238 |
+
if old_conv.kernel_size[0] == 4:
|
| 239 |
+
weight[:, 1, 0, :, 1, 0] = old_weight[:, :, 3, 0]
|
| 240 |
+
weight[:, 1, 1, :, 1, 0] = old_weight[:, :, 3, 1]
|
| 241 |
+
# (1, 1) in right/bottom coordinates
|
| 242 |
+
weight[:, 0, 0, :, 1, 1] = old_weight[:, :, 2, 2]
|
| 243 |
+
if old_conv.kernel_size[1] == 4:
|
| 244 |
+
weight[:, 0, 1, :, 1, 1] = old_weight[:, :, 2, 3]
|
| 245 |
+
if old_conv.kernel_size[0] == 4:
|
| 246 |
+
weight[:, 1, 0, :, 1, 1] = old_weight[:, :, 3, 2]
|
| 247 |
+
if all(k == 4 for k in old_conv.kernel_size):
|
| 248 |
+
weight[:, 1, 1, :, 1, 1] = old_weight[:, :, 3, 3]
|
| 249 |
+
|
| 250 |
+
conv = Conv2d(12, old_conv.out_channels, 2, 1,
|
| 251 |
+
bias=old_conv.use_bias,
|
| 252 |
+
padding=old_conv.padding == "same",
|
| 253 |
+
groups=old_conv.groups)
|
| 254 |
+
conv.weight.data[:] = weight.reshape(old_conv.out_channels,
|
| 255 |
+
12, 2, 2)
|
| 256 |
+
return conv
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class UnfoldedDemosaic(Demosaic):
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
x = self.module(x)
|
| 262 |
+
if askeras.use_keras:
|
| 263 |
+
return self.as_keras(x)
|
| 264 |
+
*B, C, H, W = x.shape
|
| 265 |
+
assert C == 12
|
| 266 |
+
permute = 2, 3, 0, 4, 1
|
| 267 |
+
permute = tuple(range(len(B))) + tuple(v + len(B) for v in permute)
|
| 268 |
+
return x.reshape(*B, 2, 2, 3, H, W
|
| 269 |
+
).permute(permute
|
| 270 |
+
).reshape(*B, 3, 2 * H, 2 * W)
|
| 271 |
+
|
| 272 |
+
def clf(self, x):
|
| 273 |
+
*B, H, W, C = x.shape
|
| 274 |
+
permute = 2, 0, 1
|
| 275 |
+
permute = tuple(range(len(B))) + tuple(v + len(B) for v in permute)
|
| 276 |
+
x = self.module(tensor(x, dtype=float32).permute(permute))
|
| 277 |
+
permute = 3, 0, 4, 1, 2
|
| 278 |
+
permute = tuple(range(len(B))) + tuple(v + len(B) for v in permute)
|
| 279 |
+
return x.reshape(*B, 2, 2, 3, H // 2, W // 2
|
| 280 |
+
).permute(permute
|
| 281 |
+
).reshape(*B, H, W, 3).detach().numpy()
|
| 282 |
+
|
| 283 |
+
def as_keras(self, x):
|
| 284 |
+
from keras.layers import Reshape, Permute
|
| 285 |
+
*_, H, W, _ = x.shape
|
| 286 |
+
return Reshape((H * 2, W * 2, 3))(
|
| 287 |
+
Permute((1, 3, 2, 4, 5))(
|
| 288 |
+
Reshape((H, W, 2, 2, 3))(x)
|
| 289 |
+
)
|
| 290 |
+
)
|
synet/katana.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""katana.py includes imports which are compatible with Katana, and layer definitions that are only compatible with Katana. However, Katana's capabilities are currently a subset of all other chip's capabilities, so it includes only imports for now."""
|
| 2 |
+
|
| 3 |
+
from .layers import (Conv2dInvertedResidual, Head, SWSBiRNN, SRNN)
|
| 4 |
+
from .base import askeras, Conv2d, Cat, ReLU, BatchNorm, Grayscale
|
synet/layers.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""layers.py is the high level model building layer of synet. It
|
| 2 |
+
defines useful composite layers which are compatible with multiple
|
| 3 |
+
chips. Because it is built with layers from base.py, exports come
|
| 4 |
+
"free". As a rule of thumb to differentiate between base.py,
|
| 5 |
+
layers.py:
|
| 6 |
+
|
| 7 |
+
- base.py should only import from torch, keras, and tensorflow.
|
| 8 |
+
- layers.py should only import from base.py.
|
| 9 |
+
|
| 10 |
+
If you sublcass from something in base.py OTHER than Module, you
|
| 11 |
+
should add a test case for it in tests/test_keras.py.
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
from typing import Union, Tuple, Optional
|
| 15 |
+
|
| 16 |
+
from .base import (ReLU, BatchNorm, Conv2d, Module, Cat, Sequential,
|
| 17 |
+
RNN, GRU, LSTM, Transpose, Reshape, Flip, Add,
|
| 18 |
+
Shape, ModuleList, ChannelSlice, DepthwiseConv2d)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# because this module only reinterprets Conv2d parameters, the test
|
| 22 |
+
# case is omitted.
|
| 23 |
+
class DepthwiseConv2d(DepthwiseConv2d):
|
| 24 |
+
def __init__(self,
|
| 25 |
+
channels: int,
|
| 26 |
+
kernel_size: Union[int, Tuple[int, int]],
|
| 27 |
+
stride: int = 1,
|
| 28 |
+
bias: bool = False,
|
| 29 |
+
padding: Optional[bool] = True):
|
| 30 |
+
super().__init__(channels, channels, kernel_size, stride,
|
| 31 |
+
bias, padding, groups=channels)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class InvertedResidual(Module):
|
| 35 |
+
"""
|
| 36 |
+
Block of conv2D -> activation -> linear pointwise with residual concat.
|
| 37 |
+
Inspired by Inverted Residual blocks which are the main building block
|
| 38 |
+
of MobileNet. It is stable and gives low peek memory before and after.
|
| 39 |
+
Additionally, the computations are extremely efficient on our chips
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, in_channels, expansion_factor,
|
| 43 |
+
out_channels=None, stride=1, kernel_size=3,
|
| 44 |
+
skip=True):
|
| 45 |
+
"""This inverted residual takes in_channels to
|
| 46 |
+
in_channels*expansion_factor with a 3x3 convolution. Then
|
| 47 |
+
after a batchnorm and ReLU, the activations are taken back
|
| 48 |
+
down to in_channels (or out_channels, if specified). If
|
| 49 |
+
out_channels is not specified (or equals in_channels), and the
|
| 50 |
+
stride is 1, then the input will be added to the output before
|
| 51 |
+
returning."""
|
| 52 |
+
super().__init__()
|
| 53 |
+
if out_channels is None:
|
| 54 |
+
out_channels = in_channels
|
| 55 |
+
hidden = int(in_channels * expansion_factor)
|
| 56 |
+
self.layers = Sequential(Conv2d(in_channels,
|
| 57 |
+
out_channels=hidden,
|
| 58 |
+
kernel_size=kernel_size,
|
| 59 |
+
stride=stride),
|
| 60 |
+
BatchNorm(hidden),
|
| 61 |
+
ReLU(6),
|
| 62 |
+
Conv2d(in_channels=hidden,
|
| 63 |
+
out_channels=out_channels,
|
| 64 |
+
kernel_size=1),
|
| 65 |
+
BatchNorm(out_channels))
|
| 66 |
+
self.stride = stride
|
| 67 |
+
self.cheq = in_channels == out_channels and skip
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
y = self.layers(x)
|
| 71 |
+
if self.stride == 1 and self.cheq:
|
| 72 |
+
return x + y
|
| 73 |
+
return y
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# for backwards compatibility
|
| 77 |
+
Conv2dInvertedResidual = InvertedResidual
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Head(Module):
|
| 81 |
+
def __init__(self, in_channels, out_channels, num=4):
|
| 82 |
+
"""Creates a sequence of convolutions with ReLU(6)'s.
|
| 83 |
+
in_channels features are converted to out_channels in the first
|
| 84 |
+
convolution. All other convolutions have out_channels going in and
|
| 85 |
+
out of that layer. num (default 4) convolutions are used in total.
|
| 86 |
+
|
| 87 |
+
"""
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.relu = ReLU(6)
|
| 90 |
+
out_channels = [in_channels] * (num - 1) + [out_channels]
|
| 91 |
+
self.model = Sequential(*(Sequential(Conv2d(in_channels,
|
| 92 |
+
out_channels,
|
| 93 |
+
3, bias=True),
|
| 94 |
+
self.relu)
|
| 95 |
+
for out_channels in out_channels))
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
return self.model(x)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class CoBNRLU(Module):
|
| 102 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
| 103 |
+
stride=1, bias=False, padding=True, groups=1,
|
| 104 |
+
max_val=6, name=None):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.module = Sequential(Conv2d(in_channels, out_channels,
|
| 107 |
+
kernel_size, stride, bias, padding,
|
| 108 |
+
groups),
|
| 109 |
+
BatchNorm(out_channels),
|
| 110 |
+
ReLU(max_val, name=name))
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
return self.module(x)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class GenericSRNN(Module):
|
| 117 |
+
"""
|
| 118 |
+
Implements GenericSRNN (Generic Separable RNN), which processes an input tensor sequentially
|
| 119 |
+
along its X-axis and Y-axis using two RNNs.
|
| 120 |
+
This approach first applies an RNN along the X-axis of the input, then feeds the resulting tensor
|
| 121 |
+
into another RNN along the Y-axis.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
- hidden_size_x (int): The number of features in the hidden state of the X-axis RNN.
|
| 125 |
+
- hidden_size_y (int): The number of features in the hidden state of the Y-axis RNN,
|
| 126 |
+
which also determines the output size.
|
| 127 |
+
- num_layers (int, optional): Number of recurrent layers for each RNN, defaulting to 1.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
- output (tensor): The output tensor from the Y-axis RNN.
|
| 131 |
+
- hn_x (tensor): The final hidden state from the X-axis RNN.
|
| 132 |
+
- hn_y (tensor): The final hidden state from the Y-axis RNN.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(self, hidden_size_x: int, hidden_size_y: int) -> None:
|
| 136 |
+
super(GenericSRNN, self).__init__()
|
| 137 |
+
self.output_size_x = hidden_size_x
|
| 138 |
+
self.output_size_y = hidden_size_y
|
| 139 |
+
|
| 140 |
+
self.transpose = Transpose()
|
| 141 |
+
self.reshape = Reshape()
|
| 142 |
+
self.get_shape = Shape()
|
| 143 |
+
|
| 144 |
+
def forward(self, x, rnn_x: Module, rnn_y: Module):
|
| 145 |
+
"""
|
| 146 |
+
Performs the forward pass for the HierarchicalRNN module.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
- x (tensor): Input tensor with shape [batch_size, channels, height, width]
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
- output (tensor): The final output tensor from the Y-axis RNN.
|
| 153 |
+
"""
|
| 154 |
+
batch_size, channels, H, W = self.get_shape(x)
|
| 155 |
+
|
| 156 |
+
# Rearrange the tensor to (batch_size*H, W, channels) for
|
| 157 |
+
# RNN processing over height This step prepares the data by aligning
|
| 158 |
+
# it along the width, treating each row separately.
|
| 159 |
+
# the keep_channel_last is true in case using channel last, but the
|
| 160 |
+
# assumption of the transpose dims is that the input is channels first,
|
| 161 |
+
# so is sign to keep the channels in the last dim.
|
| 162 |
+
x_w = self.transpose(x, (0, 2, 3, 1),
|
| 163 |
+
keep_channel_last=True) # Rearranges to (batch_size, H, W, channels)
|
| 164 |
+
x_w = self.reshape(x_w, (batch_size * H, W, channels))
|
| 165 |
+
|
| 166 |
+
output_x, _ = rnn_x(x_w)
|
| 167 |
+
|
| 168 |
+
# Prepare the output from the X-axis RNN for Y-axis processing by
|
| 169 |
+
# rearranging it to (batch_size*W, H, output_size_x),
|
| 170 |
+
# enabling RNN application over width.
|
| 171 |
+
output_x_reshape = self.reshape(output_x,
|
| 172 |
+
(batch_size, H, W, self.output_size_x))
|
| 173 |
+
output_x_permute = self.transpose(output_x_reshape, (0, 2, 1, 3))
|
| 174 |
+
output_x_permute = self.reshape(output_x_permute,
|
| 175 |
+
(batch_size * W, H, self.output_size_x))
|
| 176 |
+
|
| 177 |
+
output, _ = rnn_y(output_x_permute)
|
| 178 |
+
|
| 179 |
+
# Reshape and rearrange the final output to
|
| 180 |
+
# (batch_size, channels, height, width),
|
| 181 |
+
# restoring the original input dimensions with the transformed data.
|
| 182 |
+
output = self.reshape(output, (batch_size, W, H, self.output_size_y))
|
| 183 |
+
output = self.transpose(output, (0, 3, 2, 1), keep_channel_last=True)
|
| 184 |
+
|
| 185 |
+
return output
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class SRNN(GenericSRNN):
|
| 189 |
+
"""
|
| 190 |
+
Implements the Separable Recurrent Neural Network (SRNN).
|
| 191 |
+
This model extends a standard RNN by introducing separability in processing.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(self, input_size: int, hidden_size_x: int, hidden_size_y: int,
|
| 195 |
+
base: str = 'RNN', num_layers: int = 1, bias: bool = True,
|
| 196 |
+
batch_first: bool = True, dropout: float = 0.0,
|
| 197 |
+
bidirectional: bool = False) -> None:
|
| 198 |
+
"""
|
| 199 |
+
Initializes the SRNN model with the given parameters.
|
| 200 |
+
|
| 201 |
+
Parameters:
|
| 202 |
+
- input_size: The number of expected features in the input `x`
|
| 203 |
+
- hidden_size_x: The number of features in the hidden state `x`
|
| 204 |
+
- hidden_size_y: The number of features in the hidden state `y`
|
| 205 |
+
- base: The type of RNN to use (e.g., 'RNN', 'LSTM', 'GRU')
|
| 206 |
+
- num_layers: Number of recurrent layers. E.g., setting `num_layers=2`
|
| 207 |
+
would mean stacking two RNNs together
|
| 208 |
+
- bias: If `False`, then the layer does not use bias weights `b_ih` and
|
| 209 |
+
`b_hh`. Default: `True`
|
| 210 |
+
- batch_first: If `True`, then the input and output tensors are provided
|
| 211 |
+
as (batch, seq, feature). Default: `True`
|
| 212 |
+
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of
|
| 213 |
+
each RNN layer except the last layer,
|
| 214 |
+
with dropout probability equal to `dropout`. Default: 0
|
| 215 |
+
- bidirectional: If `True`, becomes a torch implementation of
|
| 216 |
+
bidirectional RNN (two RNN blocks, one for the forward pass and one
|
| 217 |
+
for the backward). Default: `False`
|
| 218 |
+
|
| 219 |
+
Creates two `RNN` instances for processing in `x` and `y`
|
| 220 |
+
dimensions, respectively.
|
| 221 |
+
|
| 222 |
+
From our experiments, we found that the best results were
|
| 223 |
+
obtained with the following parameters:
|
| 224 |
+
base='RNN', num_layers=1, bias=True, batch_first=True, dropout=0
|
| 225 |
+
"""
|
| 226 |
+
super(SRNN, self).__init__(hidden_size_x, hidden_size_y)
|
| 227 |
+
|
| 228 |
+
# Dictionary mapping base types to their respective PyTorch class
|
| 229 |
+
RNN_bases = {'RNN': RNN,
|
| 230 |
+
'GRU': GRU,
|
| 231 |
+
'LSTM': LSTM}
|
| 232 |
+
|
| 233 |
+
self.rnn_x = RNN_bases[base](input_size=input_size,
|
| 234 |
+
hidden_size=hidden_size_x,
|
| 235 |
+
num_layers=num_layers,
|
| 236 |
+
bias=bias,
|
| 237 |
+
batch_first=batch_first,
|
| 238 |
+
dropout=dropout,
|
| 239 |
+
bidirectional=bidirectional)
|
| 240 |
+
|
| 241 |
+
self.rnn_y = RNN_bases[base](input_size=hidden_size_x,
|
| 242 |
+
hidden_size=hidden_size_y,
|
| 243 |
+
num_layers=num_layers,
|
| 244 |
+
bias=bias,
|
| 245 |
+
batch_first=batch_first,
|
| 246 |
+
dropout=dropout,
|
| 247 |
+
bidirectional=bidirectional)
|
| 248 |
+
|
| 249 |
+
# Output sizes of the model in the `x` and `y` dimensions.
|
| 250 |
+
self.output_size_x = hidden_size_x
|
| 251 |
+
self.output_size_y = hidden_size_y
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
"""
|
| 255 |
+
Defines the forward pass of the SRNN.
|
| 256 |
+
|
| 257 |
+
Parameters:
|
| 258 |
+
- x: The input tensor to the RNN
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
- The output of the SRNN after processing the input tensor
|
| 262 |
+
`x` through both the `x` and `y` RNNs.
|
| 263 |
+
"""
|
| 264 |
+
output = super().forward(x, self.rnn_x, self.rnn_y)
|
| 265 |
+
return output
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class SWSBiRNN(GenericSRNN):
|
| 269 |
+
"""
|
| 270 |
+
Implements the Weights Shared Bi-directional Separable Recurrent Neural Network (WSBiSRNN).
|
| 271 |
+
This model extends a standard RNN by introducing bi-directionality and separability in processing,
|
| 272 |
+
with weight sharing to reduce the number of parameters.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(self, input_size: int, hidden_size_x: int, hidden_size_y: int,
|
| 276 |
+
base: str = 'RNN', num_layers: int = 1, bias: bool = True,
|
| 277 |
+
batch_first: bool = True, dropout: float = 0.0) -> None:
|
| 278 |
+
"""
|
| 279 |
+
Initializes the WSBiSRNN model with the given parameters.
|
| 280 |
+
|
| 281 |
+
Parameters:
|
| 282 |
+
- input_size: The number of expected features in the input `x`
|
| 283 |
+
- hidden_size_x: The number of features in the hidden state `x`
|
| 284 |
+
- hidden_size_y: The number of features in the hidden state `y`
|
| 285 |
+
- base: The type of RNN to use (e.g., 'RNN', 'LSTM', 'GRU')
|
| 286 |
+
- num_layers: Number of recurrent layers. E.g., setting `num_layers=2`
|
| 287 |
+
would mean stacking two RNNs together
|
| 288 |
+
- bias: If `False`, then the layer does not use bias weights `b_ih` and
|
| 289 |
+
`b_hh`. Default: `True`
|
| 290 |
+
- batch_first: If `True`, then the input and output tensors are provided
|
| 291 |
+
as (batch, seq, feature). Default: `True`
|
| 292 |
+
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of
|
| 293 |
+
each RNN layer except the last layer,
|
| 294 |
+
with dropout probability equal to `dropout`. Default: 0
|
| 295 |
+
|
| 296 |
+
Creates two `WSBiRNN` instances for processing in `x` and `y`
|
| 297 |
+
dimensions, respectively.
|
| 298 |
+
|
| 299 |
+
From our experiments, we found that the best results were
|
| 300 |
+
obtained with the following parameters:
|
| 301 |
+
base='RNN', num_layers=1, bias=True, batch_first=True, dropout=0
|
| 302 |
+
"""
|
| 303 |
+
super(SWSBiRNN, self).__init__(hidden_size_x, hidden_size_y)
|
| 304 |
+
|
| 305 |
+
self.rnn_x = WSBiRNN(input_size=input_size,
|
| 306 |
+
hidden_size=hidden_size_x,
|
| 307 |
+
num_layers=num_layers,
|
| 308 |
+
base=base,
|
| 309 |
+
bias=bias,
|
| 310 |
+
batch_first=batch_first,
|
| 311 |
+
dropout=dropout)
|
| 312 |
+
|
| 313 |
+
self.rnn_y = WSBiRNN(input_size=hidden_size_x,
|
| 314 |
+
hidden_size=hidden_size_y,
|
| 315 |
+
num_layers=num_layers,
|
| 316 |
+
base=base,
|
| 317 |
+
bias=bias,
|
| 318 |
+
batch_first=batch_first,
|
| 319 |
+
dropout=dropout)
|
| 320 |
+
|
| 321 |
+
# Output sizes of the model in the `x` and `y` dimensions.
|
| 322 |
+
self.output_size_x = hidden_size_x
|
| 323 |
+
self.output_size_y = hidden_size_y
|
| 324 |
+
|
| 325 |
+
def forward(self, x):
|
| 326 |
+
"""
|
| 327 |
+
Defines the forward pass of the WSBiSRNN.
|
| 328 |
+
|
| 329 |
+
Parameters:
|
| 330 |
+
- x: The input tensor to the RNN
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
- The output of the WSBiSRNN after processing the input tensor `x` through both the `x` and `y` RNNs.
|
| 334 |
+
"""
|
| 335 |
+
output = super().forward(x, self.rnn_x, self.rnn_y)
|
| 336 |
+
return output
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class WSBiRNN(Module):
|
| 340 |
+
"""
|
| 341 |
+
WSBiRNN (Weight-Shared Bidirectional) RNN is a custom implementation of a bidirectional
|
| 342 |
+
RNN that processes input sequences in both forward and reverse directions
|
| 343 |
+
and combines the outputs. This class manually implements
|
| 344 |
+
bidirectional functionality using a specified base RNN (e.g., vanilla RNN, GRU, LSTM)
|
| 345 |
+
and combines the forward and reverse outputs.
|
| 346 |
+
|
| 347 |
+
Attributes:
|
| 348 |
+
rnn (Module): The RNN module used for processing sequences in the forward direction.
|
| 349 |
+
hidden_size (int): The size of the hidden layer in the RNN.
|
| 350 |
+
flip (Flip): An instance of the Flip class for reversing the sequence order.
|
| 351 |
+
add (Add): An instance of the Add class for combining forward and reverse outputs.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
| 355 |
+
base: str = 'RNN', bias: bool = True, batch_first: bool = True,
|
| 356 |
+
dropout: float = 0.0) -> None:
|
| 357 |
+
"""
|
| 358 |
+
Initializes the BiDirectionalRNN module with the specified parameters.
|
| 359 |
+
|
| 360 |
+
Parameters:
|
| 361 |
+
input_size (int): The number of expected features in the input `x`.
|
| 362 |
+
hidden_size (int): The number of features in the hidden state `h`.
|
| 363 |
+
num_layers (int, optional): Number of recurrent layers. Default: 1.
|
| 364 |
+
base (str, optional): Type of RNN ('RNN', 'GRU', 'LSTM'). Default: 'RNN'.
|
| 365 |
+
bias (bool, optional): If False, then the layer does not use bias weights. Default: True.
|
| 366 |
+
batch_first (bool, optional): If True, then the input and output tensors are provided
|
| 367 |
+
as (batch, seq, feature). Default: True.
|
| 368 |
+
dropout (float, optional): If non-zero, introduces a Dropout layer on the outputs of
|
| 369 |
+
each RNN layer except the last layer. Default: 0.
|
| 370 |
+
"""
|
| 371 |
+
super(WSBiRNN, self).__init__()
|
| 372 |
+
|
| 373 |
+
# Dictionary mapping base types to their respective PyTorch class
|
| 374 |
+
RNN_bases = {'RNN': RNN,
|
| 375 |
+
'GRU': GRU,
|
| 376 |
+
'LSTM': LSTM}
|
| 377 |
+
|
| 378 |
+
# Initialize the forward RNN module
|
| 379 |
+
self.rnn = RNN_bases[base](input_size=input_size,
|
| 380 |
+
hidden_size=hidden_size,
|
| 381 |
+
num_layers=num_layers,
|
| 382 |
+
bias=bias,
|
| 383 |
+
batch_first=batch_first,
|
| 384 |
+
dropout=dropout,
|
| 385 |
+
bidirectional=False)
|
| 386 |
+
|
| 387 |
+
# Initialize utilities for flipping sequences and combining outputs
|
| 388 |
+
self.flip = Flip()
|
| 389 |
+
self.add = Add()
|
| 390 |
+
|
| 391 |
+
def forward(self, x):
|
| 392 |
+
"""
|
| 393 |
+
Defines the forward pass for the bidirectional RNN.
|
| 394 |
+
|
| 395 |
+
Parameters:
|
| 396 |
+
x (Tensor): The input sequence tensor.
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
Tensor: The combined output of the forward and reverse processed sequences.
|
| 400 |
+
_: Placeholder for compatibility with the expected RNN output format.
|
| 401 |
+
"""
|
| 402 |
+
# Reverse the sequence for processing in the reverse direction
|
| 403 |
+
x_reverse = self.flip(x, [1])
|
| 404 |
+
|
| 405 |
+
# Process sequences in forward and reverse directions
|
| 406 |
+
out_forward, _ = self.rnn(x)
|
| 407 |
+
out_reverse, _ = self.rnn(x_reverse)
|
| 408 |
+
|
| 409 |
+
# Flip the output from the reverse direction to align with forward
|
| 410 |
+
# direction
|
| 411 |
+
out_reverse_flip = self.flip(out_reverse, [1])
|
| 412 |
+
|
| 413 |
+
# Combine the outputs from the forward and reverse directions
|
| 414 |
+
return self.add(out_reverse_flip, out_forward), _
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class S2f(Module):
|
| 418 |
+
# Synaptics C2f. Constructed from tflite inspection
|
| 419 |
+
def __init__(self, in_channels, n, out_channels=None, nslice=None,
|
| 420 |
+
skip=True):
|
| 421 |
+
super().__init__()
|
| 422 |
+
if out_channels is None:
|
| 423 |
+
out_channels = in_channels
|
| 424 |
+
if nslice is not None:
|
| 425 |
+
c = nslice
|
| 426 |
+
else:
|
| 427 |
+
c = in_channels // 2
|
| 428 |
+
self.slice = ChannelSlice(slice(c))
|
| 429 |
+
self.ir = ModuleList([InvertedResidual(c, 1, skip=skip is True)
|
| 430 |
+
for _ in range(n)])
|
| 431 |
+
self.cat = Cat()
|
| 432 |
+
self.decode = CoBNRLU(in_channels + n*c, out_channels, bias=True)
|
| 433 |
+
|
| 434 |
+
def forward(self, x):
|
| 435 |
+
out = [x]
|
| 436 |
+
y = self.slice(x)
|
| 437 |
+
for ir in self.ir:
|
| 438 |
+
out.append(y := ir(y))
|
| 439 |
+
return self.decode(self.cat(out))
|
synet/legacy.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from numpy import array
|
| 2 |
+
|
| 3 |
+
from os.path import join, dirname
|
| 4 |
+
from json import load
|
| 5 |
+
from tensorflow.keras.models import load_model
|
| 6 |
+
def get_katananet_model(model_path, input_shape, low_thld, **kwds):
|
| 7 |
+
"""Load katananet model.
|
| 8 |
+
|
| 9 |
+
model_dir: str
|
| 10 |
+
path to directory with model.h5.
|
| 11 |
+
input_shape: iterable of ints
|
| 12 |
+
shape of the cell run.
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
raw_model = load_model(model_path, compile=False)
|
| 16 |
+
anchor_params = load(open(join(dirname(model_path), "anchors.json")))
|
| 17 |
+
anchors = gen_anchors(input_shape, **anchor_params)
|
| 18 |
+
def model(image):
|
| 19 |
+
(deltas,), (scores,) = raw_model.predict_on_batch(preproc(image))
|
| 20 |
+
# low thld
|
| 21 |
+
keep = scores.max(1) > low_thld
|
| 22 |
+
deltas, anchor_keep, scores = deltas[keep], anchors[keep], scores[keep]
|
| 23 |
+
# get_abs_coords only get coordinates relative to cell
|
| 24 |
+
boxes = get_abs_coords(deltas, anchor_keep,
|
| 25 |
+
training_scale=.2, training_shift=0,
|
| 26 |
+
maxx=image.shape[-1], maxy=image.shape[-2])
|
| 27 |
+
# apply nms
|
| 28 |
+
boxes, scores = nms(boxes, scores, threshold=.3)
|
| 29 |
+
return boxes, scores.squeeze(-1)
|
| 30 |
+
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
from numpy import float32, expand_dims
|
| 34 |
+
def preproc(image):
|
| 35 |
+
"""Convert image values from integer range [0,255] to float32
|
| 36 |
+
range [-1,1)."""
|
| 37 |
+
if len(image.shape) < 3:
|
| 38 |
+
image = expand_dims(image, 0)
|
| 39 |
+
return image.astype(float32) / 128 - 1
|
| 40 |
+
|
| 41 |
+
from numpy import zeros, arange, concatenate
|
| 42 |
+
from math import ceil
|
| 43 |
+
def gen_anchors(image_shape, strides, sizes, ratios, scales):
|
| 44 |
+
imy, imx = image_shape
|
| 45 |
+
all_anchors = []
|
| 46 |
+
scales = array(scales).reshape(-1, 1)
|
| 47 |
+
ratios = array(ratios).reshape(-1, 1, 1)**.5
|
| 48 |
+
for stride, size in zip(strides, sizes):
|
| 49 |
+
py, px = ceil(imy/stride), ceil(imx/stride)
|
| 50 |
+
anchors = zeros((py, px, len(ratios), len(scales), 4))
|
| 51 |
+
# anchors as (xc, yc, w, h)
|
| 52 |
+
anchors[...,2:] = size * scales
|
| 53 |
+
# apply ratios
|
| 54 |
+
anchors[...,2] /= ratios[...,0]
|
| 55 |
+
anchors[...,3] *= ratios[...,0]
|
| 56 |
+
# convert to xyxy
|
| 57 |
+
anchors[...,:2] -= anchors[...,2:]/2
|
| 58 |
+
anchors[...,2:] /= 2
|
| 59 |
+
# add offsets for xy position
|
| 60 |
+
anchors[...,0::2] += ((arange(px) + 0.5) * stride).reshape(-1,1,1,1)
|
| 61 |
+
anchors[...,1::2] += ((arange(py) + 0.5) * stride).reshape(-1,1,1,1,1)
|
| 62 |
+
all_anchors.append(anchors.reshape(-1, 4))
|
| 63 |
+
return concatenate(all_anchors)
|
| 64 |
+
|
| 65 |
+
from numpy import clip, newaxis
|
| 66 |
+
def get_abs_coords(deltas, anchors, training_scale, training_shift,
|
| 67 |
+
maxx, maxy):
|
| 68 |
+
"""Convert model output (deltas) into "absolute" coordinates.
|
| 69 |
+
Note: absolute coordinates here are still relative to the grid
|
| 70 |
+
cell being run.
|
| 71 |
+
|
| 72 |
+
deltas: ndarray
|
| 73 |
+
nx4 array of xyxy values.
|
| 74 |
+
anchors: ndarray
|
| 75 |
+
nx4 array of ofsets.
|
| 76 |
+
training_scale: float
|
| 77 |
+
scale specific to our training code. For us always set to .2.
|
| 78 |
+
training_shift: float
|
| 79 |
+
shift specific to our training code. For us is always 0.
|
| 80 |
+
maxx: float
|
| 81 |
+
Max x value. Used to clip final results to fit in cell.
|
| 82 |
+
maxy: float
|
| 83 |
+
Max y value. Used to clip final results to fit in cell.
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
width, height = (anchors[:, 2:4] - anchors[:, 0:2]).T
|
| 87 |
+
deltas = deltas * training_scale + training_shift
|
| 88 |
+
deltas[:,0::2] *= width [...,newaxis]
|
| 89 |
+
deltas[:,1::2] *= height[...,newaxis]
|
| 90 |
+
boxes = deltas + anchors
|
| 91 |
+
boxes[:, 0::2] = clip(boxes[:, 0::2], 0, maxx)
|
| 92 |
+
boxes[:, 1::2] = clip(boxes[:, 1::2], 0, maxy)
|
| 93 |
+
return boxes
|
| 94 |
+
|
| 95 |
+
from numpy import argsort, maximum, minimum
|
| 96 |
+
def nms(boxes, score, threshold):
|
| 97 |
+
"""
|
| 98 |
+
Non-maxima supression to remove redundant boxes
|
| 99 |
+
:param bounding_boxes: Input box coordinates
|
| 100 |
+
:param confidence_score: Confidence scores for each box
|
| 101 |
+
:param labels: Class label for each box
|
| 102 |
+
:param threshold: Only boxes above this threshold are selected
|
| 103 |
+
:return:
|
| 104 |
+
Final detected boxes
|
| 105 |
+
"""
|
| 106 |
+
if not len(boxes):
|
| 107 |
+
return boxes, score
|
| 108 |
+
|
| 109 |
+
# coordinates of bounding boxes
|
| 110 |
+
all_x1 = boxes[:, 0]
|
| 111 |
+
all_y1 = boxes[:, 1]
|
| 112 |
+
all_x2 = boxes[:, 2]
|
| 113 |
+
all_y2 = boxes[:, 3]
|
| 114 |
+
|
| 115 |
+
# Picked bounding boxes
|
| 116 |
+
picked_boxes = []
|
| 117 |
+
picked_score = []
|
| 118 |
+
|
| 119 |
+
# Compute areas of bounding boxes
|
| 120 |
+
areas = (all_y2 - all_y1 + 1) * (all_x2 - all_x1 + 1)
|
| 121 |
+
|
| 122 |
+
# Sort by confidence score of bounding boxes
|
| 123 |
+
order = argsort(-score.max(-1))
|
| 124 |
+
|
| 125 |
+
# Iterate bounding boxes
|
| 126 |
+
while order.size > 0:
|
| 127 |
+
# The index of largest confidence score
|
| 128 |
+
index = order[0]
|
| 129 |
+
order = order[1:]
|
| 130 |
+
|
| 131 |
+
# Pick the bounding box with largest confidence score
|
| 132 |
+
picked_boxes.append(boxes[index])
|
| 133 |
+
picked_score.append(score[index])
|
| 134 |
+
|
| 135 |
+
# Compute ordinates of intersection-over-union(IOU)
|
| 136 |
+
y1 = maximum(all_y1[index], all_y1[order])
|
| 137 |
+
x1 = maximum(all_x1[index], all_x1[order])
|
| 138 |
+
y2 = minimum(all_y2[index], all_y2[order])
|
| 139 |
+
x2 = minimum(all_x2[index], all_x2[order])
|
| 140 |
+
|
| 141 |
+
# Compute areas of intersection-over-union
|
| 142 |
+
w = maximum(0.0, x2 - x1 + 1)
|
| 143 |
+
h = maximum(0.0, y2 - y1 + 1)
|
| 144 |
+
intersection = w * h
|
| 145 |
+
|
| 146 |
+
# Compute the ratio between intersection and union
|
| 147 |
+
ratio = intersection / (areas[index] + areas[order] - intersection)
|
| 148 |
+
|
| 149 |
+
order = order[ratio < threshold]
|
| 150 |
+
|
| 151 |
+
return array(picked_boxes), array(picked_score)
|
synet/metrics.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser, Namespace
|
| 2 |
+
from glob import glob
|
| 3 |
+
from os import listdir, makedirs
|
| 4 |
+
from os.path import join, basename, isfile, isdir, isabs
|
| 5 |
+
|
| 6 |
+
from numpy import genfromtxt
|
| 7 |
+
from torch import tensor, stack, cat, empty
|
| 8 |
+
from ultralytics.utils.ops import xywh2xyxy
|
| 9 |
+
from ultralytics.data.utils import check_det_dataset, img2label_paths
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
aP_curve_points = 10000
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_opt() -> Namespace:
|
| 16 |
+
parser = ArgumentParser()
|
| 17 |
+
parser.add_argument("data_yamls", nargs="+")
|
| 18 |
+
parser.add_argument("--out-dirs", nargs="+")
|
| 19 |
+
parser.add_argument("--project")
|
| 20 |
+
parser.add_argument("--name")
|
| 21 |
+
parser.add_argument("--print-jobs", action="store_true")
|
| 22 |
+
parser.add_argument("--precisions", nargs="+", type=float, required=True,
|
| 23 |
+
help="CANNOT BE SPECIFIED WITH --precisions=...' "
|
| 24 |
+
"SYNTAX: MUST BE '--precisions PREC1 PREC2 ...'")
|
| 25 |
+
return parser.parse_known_args()[0]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def txt2xyxy(txt : str, conf=False) -> tensor:
|
| 29 |
+
"""Convert txt path to array of (cls, x1, y1, x2, y2[, conf])"""
|
| 30 |
+
a = tensor(genfromtxt(txt, ndmin=2))
|
| 31 |
+
if not len(a):
|
| 32 |
+
return empty(0, 5+int(conf))
|
| 33 |
+
a[:, 1:5] = xywh2xyxy(a[:, 1:5])
|
| 34 |
+
return a
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_gt(data_yaml : str) -> dict:
|
| 38 |
+
"""Obtain {"###.txt" : (Mx5 array)} dictionary mapping each data
|
| 39 |
+
sample to an array of ground truths. See get_pred().
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
cfg = check_det_dataset(data_yaml)
|
| 43 |
+
path = cfg.get('test', cfg['val'])
|
| 44 |
+
f = []
|
| 45 |
+
for p in path if isinstance(path, list) else [path]:
|
| 46 |
+
if isdir(p):
|
| 47 |
+
f += glob(join(p, "**", "*.*"), recursive=True)
|
| 48 |
+
else:
|
| 49 |
+
f += [t if isabs(t) else join(dirname(p), t)
|
| 50 |
+
for t in open(p).read().splitlines()]
|
| 51 |
+
print('getting gt')
|
| 52 |
+
return {basename(l): txt2xyxy(l, conf=False)
|
| 53 |
+
for l in img2label_paths(f)}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_pred(pred : str) -> dict:
|
| 57 |
+
"""from model output dir from validation, pred, obtain {"###.txt"
|
| 58 |
+
: (Mx6 array)} mapping each data sample to array of predictions (with
|
| 59 |
+
confidence) on that sample. See get_gt().
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
print('getting pred')
|
| 63 |
+
return {name : txt2xyxy(join(pred, name), conf=True)
|
| 64 |
+
for name in listdir(pred)}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
from yolov5.val import process_batch
|
| 68 |
+
# from torchvision.ops import box_iou
|
| 69 |
+
# def validate_preds(sample_pred, sample_gt):
|
| 70 |
+
# box_iou(sample_pred)
|
| 71 |
+
# correct = zeros(len(sample_pred)).astype(bool)
|
| 72 |
+
|
| 73 |
+
def get_tp_ngt(gt : dict, pred : dict) -> tuple:
|
| 74 |
+
"""From ground truth and prediction dictionaries (as given by
|
| 75 |
+
get_gt() and get_pred() funcs resp.), generate a single Mx3 array, tp,
|
| 76 |
+
for the entire dataset, as well as a dictionary, gt = { (class : int)
|
| 77 |
+
: (count : int) }, giving the count of each ground truth. The array
|
| 78 |
+
is interpreted as meaning there are M predictions denoted by (conf,
|
| 79 |
+
class, TP) giving the network predicted confidence, the network
|
| 80 |
+
predicted class, and a flag TP which is 1 if the sample is considered
|
| 81 |
+
a true positive.
|
| 82 |
+
|
| 83 |
+
"""
|
| 84 |
+
# after this point, we don't care about which pred came from which
|
| 85 |
+
# data sample in this data split
|
| 86 |
+
tp = cat([stack((pred[fname][:, 5], # conf
|
| 87 |
+
pred[fname][:, 0], # class
|
| 88 |
+
process_batch(pred[fname][:,[1,2,3,4,5,0]],
|
| 89 |
+
gt[fname],
|
| 90 |
+
tensor([.5])
|
| 91 |
+
).squeeze(1)), # TP
|
| 92 |
+
-1)
|
| 93 |
+
for fname in pred])
|
| 94 |
+
l = cat([gt[fname][:, 0] for fname in pred])
|
| 95 |
+
ngt = {int(c.item()) : (l == c).sum() for c in l.unique()}
|
| 96 |
+
return tp, ngt
|
| 97 |
+
|
| 98 |
+
from torch import cumsum, arange, linspace
|
| 99 |
+
from numpy import interp
|
| 100 |
+
from matplotlib.pyplot import (plot, legend, title, xlabel, ylabel,
|
| 101 |
+
savefig, clf, scatter, grid, xlim, ylim)
|
| 102 |
+
def get_aps(tp : tensor, ngt : dict, precisions : list, label : str,
|
| 103 |
+
project : str, glob_confs : [list, None] = None) -> list:
|
| 104 |
+
"""This is the main metrics AND plotting function. All other
|
| 105 |
+
functions exist to "wrangle" the data into an optimal format for this
|
| 106 |
+
function. From a 'tp' tensor and 'ngt' dict (see get_tp_ngt()),
|
| 107 |
+
compute various metrics, including the operating point at
|
| 108 |
+
'precisions'[c] for each class c. Plots are labeled and nammed based
|
| 109 |
+
on 'label', and placed in the output dir 'project'. Additionally, if
|
| 110 |
+
glob_confs is also given, plot the operating point at that confidence
|
| 111 |
+
threshold. Returns the confidence threshold corresponding to each
|
| 112 |
+
precision threshold in 'precisions'.
|
| 113 |
+
|
| 114 |
+
"""
|
| 115 |
+
# if there are fewer precision thresholds specified than classes
|
| 116 |
+
# present, and only one precision is specified, use that precision
|
| 117 |
+
# for all classes
|
| 118 |
+
if max(ngt) > len(precisions) - 1:
|
| 119 |
+
if len(precisions) == 1:
|
| 120 |
+
print("applying same precision to all classes")
|
| 121 |
+
precisions *= max(ngt)
|
| 122 |
+
else:
|
| 123 |
+
print("specified", len(precisions), "precisions, but have",
|
| 124 |
+
max(ngt)+1, "classes")
|
| 125 |
+
exit()
|
| 126 |
+
# Main loop. One for each class. AP calculated at the end
|
| 127 |
+
AP, confs, op_P, op_R, half_P, half_R = [], [], [], [], [], []
|
| 128 |
+
if glob_confs is not None: glob_P, glob_R = [], []
|
| 129 |
+
for cls, prec in enumerate(precisions):
|
| 130 |
+
print("For class:", cls)
|
| 131 |
+
|
| 132 |
+
# choose class and omit class field
|
| 133 |
+
selected = tp[tp[:,1] == cls][:,::2]
|
| 134 |
+
|
| 135 |
+
# sort descending
|
| 136 |
+
selected = selected[selected[:, 0].argsort(descending=True)]
|
| 137 |
+
|
| 138 |
+
# calculate PR values
|
| 139 |
+
assert len(selected.shape) == 2
|
| 140 |
+
tpcount = cumsum(selected[:,1], 0).numpy()
|
| 141 |
+
P = tpcount / arange(1, len(tpcount) + 1)
|
| 142 |
+
R = tpcount / ngt.get(cls, 0)
|
| 143 |
+
# enforce that P should be monotone
|
| 144 |
+
P = P.flip(0).cummax(0)[0].flip(0)
|
| 145 |
+
|
| 146 |
+
# calculate operating point from precision.
|
| 147 |
+
# operating index is where the precision last surpasses precision thld
|
| 148 |
+
# argmax on bool array returns first time condition is met.
|
| 149 |
+
# Precision is not monotone, so need to reverse, argmax, then find ind
|
| 150 |
+
assert len(P.shape) == 1
|
| 151 |
+
confs.append(selected[(P < prec).byte().argmax() -1, 0])
|
| 152 |
+
op_ind = (selected[:,0] <= confs[-1]).byte().argmax() - 1
|
| 153 |
+
op_P.append(P[op_ind])
|
| 154 |
+
op_R.append(R[op_ind])
|
| 155 |
+
print(f"Conf, Precision, Recall at operating point precision={prec}")
|
| 156 |
+
print(f"{confs[-1]:.6f}, {op_P[-1]:.6f}, {op_R[-1]:.6f}")
|
| 157 |
+
|
| 158 |
+
if glob_confs is not None:
|
| 159 |
+
# if glob threshold is passed, also find that PR point
|
| 160 |
+
glob_ind = (selected[:,0] <= glob_confs[cls]).byte().argmax() - 1
|
| 161 |
+
glob_P.append(P[glob_ind])
|
| 162 |
+
glob_R.append(R[glob_ind])
|
| 163 |
+
print("Conf, Precision, Recall at global operating point:")
|
| 164 |
+
print(f"""{glob_confs[cls]:.6f}, {glob_P[-1]
|
| 165 |
+
:.6f}, {glob_R[-1]:.6f}""")
|
| 166 |
+
|
| 167 |
+
# show .5 conf operating point
|
| 168 |
+
half_ind = (selected[:,0] <= .5).byte().argmax() - 1
|
| 169 |
+
half_P.append(P[half_ind])
|
| 170 |
+
half_R.append(R[half_ind])
|
| 171 |
+
print(f"Conf, Precision, Recall at C=.5 point")
|
| 172 |
+
print(f"{.5:.6f}, {half_P[-1]:.6f}, {half_R[-1]:.6f}")
|
| 173 |
+
|
| 174 |
+
# generate plotting points/AP calc points
|
| 175 |
+
Ri = linspace(0, 1, aP_curve_points)
|
| 176 |
+
Pi = interp(Ri, R, P)
|
| 177 |
+
# use these values for AP calc over raw to avoid machine error
|
| 178 |
+
AP.append(Pi.sum() / aP_curve_points)
|
| 179 |
+
print("class AP:", AP[-1].item(), end="\n\n")
|
| 180 |
+
plot(Ri, Pi, label=f"{cls}: AP={AP[-1]:.6f}")
|
| 181 |
+
|
| 182 |
+
# calculate mAP
|
| 183 |
+
mAP = sum(AP)/len(AP)
|
| 184 |
+
print("mAP:", mAP, end="\n\n\n")
|
| 185 |
+
title(f"{basename(label)} mAP={mAP:.6f}")
|
| 186 |
+
|
| 187 |
+
# plot other points
|
| 188 |
+
scatter(op_R, op_P, label="precision operating point")
|
| 189 |
+
scatter(half_R, half_P, label=".5 conf")
|
| 190 |
+
if glob_confs is not None:
|
| 191 |
+
scatter(glob_R, glob_P, label="global operating point")
|
| 192 |
+
|
| 193 |
+
# save plot
|
| 194 |
+
legend()
|
| 195 |
+
xlabel("Recall")
|
| 196 |
+
ylabel("Precision")
|
| 197 |
+
grid()
|
| 198 |
+
xlim(0, 1)
|
| 199 |
+
ylim(0, 1)
|
| 200 |
+
savefig(join(project, f"{basename(label)}.png"))
|
| 201 |
+
clf()
|
| 202 |
+
|
| 203 |
+
return confs
|
| 204 |
+
|
| 205 |
+
def metrics(data_yamls : list, out_dirs : list, precisions : list,
|
| 206 |
+
project : str) -> None:
|
| 207 |
+
"""High level function for computing metrics and generating plots
|
| 208 |
+
for the combined data plus each data split. Requires list of data
|
| 209 |
+
yamls, data_yamls, model output dirs, out_dirs, classwise precision
|
| 210 |
+
thresholds, precisions, and output dir, project.
|
| 211 |
+
|
| 212 |
+
"""
|
| 213 |
+
tp_ngt = {}
|
| 214 |
+
for data_yaml, out_dir in zip(data_yamls, out_dirs):
|
| 215 |
+
tp_ngt[data_yaml] = get_tp_ngt(get_gt(data_yaml),
|
| 216 |
+
get_pred(join(out_dir, 'labels')))
|
| 217 |
+
print("Done reading results. Results across all data yamls:", end="\n\n")
|
| 218 |
+
confs = get_aps(cat([tp for tp, _ in tp_ngt.values()]),
|
| 219 |
+
{c : sum(ngt.get(c, 0) for _, ngt in tp_ngt.values())
|
| 220 |
+
for c in set.union(*(set(ngt.keys())
|
| 221 |
+
for _, ngt in tp_ngt.values()))},
|
| 222 |
+
precisions,
|
| 223 |
+
"all",
|
| 224 |
+
project)
|
| 225 |
+
if len(tp_ngt) == 1:
|
| 226 |
+
return
|
| 227 |
+
for data_yaml, (tp, ngt) in tp_ngt.items():
|
| 228 |
+
print("Results for", data_yaml, end="\n\n")
|
| 229 |
+
get_aps(tp, ngt, precisions, data_yaml, project, confs)
|
| 230 |
+
|
| 231 |
+
from sys import argv
|
| 232 |
+
from synet.__main__ import main as synet_main
|
| 233 |
+
def run(data_yamls : list, out_dirs : [list, None], print_jobs : bool,
|
| 234 |
+
precisions : list, project : [str, None], name : None):
|
| 235 |
+
"""Entrypoint function. Compute metrics of model on data_yamls.
|
| 236 |
+
|
| 237 |
+
If out_dirs is specified, it should be a list of output directories
|
| 238 |
+
used for validation runs on the datasets specified by data_yamls (in
|
| 239 |
+
the same order).
|
| 240 |
+
|
| 241 |
+
If out_dirs is not specified, then all necessary validation args
|
| 242 |
+
should be specified in command-line args (sys.argv). In this case, it
|
| 243 |
+
will run validation of your model on each specified data yaml before
|
| 244 |
+
attempting to compute metrics.
|
| 245 |
+
|
| 246 |
+
If print_jobs is specified, then the commands to run the various
|
| 247 |
+
validation jobs are printed instead. This is useful if you would like
|
| 248 |
+
to run the validation jobs in parallel.
|
| 249 |
+
|
| 250 |
+
If project is specified, this will be used as the base output
|
| 251 |
+
directory for plots and generated validation jobs.
|
| 252 |
+
|
| 253 |
+
name should never be specified. validation job names are generated by
|
| 254 |
+
this function, so you must not try to specify your own.
|
| 255 |
+
|
| 256 |
+
precisions is a list of precision thresholds. This is used as an
|
| 257 |
+
operating point which is also reported by the metrics here. It is
|
| 258 |
+
either one value (used for all classes), or a list of values
|
| 259 |
+
correspoinding to the labels in order.
|
| 260 |
+
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
# decide output dir
|
| 264 |
+
assert name is None, "--name specified by metrics. Do not specify"
|
| 265 |
+
makedirs(project, exist_ok=True)
|
| 266 |
+
if project is None:
|
| 267 |
+
project = "metrics"
|
| 268 |
+
argv.append(f"--project={project}")
|
| 269 |
+
|
| 270 |
+
# if val was already run, just compute metrics
|
| 271 |
+
if out_dirs is not None:
|
| 272 |
+
assert len(out_dirs) == len(data_yamls), \
|
| 273 |
+
"Please specify one output for each data yaml"
|
| 274 |
+
print("Using prerun results to compute metrcis")
|
| 275 |
+
return metrics(data_yamls, out_dirs, precisions, project)
|
| 276 |
+
|
| 277 |
+
## modify argv
|
| 278 |
+
# add necessary flags
|
| 279 |
+
for flag in "--save-conf", '--save-txt', '--exist-ok':
|
| 280 |
+
if flag not in argv:
|
| 281 |
+
argv.append(flag)
|
| 282 |
+
# remove precisions flag from args
|
| 283 |
+
argv.remove("--precisions")
|
| 284 |
+
if print_jobs:
|
| 285 |
+
argv.remove("--print-jobs")
|
| 286 |
+
rm = [arg for precison in precisions for arg in argv
|
| 287 |
+
if arg.isnumeric() and -.0001 <= float(arg) - precision <= .0001]
|
| 288 |
+
for r in rm:
|
| 289 |
+
if r in argv: argv.remove(r)
|
| 290 |
+
# remove data yamls from args
|
| 291 |
+
for data_yaml in data_yamls: argv.remove(data_yaml)
|
| 292 |
+
# run validation
|
| 293 |
+
argv.insert(1, "val")
|
| 294 |
+
|
| 295 |
+
## generate val jobs
|
| 296 |
+
if print_jobs:
|
| 297 |
+
print("Submit the following jobs:")
|
| 298 |
+
out_dirs = []
|
| 299 |
+
for i, data_yaml in enumerate(data_yamls):
|
| 300 |
+
# specify data and out dir.
|
| 301 |
+
flags = f"--data={data_yaml}", f"--name=data-split{i}"
|
| 302 |
+
argv.extend(flags)
|
| 303 |
+
# run/print the job
|
| 304 |
+
print(" ".join(argv))
|
| 305 |
+
if not print_jobs:
|
| 306 |
+
print("starting job")
|
| 307 |
+
synet_main()
|
| 308 |
+
# main removes job type from argv, re-add it
|
| 309 |
+
argv.insert(1, "val")
|
| 310 |
+
# revert argv for next job
|
| 311 |
+
for flag in flags:
|
| 312 |
+
argv.remove(flag)
|
| 313 |
+
# keep track of output dirs in order
|
| 314 |
+
out_dirs.append(join(project, f"data-split{i}"))
|
| 315 |
+
|
| 316 |
+
## calculate metrics
|
| 317 |
+
if print_jobs:
|
| 318 |
+
print("Once jobs finish, run:")
|
| 319 |
+
|
| 320 |
+
print(" ".join([argv[0], "metrics", *data_yamls, "--out-dirs",
|
| 321 |
+
*out_dirs, "--precisions", *(str(prec)
|
| 322 |
+
for prec in precisions)]))
|
| 323 |
+
if not print_jobs:
|
| 324 |
+
print("computing metrics")
|
| 325 |
+
return metrics(data_yamls, out_dirs, precisions, project)
|
| 326 |
+
|
| 327 |
+
def main():
|
| 328 |
+
run(**vars(parse_opt()))
|
synet/quantize.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
from glob import glob
|
| 4 |
+
from os.path import dirname, isabs, isdir, join, splitext
|
| 5 |
+
from random import shuffle
|
| 6 |
+
|
| 7 |
+
from cv2 import imread, resize
|
| 8 |
+
from keras import Input, Model
|
| 9 |
+
from numpy import float32
|
| 10 |
+
from numpy.random import rand
|
| 11 |
+
from tensorflow import int8, lite
|
| 12 |
+
from torch import no_grad
|
| 13 |
+
|
| 14 |
+
from .base import askeras
|
| 15 |
+
from .backends import get_backend
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_opt(args=None):
|
| 19 |
+
"""parse_opt() is used to make it compatible with how yolov5
|
| 20 |
+
obtains arguments.
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
parser = ArgumentParser()
|
| 24 |
+
parser.add_argument("--backend", type=get_backend,
|
| 25 |
+
default=get_backend('ultralytics'))
|
| 26 |
+
parser.add_argument("--model", "--cfg", '--weights')
|
| 27 |
+
parser.add_argument("--image-shape", nargs=2, type=int)
|
| 28 |
+
parser.add_argument("--data")
|
| 29 |
+
parser.add_argument("--kwds", nargs="+", default=[])
|
| 30 |
+
parser.add_argument("--channels", "-c", default=3, type=int)
|
| 31 |
+
parser.add_argument("--number", "-n", default=500, type=int)
|
| 32 |
+
parser.add_argument("--val-post",
|
| 33 |
+
help="path to sample image to validate on.")
|
| 34 |
+
parser.add_argument("--tflite",
|
| 35 |
+
help="path to existing tflite (for validating).")
|
| 36 |
+
return parser.parse_args(args=args)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def run(backend, image_shape, model, data, number, channels, kwds,
|
| 40 |
+
val_post, tflite):
|
| 41 |
+
"""Entrypoint to quantize.py. Quantize the model specified by
|
| 42 |
+
weights (falling back to cfg), using samples from the data yaml with
|
| 43 |
+
image shape image_shape, using only number samples.
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
backend.patch()
|
| 47 |
+
model = backend.maybe_grab_from_zoo(model)
|
| 48 |
+
|
| 49 |
+
if tflite is None:
|
| 50 |
+
tflite = get_tflite(backend, image_shape, model, data,
|
| 51 |
+
number, channels, kwds)
|
| 52 |
+
|
| 53 |
+
if val_post:
|
| 54 |
+
backend.val_post(model, tflite, val_post, image_shape=image_shape)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_tflite(backend, image_shape, model_path, data, number,
|
| 58 |
+
channels, kwds):
|
| 59 |
+
|
| 60 |
+
# maybe get image shape
|
| 61 |
+
if image_shape is None:
|
| 62 |
+
image_shape = backend.get_shape(model_path)
|
| 63 |
+
|
| 64 |
+
# generate keras model
|
| 65 |
+
ptmodel = backend.get_model(model_path)
|
| 66 |
+
inp = Input(image_shape+[channels], batch_size=1)
|
| 67 |
+
with askeras(imgsz=image_shape, quant_export=True,
|
| 68 |
+
**dict(s.split("=") for s in kwds)), \
|
| 69 |
+
no_grad():
|
| 70 |
+
kmodel = Model(inp, ptmodel(inp))
|
| 71 |
+
|
| 72 |
+
print('model params:', kmodel.count_params())
|
| 73 |
+
|
| 74 |
+
# quantize the model
|
| 75 |
+
return quantize(kmodel, data, image_shape, number,
|
| 76 |
+
splitext(model_path)[0]+".tflite",
|
| 77 |
+
channels, backend=backend)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def quantize(kmodel, data, image_shape, N=500, out_path=None, channels=1,
|
| 81 |
+
generator=None, backend=None):
|
| 82 |
+
"""Given a keras model, kmodel, and data yaml at data, quantize
|
| 83 |
+
using N samples reshaped to image_shape and place the output model at
|
| 84 |
+
out_path.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
# more or less boilerplate code
|
| 88 |
+
converter = lite.TFLiteConverter.from_keras_model(kmodel)
|
| 89 |
+
converter.optimizations = [lite.Optimize.DEFAULT]
|
| 90 |
+
converter.inference_input_type = int8
|
| 91 |
+
converter.inference_output_type = int8
|
| 92 |
+
|
| 93 |
+
if generator:
|
| 94 |
+
converter.representative_dataset = generator
|
| 95 |
+
elif data is None:
|
| 96 |
+
converter.representative_dataset = \
|
| 97 |
+
lambda: phony_data(image_shape, channels)
|
| 98 |
+
else:
|
| 99 |
+
converter.representative_dataset = \
|
| 100 |
+
lambda: representative_data(backend.get_data(data), image_shape, N, channels)
|
| 101 |
+
|
| 102 |
+
# quantize
|
| 103 |
+
tflite_quant_model = converter.convert()
|
| 104 |
+
|
| 105 |
+
# write out tflite
|
| 106 |
+
if out_path:
|
| 107 |
+
with open(out_path, "wb") as f:
|
| 108 |
+
f.write(tflite_quant_model)
|
| 109 |
+
|
| 110 |
+
return tflite_quant_model
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def representative_data(data, image_shape, N, channels):
|
| 114 |
+
"""Obtains dataset from data, samples N samples, and returns those
|
| 115 |
+
samples reshaped to image_shape.
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
path = data.get('test', data['val'])
|
| 119 |
+
f = []
|
| 120 |
+
for p in path if isinstance(path, list) else [path]:
|
| 121 |
+
if isdir(p):
|
| 122 |
+
f += glob(join(p, "**", "*.*"), recursive=True)
|
| 123 |
+
else:
|
| 124 |
+
f += [t if isabs(t) else join(dirname(p), t)
|
| 125 |
+
for t in open(p).read().splitlines()]
|
| 126 |
+
shuffle(f)
|
| 127 |
+
for fpth in f[:N]:
|
| 128 |
+
im = imread(fpth)
|
| 129 |
+
if im.shape[0] != image_shape[0] or im.shape[1] != image_shape[1]:
|
| 130 |
+
im = resize(im, image_shape[::-1])
|
| 131 |
+
if im.shape[-1] != channels:
|
| 132 |
+
assert channels == 1
|
| 133 |
+
im = im.mean(-1, keepdims=True)
|
| 134 |
+
yield [im[None].astype(float32) / 255]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def phony_data(image_shape, channels):
|
| 138 |
+
for _ in range(2):
|
| 139 |
+
yield [rand(1, *image_shape, channels).astype(float32)]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def main(args=None):
|
| 143 |
+
return run(**vars(parse_opt(args)))
|
synet/sabre.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""katana.py includes imports which are compatible with Katana, and
|
| 2 |
+
layer definitions that are only compatible with Katana. However,
|
| 3 |
+
Katana's capabilities are currently a subset of all other chip's
|
| 4 |
+
capabilities, so it includes only imports for now."""
|
| 5 |
+
|
| 6 |
+
from .layers import Conv2dInvertedResidual, Head
|
| 7 |
+
from .base import askeras, Conv2d, Cat, ReLU, BatchNorm, Grayscale
|
synet/tflite_utils.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""This module exists to hold all tflite related processing. The main
|
| 3 |
+
benefit of keeping this in a seperate modules is so that large
|
| 4 |
+
dependencies (like ultralytics) need not be imported when simulating
|
| 5 |
+
tflite execution (like for demos). However, visualization
|
| 6 |
+
(interpretation} of the model is left to ultralytics. This module
|
| 7 |
+
also serves as a reference for C and/or other implementations;
|
| 8 |
+
however, do read any "Notes" sections in any function docstrings
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from argparse import ArgumentParser
|
| 13 |
+
from typing import Optional, List
|
| 14 |
+
|
| 15 |
+
from cv2 import (imread, rectangle, addWeighted, imwrite, resize,
|
| 16 |
+
circle, putText, FONT_HERSHEY_TRIPLEX)
|
| 17 |
+
from numpy import (newaxis, ndarray, int8, float32 as npfloat32,
|
| 18 |
+
concatenate as cat, max as npmax, argmax, empty,
|
| 19 |
+
array)
|
| 20 |
+
from tensorflow import lite
|
| 21 |
+
from torch import tensor, float32 as torchfloat32, sigmoid, tensordot, \
|
| 22 |
+
repeat_interleave
|
| 23 |
+
from torchvision.ops import nms
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parse_opt(args=None):
|
| 27 |
+
parser = ArgumentParser()
|
| 28 |
+
parser.add_argument('tflite')
|
| 29 |
+
parser.add_argument('img')
|
| 30 |
+
parser.add_argument('--conf-thresh', type=float, default=.25)
|
| 31 |
+
parser.add_argument('--iou-thresh', type=float, default=.5)
|
| 32 |
+
parser.add_argument('--backend', default='ultralytics')
|
| 33 |
+
parser.add_argument('--task', default='segment')
|
| 34 |
+
parser.add_argument('--image-shape', nargs=2, type=int, default=None)
|
| 35 |
+
return parser.parse_args(args=args)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def tf_run(tflite, img, conf_thresh=.25, iou_thresh=.7,
|
| 39 |
+
backend='ultralytics', task='segment', image_shape=None):
|
| 40 |
+
"""Run a tflite model on an image, including post-processing.
|
| 41 |
+
|
| 42 |
+
Loads the tflite, loads the image, preprocesses the image,
|
| 43 |
+
evaluates the tflite on the pre-processed image, and performs
|
| 44 |
+
|
| 45 |
+
post-processing on the tflite output with a given confidence and
|
| 46 |
+
iou threshold.
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
tflite : str or buffer
|
| 51 |
+
Path to tflite file, or a raw tflite buffer
|
| 52 |
+
img : str or ndarray
|
| 53 |
+
Path to image to evaluate on, or the image as read by cv2.imread.
|
| 54 |
+
conf_thresh : float
|
| 55 |
+
Confidence threshold applied before NMS
|
| 56 |
+
iou_thresh : float
|
| 57 |
+
IoU threshold for NMS
|
| 58 |
+
backend : {"ultralytics"}
|
| 59 |
+
The backend which is used. For now, only "ultralytics" is supported.
|
| 60 |
+
task : {"classify", "detect", "segment", "pose"}
|
| 61 |
+
The computer vision task to which the tflite model corresponds.
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
ndarray or tuple of ndarrys
|
| 66 |
+
Return the result of running preprocessing, tflite evaluation,
|
| 67 |
+
and postprocessing on the input image. Segmentation models
|
| 68 |
+
produce two outputs as a tuple.
|
| 69 |
+
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# initialize tflite interpreter.
|
| 73 |
+
interpreter = lite.Interpreter(
|
| 74 |
+
**{"model_path" if isinstance(tflite, str)
|
| 75 |
+
else "model_content": tflite,
|
| 76 |
+
'experimental_op_resolver_type':
|
| 77 |
+
lite.experimental.OpResolverType.BUILTIN_REF}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# read the image if given as path
|
| 81 |
+
if isinstance(img, str):
|
| 82 |
+
img = imread(img)
|
| 83 |
+
|
| 84 |
+
if image_shape is not None:
|
| 85 |
+
img = resize(img, image_shape[::-1])
|
| 86 |
+
|
| 87 |
+
# make image RGB (not BGR) channel order, BCHW dimensions, and
|
| 88 |
+
# in the range [0, 1]. cv2's imread reads in BGR channel
|
| 89 |
+
# order, with dimensions in Height, Width, Channel order.
|
| 90 |
+
# Also, imread keeps images as integers in [0,255]. Normalize
|
| 91 |
+
# to floats in [0, 1]. Also, model expects a batch dimension,
|
| 92 |
+
# so add a dimension at the beginning
|
| 93 |
+
img = img[newaxis, ..., ::-1] / 255
|
| 94 |
+
# FW TEAM NOTE: It might be strange converting to float here, but
|
| 95 |
+
# the model might have been quantized to use a subset of the [0,1]
|
| 96 |
+
# range, i.e. 220 could map to 255
|
| 97 |
+
|
| 98 |
+
# Run tflite interpreter on the input image
|
| 99 |
+
preds = run_interpreter(interpreter, img)
|
| 100 |
+
|
| 101 |
+
if task == 'classify':
|
| 102 |
+
return preds
|
| 103 |
+
|
| 104 |
+
# Procces the tflite output to be one tensor
|
| 105 |
+
preds = concat_reshape(preds, task)
|
| 106 |
+
|
| 107 |
+
# perform nms
|
| 108 |
+
return apply_nms(preds, conf_thresh, iou_thresh)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def run_interpreter(interpreter: Optional[lite.Interpreter],
|
| 112 |
+
input_arr: ndarray) -> List[ndarray]:
|
| 113 |
+
"""Evaluating tflite interpreter on input data
|
| 114 |
+
|
| 115 |
+
Parameters
|
| 116 |
+
----------
|
| 117 |
+
interpreter : Interpreter
|
| 118 |
+
the tflite interpreter to run
|
| 119 |
+
input_arr : 4d ndarray
|
| 120 |
+
tflite model input with shape (batch, height, width, channels)
|
| 121 |
+
|
| 122 |
+
Returns
|
| 123 |
+
-------
|
| 124 |
+
list
|
| 125 |
+
List of output arrays from running interpreter. The order and
|
| 126 |
+
content of the output is specific to the task and if model
|
| 127 |
+
outputs xywh or xyxy.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
interpreter.allocate_tensors()
|
| 131 |
+
in_scale, in_zero = interpreter.get_input_details()[0]['quantization']
|
| 132 |
+
out_scale_zero_index = [(*detail['quantization'], detail['index'])
|
| 133 |
+
for detail in
|
| 134 |
+
sorted(interpreter.get_output_details(),
|
| 135 |
+
key=lambda x:x['name'])]
|
| 136 |
+
# run tflite on image
|
| 137 |
+
assert interpreter.get_input_details()[0]['index'] == 0
|
| 138 |
+
assert interpreter.get_input_details()[0]['dtype'] is int8
|
| 139 |
+
interpreter.set_tensor(0, (input_arr / in_scale + in_zero).astype(int8))
|
| 140 |
+
interpreter.invoke()
|
| 141 |
+
# indexing below with [0] removes the batch dimension, which is always 1.
|
| 142 |
+
return [(interpreter.get_tensor(index)[0].astype(npfloat32) - zero) * scale
|
| 143 |
+
for scale, zero, index in out_scale_zero_index]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def concat_reshape(model_output: List[ndarray],
|
| 147 |
+
task: str,
|
| 148 |
+
xywh: Optional[bool] = False,
|
| 149 |
+
classes_to_index: Optional[bool] = True
|
| 150 |
+
) -> ndarray:
|
| 151 |
+
"""Concatenate, reshape, and transpose model output to match pytorch.
|
| 152 |
+
|
| 153 |
+
This method reordering the tflite output structure to be fit to run
|
| 154 |
+
post process such as NMS etc.
|
| 155 |
+
|
| 156 |
+
Parameters
|
| 157 |
+
----------
|
| 158 |
+
model_output : list
|
| 159 |
+
Output from running tflite.
|
| 160 |
+
task : {"classify", "detect", "segment", "pose"}
|
| 161 |
+
The task the model performs.
|
| 162 |
+
xywh : bool, default=False
|
| 163 |
+
If true, model output should be converted to xywh. Only use for
|
| 164 |
+
python evaluation.
|
| 165 |
+
classes_to_index : bool, default=True
|
| 166 |
+
If true, convert the classes output logits to single class index
|
| 167 |
+
|
| 168 |
+
Returns
|
| 169 |
+
-------
|
| 170 |
+
ndarray or list
|
| 171 |
+
Final output after concatenating and reshaping input. Returns
|
| 172 |
+
an ndarray for every task except "segment" which returns a
|
| 173 |
+
tupule of two arrays.
|
| 174 |
+
|
| 175 |
+
Notes
|
| 176 |
+
-----
|
| 177 |
+
The python implementation here concats all output before applying
|
| 178 |
+
nms. This is to mirror the original pytorch implementation. For
|
| 179 |
+
a more efficient implementation, you may want to perform
|
| 180 |
+
confidence thresholding and nms on the boxes and scores, masking
|
| 181 |
+
other tensor appropriately, before reshaping and concatenating.
|
| 182 |
+
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
# interperate input tuple of tensors based on task. Though
|
| 186 |
+
# produced tflite always have output names like
|
| 187 |
+
# "StatefulPartitionedCall:0", the numbers at the end are infact
|
| 188 |
+
# alphabetically ordered by the final layer name for each output,
|
| 189 |
+
# even though those names are discarded. Hence, the following
|
| 190 |
+
# variables are nammed to match the corresponding output layer
|
| 191 |
+
# name and always appear in alphabetical order.
|
| 192 |
+
if task == "pose":
|
| 193 |
+
box1, box2, cls, kpts, pres = model_output
|
| 194 |
+
_, num_kpts, _ = kpts.shape
|
| 195 |
+
if task == "segment":
|
| 196 |
+
box1, box2, cls, proto, seg = model_output
|
| 197 |
+
if task == "detect":
|
| 198 |
+
box1, box2, cls = model_output
|
| 199 |
+
|
| 200 |
+
# obtain class confidences.
|
| 201 |
+
if classes_to_index:
|
| 202 |
+
# for yolov5, treat objectness seperately
|
| 203 |
+
cls = (npmax(cls, axis=1, keepdims=True),
|
| 204 |
+
argmax(cls, axis=1, keepdims=True))
|
| 205 |
+
else:
|
| 206 |
+
cls = cls,
|
| 207 |
+
|
| 208 |
+
# xywh is only necessary for python evaluation.
|
| 209 |
+
if xywh:
|
| 210 |
+
bbox_xy_center = (box1 + box2) / 2
|
| 211 |
+
bbox_wh = box2 - box1
|
| 212 |
+
bbox = cat([bbox_xy_center, bbox_wh], -1)
|
| 213 |
+
else:
|
| 214 |
+
bbox = cat([box1, box2], -1)
|
| 215 |
+
|
| 216 |
+
# return final concatenated output
|
| 217 |
+
# FW TEAM NOTE: Though this procedure creates output consistent
|
| 218 |
+
# with the original pytorch behavior of these models, you probably
|
| 219 |
+
# want to do something more clever, i.e. perform NMS reading from
|
| 220 |
+
# the arrays without concatenating. At the very least, maybe do a
|
| 221 |
+
# confidence filter before trying to copy the full tensors. Also,
|
| 222 |
+
# future models might have several times the output size, so keep
|
| 223 |
+
# that in mind.
|
| 224 |
+
if task == "segment":
|
| 225 |
+
# FW TEAM NOTE: the second element here move channel axis to
|
| 226 |
+
# beginning in line with pytorch behavior. Maybe not relevent.
|
| 227 |
+
|
| 228 |
+
# FW TEAM NOTE: the proto array is HUGE (HxWx64). You
|
| 229 |
+
# probably want to compute individual instance masks for your
|
| 230 |
+
# implementation. See the YOLACT paper on arxiv.org:
|
| 231 |
+
# https://arxiv.org/abs/1904.02689. Basically, for each
|
| 232 |
+
# instance that survives NMS, generate the segmentation (only
|
| 233 |
+
# HxW for each instance) by taking the iner product of seg
|
| 234 |
+
# with each pixel in proto.
|
| 235 |
+
return cat((bbox, *cls, seg), axis=-1), proto
|
| 236 |
+
if task == 'pose':
|
| 237 |
+
return cat((bbox, *cls, cat((kpts, pres), -1
|
| 238 |
+
).reshape(-1, num_kpts * 3)),
|
| 239 |
+
axis=-1)
|
| 240 |
+
if task == 'detect':
|
| 241 |
+
return cat((bbox, *cls), axis=-1)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def apply_nms(preds: ndarray, conf_thresh: float, iou_thresh: float):
|
| 245 |
+
"""Apply NMS on ndarray prepared model output
|
| 246 |
+
|
| 247 |
+
preds : ndarray or tuple of ndarray
|
| 248 |
+
prepared model output. Is a tuple of two arrays for "segment" task
|
| 249 |
+
conf_thresh : float
|
| 250 |
+
confidence threshold applied before NMS.
|
| 251 |
+
|
| 252 |
+
Returns
|
| 253 |
+
-------
|
| 254 |
+
ndarray or tuple of ndarray
|
| 255 |
+
same structure as preds, but with some values suppressed (removed).
|
| 256 |
+
|
| 257 |
+
Notes
|
| 258 |
+
-----
|
| 259 |
+
This function converts ndarrays to pytorch tensors for two reasons:
|
| 260 |
+
- the nms code requires torch tensor inputs
|
| 261 |
+
- the output format becomes identical to that used by
|
| 262 |
+
ultralytics, and so can be passed to an ultralytics visualizer.
|
| 263 |
+
|
| 264 |
+
Also, as mentioned in the concat_reshape function, you may want to
|
| 265 |
+
perform nms and thresholding before combining all the output.
|
| 266 |
+
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
# THIS FUNCTION IS CURRENTLY HARD-CODED FOR SINGLE CLASS
|
| 270 |
+
|
| 271 |
+
# segmentation task returns a tuple of (preds, proto)
|
| 272 |
+
if isinstance(preds, tuple):
|
| 273 |
+
is_tuple = True
|
| 274 |
+
preds, proto = preds
|
| 275 |
+
else:
|
| 276 |
+
is_tuple = False
|
| 277 |
+
|
| 278 |
+
# perform confidence thresholding, and convert to tensor for nms.
|
| 279 |
+
# The trickiness here is that yolov5 has an objectness score plus
|
| 280 |
+
# per-class probabilities while ultralytics has just per-class
|
| 281 |
+
# scores. Yolov5 uses objectness for confidence thresholding, but
|
| 282 |
+
# then uses objectness * per-class probablities for confidences
|
| 283 |
+
# therafter.
|
| 284 |
+
preds = tensor(preds[preds[:, 4] > conf_thresh], dtype=torchfloat32)
|
| 285 |
+
|
| 286 |
+
# Perform NMS
|
| 287 |
+
# https://pytorch.org/vision/stable/generated/torchvision.ops.nms.html
|
| 288 |
+
preds = preds[nms(preds[:, :4], preds[:, 4], iou_thresh)]
|
| 289 |
+
return (preds, tensor(proto)) if is_tuple else preds
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def build_masks(preds, proto):
|
| 293 |
+
# contract mask embeddings with proto
|
| 294 |
+
# ( N x k dot k x h x w )
|
| 295 |
+
masks = sigmoid(tensordot(preds[:, 6:], proto, dims=1))
|
| 296 |
+
# upsamle mask
|
| 297 |
+
for dim in 1, 2:
|
| 298 |
+
masks = repeat_interleave(masks, repeats=8, dim=dim)
|
| 299 |
+
# clip mask to box
|
| 300 |
+
for (x1, y1, x2, y2), mask in zip(preds[:, :4], masks):
|
| 301 |
+
# integer math may be off-by-one near boarder in this
|
| 302 |
+
# application.
|
| 303 |
+
mask[:int(y1)] = 0
|
| 304 |
+
mask[int(y2):] = 0
|
| 305 |
+
mask[:, :int(x1)] = 0
|
| 306 |
+
mask[:, int(x2):] = 0
|
| 307 |
+
return preds[:, :6], masks
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def main(args=None):
|
| 311 |
+
# read options, run model
|
| 312 |
+
opt = parse_opt(args)
|
| 313 |
+
img = opt.img = imread(opt.img)
|
| 314 |
+
if opt.image_shape is not None:
|
| 315 |
+
img = opt.img = resize(opt.img, opt.image_shape[::-1])
|
| 316 |
+
opt.image_shape = None
|
| 317 |
+
|
| 318 |
+
preds = tf_run(**vars(opt))
|
| 319 |
+
if opt.task == 'segment':
|
| 320 |
+
preds, masks = build_masks(*preds)
|
| 321 |
+
|
| 322 |
+
# shrink mask if upsamle gave too large of area
|
| 323 |
+
for imdim, maskdim in (0, 1), (1, 2):
|
| 324 |
+
extra, carry = divmod(masks.shape[maskdim] - img.shape[imdim], 2)
|
| 325 |
+
if extra == carry == 0:
|
| 326 |
+
continue
|
| 327 |
+
masks = masks[(*(slice(None),)*maskdim, slice(extra, -(extra+carry)))]
|
| 328 |
+
|
| 329 |
+
# visualize masks and rectangles
|
| 330 |
+
img_overlay = img.copy()
|
| 331 |
+
img_overlay[masks.max(0).values > .5] = (0, 255, 0)
|
| 332 |
+
img = addWeighted(img, .5, img_overlay, .5, 0)
|
| 333 |
+
elif opt.task == 'pose':
|
| 334 |
+
for x1, y1, x2, y2, conf, cls, *kpts in preds:
|
| 335 |
+
for x, y, p in zip(kpts[0::3], kpts[1::3], kpts[2::3]):
|
| 336 |
+
if p > .5:
|
| 337 |
+
circle(img, (int(x), int(y)), 3, (255, 0, 0), -1)
|
| 338 |
+
elif opt.task != 'classify':
|
| 339 |
+
for x1, y1, x2, y2, *cls in preds:
|
| 340 |
+
rectangle(img, (int(x1), int(y1)),
|
| 341 |
+
(int(x2), int(y2)),
|
| 342 |
+
(0, 0, 255), 2)
|
| 343 |
+
elif opt.task == 'classify':
|
| 344 |
+
putText(img, str(*preds), (20, 40), FONT_HERSHEY_TRIPLEX, 1.0, (0, 0, 0))
|
| 345 |
+
imwrite(opt.task+'.png', img)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == '__main__':
|
| 349 |
+
main()
|
synet/ultralytics_patches.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backends.ultralytics import *
|
| 2 |
+
print("WARNING: ultralytics_patches.py exists for backwards model "
|
| 3 |
+
"compatibility only. Do not import this module if possible.")
|
synet/zoo/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import listdir
|
| 2 |
+
from os.path import abspath, dirname, join, isfile, commonpath
|
| 3 |
+
from urllib import request
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
WEIGHT_URL_ROOT = "http://profiler/"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def in_zoo(model, backend):
|
| 10 |
+
"""Return True if model refers to something in the SyNet zoo."""
|
| 11 |
+
# check if absolute path to model in the zoo was given
|
| 12 |
+
if isfile(model):
|
| 13 |
+
return dirname(__file__) == commonpath((__file__, abspath(model)))
|
| 14 |
+
# otherwise check if name is relative to zoo dir.
|
| 15 |
+
return isfile(join(dirname(__file__), backend, model))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_config(model, backend):
|
| 19 |
+
"""Return the path to a model. Check the zoo if necessary."""
|
| 20 |
+
if isfile(model):
|
| 21 |
+
return model
|
| 22 |
+
return join(dirname(__file__), backend, model)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_weights(model, backend):
|
| 26 |
+
if isfile(model):
|
| 27 |
+
return model
|
| 28 |
+
with request.urlopen(join(WEIGHT_URL_ROOT, backend, model)) as remotefile:
|
| 29 |
+
with open(model, 'wb') as localfile:
|
| 30 |
+
localfile.write(remotefile.read())
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_configs(backend):
|
| 35 |
+
return listdir(join(dirname(__file__), backend))
|
synet/zoo/ultralytics/sabre-detect-vga.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nc: 1 # number of classes
|
| 2 |
+
#kpt_shape: [17, 3]
|
| 3 |
+
depth_multiple: 1 # model depth multiple
|
| 4 |
+
width_multiple: 1 # layer channel multiple
|
| 5 |
+
chip: sabre
|
| 6 |
+
image_shape: [480, 640]
|
| 7 |
+
# anchors:
|
| 8 |
+
# # autogenerated by yolo
|
| 9 |
+
# - [0,0, 0,0, 0,0] # P3/8
|
| 10 |
+
# - [0,0, 0,0, 0,0] # P4/16
|
| 11 |
+
# - [0,0, 0,0, 0,0] # P5/32
|
| 12 |
+
backbone:
|
| 13 |
+
# [from, number, module, args]
|
| 14 |
+
#src num layer params id rf notes
|
| 15 |
+
[[-1, 1, InvertedResidual, [3, 4, 12, 2]], # 0 c1 1 stride -> 2
|
| 16 |
+
[-1, 1, InvertedResidual, [12, 4, 48, 2]], # 1 c2 3 stride -> 4
|
| 17 |
+
[-1, 1, InvertedResidual, [48, 4, 48, 2]], # 2 7 stride -> 8
|
| 18 |
+
[-1, 2, InvertedResidual, [48, 6, 48]], # 3 c3
|
| 19 |
+
[-1, 1, InvertedResidual, [48, 5, 64, 2]], # 4 15 stride -> 16
|
| 20 |
+
[-1, 2, InvertedResidual, [64, 4, 64]], # 5 c4 47
|
| 21 |
+
[-1, 1, InvertedResidual, [64, 3, 64, 2]], # 6 63 stride -> 32
|
| 22 |
+
[-1, 2, InvertedResidual, [64, 2]]] # 7 c5 127
|
| 23 |
+
|
| 24 |
+
# YOLOv5 v6.0 head
|
| 25 |
+
head:
|
| 26 |
+
[[ 5, 1, Head, [64, 64, 3]], # 8 o4 95
|
| 27 |
+
[ 7, 1, Head, [64, 64, 3]], # 9 o5 191
|
| 28 |
+
[[ 8, 9], 1, Detect, [nc, [64, 64], 2]] # 43 Detect(P4-P6)
|
| 29 |
+
] # rfs 127 255
|
synet/zoo/ultralytics/sabre-keypoint-vga.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nc: 1 # number of classes
|
| 2 |
+
kpt_shape: [17, 3]
|
| 3 |
+
depth_multiple: 1 # model depth multiple
|
| 4 |
+
width_multiple: 1 # layer channel multiple
|
| 5 |
+
chip: sabre
|
| 6 |
+
image_shape: [480, 640]
|
| 7 |
+
# anchors:
|
| 8 |
+
# # autogenerated by yolo
|
| 9 |
+
# - [0,0, 0,0, 0,0] # P3/8
|
| 10 |
+
# - [0,0, 0,0, 0,0] # P4/16
|
| 11 |
+
# - [0,0, 0,0, 0,0] # P5/32
|
| 12 |
+
backbone:
|
| 13 |
+
# [from, number, module, args]
|
| 14 |
+
#src num layer params id rf notes
|
| 15 |
+
[[-1, 1, InvertedResidual, [3, 4, 12, 2]], # 0 c1 1 stride -> 2
|
| 16 |
+
[-1, 1, InvertedResidual, [12, 4, 48, 2]], # 1 c2 3 stride -> 4
|
| 17 |
+
[-1, 1, InvertedResidual, [48, 4, 48, 2]], # 2 7 stride -> 8
|
| 18 |
+
[-1, 2, InvertedResidual, [48, 5, 48]], # 3 c3
|
| 19 |
+
[-1, 1, InvertedResidual, [48, 4, 64, 2]], # 4 15 stride -> 16
|
| 20 |
+
[-1, 2, InvertedResidual, [64, 3, 64]], # 5 c4 47
|
| 21 |
+
[-1, 1, InvertedResidual, [64, 2, 64, 2]], # 6 63 stride -> 32
|
| 22 |
+
[-1, 2, InvertedResidual, [64, 1]]] # 7 c5 127
|
| 23 |
+
|
| 24 |
+
# YOLOv5 v6.0 head
|
| 25 |
+
head:
|
| 26 |
+
[[ 5, 1, Head, [64, 64, 3]], # 8 o4 95
|
| 27 |
+
[ 7, 1, Head, [64, 64, 3]], # 9 o5 191
|
| 28 |
+
[[ 8, 9], 1, Pose, [nc, [17,3], [64, 64], 2]] # 43 Detect(P4-P6)
|
| 29 |
+
] # rfs 127 255
|
synet/zoo/ultralytics/sabre-segment-vga.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nc: 1 # number of classes
|
| 2 |
+
#kpt_shape: [17, 3]
|
| 3 |
+
depth_multiple: 1 # model depth multiple
|
| 4 |
+
width_multiple: 1 # layer channel multiple
|
| 5 |
+
chip: sabre
|
| 6 |
+
image_shape: [480, 640]
|
| 7 |
+
# anchors:
|
| 8 |
+
# # autogenerated by yolo
|
| 9 |
+
# - [0,0, 0,0, 0,0] # P3/8
|
| 10 |
+
# - [0,0, 0,0, 0,0] # P4/16
|
| 11 |
+
# - [0,0, 0,0, 0,0] # P5/32
|
| 12 |
+
backbone:
|
| 13 |
+
# [from, number, module, args]
|
| 14 |
+
#src num layer params id rf notes
|
| 15 |
+
[[-1, 1, InvertedResidual, [3, 4, 12, 2]], # 0 c1 1 stride -> 2
|
| 16 |
+
[-1, 1, InvertedResidual, [12, 4, 48, 2]], # 1 c2 3 stride -> 4
|
| 17 |
+
[-1, 1, InvertedResidual, [48, 4, 48, 2]], # 2 7 stride -> 8
|
| 18 |
+
[-1, 2, InvertedResidual, [48, 5, 48]], # 3 c3
|
| 19 |
+
[-1, 1, InvertedResidual, [48, 4, 64, 2]], # 4 15 stride -> 16
|
| 20 |
+
[-1, 2, InvertedResidual, [64, 3, 64]], # 5 c4 47
|
| 21 |
+
[-1, 1, InvertedResidual, [64, 2, 64, 2]], # 6 63 stride -> 32
|
| 22 |
+
[-1, 2, InvertedResidual, [64, 1]]] # 7 c5 127
|
| 23 |
+
|
| 24 |
+
# YOLOv5 v6.0 head
|
| 25 |
+
head:
|
| 26 |
+
[[ 5, 1, Head, [64, 64, 3]], # 8 o4 95
|
| 27 |
+
[ 7, 1, Head, [64, 64, 3]], # 9 o5 191
|
| 28 |
+
[[ 8, 9], 1, Segment, [nc, 32, 96, [64, 64], 2]] # 43 Detect(P4-P6)
|
| 29 |
+
] # rfs 127 255
|
tests/test_demosaic.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import rand
|
| 2 |
+
from torch.nn.init import uniform_
|
| 3 |
+
|
| 4 |
+
from synet.base import Conv2d
|
| 5 |
+
from synet.demosaic import reshape_conv
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_reshape_conv():
|
| 9 |
+
conv = Conv2d(3, 13, 3, 2)
|
| 10 |
+
for param in conv.parameters():
|
| 11 |
+
uniform_(param, -1)
|
| 12 |
+
reshaped_conv = reshape_conv(conv)
|
| 13 |
+
inp = rand(3, 480, 640)
|
| 14 |
+
reshaped_inp = inp.reshape(3, 240, 2, 320, 2
|
| 15 |
+
).permute(2, 4, 0, 1, 3
|
| 16 |
+
).reshape(12, 240, 320)
|
| 17 |
+
assert (reshaped_conv(reshaped_inp) - conv(inp)).abs().max() < 1e-5
|
tests/test_keras.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from numpy import absolute
|
| 2 |
+
from torch import rand
|
| 3 |
+
from torch.nn.init import uniform_
|
| 4 |
+
|
| 5 |
+
from synet.base import askeras
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
BATCH_SIZE = 2
|
| 9 |
+
IN_CHANNELS = 5
|
| 10 |
+
OUT_CHANNELS = 7
|
| 11 |
+
SHAPES = [(i, i) for i in range(4, 8)]
|
| 12 |
+
MAX_DIFF = -1
|
| 13 |
+
TOLERANCE = 2e-4
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def diff_arr(out1, out2):
|
| 17 |
+
"""compare two arrays. Return the max difference."""
|
| 18 |
+
if isinstance(out1, (list, tuple)):
|
| 19 |
+
assert isinstance(out2, (list, tuple))
|
| 20 |
+
return max(diff_arr(o1, o2) for o1, o2 in zip(out1, out2))
|
| 21 |
+
assert all(s1 == s2 for s1, s2 in zip(out1.shape, out2.shape)), \
|
| 22 |
+
(out1.shape, out2.shape)
|
| 23 |
+
return absolute(out1 - out2).max()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def t_actv_to_k(actv):
|
| 27 |
+
if isinstance(actv, (tuple, list)):
|
| 28 |
+
return [t_actv_to_k(a) for a in actv]
|
| 29 |
+
if len(actv.shape) == 4:
|
| 30 |
+
tp = 0, 2, 3, 1
|
| 31 |
+
elif len(actv.shape) == 3:
|
| 32 |
+
tp = 0, 2, 1
|
| 33 |
+
elif len(actv.shape) == 2:
|
| 34 |
+
tp = 0, 1
|
| 35 |
+
return actv.detach().numpy().transpose(*tp)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def k_to_numpy(actv):
|
| 39 |
+
if isinstance(actv, (list, tuple)):
|
| 40 |
+
return [k_to_numpy(k) for k in actv]
|
| 41 |
+
if hasattr(actv, "numpy"):
|
| 42 |
+
return actv.numpy()
|
| 43 |
+
return actv
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def validate_layer(layer, torch_inp, **akwds):
|
| 47 |
+
"""Given synet layer, test on some torch input activations and
|
| 48 |
+
return max error between two output activations
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
tout = layer(torch_inp[:])
|
| 52 |
+
with askeras(imgsz=torch_inp[0].shape[-2:], **akwds):
|
| 53 |
+
kout = k_to_numpy(layer(t_actv_to_k(torch_inp)))
|
| 54 |
+
if isinstance(tout, dict):
|
| 55 |
+
assert len(tout) == len(kout)
|
| 56 |
+
return max(diff_arr(t_actv_to_k(tout[key]), kout[key])
|
| 57 |
+
for key in tout)
|
| 58 |
+
elif isinstance(tout, list):
|
| 59 |
+
assert len(tout) == len(kout)
|
| 60 |
+
return max(diff_arr(t_actv_to_k(t), k)
|
| 61 |
+
for t, k in zip(tout, kout))
|
| 62 |
+
return diff_arr(t_actv_to_k(tout), kout)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def validate(layer, batch_size=BATCH_SIZE,
|
| 66 |
+
in_channels=IN_CHANNELS, shapes=SHAPES, **akwds):
|
| 67 |
+
"""Run validate_layer on a set of random input shapes. Prints the max
|
| 68 |
+
difference between all configurations.
|
| 69 |
+
|
| 70 |
+
"""
|
| 71 |
+
for param in layer.parameters():
|
| 72 |
+
uniform_(param, -1)
|
| 73 |
+
max_diff = max(validate_layer(layer,
|
| 74 |
+
[rand(batch_size, in_channels, *s)*2-1
|
| 75 |
+
for s in shape]
|
| 76 |
+
if len(shape) and isinstance(shape[0], tuple)
|
| 77 |
+
else rand(batch_size, in_channels, *shape
|
| 78 |
+
)*2-1,
|
| 79 |
+
**akwds)
|
| 80 |
+
for shape in shapes)
|
| 81 |
+
print("max_diff:", max_diff)
|
| 82 |
+
assert max_diff < TOLERANCE
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_conv2d():
|
| 86 |
+
from synet.base import Conv2d
|
| 87 |
+
print("testing Conv2d")
|
| 88 |
+
in_channels = 12
|
| 89 |
+
out_channels = 24
|
| 90 |
+
for bias in True, False:
|
| 91 |
+
for kernel, stride in ((1, 1), (2, 1), (3, 1), (3, 2), (4, 1),
|
| 92 |
+
(4, 2), (4, 3)):
|
| 93 |
+
for padding in True, False:
|
| 94 |
+
for groups in 1, 2, 3:
|
| 95 |
+
validate(Conv2d(in_channels, out_channels, kernel,
|
| 96 |
+
stride, bias, padding),
|
| 97 |
+
in_channels=in_channels)
|
| 98 |
+
|
| 99 |
+
def test_dw_conv2d():
|
| 100 |
+
from synet.layers import DepthwiseConv2d
|
| 101 |
+
print("testing dw Conv2d")
|
| 102 |
+
channels = 32
|
| 103 |
+
for bias in True, False:
|
| 104 |
+
for kernel, stride in ((1, 1), (2, 1), (3, 1), (3, 2), (4, 1),
|
| 105 |
+
(4, 2), (4, 3)):
|
| 106 |
+
for padding in True, False:
|
| 107 |
+
validate(DepthwiseConv2d(channels, kernel,
|
| 108 |
+
stride, bias, padding),
|
| 109 |
+
in_channels=channels)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test_convtranspose():
|
| 113 |
+
from synet.base import ConvTranspose2d
|
| 114 |
+
validate(ConvTranspose2d(IN_CHANNELS, OUT_CHANNELS, 2, 2, 0, bias=True))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def test_relu():
|
| 118 |
+
from synet.base import ReLU
|
| 119 |
+
validate(ReLU(.6))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_upsample():
|
| 123 |
+
from synet.base import Upsample
|
| 124 |
+
for scale_factor in 1, 2, 3:
|
| 125 |
+
for mode in Upsample.allowed_modes:
|
| 126 |
+
validate(Upsample(scale_factor, mode))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_globavgpool():
|
| 130 |
+
from synet.base import GlobalAvgPool
|
| 131 |
+
validate(GlobalAvgPool())
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def test_dropout():
|
| 135 |
+
from synet.base import Dropout
|
| 136 |
+
for p in 0.0, 0.5, 1.0:
|
| 137 |
+
for inplace in True, False:
|
| 138 |
+
layer = Dropout(p, inplace=inplace)
|
| 139 |
+
layer.eval()
|
| 140 |
+
validate(layer)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_linear():
|
| 144 |
+
from synet.base import Linear
|
| 145 |
+
for bias in True, False:
|
| 146 |
+
validate(Linear(IN_CHANNELS, OUT_CHANNELS, bias), shapes=[()])
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def test_batchnorm():
|
| 150 |
+
from synet.base import BatchNorm
|
| 151 |
+
validate(BatchNorm(IN_CHANNELS), train=True)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def test_ultralytics_detect():
|
| 155 |
+
from synet.backends.ultralytics import Detect
|
| 156 |
+
for sm_split in ((True, None), (2, True)):
|
| 157 |
+
layer = Detect(80, (IN_CHANNELS, IN_CHANNELS), *sm_split)
|
| 158 |
+
layer.eval()
|
| 159 |
+
layer.export = True
|
| 160 |
+
layer.format = "tflite"
|
| 161 |
+
layer.stride[0], layer.stride[1] = 1, 2
|
| 162 |
+
validate(layer,
|
| 163 |
+
shapes=[((4, 6), (2, 3)),
|
| 164 |
+
((5, 7), (3, 4)),
|
| 165 |
+
((6, 8), (3, 4))],
|
| 166 |
+
xywh=True)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def test_ultralytics_pose():
|
| 170 |
+
from synet.backends.ultralytics import Pose
|
| 171 |
+
for sm_split in ((True, None), (2, True)):
|
| 172 |
+
for kpt_shape in ([17, 2], [17, 3]):
|
| 173 |
+
layer = Pose(80, kpt_shape, (IN_CHANNELS, IN_CHANNELS), *sm_split)
|
| 174 |
+
layer.eval()
|
| 175 |
+
layer.export = True
|
| 176 |
+
layer.format = "tflite"
|
| 177 |
+
layer.stride[0], layer.stride[1] = 1, 2
|
| 178 |
+
validate(layer,
|
| 179 |
+
shapes=[((4, 6), (2, 3)),
|
| 180 |
+
((5, 7), (3, 4)),
|
| 181 |
+
((6, 8), (3, 4))],
|
| 182 |
+
xywh=True)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def test_ultralytics_segment():
|
| 186 |
+
from synet.backends.ultralytics import Segment
|
| 187 |
+
layer = Segment(nc=80, nm=32, npr=256, ch=(IN_CHANNELS, IN_CHANNELS))
|
| 188 |
+
layer.eval()
|
| 189 |
+
layer.export = True
|
| 190 |
+
layer.format = "tflite"
|
| 191 |
+
layer.stride[0], layer.stride[1] = 1, 2
|
| 192 |
+
validate(layer,
|
| 193 |
+
shapes=[(( 4, 4), (2, 2)),
|
| 194 |
+
(( 8, 8), (4, 4)),
|
| 195 |
+
((12, 12), (6, 6))],
|
| 196 |
+
xywh=True)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def test_ultralytics_classify():
|
| 200 |
+
from synet.backends.ultralytics import Classify
|
| 201 |
+
layer = Classify(None, c1=IN_CHANNELS, c2=OUT_CHANNELS)
|
| 202 |
+
layer.eval()
|
| 203 |
+
layer.export = True
|
| 204 |
+
layer.format = 'tflite'
|
| 205 |
+
validate(layer)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def test_channelslice():
|
| 209 |
+
from synet.base import ChannelSlice
|
| 210 |
+
validate(ChannelSlice(slice(4, 8)), in_channels=12)
|
tests/test_ultralytics.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import chdir
|
| 2 |
+
from synet.backends import get_backend
|
| 3 |
+
from synet.quantize import main
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
backend = get_backend("ultralytics")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_quantize(tmp_path):
|
| 10 |
+
chdir(tmp_path)
|
| 11 |
+
for config in backend.get_configs():
|
| 12 |
+
main(("--backend=ultralytics",
|
| 13 |
+
"--model="+config,
|
| 14 |
+
"--number=1"))
|
| 15 |
+
main(("--backend=ultralytics",
|
| 16 |
+
"--model="+config,
|
| 17 |
+
"--number=1",
|
| 18 |
+
"--image-shape", "321", "319"))
|