File size: 1,698 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from omegaconf import DictConfig
from ppd.utils.logger import Log

class LRTable:
    def __init__(self,
                 default_lr: float = 1e-4,
                 prefix_table: DictConfig = None,
                 postfix_table: DictConfig = None,
                 else_table: DictConfig = None):
        # 
        self.default_lr = default_lr
        self.prefix_table = prefix_table
        self.postfix_table = postfix_table
        self.else_table = else_table
        self.tables = [self.prefix_table, self.postfix_table, self.else_table]
        self.tags = ["prefix_", "postfix_", "else_"]
        
    def match_table(self, key) -> bool:
        if self.prefix_table is not None:
            for prefix in self.prefix_table:
                if key.startswith(prefix):
                    return 0, prefix
        
        if self.postfix_table is not None:
            for postfix in self.postfix_table:
                if key.endswith(postfix):
                    return 1, postfix
        
        if self.else_table is not None:
            for else_key in self.else_table:
                if else_key in key and self.tables[2][else_key] is not None:
                    return 2, else_key
        
        return 3, None        
        
    def get_lr(self, key) -> float:
        table_idx, match_key = self.match_table(key)
        if table_idx == 3:
            Log.debug(f'{key} is not matched to any table, use default lr: {self.default_lr}')
            return 'default', self.default_lr
        else:
            Log.debug(f'{key} is matched to table {self.tags[table_idx]}: {match_key}')
            return self.tags[table_idx] + match_key, self.tables[table_idx][match_key]