File size: 1,797 Bytes
1c77735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
from app.preprocessing.base import PreprocessingStep, PreprocessingContext, PreprocessingError


DTYPE_MAP = {
    "float32": np.float32,
    "float64": np.float64,
    "int32": np.int32,
    "int64": np.int64,
    "uint8": np.uint8,
}


class TensorizeStep(PreprocessingStep):
    name = "tensorize"
    description = "Convert processed image or patches to model-ready tensor format"
    version = "1.0.0"
    order = 8
    enabled = True
    required = False

    async def process(self, ctx: PreprocessingContext, params: dict) -> PreprocessingContext:
        dtype_name = params.get("dtype", "float32")
        add_batch_dim = params.get("add_batch_dim", True)
        channel_first = params.get("channel_first", True)
        dtype = DTYPE_MAP.get(dtype_name, np.float32)

        source = "unknown"

        if ctx.patches:
            stack = np.stack(ctx.patches, axis=0).astype(dtype)
            if add_batch_dim:
                stack = stack[np.newaxis, ...]
            ctx.tensor = stack
            source = "patches"
            shape = list(stack.shape)
        elif ctx.image_array is not None:
            arr = ctx.image_array.astype(dtype)
            if arr.ndim == 2:
                arr = arr[:, :, np.newaxis]
            if channel_first and arr.ndim == 3:
                arr = np.transpose(arr, (2, 0, 1))
            if add_batch_dim:
                arr = arr[np.newaxis, ...]
            ctx.tensor = arr
            source = "image_array"
            shape = list(arr.shape)
        else:
            raise PreprocessingError("No patches or image_array available for tensorization")

        ctx.step_outputs["tensorize"] = {
            "shape": shape,
            "dtype": dtype_name,
            "from": source,
        }

        return ctx