| |
| """ |
| ViroDNABERT2 embedding 示例(适合放在 Hugging Face 模型仓库中)。 |
| |
| 流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。 |
| |
| 模型加载逻辑参考 script/run_all.py(DNABERT2-virobench 分支): |
| model_type=hyenadna + pytorch_model.bin -> DNABERT-2-117M 基座 + 覆盖预训练权重。 |
| |
| 依赖: pip install torch transformers |
| |
| """ |
|
|
| import os |
| import random |
|
|
| import torch |
| from transformers import AutoModel, AutoTokenizer |
|
|
| REPO_DIR = "YDXX/ViroDNABERT2" |
| |
| BASE_MODEL = "zhihan1996/DNABERT-2-117M" |
| MAX_LENGTH = 512 |
|
|
|
|
| def _load_custom_weights(model, weights_path: str) -> None: |
| """加载本地 pytorch_model.bin(仅 bert backbone)。""" |
| state = torch.load(weights_path, map_location="cpu", weights_only=True) |
| remapped = {} |
| for key, value in state.items(): |
| if key.startswith("model.model.bert."): |
| remapped[key[len("model.model.bert."):]] = value |
| elif key.startswith("model.bert."): |
| remapped[key[len("model.bert."):]] = value |
| elif key.startswith("bert."): |
| remapped[key[len("bert."):]] = value |
| if not remapped: |
| raise ValueError(f"No DNABERT2 backbone keys found in {weights_path}") |
| model.load_state_dict(remapped, strict=False) |
|
|
|
|
| def _get_device() -> torch.device: |
| if torch.cuda.is_available(): |
| dev = torch.device("cuda:0") |
| torch.cuda.set_device(0) |
| return dev |
| return torch.device("cpu") |
|
|
|
|
| def load_model(): |
| device = _get_device() |
| weights_path = os.path.join(REPO_DIR, "pytorch_model.bin") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
| model = AutoModel.from_pretrained(BASE_MODEL, trust_remote_code=True) |
| if os.path.isfile(weights_path): |
| _load_custom_weights(model, weights_path) |
|
|
| model.to(device).eval() |
| return tokenizer, model, device |
|
|
|
|
| @torch.no_grad() |
| def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor: |
| enc = tokenizer( |
| sequence.upper(), |
| return_tensors="pt", |
| truncation=True, |
| max_length=MAX_LENGTH, |
| ) |
| input_ids = enc["input_ids"].to(device) |
| attn_mask = enc.get("attention_mask") |
| if attn_mask is None: |
| attn_mask = torch.ones_like(input_ids) |
| else: |
| attn_mask = attn_mask.to(device) |
|
|
| out = model(input_ids=input_ids, attention_mask=attn_mask) |
| hidden = out.last_hidden_state |
|
|
| |
| spec = tokenizer.get_special_tokens_mask( |
| input_ids[0].tolist(), already_has_special_tokens=True |
| ) |
| valid = attn_mask.bool() & ~torch.tensor(spec, device=device, dtype=torch.bool) |
| if valid.any(): |
| return hidden[0][valid].mean(dim=0, keepdim=True) |
| return hidden.mean(dim=1) |
|
|
|
|
| if __name__ == "__main__": |
| seq = "".join(random.choices("ACGT", k=512)) |
| print(f"sequence (len={len(seq)}): {seq[:64]}...") |
|
|
| tokenizer, model, device = load_model() |
| emb = get_embedding(seq, tokenizer, model, device) |
|
|
| print(f"embedding shape: {tuple(emb.shape)}") |
| print(f"embedding[:8]: {emb[0, :8].cpu().tolist()}") |
|
|