File size: 1,441 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import argparse
import torch    
from huggingface_hub import login

from e1_fastplms.modeling_e1 import E1ForMaskedLM, E1Config


model_dict = {
    'Profluent-E1-150M': 'Profluent-Bio/E1-150m',
    'Profluent-E1-300M': 'Profluent-Bio/E1-300m',
    'Profluent-E1-600M': 'Profluent-Bio/E1-600m',
}


if __name__ == "__main__":
    # py -m e1_fastplms.get_e1_weights

    parser = argparse.ArgumentParser()
    parser.add_argument('--hf_token', type=str, default=None)
    parser.add_argument("--skip-weights", action="store_true")
    args = parser.parse_args()

    if args.hf_token:
        login(token=args.hf_token)

    for model_name in model_dict:
        repo_id = "Synthyra/" + model_name
        config = E1Config.from_pretrained(model_dict[model_name])
        config.auto_map = {
            "AutoConfig": "modeling_e1.E1Config",
            "AutoModel": "modeling_e1.E1Model",
            "AutoModelForMaskedLM": "modeling_e1.E1ForMaskedLM",
            "AutoModelForSequenceClassification": "modeling_e1.E1ForSequenceClassification",
            "AutoModelForTokenClassification": "modeling_e1.E1ForTokenClassification"
        }
        if args.skip_weights:
            config.push_to_hub(repo_id)
            print(f"[skip-weights] uploaded config for {repo_id}")
            continue
        model = E1ForMaskedLM.from_pretrained(model_dict[model_name], config=config, dtype=torch.float32)
        model.push_to_hub(repo_id)