File size: 1,698 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | from omegaconf import DictConfig
from ppd.utils.logger import Log
class LRTable:
def __init__(self,
default_lr: float = 1e-4,
prefix_table: DictConfig = None,
postfix_table: DictConfig = None,
else_table: DictConfig = None):
#
self.default_lr = default_lr
self.prefix_table = prefix_table
self.postfix_table = postfix_table
self.else_table = else_table
self.tables = [self.prefix_table, self.postfix_table, self.else_table]
self.tags = ["prefix_", "postfix_", "else_"]
def match_table(self, key) -> bool:
if self.prefix_table is not None:
for prefix in self.prefix_table:
if key.startswith(prefix):
return 0, prefix
if self.postfix_table is not None:
for postfix in self.postfix_table:
if key.endswith(postfix):
return 1, postfix
if self.else_table is not None:
for else_key in self.else_table:
if else_key in key and self.tables[2][else_key] is not None:
return 2, else_key
return 3, None
def get_lr(self, key) -> float:
table_idx, match_key = self.match_table(key)
if table_idx == 3:
Log.debug(f'{key} is not matched to any table, use default lr: {self.default_lr}')
return 'default', self.default_lr
else:
Log.debug(f'{key} is matched to table {self.tags[table_idx]}: {match_key}')
return self.tags[table_idx] + match_key, self.tables[table_idx][match_key] |