File size: 1,969 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# TODO: remove the type: ignore in below after deprecating python 3.6
from dataclasses import dataclass, field  # type: ignore
from enum import Enum

import torch.nn as nn

# import yaml  # type: ignore


class TrainerState(Enum):
    STARTING = 0
    TRAINING = 1
    VALIDATE = 2
    TERMINATE = 3


# NOTE: this class needs to be redefined according to the needed parameters.
@dataclass
class Configuration:
    data_path: str = field(default="./", metadata={"help": "The input data directory."})
    batch_size: int = field(default=1, metadata={"help": "The number of batches for the training dataloader."})
    num_epochs: int = field(default=1, metadata={"help": "The number of epochs to run the training."})
    lr: float = field(default=1e-3, metadata={"help": "The learning rate to be used for the optimize."})
    output_path: str = field(default="./output", metadata={"help": "The output data directory."})
    image_size: tuple = field(default=(224, 224), metadata={"help": "The input image size."})

    # TODO: possibly remove because hydra already do this
    # def __init__(self, **entries):
    #     for k, v in entries.items():
    #         self.__dict__[k] = Configuration(**v) if isinstance(v, dict) else v

    # @classmethod
    # def from_yaml(cls, config_file: str):
    #     """Create an instance of the configuration from a yaml file."""
    #     with open(config_file) as f:
    #         data = yaml.safe_load(f)
    #     return cls(**data)


class Lambda(nn.Module):
    """Module to create a lambda function as nn.Module.

    Args:
        fcn: a pointer to any function.

    Example:
        >>> import torch
        >>> import kornia as K
        >>> fcn = Lambda(lambda x: K.geometry.resize(x, (32, 16)))
        >>> fcn(torch.rand(1, 4, 64, 32)).shape
        torch.Size([1, 4, 32, 16])
    """
    def __init__(self, fcn):
        super().__init__()
        self.fcn = fcn

    def forward(self, x):
        return self.fcn(x)