File size: 812 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Callable
import torch
import torch.nn as nn
class Lambda(nn.Module):
"""Applies user-defined lambda as a transform.
Args:
func: Callable function.
Returns:
The output of the user-defined lambda.
Example:
>>> import kornia
>>> x = torch.rand(1, 3, 5, 5)
>>> f = Lambda(lambda x: kornia.color.rgb_to_grayscale(x))
>>> f(x).shape
torch.Size([1, 1, 5, 5])
"""
def __init__(self, func: Callable) -> None:
super().__init__()
if not callable(func):
raise TypeError(f"Argument lambd should be callable, got {repr(type(func).__name__)}")
self.func = func
def forward(self, img: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return self.func(img, *args, **kwargs)
|