File size: 2,378 Bytes
dc41c89
 
79e7993
dc41c89
 
 
79e7993
165f130
 
 
79e7993
 
 
 
 
 
165f130
dc41c89
79e7993
dc41c89
 
 
79e7993
dc41c89
 
 
79e7993
dc41c89
 
 
 
79e7993
dc41c89
 
 
 
 
 
 
79e7993
 
 
 
 
 
165f130
79e7993
dc41c89
165f130
 
 
 
 
79e7993
165f130
79e7993
 
 
165f130
dc41c89
165f130
79e7993
 
 
dc41c89
 
165f130
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Color manipulation functions
def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip("#")
    r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
    return r, g, b


def blend_colors(color1, color2, blend_strength):
    rgb1 = hex_to_rgb(color1)
    rgb2 = hex_to_rgb(color2)
    new_color = tuple(
        map(
            lambda i: int(rgb1[i] * blend_strength + rgb2[i] * (1 - blend_strength)),
            range(3),
        )
    )
    return rgb_to_hex(*new_color)


def increase_brightness(r, g, b, factor):
    return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))


def decrease_brightness(r, g, b, factor):
    return tuple(map(lambda x: int(x * factor), (r, g, b)))


def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
    gray = 0.299 * r + 0.587 * g + 0.114 * b
    return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))


def rgb_to_hex(r, g, b):
    r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
    return f"#{r:02x}{g:02x}{b:02x}"


# Color assignment function
def get_color_for_config(config: dict):
    attn_implementation, sdpa_backend = (
        config["attn_implementation"],
        config["sdpa_backend"],
    )
    compile_mode = config["compile_mode"] is not None
    barycenter = 1 - (compile_mode + 2 * config["kernelize"]) / 3

    # Eager
    if attn_implementation == "eager":
        color = blend_colors("#FA7F7FFF", "#FF2D2DFF", barycenter)
    # SDPA - math
    elif attn_implementation == "sdpa" and sdpa_backend == "math":
        color = blend_colors("#7AB8FFFF", "#277CD0FF", barycenter)
    # SDPA - flash attention
    elif attn_implementation == "sdpa" and sdpa_backend in [None, "flash_attention"]:
        color = blend_colors("#81FF9CFF", "#219F3CFF", barycenter)
    # SDPA - efficient attention
    elif attn_implementation == "sdpa" and sdpa_backend == "efficient_attention":
        color = blend_colors("#DB81FFFF", "#9C33B1FF", barycenter)
    # Flash attention
    elif attn_implementation == "flash_attention_2":
        color = blend_colors("#FFDB70FF", "#DFD002FF", barycenter)
    # Flex attention
    elif attn_implementation == "flex_attention":
        color = blend_colors("#DB81FFFF", "#9C33B1FF", barycenter)
    else:
        raise ValueError(f"Unknown attention implementation: {attn_implementation}")
    return color