import os import torch import numpy as np from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer from Bio import SeqIO from tqdm import tqdm # ========== CONFIG ========== MODEL_DIR = "/home/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000" FASTA_DIR = "/home/n5huang/dna_token/cCRE_classes/chr1_files" OUTPUT_DIR = "/home/n5huang/dna_token/outputs_cCREemb/" os.makedirs(OUTPUT_DIR, exist_ok=True) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_CLASSES = {"dna": (BertConfig, BertForMaskedLM, DNATokenizer)} # ========== LOAD MODEL ========== def load_model(model_dir): config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}") config = config_class.from_pretrained(model_dir) model = BertModel.from_pretrained(model_dir, config=config) tokenizer = tokenizer_class.from_pretrained(model_dir) model.to(DEVICE) model.eval() print(f"āœ… Model loaded on {DEVICE}, vocab size = {len(tokenizer)}") return model, tokenizer # ========== SEQUENCE HELPERS ========== def seq_to_kmers(seq, k=6): seq = seq.upper().replace("N", "") if len(seq) < k: return "" return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)]) def get_fasta_sequences(fasta_file): sequences = [] for record in SeqIO.parse(fasta_file, "fasta"): seq = str(record.seq).upper() if len(seq) >= 50: sequences.append(seq) return sequences # ========== EMBEDDING GENERATION ========== def get_cls_embeddings(batch_seqs, model, tokenizer, device, max_len=512): inputs = tokenizer.batch_encode_plus( batch_seqs, padding="max_length", truncation=True, max_length=max_len, return_tensors="pt" ) # Move tensors to device inputs = {k: v.to(device) for k, v in inputs.items()} # Forward pass with torch.no_grad(): outputs = model(**inputs) # Extract CLS embedding cls_embeddings = outputs[0][:, 0, :].cpu().numpy() return cls_embeddings # ========== MAIN EXECUTION ========== def main(): model, tokenizer = load_model(MODEL_DIR) fasta_files = [f for f in os.listdir(FASTA_DIR) if f.endswith(".fa")] print(f"\nFound {len(fasta_files)} FASTA files in {FASTA_DIR}") for fasta_file in fasta_files: fasta_path = os.path.join(FASTA_DIR, fasta_file) print(f"\nšŸš€ Processing: {fasta_file}") sequences = get_fasta_sequences(fasta_path) if len(sequences) == 0: print(f"āš ļø No valid sequences found in {fasta_file}") continue # --- Remove duplicates --- unique_sequences = list(set(sequences)) if len(unique_sequences) < len(sequences): print(f"āš ļø Removed {len(sequences) - len(unique_sequences)} duplicate sequences") # --- Convert to k-mers --- kmers = [seq_to_kmers(s) for s in unique_sequences if len(s) >= 6] # --- Sanity check on tokenization --- example_tokens = tokenizer.tokenize(kmers[0])[:10] print(f"šŸ”¹ Example tokens: {example_tokens}") # --- Batch embedding extraction --- all_embs = [] batch_size = 16 for i in tqdm(range(0, len(kmers), batch_size), desc=f"Embedding {fasta_file}"): batch = kmers[i:i+batch_size] batch_embs = get_cls_embeddings(batch, model, tokenizer, DEVICE) all_embs.append(batch_embs) all_embs = np.vstack(all_embs) out_path = os.path.join(OUTPUT_DIR, fasta_file.replace(".fa", "_emb.npy")) np.save(out_path, all_embs) print(f"āœ… Saved {all_embs.shape} embeddings to {out_path}") print("\nšŸŽ‰ All cell-type embeddings generated successfully!") if __name__ == "__main__": main()