| from omegaconf import DictConfig |
| from ppd.utils.logger import Log |
|
|
| class LRTable: |
| def __init__(self, |
| default_lr: float = 1e-4, |
| prefix_table: DictConfig = None, |
| postfix_table: DictConfig = None, |
| else_table: DictConfig = None): |
| |
| self.default_lr = default_lr |
| self.prefix_table = prefix_table |
| self.postfix_table = postfix_table |
| self.else_table = else_table |
| self.tables = [self.prefix_table, self.postfix_table, self.else_table] |
| self.tags = ["prefix_", "postfix_", "else_"] |
| |
| def match_table(self, key) -> bool: |
| if self.prefix_table is not None: |
| for prefix in self.prefix_table: |
| if key.startswith(prefix): |
| return 0, prefix |
| |
| if self.postfix_table is not None: |
| for postfix in self.postfix_table: |
| if key.endswith(postfix): |
| return 1, postfix |
| |
| if self.else_table is not None: |
| for else_key in self.else_table: |
| if else_key in key and self.tables[2][else_key] is not None: |
| return 2, else_key |
| |
| return 3, None |
| |
| def get_lr(self, key) -> float: |
| table_idx, match_key = self.match_table(key) |
| if table_idx == 3: |
| Log.debug(f'{key} is not matched to any table, use default lr: {self.default_lr}') |
| return 'default', self.default_lr |
| else: |
| Log.debug(f'{key} is matched to table {self.tags[table_idx]}: {match_key}') |
| return self.tags[table_idx] + match_key, self.tables[table_idx][match_key] |