|
|
|
|
|
""" |
|
|
Extract reverse complement embeddings for TBX5 motif data using Evo2 40B model. |
|
|
- Extract embeddings from block 20 pre-normalization layer |
|
|
- Use 8192bp window around motif site |
|
|
- Average embeddings for 61bp sequences (reverse complement) |
|
|
- Create 4096 dimensional feature vector for each motif |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import gzip |
|
|
from Bio import SeqIO |
|
|
from Bio.Seq import Seq |
|
|
from evo2 import Evo2 |
|
|
import pickle |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
|
|
|
|
|
|
tqdm.pandas() |
|
|
|
|
|
|
|
|
WINDOW_SIZE = 8192 |
|
|
LAYER_NAME = "blocks.26.mlp.l3" |
|
|
SEQUENCE_LENGTH = 61 |
|
|
BATCH_SIZE = 8 |
|
|
|
|
|
def get_reverse_complement(sequence): |
|
|
"""Get reverse complement of DNA sequence.""" |
|
|
return str(Seq(sequence).reverse_complement()) |
|
|
|
|
|
def load_fasta(fasta_path, chromosome): |
|
|
"""Load chromosome FASTA file.""" |
|
|
print(f"Loading chromosome {chromosome} FASTA file...") |
|
|
with gzip.open(fasta_path, "rt") as handle: |
|
|
for record in SeqIO.parse(handle, "fasta"): |
|
|
seq = str(record.seq).upper() |
|
|
print(f"Loaded chromosome {chromosome}, length: {len(seq):,} bp") |
|
|
return seq |
|
|
return None |
|
|
|
|
|
def normalize_sequence_length(df): |
|
|
"""Normalize all sequences to 61bp length.""" |
|
|
print("Normalizing sequence lengths to 61bp...") |
|
|
|
|
|
df_normalized = df.copy() |
|
|
|
|
|
for idx, row in df_normalized.iterrows(): |
|
|
start = row['start'] |
|
|
end = row['end'] |
|
|
current_length = end - start + 1 |
|
|
|
|
|
if current_length != SEQUENCE_LENGTH: |
|
|
if current_length < SEQUENCE_LENGTH: |
|
|
|
|
|
extension = SEQUENCE_LENGTH - current_length |
|
|
new_start = max(0, start - extension // 2) |
|
|
new_end = new_start + SEQUENCE_LENGTH - 1 |
|
|
else: |
|
|
|
|
|
excess = current_length - SEQUENCE_LENGTH |
|
|
new_start = start + excess // 2 |
|
|
new_end = new_start + SEQUENCE_LENGTH - 1 |
|
|
|
|
|
df_normalized.at[idx, 'start'] = new_start |
|
|
df_normalized.at[idx, 'end'] = new_end |
|
|
df_normalized.at[idx, 'length'] = SEQUENCE_LENGTH |
|
|
|
|
|
print(f"Normalized {len(df_normalized)} sequences to {SEQUENCE_LENGTH}bp") |
|
|
return df_normalized |
|
|
|
|
|
def get_sequence_window(chr_seq, start, end, window_size=WINDOW_SIZE): |
|
|
""" |
|
|
Extract sequence window around motif site. |
|
|
|
|
|
Args: |
|
|
chr_seq: Full chromosome sequence |
|
|
start: Start position of motif (1-based) |
|
|
end: End position of motif (1-based) |
|
|
window_size: Size of window around motif (default 8192bp) |
|
|
|
|
|
Returns: |
|
|
seq_window: Sequence window around motif |
|
|
motif_pos: Position of motif in the window |
|
|
""" |
|
|
|
|
|
start_0 = start - 1 |
|
|
end_0 = end - 1 |
|
|
|
|
|
|
|
|
motif_center = (start_0 + end_0) // 2 |
|
|
|
|
|
|
|
|
half_window = window_size // 2 |
|
|
window_start = max(0, motif_center - half_window) |
|
|
window_end = min(len(chr_seq), motif_center + half_window) |
|
|
|
|
|
|
|
|
seq_window = chr_seq[window_start:window_end] |
|
|
|
|
|
|
|
|
motif_start_in_window = start_0 - window_start |
|
|
motif_end_in_window = end_0 - window_start |
|
|
|
|
|
return seq_window, motif_start_in_window, motif_end_in_window |
|
|
|
|
|
def extract_embeddings_batch(model, sequences, layer_name=LAYER_NAME): |
|
|
""" |
|
|
Extract embeddings for a batch of sequences. |
|
|
|
|
|
Args: |
|
|
model: Evo2 model |
|
|
sequences: List of DNA sequences |
|
|
layer_name: Name of layer to extract embeddings from |
|
|
|
|
|
Returns: |
|
|
embeddings: Averaged embeddings for each sequence |
|
|
""" |
|
|
all_embeddings = [] |
|
|
|
|
|
for seq in sequences: |
|
|
|
|
|
input_ids = ( |
|
|
torch.tensor( |
|
|
model.tokenizer.tokenize(seq), |
|
|
dtype=torch.int, |
|
|
) |
|
|
.unsqueeze(0) |
|
|
.to("cuda:0") |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
_, embeddings = model( |
|
|
input_ids, return_embeddings=True, layer_names=[layer_name] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
avg_embedding = embeddings[layer_name].mean(dim=1).float().cpu().numpy() |
|
|
all_embeddings.append(avg_embedding) |
|
|
|
|
|
return np.vstack(all_embeddings) |
|
|
|
|
|
def process_motifs(model, chr_seq, motif_df, chromosome): |
|
|
""" |
|
|
Process all motifs and extract reverse complement embeddings. |
|
|
|
|
|
Args: |
|
|
model: Evo2 model |
|
|
chr_seq: Chromosome sequence |
|
|
motif_df: DataFrame with motif information |
|
|
chromosome: Chromosome identifier |
|
|
|
|
|
Returns: |
|
|
embeddings_dict: Dictionary with motif indices as keys and embeddings as values |
|
|
""" |
|
|
embeddings_dict = {} |
|
|
failed_motifs = [] |
|
|
|
|
|
print(f"Processing {len(motif_df)} motifs on chromosome {chromosome} (reverse complement)...") |
|
|
|
|
|
for idx, row in tqdm( |
|
|
motif_df.iterrows(), |
|
|
total=len(motif_df), |
|
|
desc=f"Chr{chromosome} RC embeddings", |
|
|
ncols=100, |
|
|
leave=True, |
|
|
position=0 |
|
|
): |
|
|
try: |
|
|
|
|
|
start = int(row['start']) |
|
|
end = int(row['end']) |
|
|
|
|
|
|
|
|
seq_window, motif_start, motif_end = get_sequence_window( |
|
|
chr_seq, start, end |
|
|
) |
|
|
|
|
|
if seq_window is None: |
|
|
failed_motifs.append(idx) |
|
|
continue |
|
|
|
|
|
|
|
|
motif_seq = seq_window[motif_start:motif_end+1] |
|
|
|
|
|
|
|
|
if len(motif_seq) != SEQUENCE_LENGTH: |
|
|
print(f"Warning: Motif length {len(motif_seq)} != {SEQUENCE_LENGTH} at position {start}-{end}") |
|
|
failed_motifs.append(idx) |
|
|
continue |
|
|
|
|
|
|
|
|
motif_seq_rc = get_reverse_complement(motif_seq) |
|
|
|
|
|
|
|
|
embeddings = extract_embeddings_batch(model, [motif_seq_rc]) |
|
|
|
|
|
|
|
|
motif_embedding = embeddings[0] |
|
|
|
|
|
embeddings_dict[idx] = { |
|
|
"start": start, |
|
|
"end": end, |
|
|
"embedding": motif_embedding, |
|
|
"tbx5_score": row.get("tbx5_score", 0), |
|
|
"label": row.get("label", 0), |
|
|
"chromosome": chromosome, |
|
|
"sequence_type": "reverse_complement", |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing motif at index {idx}: {e}") |
|
|
failed_motifs.append(idx) |
|
|
continue |
|
|
|
|
|
print(f"Successfully processed {len(embeddings_dict)} motifs (reverse complement)") |
|
|
if failed_motifs: |
|
|
print(f"Failed to process {len(failed_motifs)} motifs: {failed_motifs[:10]}...") |
|
|
|
|
|
return embeddings_dict |
|
|
|
|
|
def save_embeddings(embeddings_dict, output_path, chromosome): |
|
|
"""Save embeddings to file.""" |
|
|
print(f"Saving reverse complement embeddings to {output_path}") |
|
|
|
|
|
|
|
|
save_data = { |
|
|
"embeddings": {}, |
|
|
"metadata": { |
|
|
"chromosome": chromosome, |
|
|
"window_size": WINDOW_SIZE, |
|
|
"sequence_length": SEQUENCE_LENGTH, |
|
|
"layer_name": LAYER_NAME, |
|
|
"embedding_dim": 4096, |
|
|
"num_motifs": len(embeddings_dict), |
|
|
"sequence_type": "reverse_complement", |
|
|
}, |
|
|
} |
|
|
|
|
|
for idx, data in embeddings_dict.items(): |
|
|
save_data["embeddings"][idx] = data |
|
|
|
|
|
|
|
|
with open(output_path, "wb") as f: |
|
|
pickle.dump(save_data, f) |
|
|
|
|
|
|
|
|
np_output = output_path.replace(".pkl", "_arrays.npz") |
|
|
|
|
|
|
|
|
indices = [] |
|
|
starts = [] |
|
|
ends = [] |
|
|
embeddings = [] |
|
|
tbx5_scores = [] |
|
|
labels = [] |
|
|
|
|
|
for idx, data in embeddings_dict.items(): |
|
|
indices.append(idx) |
|
|
starts.append(data["start"]) |
|
|
ends.append(data["end"]) |
|
|
embeddings.append(data["embedding"]) |
|
|
tbx5_scores.append(data["tbx5_score"]) |
|
|
labels.append(data["label"]) |
|
|
|
|
|
if len(embeddings) > 0: |
|
|
np.savez_compressed( |
|
|
np_output, |
|
|
indices=np.array(indices), |
|
|
starts=np.array(starts), |
|
|
ends=np.array(ends), |
|
|
embeddings=np.vstack(embeddings), |
|
|
tbx5_scores=np.array(tbx5_scores), |
|
|
labels=np.array(labels), |
|
|
metadata=save_data["metadata"], |
|
|
) |
|
|
print(f"Saved numpy arrays to {np_output}") |
|
|
else: |
|
|
print("No embeddings to save in numpy format") |
|
|
|
|
|
def main(): |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Extract reverse complement embeddings for TBX5 motif data" |
|
|
) |
|
|
parser.add_argument( |
|
|
"chromosome", type=str, help="Chromosome to process (e.g., 1, 2, X, Y)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fasta-dir", |
|
|
type=str, |
|
|
default="fasta", |
|
|
help="Directory containing FASTA files (default: fasta)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--csv-file", |
|
|
type=str, |
|
|
default="processed_data/all_tbx5_data.csv", |
|
|
help="TBX5 CSV file (default: processed_data/all_tbx5_data.csv)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="tbx5_embeddings_reverse_complement", |
|
|
help="Output directory for reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
default="evo2_40b", |
|
|
help="Evo2 model to use (default: evo2_40b)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
chromosome = args.chromosome |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
fasta_path = os.path.join( |
|
|
args.fasta_dir, f"Homo_sapiens.GRCh38.dna.chromosome.{chromosome}.fa.gz" |
|
|
) |
|
|
csv_path = args.csv_file |
|
|
output_path = os.path.join(args.output_dir, f"chr{chromosome}_tbx5_embeddings_rc.pkl") |
|
|
|
|
|
|
|
|
if not os.path.exists(fasta_path): |
|
|
print(f"Error: FASTA file not found at {fasta_path}") |
|
|
return 1 |
|
|
|
|
|
if not os.path.exists(csv_path): |
|
|
print(f"Error: CSV file not found at {csv_path}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
chr_seq = load_fasta(fasta_path, chromosome) |
|
|
if chr_seq is None: |
|
|
print(f"Error: Failed to load chromosome {chromosome} sequence") |
|
|
return 1 |
|
|
|
|
|
|
|
|
print(f"Loading TBX5 data for chromosome {chromosome}...") |
|
|
motif_df = pd.read_csv(csv_path) |
|
|
|
|
|
|
|
|
chr_motif_df = motif_df[motif_df['chromosome'] == chromosome].copy() |
|
|
|
|
|
if len(chr_motif_df) == 0: |
|
|
print(f"Warning: No chromosome {chromosome} motifs found in TBX5 data") |
|
|
|
|
|
save_data = { |
|
|
"embeddings": {}, |
|
|
"metadata": { |
|
|
"chromosome": chromosome, |
|
|
"window_size": WINDOW_SIZE, |
|
|
"sequence_length": SEQUENCE_LENGTH, |
|
|
"layer_name": LAYER_NAME, |
|
|
"embedding_dim": 4096, |
|
|
"num_motifs": 0, |
|
|
"sequence_type": "reverse_complement", |
|
|
}, |
|
|
} |
|
|
with open(output_path, "wb") as f: |
|
|
pickle.dump(save_data, f) |
|
|
print(f"Created empty reverse complement embeddings file for chromosome {chromosome}") |
|
|
return 0 |
|
|
|
|
|
print(f"Found {len(chr_motif_df)} motifs on chromosome {chromosome}") |
|
|
|
|
|
|
|
|
chr_motif_df = normalize_sequence_length(chr_motif_df) |
|
|
|
|
|
|
|
|
print(f"Loading {args.model} model...") |
|
|
model = Evo2(args.model) |
|
|
model.model.eval() |
|
|
|
|
|
|
|
|
embeddings_dict = process_motifs(model, chr_seq, chr_motif_df, chromosome) |
|
|
|
|
|
|
|
|
save_embeddings(embeddings_dict, output_path, chromosome) |
|
|
|
|
|
print(f"Done processing chromosome {chromosome} (reverse complement)!") |
|
|
|
|
|
|
|
|
print(f"\n=== Summary for Chromosome {chromosome} (Reverse Complement) ===") |
|
|
print(f"Total motifs processed: {len(embeddings_dict)}") |
|
|
print(f"Embedding dimension: 4096") |
|
|
print(f"Sequence length: {SEQUENCE_LENGTH}bp") |
|
|
print(f"Window size: {WINDOW_SIZE}bp") |
|
|
print(f"Sequence type: Reverse complement") |
|
|
print(f"Output files:") |
|
|
print(f" - {output_path}") |
|
|
print(f" - {output_path.replace('.pkl', '_arrays.npz')}") |
|
|
|
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|