File size: 1,923 Bytes
f3270e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from pathlib import Path

import numpy as np
import pytest

from doctr import datasets


def test_visiondataset():
    url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip"
    with pytest.raises(ValueError):
        datasets.datasets.VisionDataset(url, download=False)

    dataset = datasets.datasets.VisionDataset(url, download=True, extract_archive=True)
    assert len(dataset) == 0
    assert repr(dataset) == "VisionDataset()"


def test_abstractdataset(mock_image_path):
    with pytest.raises(ValueError):
        datasets.datasets.AbstractDataset("my/fantasy/folder")

    # Check transforms
    path = Path(mock_image_path)
    ds = datasets.datasets.AbstractDataset(path.parent)
    # Check target format
    with pytest.raises(AssertionError):
        ds.data = [(path.name, 0)]
        img, target = ds[0]
    with pytest.raises(AssertionError):
        ds.data = [(path.name, dict(boxes=np.array([[0, 0, 1, 1]])))]
        img, target = ds[0]
    with pytest.raises(AssertionError):
        ds.data = [(ds.data[0][0], {"label": "A"})]
        img, target = ds[0]

    # Patch some data
    ds.data = [(path.name, np.array([0]))]

    # Fetch the img
    img, target = ds[0]
    assert isinstance(target, np.ndarray) and target == np.array([0])

    # Check img_transforms
    ds.img_transforms = lambda x: 1 - x
    img2, target2 = ds[0]
    assert np.all(img2.numpy() == 1 - img.numpy())
    assert target == target2

    # Check sample_transforms
    ds.img_transforms = None
    ds.sample_transforms = lambda x, y: (x, y + 1)
    img3, target3 = ds[0]
    assert np.all(img3.numpy() == img.numpy()) and (target3 == (target + 1))

    # Check inplace modifications
    ds.data = [(ds.data[0][0], "A")]

    def inplace_transfo(x, target):
        target += "B"
        return x, target

    ds.sample_transforms = inplace_transfo
    _, t = ds[0]
    _, t = ds[0]
    assert t == "AB"