File size: 2,593 Bytes
0c4803b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from constants import *

def print_list(list):
    """print one-dimensional list

    :param list: List[int]
    :return: None
    """
    for i, x in enumerate(list):
        print(x, end='\n')
        
def get_dict_depth(d, depth=0):
    if not isinstance(d, dict):
        return depth
    if not d:
        return depth

    return max(get_dict_depth(v, depth + 1) for v in d.values())

def latency_to_string(latency_in_s, precision=2):
    if latency_in_s is None:
        return "None"
    day = 24 * 60 * 60
    hour = 60 * 60
    minute = 60
    ms = 1 / 1000
    us = 1 / 1000000
    if latency_in_s // day > 0:
        return str(round(latency_in_s / day, precision)) + " days"
    elif latency_in_s // hour > 0:
        return str(round(latency_in_s / hour, precision)) + " hours"
    elif latency_in_s // minute > 0:
        return str(round(latency_in_s / minute, precision)) + " minutes"
    elif latency_in_s > 1:
        return str(round(latency_in_s, precision)) + " s"
    elif latency_in_s > ms:
        return str(round(latency_in_s / ms, precision)) + " ms"
    else:
        return str(round(latency_in_s / us, precision)) + " us"
    
def num_to_string(num, precision=2):
    if num is None:
        return "None"
    if num // 10**12 > 0:
        return str(round(num / 10.0**12, precision)) + " T"
    elif num // 10**9 > 0:
        return str(round(num / 10.0**9, precision)) + " G"
    elif num // 10**6 > 0:
        return str(round(num / 10.0**6, precision)) + " M"
    elif num // 10**3 > 0:
        return str(round(num / 10.0**3, precision)) + " K"
    else:
        return str(num)

def get_readable_summary_dict(summary_dict: dict, title="Summary") -> str:
    log_str = f"\n{title.center(PRINT_LINE_WIDTH, '-')}\n"
    for key, value in summary_dict.items():
        if "num_tokens" in key or "num_params" in key or "flops" in key:
            log_str += f"{key}: {num_to_string(value)}\n"
        elif "gpu_hours" == key:
            log_str += f"{key}: {int(value)}\n"
        elif "memory" in key and "efficiency" not in key:
            log_str += f"{key}: {num_to_string(value)}B\n"
        elif "latency" in key:
            log_str += f"{key}: {latency_to_string(value)}\n"
        else:
            log_str += f"{key}: {value}\n"
    log_str += f"{'-' * PRINT_LINE_WIDTH}\n"
    return log_str

def within_range(val, target, tolerance):
    return abs(val - target) / target < tolerance

def average(lst):
    if not lst:
        return None
    return sum(lst) / len(lst)

def max_value(lst):
    if not lst:
        return None
    return max(lst)