File size: 3,899 Bytes
ab6c03c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
import numpy as np
from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
from Bio import SeqIO
from tqdm import tqdm

# ========== CONFIG ==========
MODEL_DIR = "/home/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000"
FASTA_DIR = "/home/n5huang/dna_token/cCRE_classes/chr1_files"
OUTPUT_DIR = "/home/n5huang/dna_token/outputs_cCREemb/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_CLASSES = {"dna": (BertConfig, BertForMaskedLM, DNATokenizer)}

# ========== LOAD MODEL ==========
def load_model(model_dir):
    config_class, model_class, tokenizer_class = MODEL_CLASSES['dna']
    print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")

    config = config_class.from_pretrained(model_dir)
    model = BertModel.from_pretrained(model_dir, config=config)
    tokenizer = tokenizer_class.from_pretrained(model_dir)

    model.to(DEVICE)
    model.eval()

    print(f"✅ Model loaded on {DEVICE}, vocab size = {len(tokenizer)}")
    return model, tokenizer

# ========== SEQUENCE HELPERS ==========
def seq_to_kmers(seq, k=6):
    seq = seq.upper().replace("N", "")
    if len(seq) < k:
        return ""
    return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])

def get_fasta_sequences(fasta_file):
    sequences = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        seq = str(record.seq).upper()
        if len(seq) >= 50:
            sequences.append(seq)
    return sequences

# ========== EMBEDDING GENERATION ==========
def get_cls_embeddings(batch_seqs, model, tokenizer, device, max_len=512):
    inputs = tokenizer.batch_encode_plus(
        batch_seqs,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    # Move tensors to device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extract CLS embedding
    cls_embeddings = outputs[0][:, 0, :].cpu().numpy()
    return cls_embeddings

# ========== MAIN EXECUTION ==========
def main():
    model, tokenizer = load_model(MODEL_DIR)

    fasta_files = [f for f in os.listdir(FASTA_DIR) if f.endswith(".fa")]
    print(f"\nFound {len(fasta_files)} FASTA files in {FASTA_DIR}")

    for fasta_file in fasta_files:
        fasta_path = os.path.join(FASTA_DIR, fasta_file)
        print(f"\n🚀 Processing: {fasta_file}")

        sequences = get_fasta_sequences(fasta_path)
        if len(sequences) == 0:
            print(f"⚠️ No valid sequences found in {fasta_file}")
            continue

        # --- Remove duplicates ---
        unique_sequences = list(set(sequences))
        if len(unique_sequences) < len(sequences):
            print(f"⚠️ Removed {len(sequences) - len(unique_sequences)} duplicate sequences")

        # --- Convert to k-mers ---
        kmers = [seq_to_kmers(s) for s in unique_sequences if len(s) >= 6]

        # --- Sanity check on tokenization ---
        example_tokens = tokenizer.tokenize(kmers[0])[:10]
        print(f"🔹 Example tokens: {example_tokens}")

        # --- Batch embedding extraction ---
        all_embs = []
        batch_size = 16
        for i in tqdm(range(0, len(kmers), batch_size), desc=f"Embedding {fasta_file}"):
            batch = kmers[i:i+batch_size]
            batch_embs = get_cls_embeddings(batch, model, tokenizer, DEVICE)
            all_embs.append(batch_embs)

        all_embs = np.vstack(all_embs)
        out_path = os.path.join(OUTPUT_DIR, fasta_file.replace(".fa", "_emb.npy"))
        np.save(out_path, all_embs)

        print(f"✅ Saved {all_embs.shape} embeddings to {out_path}")

    print("\n🎉 All cell-type embeddings generated successfully!")

if __name__ == "__main__":
    main()