File size: 2,905 Bytes
6176175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
"""
ViroCaduceus embedding 示例(适合放在 Hugging Face 模型仓库中)。

流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。

依赖: pip install torch transformers mamba-ssm causal-conv1d
"""

import os
import random

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

REPO_DIR = "YDXX/ViroCaduceus"
# 训练导出权重 (hyenadna + pytorch_model.bin) 需叠加的 Caduceus 基座
BASE_MODEL = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"


def _load_custom_weights(model, weights_path: str) -> None:
    state = torch.load(weights_path, map_location="cpu", weights_only=True)
    remapped = {}
    for key, value in state.items():
        if key.startswith("model.model.caduceus."):
            remapped["caduceus." + key[len("model.model.caduceus."):]] = value
        elif key.startswith("caduceus."):
            remapped[key] = value
    model.load_state_dict(remapped, strict=False)


def _get_device() -> torch.device:
    if torch.cuda.is_available():
        dev = torch.device("cuda:0")
        torch.cuda.set_device(0)
        return dev
    return torch.device("cpu")


def load_model():
    device = _get_device()
    weights_path = os.path.join(REPO_DIR, "pytorch_model.bin")

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
    model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL, trust_remote_code=True)
    if os.path.isfile(weights_path):
        _load_custom_weights(model, weights_path)

    model.to(device).eval()
    return tokenizer, model, device


@torch.no_grad()
def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor:
    enc = tokenizer(
        sequence.upper(),
        return_tensors="pt",
        truncation=True,
        max_length=1024,
    )
    input_ids = enc["input_ids"].to(device, non_blocking=False)

    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        pad_id = getattr(tokenizer, "_vocab_str_to_int", {}).get("[PAD]", 4)
    mask = input_ids.ne(pad_id)
    for sid in tokenizer.all_special_ids:
        mask = mask & input_ids.ne(sid)

    if device.type == "cuda":
        with torch.cuda.device(device):
            out = model(input_ids=input_ids, output_hidden_states=True)
    else:
        out = model(input_ids=input_ids, output_hidden_states=True)
    hidden = out.hidden_states[-1]  # (1, L, D)
    mask_f = mask.unsqueeze(-1).float()
    return (hidden * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1.0)


if __name__ == "__main__":
    seq = "".join(random.choices("ACGT", k=512))
    print(f"sequence (len={len(seq)}): {seq[:64]}...")

    tokenizer, model, device = load_model()
    emb = get_embedding(seq, tokenizer, model, device)

    print(f"embedding shape: {tuple(emb.shape)}")
    print(f"embedding[:8]:   {emb[0, :8].cpu().tolist()}")