Spaces:
Running
Running
| import numpy as np | |
| import pytest | |
| from doctr.contrib import artefacts | |
| from doctr.contrib.base import _BasePredictor | |
| from doctr.io import DocumentFile | |
| def test_base_predictor(): | |
| # check that we need to provide either a url or a model_path | |
| with pytest.raises(ValueError): | |
| _ = _BasePredictor(batch_size=2) | |
| predictor = _BasePredictor(batch_size=2, url=artefacts.default_cfgs["yolov8_artefact"]["url"]) | |
| # check that we need to implement preprocess and postprocess | |
| with pytest.raises(NotImplementedError): | |
| predictor.preprocess(np.zeros((10, 10, 3))) | |
| with pytest.raises(NotImplementedError): | |
| predictor.postprocess([np.zeros((10, 10, 3))], [[np.zeros((10, 10, 3))]]) | |
| def test_artefact_detector(mock_artefact_image_stream): | |
| doc = DocumentFile.from_images([mock_artefact_image_stream]) | |
| detector = artefacts.ArtefactDetector(batch_size=2, conf_threshold=0.5, iou_threshold=0.5) | |
| results = detector(doc) | |
| assert isinstance(results, list) and len(results) == 1 and isinstance(results[0], list) | |
| assert all(isinstance(artefact, dict) for artefact in results[0]) | |
| # check result keys | |
| assert all(key in results[0][0] for key in ["label", "confidence", "box"]) | |
| assert all(len(artefact["box"]) == 4 for artefact in results[0]) | |
| assert all(isinstance(coord, int) for box in results[0] for coord in box["box"]) | |
| assert all(isinstance(artefact["confidence"], float) for artefact in results[0]) | |
| assert all(isinstance(artefact["label"], str) for artefact in results[0]) | |
| # check results for the mock image are 9 artefacts | |
| assert len(results[0]) == 9 | |
| # test visualization non-blocking for tests | |
| detector.show(block=False) | |