Spaces:
Sleeping
Sleeping
File size: 1,797 Bytes
1c77735 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import numpy as np
from app.preprocessing.base import PreprocessingStep, PreprocessingContext, PreprocessingError
DTYPE_MAP = {
"float32": np.float32,
"float64": np.float64,
"int32": np.int32,
"int64": np.int64,
"uint8": np.uint8,
}
class TensorizeStep(PreprocessingStep):
name = "tensorize"
description = "Convert processed image or patches to model-ready tensor format"
version = "1.0.0"
order = 8
enabled = True
required = False
async def process(self, ctx: PreprocessingContext, params: dict) -> PreprocessingContext:
dtype_name = params.get("dtype", "float32")
add_batch_dim = params.get("add_batch_dim", True)
channel_first = params.get("channel_first", True)
dtype = DTYPE_MAP.get(dtype_name, np.float32)
source = "unknown"
if ctx.patches:
stack = np.stack(ctx.patches, axis=0).astype(dtype)
if add_batch_dim:
stack = stack[np.newaxis, ...]
ctx.tensor = stack
source = "patches"
shape = list(stack.shape)
elif ctx.image_array is not None:
arr = ctx.image_array.astype(dtype)
if arr.ndim == 2:
arr = arr[:, :, np.newaxis]
if channel_first and arr.ndim == 3:
arr = np.transpose(arr, (2, 0, 1))
if add_batch_dim:
arr = arr[np.newaxis, ...]
ctx.tensor = arr
source = "image_array"
shape = list(arr.shape)
else:
raise PreprocessingError("No patches or image_array available for tensorization")
ctx.step_outputs["tensorize"] = {
"shape": shape,
"dtype": dtype_name,
"from": source,
}
return ctx
|