ViroDNABERT2 / get_embedding.py
YDXX's picture
Upload ViroDNABERT2 model
7597d0f verified
#!/usr/bin/env python3
"""
ViroDNABERT2 embedding 示例(适合放在 Hugging Face 模型仓库中)。
流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。
模型加载逻辑参考 script/run_all.py(DNABERT2-virobench 分支):
model_type=hyenadna + pytorch_model.bin -> DNABERT-2-117M 基座 + 覆盖预训练权重。
依赖: pip install torch transformers
"""
import os
import random
import torch
from transformers import AutoModel, AutoTokenizer
REPO_DIR = "YDXX/ViroDNABERT2"
# 训练导出权重需叠加的 DNABERT-2 基座
BASE_MODEL = "zhihan1996/DNABERT-2-117M"
MAX_LENGTH = 512
def _load_custom_weights(model, weights_path: str) -> None:
"""加载本地 pytorch_model.bin(仅 bert backbone)。"""
state = torch.load(weights_path, map_location="cpu", weights_only=True)
remapped = {}
for key, value in state.items():
if key.startswith("model.model.bert."):
remapped[key[len("model.model.bert."):]] = value
elif key.startswith("model.bert."):
remapped[key[len("model.bert."):]] = value
elif key.startswith("bert."):
remapped[key[len("bert."):]] = value
if not remapped:
raise ValueError(f"No DNABERT2 backbone keys found in {weights_path}")
model.load_state_dict(remapped, strict=False)
def _get_device() -> torch.device:
if torch.cuda.is_available():
dev = torch.device("cuda:0")
torch.cuda.set_device(0)
return dev
return torch.device("cpu")
def load_model():
device = _get_device()
weights_path = os.path.join(REPO_DIR, "pytorch_model.bin")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
model = AutoModel.from_pretrained(BASE_MODEL, trust_remote_code=True)
if os.path.isfile(weights_path):
_load_custom_weights(model, weights_path)
model.to(device).eval()
return tokenizer, model, device
@torch.no_grad()
def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor:
enc = tokenizer(
sequence.upper(),
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
)
input_ids = enc["input_ids"].to(device)
attn_mask = enc.get("attention_mask")
if attn_mask is None:
attn_mask = torch.ones_like(input_ids)
else:
attn_mask = attn_mask.to(device)
out = model(input_ids=input_ids, attention_mask=attn_mask)
hidden = out.last_hidden_state # (1, L, H)
# mean pooling,排除 special tokens
spec = tokenizer.get_special_tokens_mask(
input_ids[0].tolist(), already_has_special_tokens=True
)
valid = attn_mask.bool() & ~torch.tensor(spec, device=device, dtype=torch.bool)
if valid.any():
return hidden[0][valid].mean(dim=0, keepdim=True)
return hidden.mean(dim=1)
if __name__ == "__main__":
seq = "".join(random.choices("ACGT", k=512))
print(f"sequence (len={len(seq)}): {seq[:64]}...")
tokenizer, model, device = load_model()
emb = get_embedding(seq, tokenizer, model, device)
print(f"embedding shape: {tuple(emb.shape)}")
print(f"embedding[:8]: {emb[0, :8].cpu().tolist()}")