| import copy |
| import os |
| import torch |
| from huggingface_hub import HfApi, login |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig |
|
|
| from esm_plusplus.load_official import load_official_model |
| from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM |
| from weight_parity_utils import assert_state_dict_equal, assert_model_parameters_fp32 |
|
|
|
|
| MODEL_DICT = { |
| "Synthyra/ESMplusplus_small": "esmc-300", |
| "Synthyra/ESMplusplus_large": "esmc-600", |
| } |
|
|
|
|
| def _resolve_repo_items(repo_ids: list[str] | None) -> list[tuple[str, str]]: |
| if repo_ids is None or len(repo_ids) == 0: |
| return list(MODEL_DICT.items()) |
|
|
| selected_items: list[tuple[str, str]] = [] |
| for repo_id in repo_ids: |
| assert repo_id in MODEL_DICT, ( |
| f"Unknown repo_id {repo_id}. " |
| f"Valid options: {sorted(MODEL_DICT.keys())}" |
| ) |
| selected_items.append((repo_id, MODEL_DICT[repo_id])) |
| return selected_items |
|
|
|
|
| if __name__ == "__main__": |
| |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hf_token", type=str, default=None) |
| parser.add_argument("--repo_ids", nargs="*", type=str, default=None) |
| parser.add_argument("--dry_run", action="store_true") |
| parser.add_argument("--skip-weights", action="store_true") |
| args = parser.parse_args() |
| api = HfApi() |
|
|
| if args.hf_token is not None: |
| assert len(args.hf_token) > 0, "--hf_token cannot be empty." |
| login(token=args.hf_token) |
|
|
| script_root = os.path.dirname(os.path.abspath(__file__)) |
|
|
| for repo_id, esmc_model_key in _resolve_repo_items(args.repo_ids): |
| official_model, tokenizer = load_official_model(esmc_model_key, device=torch.device("cpu"), dtype=torch.float32) |
| |
| config = AutoConfig.from_pretrained(repo_id, trust_remote_code=True) |
| config.auto_map = { |
| "AutoConfig": "modeling_esm_plusplus.ESMplusplusConfig", |
| "AutoModel": "modeling_esm_plusplus.ESMplusplusModel", |
| "AutoModelForMaskedLM": "modeling_esm_plusplus.ESMplusplusForMaskedLM", |
| "AutoModelForSequenceClassification": "modeling_esm_plusplus.ESMplusplusForSequenceClassification", |
| "AutoModelForTokenClassification": "modeling_esm_plusplus.ESMplusplusForTokenClassification", |
| } |
| config.tie_word_embeddings = False |
| if args.skip_weights: |
| if args.dry_run: |
| print(f"[skip-weights][dry-run] validated config+tokenizer parity for {repo_id}") |
| continue |
| config.push_to_hub(repo_id) |
| tokenizer.push_to_hub(repo_id) |
| print(f"[skip-weights] uploaded config+tokenizer for {repo_id}") |
| continue |
| model = ESMplusplusForMaskedLM(config=config).eval().cpu().to(torch.float32) |
| load_result = model.load_state_dict(official_model.model.state_dict(), strict=True) |
|
|
| |
| model.sequence_head[0].weight = copy.deepcopy(official_model.model.sequence_head[0].weight) |
| model.sequence_head[0].bias = copy.deepcopy(official_model.model.sequence_head[0].bias) |
| model.sequence_head[2].weight = copy.deepcopy(official_model.model.sequence_head[2].weight) |
| model.sequence_head[2].bias = copy.deepcopy(official_model.model.sequence_head[2].bias) |
| model.sequence_head[3].weight = copy.deepcopy(official_model.model.sequence_head[3].weight) |
| model.sequence_head[3].bias = copy.deepcopy(official_model.model.sequence_head[3].bias) |
| |
| assert_model_parameters_fp32( |
| model=model, |
| model_name=f"mapped ESM++ model ({esmc_model_key})", |
| ) |
| assert_model_parameters_fp32( |
| model=official_model.model, |
| model_name=f"official ESM++ model ({esmc_model_key})", |
| ) |
| assert_state_dict_equal( |
| reference_state_dict=official_model.model.state_dict(), |
| candidate_state_dict=model.state_dict(), |
| context=f"ESMC/ESM++ weight parity ({esmc_model_key})", |
| ) |
|
|
| if args.dry_run: |
| print(f"[dry_run] validated ESM++ parity for {repo_id} <- {esmc_model_key}") |
| continue |
|
|
| tokenizer.push_to_hub(repo_id) |
| model.push_to_hub(repo_id) |
| api.upload_file( |
| path_or_fileobj="esm_plusplus/modeling_esm_plusplus.py", |
| path_in_repo="modeling_esm_plusplus.py", |
| repo_id=repo_id, |
| repo_type="model", |
| ) |
| downloaded_model = AutoModelForMaskedLM.from_pretrained( |
| repo_id, |
| dtype=torch.float32, |
| device_map="cpu", |
| force_download=True, |
| trust_remote_code=True, |
| ) |
| assert_state_dict_equal( |
| reference_state_dict=official_model.model.state_dict(), |
| candidate_state_dict=downloaded_model.state_dict(), |
| context=f"ESMC/ESM++ weight parity post-download ({repo_id})", |
| ) |
|
|