#!/usr/bin/env python3 """ ViroCaduceus embedding 示例(适合放在 Hugging Face 模型仓库中)。 流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。 依赖: pip install torch transformers mamba-ssm causal-conv1d """ import os import random import torch from transformers import AutoModelForMaskedLM, AutoTokenizer REPO_DIR = "YDXX/ViroCaduceus" # 训练导出权重 (hyenadna + pytorch_model.bin) 需叠加的 Caduceus 基座 BASE_MODEL = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16" def _load_custom_weights(model, weights_path: str) -> None: state = torch.load(weights_path, map_location="cpu", weights_only=True) remapped = {} for key, value in state.items(): if key.startswith("model.model.caduceus."): remapped["caduceus." + key[len("model.model.caduceus."):]] = value elif key.startswith("caduceus."): remapped[key] = value model.load_state_dict(remapped, strict=False) def _get_device() -> torch.device: if torch.cuda.is_available(): dev = torch.device("cuda:0") torch.cuda.set_device(0) return dev return torch.device("cpu") def load_model(): device = _get_device() weights_path = os.path.join(REPO_DIR, "pytorch_model.bin") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL, trust_remote_code=True) if os.path.isfile(weights_path): _load_custom_weights(model, weights_path) model.to(device).eval() return tokenizer, model, device @torch.no_grad() def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor: enc = tokenizer( sequence.upper(), return_tensors="pt", truncation=True, max_length=1024, ) input_ids = enc["input_ids"].to(device, non_blocking=False) pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = getattr(tokenizer, "_vocab_str_to_int", {}).get("[PAD]", 4) mask = input_ids.ne(pad_id) for sid in tokenizer.all_special_ids: mask = mask & input_ids.ne(sid) if device.type == "cuda": with torch.cuda.device(device): out = model(input_ids=input_ids, output_hidden_states=True) else: out = model(input_ids=input_ids, output_hidden_states=True) hidden = out.hidden_states[-1] # (1, L, D) mask_f = mask.unsqueeze(-1).float() return (hidden * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1.0) if __name__ == "__main__": seq = "".join(random.choices("ACGT", k=512)) print(f"sequence (len={len(seq)}): {seq[:64]}...") tokenizer, model, device = load_model() emb = get_embedding(seq, tokenizer, model, device) print(f"embedding shape: {tuple(emb.shape)}") print(f"embedding[:8]: {emb[0, :8].cpu().tolist()}")