ViroCaduceus / get_embedding.py
YDXX's picture
Upload ViroCaduceus model
6176175 verified
#!/usr/bin/env python3
"""
ViroCaduceus embedding 示例(适合放在 Hugging Face 模型仓库中)。
流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。
依赖: pip install torch transformers mamba-ssm causal-conv1d
"""
import os
import random
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
REPO_DIR = "YDXX/ViroCaduceus"
# 训练导出权重 (hyenadna + pytorch_model.bin) 需叠加的 Caduceus 基座
BASE_MODEL = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
def _load_custom_weights(model, weights_path: str) -> None:
state = torch.load(weights_path, map_location="cpu", weights_only=True)
remapped = {}
for key, value in state.items():
if key.startswith("model.model.caduceus."):
remapped["caduceus." + key[len("model.model.caduceus."):]] = value
elif key.startswith("caduceus."):
remapped[key] = value
model.load_state_dict(remapped, strict=False)
def _get_device() -> torch.device:
if torch.cuda.is_available():
dev = torch.device("cuda:0")
torch.cuda.set_device(0)
return dev
return torch.device("cpu")
def load_model():
device = _get_device()
weights_path = os.path.join(REPO_DIR, "pytorch_model.bin")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL, trust_remote_code=True)
if os.path.isfile(weights_path):
_load_custom_weights(model, weights_path)
model.to(device).eval()
return tokenizer, model, device
@torch.no_grad()
def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor:
enc = tokenizer(
sequence.upper(),
return_tensors="pt",
truncation=True,
max_length=1024,
)
input_ids = enc["input_ids"].to(device, non_blocking=False)
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = getattr(tokenizer, "_vocab_str_to_int", {}).get("[PAD]", 4)
mask = input_ids.ne(pad_id)
for sid in tokenizer.all_special_ids:
mask = mask & input_ids.ne(sid)
if device.type == "cuda":
with torch.cuda.device(device):
out = model(input_ids=input_ids, output_hidden_states=True)
else:
out = model(input_ids=input_ids, output_hidden_states=True)
hidden = out.hidden_states[-1] # (1, L, D)
mask_f = mask.unsqueeze(-1).float()
return (hidden * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1.0)
if __name__ == "__main__":
seq = "".join(random.choices("ACGT", k=512))
print(f"sequence (len={len(seq)}): {seq[:64]}...")
tokenizer, model, device = load_model()
emb = get_embedding(seq, tokenizer, model, device)
print(f"embedding shape: {tuple(emb.shape)}")
print(f"embedding[:8]: {emb[0, :8].cpu().tolist()}")