#!/usr/bin/env python3 """ ViroDNABERT2 embedding 示例(适合放在 Hugging Face 模型仓库中)。 流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。 模型加载逻辑参考 script/run_all.py(DNABERT2-virobench 分支): model_type=hyenadna + pytorch_model.bin -> DNABERT-2-117M 基座 + 覆盖预训练权重。 依赖: pip install torch transformers """ import os import random import torch from transformers import AutoModel, AutoTokenizer REPO_DIR = "YDXX/ViroDNABERT2" # 训练导出权重需叠加的 DNABERT-2 基座 BASE_MODEL = "zhihan1996/DNABERT-2-117M" MAX_LENGTH = 512 def _load_custom_weights(model, weights_path: str) -> None: """加载本地 pytorch_model.bin(仅 bert backbone)。""" state = torch.load(weights_path, map_location="cpu", weights_only=True) remapped = {} for key, value in state.items(): if key.startswith("model.model.bert."): remapped[key[len("model.model.bert."):]] = value elif key.startswith("model.bert."): remapped[key[len("model.bert."):]] = value elif key.startswith("bert."): remapped[key[len("bert."):]] = value if not remapped: raise ValueError(f"No DNABERT2 backbone keys found in {weights_path}") model.load_state_dict(remapped, strict=False) def _get_device() -> torch.device: if torch.cuda.is_available(): dev = torch.device("cuda:0") torch.cuda.set_device(0) return dev return torch.device("cpu") def load_model(): device = _get_device() weights_path = os.path.join(REPO_DIR, "pytorch_model.bin") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) model = AutoModel.from_pretrained(BASE_MODEL, trust_remote_code=True) if os.path.isfile(weights_path): _load_custom_weights(model, weights_path) model.to(device).eval() return tokenizer, model, device @torch.no_grad() def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor: enc = tokenizer( sequence.upper(), return_tensors="pt", truncation=True, max_length=MAX_LENGTH, ) input_ids = enc["input_ids"].to(device) attn_mask = enc.get("attention_mask") if attn_mask is None: attn_mask = torch.ones_like(input_ids) else: attn_mask = attn_mask.to(device) out = model(input_ids=input_ids, attention_mask=attn_mask) hidden = out.last_hidden_state # (1, L, H) # mean pooling,排除 special tokens spec = tokenizer.get_special_tokens_mask( input_ids[0].tolist(), already_has_special_tokens=True ) valid = attn_mask.bool() & ~torch.tensor(spec, device=device, dtype=torch.bool) if valid.any(): return hidden[0][valid].mean(dim=0, keepdim=True) return hidden.mean(dim=1) if __name__ == "__main__": seq = "".join(random.choices("ACGT", k=512)) print(f"sequence (len={len(seq)}): {seq[:64]}...") tokenizer, model, device = load_model() emb = get_embedding(seq, tokenizer, model, device) print(f"embedding shape: {tuple(emb.shape)}") print(f"embedding[:8]: {emb[0, :8].cpu().tolist()}")