| import torch.nn as nn | |
| def linear(*args, **kwargs): | |
| """ | |
| Create a linear module. | |
| """ | |
| return nn.Linear(*args, **kwargs) | |
| class LinerParameterTuner: | |
| def __init__(self, start, start_value, end_value, end): | |
| self.start = start | |
| self.start_value = start_value | |
| self.end_value = end_value | |
| self.end = end | |
| self.total_steps = self.end - self.start | |
| def get_value(self, step): | |
| if step < self.start: | |
| return self.start_value | |
| elif step > self.end: | |
| return self.end_value | |
| current_step = step - self.start | |
| ratio = current_step / self.total_steps | |
| current_value = self.start_value + ratio * (self.end_value - self.start_value) | |
| return current_value | |
| class StaticParameterTuner: | |
| def __init__(self, v): | |
| self.v = v | |
| def get_value(self, step): | |
| return self.v | |