File size: 864 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Model names
ESM3_OPEN_SMALL = "esm3_sm_open_v1"
ESM3_OPEN_SMALL_ALIAS_1 = "esm3-open-2024-03"
ESM3_OPEN_SMALL_ALIAS_2 = "esm3-sm-open-v1"
ESM3_OPEN_SMALL_ALIAS_3 = "esm3-open"
ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0"
ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0"
ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0"
ESMC_600M = "esmc_600m"
ESMC_300M = "esmc_300m"


def forge_only_return_single_layer_hidden_states(model_name: str):
    return model_name.startswith("esmc-6b")


def model_is_locally_supported(x: str):
    return x in {
        ESM3_OPEN_SMALL,
        ESM3_OPEN_SMALL_ALIAS_1,
        ESM3_OPEN_SMALL_ALIAS_2,
        ESM3_OPEN_SMALL_ALIAS_3,
    }


def normalize_model_name(x: str):
    if x in {ESM3_OPEN_SMALL_ALIAS_1, ESM3_OPEN_SMALL_ALIAS_2, ESM3_OPEN_SMALL_ALIAS_3}:
        return ESM3_OPEN_SMALL
    return x