File size: 3,217 Bytes
7597d0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
"""
ViroDNABERT2 embedding 示例(适合放在 Hugging Face 模型仓库中)。

流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。

模型加载逻辑参考 script/run_all.py(DNABERT2-virobench 分支):
  model_type=hyenadna + pytorch_model.bin -> DNABERT-2-117M 基座 + 覆盖预训练权重。

依赖: pip install torch transformers

"""

import os
import random

import torch
from transformers import AutoModel, AutoTokenizer

REPO_DIR = "YDXX/ViroDNABERT2"
# 训练导出权重需叠加的 DNABERT-2 基座
BASE_MODEL = "zhihan1996/DNABERT-2-117M"
MAX_LENGTH = 512


def _load_custom_weights(model, weights_path: str) -> None:
    """加载本地 pytorch_model.bin(仅 bert backbone)。"""
    state = torch.load(weights_path, map_location="cpu", weights_only=True)
    remapped = {}
    for key, value in state.items():
        if key.startswith("model.model.bert."):
            remapped[key[len("model.model.bert."):]] = value
        elif key.startswith("model.bert."):
            remapped[key[len("model.bert."):]] = value
        elif key.startswith("bert."):
            remapped[key[len("bert."):]] = value
    if not remapped:
        raise ValueError(f"No DNABERT2 backbone keys found in {weights_path}")
    model.load_state_dict(remapped, strict=False)


def _get_device() -> torch.device:
    if torch.cuda.is_available():
        dev = torch.device("cuda:0")
        torch.cuda.set_device(0)
        return dev
    return torch.device("cpu")


def load_model():
    device = _get_device()
    weights_path = os.path.join(REPO_DIR, "pytorch_model.bin")

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
    model = AutoModel.from_pretrained(BASE_MODEL, trust_remote_code=True)
    if os.path.isfile(weights_path):
        _load_custom_weights(model, weights_path)

    model.to(device).eval()
    return tokenizer, model, device


@torch.no_grad()
def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor:
    enc = tokenizer(
        sequence.upper(),
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH,
    )
    input_ids = enc["input_ids"].to(device)
    attn_mask = enc.get("attention_mask")
    if attn_mask is None:
        attn_mask = torch.ones_like(input_ids)
    else:
        attn_mask = attn_mask.to(device)

    out = model(input_ids=input_ids, attention_mask=attn_mask)
    hidden = out.last_hidden_state  # (1, L, H)

    # mean pooling,排除 special tokens
    spec = tokenizer.get_special_tokens_mask(
        input_ids[0].tolist(), already_has_special_tokens=True
    )
    valid = attn_mask.bool() & ~torch.tensor(spec, device=device, dtype=torch.bool)
    if valid.any():
        return hidden[0][valid].mean(dim=0, keepdim=True)
    return hidden.mean(dim=1)


if __name__ == "__main__":
    seq = "".join(random.choices("ACGT", k=512))
    print(f"sequence (len={len(seq)}): {seq[:64]}...")

    tokenizer, model, device = load_model()
    emb = get_embedding(seq, tokenizer, model, device)

    print(f"embedding shape: {tuple(emb.shape)}")
    print(f"embedding[:8]:   {emb[0, :8].cpu().tolist()}")