| |
| """ |
| ViroCaduceus embedding 示例(适合放在 Hugging Face 模型仓库中)。 |
| |
| 流程:随机 DNA 序列 -> mean pooling -> 打印 embedding 形状与前几维。 |
| |
| 依赖: pip install torch transformers mamba-ssm causal-conv1d |
| """ |
|
|
| import os |
| import random |
|
|
| import torch |
| from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
|
| REPO_DIR = "YDXX/ViroCaduceus" |
| |
| BASE_MODEL = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16" |
|
|
|
|
| def _load_custom_weights(model, weights_path: str) -> None: |
| state = torch.load(weights_path, map_location="cpu", weights_only=True) |
| remapped = {} |
| for key, value in state.items(): |
| if key.startswith("model.model.caduceus."): |
| remapped["caduceus." + key[len("model.model.caduceus."):]] = value |
| elif key.startswith("caduceus."): |
| remapped[key] = value |
| model.load_state_dict(remapped, strict=False) |
|
|
|
|
| def _get_device() -> torch.device: |
| if torch.cuda.is_available(): |
| dev = torch.device("cuda:0") |
| torch.cuda.set_device(0) |
| return dev |
| return torch.device("cpu") |
|
|
|
|
| def load_model(): |
| device = _get_device() |
| weights_path = os.path.join(REPO_DIR, "pytorch_model.bin") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
| model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL, trust_remote_code=True) |
| if os.path.isfile(weights_path): |
| _load_custom_weights(model, weights_path) |
|
|
| model.to(device).eval() |
| return tokenizer, model, device |
|
|
|
|
| @torch.no_grad() |
| def get_embedding(sequence: str, tokenizer, model, device: torch.device) -> torch.Tensor: |
| enc = tokenizer( |
| sequence.upper(), |
| return_tensors="pt", |
| truncation=True, |
| max_length=1024, |
| ) |
| input_ids = enc["input_ids"].to(device, non_blocking=False) |
|
|
| pad_id = tokenizer.pad_token_id |
| if pad_id is None: |
| pad_id = getattr(tokenizer, "_vocab_str_to_int", {}).get("[PAD]", 4) |
| mask = input_ids.ne(pad_id) |
| for sid in tokenizer.all_special_ids: |
| mask = mask & input_ids.ne(sid) |
|
|
| if device.type == "cuda": |
| with torch.cuda.device(device): |
| out = model(input_ids=input_ids, output_hidden_states=True) |
| else: |
| out = model(input_ids=input_ids, output_hidden_states=True) |
| hidden = out.hidden_states[-1] |
| mask_f = mask.unsqueeze(-1).float() |
| return (hidden * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1.0) |
|
|
|
|
| if __name__ == "__main__": |
| seq = "".join(random.choices("ACGT", k=512)) |
| print(f"sequence (len={len(seq)}): {seq[:64]}...") |
|
|
| tokenizer, model, device = load_model() |
| emb = get_embedding(seq, tokenizer, model, device) |
|
|
| print(f"embedding shape: {tuple(emb.shape)}") |
| print(f"embedding[:8]: {emb[0, :8].cpu().tolist()}") |
|
|