File size: 812 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Callable

import torch
import torch.nn as nn


class Lambda(nn.Module):
    """Applies user-defined lambda as a transform.

    Args:
        func: Callable function.

    Returns:
        The output of the user-defined lambda.

    Example:
        >>> import kornia
        >>> x = torch.rand(1, 3, 5, 5)
        >>> f = Lambda(lambda x: kornia.color.rgb_to_grayscale(x))
        >>> f(x).shape
        torch.Size([1, 1, 5, 5])
    """

    def __init__(self, func: Callable) -> None:
        super().__init__()
        if not callable(func):
            raise TypeError(f"Argument lambd should be callable, got {repr(type(func).__name__)}")

        self.func = func

    def forward(self, img: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        return self.func(img, *args, **kwargs)