import argparse import torch from huggingface_hub import login from e1_fastplms.modeling_e1 import E1ForMaskedLM, E1Config model_dict = { 'Profluent-E1-150M': 'Profluent-Bio/E1-150m', 'Profluent-E1-300M': 'Profluent-Bio/E1-300m', 'Profluent-E1-600M': 'Profluent-Bio/E1-600m', } if __name__ == "__main__": # py -m e1_fastplms.get_e1_weights parser = argparse.ArgumentParser() parser.add_argument('--hf_token', type=str, default=None) parser.add_argument("--skip-weights", action="store_true") args = parser.parse_args() if args.hf_token: login(token=args.hf_token) for model_name in model_dict: repo_id = "Synthyra/" + model_name config = E1Config.from_pretrained(model_dict[model_name]) config.auto_map = { "AutoConfig": "modeling_e1.E1Config", "AutoModel": "modeling_e1.E1Model", "AutoModelForMaskedLM": "modeling_e1.E1ForMaskedLM", "AutoModelForSequenceClassification": "modeling_e1.E1ForSequenceClassification", "AutoModelForTokenClassification": "modeling_e1.E1ForTokenClassification" } if args.skip_weights: config.push_to_hub(repo_id) print(f"[skip-weights] uploaded config for {repo_id}") continue model = E1ForMaskedLM.from_pretrained(model_dict[model_name], config=config, dtype=torch.float32) model.push_to_hub(repo_id)