LHMPP / core /models /utils.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
import torch.nn as nn
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
class LinerParameterTuner:
def __init__(self, start, start_value, end_value, end):
self.start = start
self.start_value = start_value
self.end_value = end_value
self.end = end
self.total_steps = self.end - self.start
def get_value(self, step):
if step < self.start:
return self.start_value
elif step > self.end:
return self.end_value
current_step = step - self.start
ratio = current_step / self.total_steps
current_value = self.start_value + ratio * (self.end_value - self.start_value)
return current_value
class StaticParameterTuner:
def __init__(self, v):
self.v = v
def get_value(self, step):
return self.v