| import os |
| import torch |
| import numpy as np |
| from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer |
| from Bio import SeqIO |
| from tqdm import tqdm |
|
|
| |
| MODEL_DIR = "/home/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000" |
| FASTA_DIR = "/home/n5huang/dna_token/cCRE_classes/chr1_files" |
| OUTPUT_DIR = "/home/n5huang/dna_token/outputs_cCREemb/" |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| MODEL_CLASSES = {"dna": (BertConfig, BertForMaskedLM, DNATokenizer)} |
|
|
| |
| def load_model(model_dir): |
| config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] |
| print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}") |
|
|
| config = config_class.from_pretrained(model_dir) |
| model = BertModel.from_pretrained(model_dir, config=config) |
| tokenizer = tokenizer_class.from_pretrained(model_dir) |
|
|
| model.to(DEVICE) |
| model.eval() |
|
|
| print(f"โ
Model loaded on {DEVICE}, vocab size = {len(tokenizer)}") |
| return model, tokenizer |
|
|
| |
| def seq_to_kmers(seq, k=6): |
| seq = seq.upper().replace("N", "") |
| if len(seq) < k: |
| return "" |
| return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)]) |
|
|
| def get_fasta_sequences(fasta_file): |
| sequences = [] |
| for record in SeqIO.parse(fasta_file, "fasta"): |
| seq = str(record.seq).upper() |
| if len(seq) >= 50: |
| sequences.append(seq) |
| return sequences |
|
|
| |
| def get_cls_embeddings(batch_seqs, model, tokenizer, device, max_len=512): |
| inputs = tokenizer.batch_encode_plus( |
| batch_seqs, |
| padding="max_length", |
| truncation=True, |
| max_length=max_len, |
| return_tensors="pt" |
| ) |
| |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| |
| cls_embeddings = outputs[0][:, 0, :].cpu().numpy() |
| return cls_embeddings |
|
|
| |
| def main(): |
| model, tokenizer = load_model(MODEL_DIR) |
|
|
| fasta_files = [f for f in os.listdir(FASTA_DIR) if f.endswith(".fa")] |
| print(f"\nFound {len(fasta_files)} FASTA files in {FASTA_DIR}") |
|
|
| for fasta_file in fasta_files: |
| fasta_path = os.path.join(FASTA_DIR, fasta_file) |
| print(f"\n๐ Processing: {fasta_file}") |
|
|
| sequences = get_fasta_sequences(fasta_path) |
| if len(sequences) == 0: |
| print(f"โ ๏ธ No valid sequences found in {fasta_file}") |
| continue |
|
|
| |
| unique_sequences = list(set(sequences)) |
| if len(unique_sequences) < len(sequences): |
| print(f"โ ๏ธ Removed {len(sequences) - len(unique_sequences)} duplicate sequences") |
|
|
| |
| kmers = [seq_to_kmers(s) for s in unique_sequences if len(s) >= 6] |
|
|
| |
| example_tokens = tokenizer.tokenize(kmers[0])[:10] |
| print(f"๐น Example tokens: {example_tokens}") |
|
|
| |
| all_embs = [] |
| batch_size = 16 |
| for i in tqdm(range(0, len(kmers), batch_size), desc=f"Embedding {fasta_file}"): |
| batch = kmers[i:i+batch_size] |
| batch_embs = get_cls_embeddings(batch, model, tokenizer, DEVICE) |
| all_embs.append(batch_embs) |
|
|
| all_embs = np.vstack(all_embs) |
| out_path = os.path.join(OUTPUT_DIR, fasta_file.replace(".fa", "_emb.npy")) |
| np.save(out_path, all_embs) |
|
|
| print(f"โ
Saved {all_embs.shape} embeddings to {out_path}") |
|
|
| print("\n๐ All cell-type embeddings generated successfully!") |
|
|
| if __name__ == "__main__": |
| main() |
|
|