LiDAR-Perfect-Depth / code /ppd /utils /lr_table.py
chenming-wu's picture
code
436b829 verified
from omegaconf import DictConfig
from ppd.utils.logger import Log
class LRTable:
def __init__(self,
default_lr: float = 1e-4,
prefix_table: DictConfig = None,
postfix_table: DictConfig = None,
else_table: DictConfig = None):
#
self.default_lr = default_lr
self.prefix_table = prefix_table
self.postfix_table = postfix_table
self.else_table = else_table
self.tables = [self.prefix_table, self.postfix_table, self.else_table]
self.tags = ["prefix_", "postfix_", "else_"]
def match_table(self, key) -> bool:
if self.prefix_table is not None:
for prefix in self.prefix_table:
if key.startswith(prefix):
return 0, prefix
if self.postfix_table is not None:
for postfix in self.postfix_table:
if key.endswith(postfix):
return 1, postfix
if self.else_table is not None:
for else_key in self.else_table:
if else_key in key and self.tables[2][else_key] is not None:
return 2, else_key
return 3, None
def get_lr(self, key) -> float:
table_idx, match_key = self.match_table(key)
if table_idx == 3:
Log.debug(f'{key} is not matched to any table, use default lr: {self.default_lr}')
return 'default', self.default_lr
else:
Log.debug(f'{key} is matched to table {self.tags[table_idx]}: {match_key}')
return self.tags[table_idx] + match_key, self.tables[table_idx][match_key]