File size: 897 Bytes
434b0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


class LinerParameterTuner:
    def __init__(self, start, start_value, end_value, end):
        self.start = start
        self.start_value = start_value
        self.end_value = end_value
        self.end = end
        self.total_steps = self.end - self.start

    def get_value(self, step):
        if step < self.start:
            return self.start_value
        elif step > self.end:
            return self.end_value

        current_step = step - self.start

        ratio = current_step / self.total_steps

        current_value = self.start_value + ratio * (self.end_value - self.start_value)
        return current_value


class StaticParameterTuner:
    def __init__(self, v):
        self.v = v

    def get_value(self, step):
        return self.v