DNABERT_save / examples /gen_cCRE_emb_final.py
nancyH's picture
Upload folder using huggingface_hub
ab6c03c verified
import os
import torch
import numpy as np
from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
from Bio import SeqIO
from tqdm import tqdm
# ========== CONFIG ==========
MODEL_DIR = "/home/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000"
FASTA_DIR = "/home/n5huang/dna_token/cCRE_classes/chr1_files"
OUTPUT_DIR = "/home/n5huang/dna_token/outputs_cCREemb/"
os.makedirs(OUTPUT_DIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_CLASSES = {"dna": (BertConfig, BertForMaskedLM, DNATokenizer)}
# ========== LOAD MODEL ==========
def load_model(model_dir):
config_class, model_class, tokenizer_class = MODEL_CLASSES['dna']
print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")
config = config_class.from_pretrained(model_dir)
model = BertModel.from_pretrained(model_dir, config=config)
tokenizer = tokenizer_class.from_pretrained(model_dir)
model.to(DEVICE)
model.eval()
print(f"โœ… Model loaded on {DEVICE}, vocab size = {len(tokenizer)}")
return model, tokenizer
# ========== SEQUENCE HELPERS ==========
def seq_to_kmers(seq, k=6):
seq = seq.upper().replace("N", "")
if len(seq) < k:
return ""
return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])
def get_fasta_sequences(fasta_file):
sequences = []
for record in SeqIO.parse(fasta_file, "fasta"):
seq = str(record.seq).upper()
if len(seq) >= 50:
sequences.append(seq)
return sequences
# ========== EMBEDDING GENERATION ==========
def get_cls_embeddings(batch_seqs, model, tokenizer, device, max_len=512):
inputs = tokenizer.batch_encode_plus(
batch_seqs,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt"
)
# Move tensors to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
# Extract CLS embedding
cls_embeddings = outputs[0][:, 0, :].cpu().numpy()
return cls_embeddings
# ========== MAIN EXECUTION ==========
def main():
model, tokenizer = load_model(MODEL_DIR)
fasta_files = [f for f in os.listdir(FASTA_DIR) if f.endswith(".fa")]
print(f"\nFound {len(fasta_files)} FASTA files in {FASTA_DIR}")
for fasta_file in fasta_files:
fasta_path = os.path.join(FASTA_DIR, fasta_file)
print(f"\n๐Ÿš€ Processing: {fasta_file}")
sequences = get_fasta_sequences(fasta_path)
if len(sequences) == 0:
print(f"โš ๏ธ No valid sequences found in {fasta_file}")
continue
# --- Remove duplicates ---
unique_sequences = list(set(sequences))
if len(unique_sequences) < len(sequences):
print(f"โš ๏ธ Removed {len(sequences) - len(unique_sequences)} duplicate sequences")
# --- Convert to k-mers ---
kmers = [seq_to_kmers(s) for s in unique_sequences if len(s) >= 6]
# --- Sanity check on tokenization ---
example_tokens = tokenizer.tokenize(kmers[0])[:10]
print(f"๐Ÿ”น Example tokens: {example_tokens}")
# --- Batch embedding extraction ---
all_embs = []
batch_size = 16
for i in tqdm(range(0, len(kmers), batch_size), desc=f"Embedding {fasta_file}"):
batch = kmers[i:i+batch_size]
batch_embs = get_cls_embeddings(batch, model, tokenizer, DEVICE)
all_embs.append(batch_embs)
all_embs = np.vstack(all_embs)
out_path = os.path.join(OUTPUT_DIR, fasta_file.replace(".fa", "_emb.npy"))
np.save(out_path, all_embs)
print(f"โœ… Saved {all_embs.shape} embeddings to {out_path}")
print("\n๐ŸŽ‰ All cell-type embeddings generated successfully!")
if __name__ == "__main__":
main()