| using System.Collections.Generic; |
| using Unity.Sentis; |
| using UnityEngine; |
| using UnityEngine.UI; |
| using UnityEngine.Video; |
| using Lays = Unity.Sentis.Layers; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| public class RunYOLO8n : MonoBehaviour |
| { |
| const string modelName = "yolov8n.sentis"; |
| |
| const string videoName = "giraffes.mp4"; |
| |
| public TextAsset labelsAsset; |
| |
| public RawImage displayImage; |
| |
| public Sprite boxTexture; |
| |
| public Font font; |
|
|
| const BackendType backend = BackendType.GPUCompute; |
|
|
| private Transform displayLocation; |
| private Model model; |
| private IWorker engine; |
| private string[] labels; |
| private RenderTexture targetRT; |
|
|
|
|
| |
| private const int imageWidth = 640; |
| private const int imageHeight = 640; |
|
|
| |
| private const int numClasses = 80; |
|
|
| private VideoPlayer video; |
|
|
| List<GameObject> boxPool = new List<GameObject>(); |
|
|
| [SerializeField, Range(0, 1)] float iouThreshold = 0.5f; |
| [SerializeField, Range(0, 1)] float scoreThreshold = 0.5f; |
| int maxOutputBoxes = 64; |
|
|
| |
| Ops ops; |
|
|
| |
| public struct BoundingBox |
| { |
| public float centerX; |
| public float centerY; |
| public float width; |
| public float height; |
| public string label; |
| } |
|
|
| |
| void Start() |
| { |
| Application.targetFrameRate = 60; |
| Screen.orientation = ScreenOrientation.LandscapeLeft; |
|
|
| ops = WorkerFactory.CreateOps(backend, null); |
|
|
| |
| labels = labelsAsset.text.Split('\n'); |
|
|
| LoadModel(); |
|
|
| targetRT = new RenderTexture(imageWidth, imageHeight, 0); |
|
|
| |
| displayLocation = displayImage.transform; |
|
|
| |
| engine = WorkerFactory.CreateWorker(backend, model); |
|
|
| SetupInput(); |
| } |
|
|
| void LoadModel() |
| { |
| |
| model = ModelLoader.Load(Application.streamingAssetsPath + "/" + modelName); |
|
|
| |
| Debug.Log($"Class names: \n{model.Metadata["names"]}"); |
|
|
| |
| |
| |
| model.AddConstant(new Lays.Constant("0", new int[] { 0 })); |
| model.AddConstant(new Lays.Constant("1", new int[] { 1 })); |
| model.AddConstant(new Lays.Constant("4", new int[] { 4 })); |
|
|
|
|
| model.AddConstant(new Lays.Constant("classes_plus_4", new int[] { numClasses + 4 })); |
| model.AddConstant(new Lays.Constant("maxOutputBoxes", new int[] { maxOutputBoxes })); |
| model.AddConstant(new Lays.Constant("iouThreshold", new float[] { iouThreshold })); |
| model.AddConstant(new Lays.Constant("scoreThreshold", new float[] { scoreThreshold })); |
| |
| |
| model.AddLayer(new Lays.Slice("boxCoords0", "output0", "0", "4", "1")); |
| model.AddLayer(new Lays.Transpose("boxCoords", "boxCoords0", new int[] { 0, 2, 1 })); |
| model.AddLayer(new Lays.Slice("scores0", "output0", "4", "classes_plus_4", "1")); |
| model.AddLayer(new Lays.ReduceMax("scores", new[] { "scores0", "1" })); |
| model.AddLayer(new Lays.ArgMax("classIDs", "scores0", 1)); |
|
|
| model.AddLayer(new Lays.NonMaxSuppression("NMS", "boxCoords", "scores", |
| "maxOutputBoxes", "iouThreshold", "scoreThreshold", |
| centerPointBox: Lays.CenterPointBox.Center |
| )); |
|
|
| model.outputs.Clear(); |
| model.AddOutput("boxCoords"); |
| model.AddOutput("classIDs"); |
| model.AddOutput("NMS"); |
| } |
|
|
| void SetupInput() |
| { |
| video = gameObject.AddComponent<VideoPlayer>(); |
| video.renderMode = VideoRenderMode.APIOnly; |
| video.source = VideoSource.Url; |
| video.url = Application.streamingAssetsPath + "/" + videoName; |
| video.isLooping = true; |
| video.Play(); |
| } |
|
|
| private void Update() |
| { |
| ExecuteML(); |
|
|
| if (Input.GetKeyDown(KeyCode.Escape)) |
| { |
| Application.Quit(); |
| } |
| } |
|
|
| public void ExecuteML() |
| { |
| ClearAnnotations(); |
|
|
| if (video && video.texture) |
| { |
| float aspect = video.width * 1f / video.height; |
| Graphics.Blit(video.texture, targetRT, new Vector2(1f / aspect, 1), new Vector2(0, 0)); |
| displayImage.texture = targetRT; |
| } |
| else return; |
|
|
| using var input = TextureConverter.ToTensor(targetRT, imageWidth, imageHeight, 3); |
| engine.Execute(input); |
|
|
| var boxCoords = engine.PeekOutput("boxCoords") as TensorFloat; |
| var NMS = engine.PeekOutput("NMS") as TensorInt; |
| var classIDs = engine.PeekOutput("classIDs") as TensorInt; |
|
|
| using var boxIDs = ops.Slice(NMS, new int[] { 2 }, new int[] { 3 }, new int[] { 1 }, new int[] { 1 }); |
| using var boxIDsFlat = boxIDs.ShallowReshape(new TensorShape(boxIDs.shape.length)) as TensorInt; |
| using var output = ops.Gather(boxCoords, boxIDsFlat, 1); |
| using var labelIDs = ops.Gather(classIDs, boxIDsFlat, 2); |
| |
| output.MakeReadable(); |
| labelIDs.MakeReadable(); |
|
|
| float displayWidth = displayImage.rectTransform.rect.width; |
| float displayHeight = displayImage.rectTransform.rect.height; |
|
|
| float scaleX = displayWidth / imageWidth; |
| float scaleY = displayHeight / imageHeight; |
|
|
| |
| for (int n = 0; n < output.shape[1]; n++) |
| { |
| var box = new BoundingBox |
| { |
| centerX = output[0, n, 0] * scaleX - displayWidth / 2, |
| centerY = output[0, n, 1] * scaleY - displayHeight / 2, |
| width = output[0, n, 2] * scaleX, |
| height = output[0, n, 3] * scaleY, |
| label = labels[labelIDs[0, 0,n]], |
| }; |
| DrawBox(box, n); |
| } |
| } |
|
|
| public void DrawBox(BoundingBox box , int id) |
| { |
| |
| GameObject panel; |
| if (id < boxPool.Count) |
| { |
| panel = boxPool[id]; |
| panel.SetActive(true); |
| } |
| else |
| { |
| panel = CreateNewBox(Color.yellow); |
| } |
| |
| panel.transform.localPosition = new Vector3(box.centerX, -box.centerY); |
|
|
| |
| RectTransform rt = panel.GetComponent<RectTransform>(); |
| rt.sizeDelta = new Vector2(box.width, box.height); |
| |
| |
| var label = panel.GetComponentInChildren<Text>(); |
| label.text = box.label; |
| } |
|
|
| public GameObject CreateNewBox(Color color) |
| { |
| |
|
|
| var panel = new GameObject("ObjectBox"); |
| panel.AddComponent<CanvasRenderer>(); |
| Image img = panel.AddComponent<Image>(); |
| img.color = color; |
| img.sprite = boxTexture; |
| img.type = Image.Type.Sliced; |
| panel.transform.SetParent(displayLocation, false); |
|
|
| |
|
|
| var text = new GameObject("ObjectLabel"); |
| text.AddComponent<CanvasRenderer>(); |
| text.transform.SetParent(panel.transform, false); |
| Text txt = text.AddComponent<Text>(); |
| txt.font = font; |
| txt.color = color; |
| txt.fontSize = 40; |
| txt.horizontalOverflow = HorizontalWrapMode.Overflow; |
|
|
| RectTransform rt2 = text.GetComponent<RectTransform>(); |
| rt2.offsetMin = new Vector2(20, rt2.offsetMin.y); |
| rt2.offsetMax = new Vector2(0, rt2.offsetMax.y); |
| rt2.offsetMin = new Vector2(rt2.offsetMin.x, 0); |
| rt2.offsetMax = new Vector2(rt2.offsetMax.x, 30); |
| rt2.anchorMin = new Vector2(0, 0); |
| rt2.anchorMax = new Vector2(1, 1); |
|
|
| boxPool.Add(panel); |
| return panel; |
| } |
|
|
| public void ClearAnnotations() |
| { |
| foreach(var box in boxPool) |
| { |
| box.SetActive(false); |
| } |
| } |
|
|
| private void OnDestroy() |
| { |
| engine?.Dispose(); |
| ops?.Dispose(); |
| } |
| } |
|
|