File size: 1,935 Bytes
dc41c89
 
 
 
 
 
165f130
 
 
 
 
dc41c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165f130
 
 
dc41c89
165f130
 
 
 
 
 
 
 
 
 
 
dc41c89
165f130
dc41c89
 
165f130
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Color manipulation functions
def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip('#')
    r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
    return r, g, b

def blend_colors(color1, color2, blend_strength):
    rgb1 = hex_to_rgb(color1)
    rgb2 = hex_to_rgb(color2)
    new_color = tuple(map(lambda i: int(rgb1[i] * blend_strength + rgb2[i] * (1 - blend_strength)), range(3)))
    return rgb_to_hex(*new_color)

def increase_brightness(r, g, b, factor):
    return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))

def decrease_brightness(r, g, b, factor):
    return tuple(map(lambda x: int(x * factor), (r, g, b)))

def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
    gray = 0.299 * r + 0.587 * g + 0.114 * b
    return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))

def rgb_to_hex(r, g, b):
    r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
    return f"#{r:02x}{g:02x}{b:02x}"


# Color assignment function
def get_color_for_config(config: dict):
    attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
    barycenter = 1 - (config["compilation"] + 2 * config["kernelize"]) / 3

    # Eager 
    if attn_implementation == "eager":
        color = blend_colors("#FA7F7FFF", "#FF2D2DFF", barycenter)

    # SDPA - math
    elif attn_implementation == "sdpa" and sdpa_backend == "math":
        color = blend_colors("#7AB8FFFF", "#277CD0FF", barycenter)

    # SDPA - flash attention
    elif attn_implementation == "sdpa" and sdpa_backend == "flash_attention":
        color = blend_colors("#81FF9CFF", "#219F3CFF", barycenter)

    # Flash attention
    elif attn_implementation == "flash_attention_2":
        color = blend_colors("#FFDB70FF", "#DFD002FF", barycenter)
    else:
        raise ValueError(f"Unknown attention implementation: {attn_implementation}")
    
    return color