Spaces:
Runtime error
Runtime error
Commit ·
4a303ce
1
Parent(s): 69d7c1c
change cas9
Browse files- app.py +9 -8
- cas12lstm.py +188 -0
- cas12lstmvcf.py +287 -0
- cas9att.py +299 -0
- cas9attvcf.py +397 -0
- cas9on.py +1 -3
- requirements.txt +3 -0
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import tiger
|
| 3 |
-
import
|
|
|
|
| 4 |
import cas9off
|
| 5 |
import cas12
|
| 6 |
import pandas as pd
|
|
@@ -22,8 +23,8 @@ st.divider()
|
|
| 22 |
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
|
| 23 |
|
| 24 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
| 25 |
-
|
| 26 |
-
cas12_path = 'cas12_model/
|
| 27 |
|
| 28 |
#plot functions
|
| 29 |
def generate_coolbox_plot(bigwig_path, region, output_image_path):
|
|
@@ -182,8 +183,8 @@ if selected_model == 'Cas9':
|
|
| 182 |
# Process predictions
|
| 183 |
if predict_button and gene_symbol:
|
| 184 |
with st.spinner('Predicting... Please wait'):
|
| 185 |
-
predictions, gene_sequence, exons =
|
| 186 |
-
sorted_predictions = sorted(predictions
|
| 187 |
st.session_state['on_target_results'] = sorted_predictions
|
| 188 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
| 189 |
st.session_state['exons'] = exons # Store exon data
|
|
@@ -283,9 +284,9 @@ if selected_model == 'Cas9':
|
|
| 283 |
|
| 284 |
|
| 285 |
# Generate files
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
|
| 290 |
# Prepare an in-memory buffer for the ZIP file
|
| 291 |
zip_buffer = io.BytesIO()
|
|
|
|
| 1 |
import os
|
| 2 |
import tiger
|
| 3 |
+
import cas9att
|
| 4 |
+
import cas9attvcf
|
| 5 |
import cas9off
|
| 6 |
import cas12
|
| 7 |
import pandas as pd
|
|
|
|
| 23 |
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
|
| 24 |
|
| 25 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
| 26 |
+
cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.keras'
|
| 27 |
+
cas12_path = 'cas12_model/BiLSTM_Cpf1_weights.keras'
|
| 28 |
|
| 29 |
#plot functions
|
| 30 |
def generate_coolbox_plot(bigwig_path, region, output_image_path):
|
|
|
|
| 183 |
# Process predictions
|
| 184 |
if predict_button and gene_symbol:
|
| 185 |
with st.spinner('Predicting... Please wait'):
|
| 186 |
+
predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
|
| 187 |
+
sorted_predictions = sorted(predictions)[:10]
|
| 188 |
st.session_state['on_target_results'] = sorted_predictions
|
| 189 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
| 190 |
st.session_state['exons'] = exons # Store exon data
|
|
|
|
| 284 |
|
| 285 |
|
| 286 |
# Generate files
|
| 287 |
+
cas9att.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
|
| 288 |
+
cas9att.create_bed_file_from_df(df, bed_file_path)
|
| 289 |
+
cas9att.create_csv_from_df(df, csv_file_path)
|
| 290 |
|
| 291 |
# Prepare an in-memory buffer for the ZIP file
|
| 292 |
zip_buffer = io.BytesIO()
|
cas12lstm.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from keras import regularizers
|
| 3 |
+
from keras.layers import Input, Dense, Dropout, Activation, Conv1D
|
| 4 |
+
from keras.layers import GlobalAveragePooling1D, AveragePooling1D
|
| 5 |
+
from keras.layers import Bidirectional, LSTM
|
| 6 |
+
from keras import Model
|
| 7 |
+
from keras.metrics import MeanSquaredError
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import requests
|
| 13 |
+
from functools import reduce
|
| 14 |
+
from operator import add
|
| 15 |
+
import tabulate
|
| 16 |
+
from difflib import SequenceMatcher
|
| 17 |
+
|
| 18 |
+
import cyvcf2
|
| 19 |
+
import parasail
|
| 20 |
+
|
| 21 |
+
import re
|
| 22 |
+
|
| 23 |
+
ntmap = {'A': (1, 0, 0, 0),
|
| 24 |
+
'C': (0, 1, 0, 0),
|
| 25 |
+
'G': (0, 0, 1, 0),
|
| 26 |
+
'T': (0, 0, 0, 1)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def get_seqcode(seq):
|
| 30 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
| 31 |
+
|
| 32 |
+
def BiLSTM_model(input_shape):
|
| 33 |
+
input = Input(shape=input_shape)
|
| 34 |
+
|
| 35 |
+
conv1 = Conv1D(128, 5, activation="relu")(input)
|
| 36 |
+
pool1 = AveragePooling1D(2)(conv1)
|
| 37 |
+
drop1 = Dropout(0.1)(pool1)
|
| 38 |
+
|
| 39 |
+
conv2 = Conv1D(128, 5, activation="relu")(drop1)
|
| 40 |
+
pool2 = AveragePooling1D(2)(conv2)
|
| 41 |
+
drop2 = Dropout(0.1)(pool2)
|
| 42 |
+
|
| 43 |
+
lstm1 = Bidirectional(LSTM(128,
|
| 44 |
+
dropout=0.1,
|
| 45 |
+
activation='tanh',
|
| 46 |
+
return_sequences=True,
|
| 47 |
+
kernel_regularizer=regularizers.l2(1e-4)))(drop2)
|
| 48 |
+
avgpool = GlobalAveragePooling1D()(lstm1)
|
| 49 |
+
|
| 50 |
+
dense1 = Dense(128,
|
| 51 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 52 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 53 |
+
activation="relu")(avgpool)
|
| 54 |
+
drop3 = Dropout(0.1)(dense1)
|
| 55 |
+
|
| 56 |
+
dense2 = Dense(32,
|
| 57 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 58 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 59 |
+
activation="relu")(drop3)
|
| 60 |
+
drop4 = Dropout(0.1)(dense2)
|
| 61 |
+
|
| 62 |
+
dense3 = Dense(32,
|
| 63 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 64 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 65 |
+
activation="relu")(drop4)
|
| 66 |
+
drop5 = Dropout(0.1)(dense3)
|
| 67 |
+
|
| 68 |
+
output = Dense(1, activation="linear")(drop5)
|
| 69 |
+
|
| 70 |
+
model = Model(inputs=[input], outputs=[output])
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
| 74 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
| 75 |
+
response = requests.get(url)
|
| 76 |
+
if response.status_code == 200:
|
| 77 |
+
gene_data = response.json()
|
| 78 |
+
if 'Transcript' in gene_data:
|
| 79 |
+
return gene_data['Transcript']
|
| 80 |
+
else:
|
| 81 |
+
print("No transcripts found for gene:", gene_symbol)
|
| 82 |
+
return None
|
| 83 |
+
else:
|
| 84 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
def fetch_ensembl_sequence(transcript_id):
|
| 88 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
| 89 |
+
response = requests.get(url)
|
| 90 |
+
if response.status_code == 200:
|
| 91 |
+
sequence_data = response.json()
|
| 92 |
+
if 'seq' in sequence_data:
|
| 93 |
+
return sequence_data['seq']
|
| 94 |
+
else:
|
| 95 |
+
print("No sequence found for transcript:", transcript_id)
|
| 96 |
+
return None
|
| 97 |
+
else:
|
| 98 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34):
|
| 102 |
+
targets = []
|
| 103 |
+
len_sequence = len(sequence)
|
| 104 |
+
#complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
| 105 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
| 106 |
+
|
| 107 |
+
for i in range(len_sequence - target_length + 1):
|
| 108 |
+
target_seq = sequence[i:i + target_length]
|
| 109 |
+
if target_seq[4:7] == 'TTT':
|
| 110 |
+
if strand == -1:
|
| 111 |
+
tar_start = end - i - target_length + 1
|
| 112 |
+
tar_end = end -i
|
| 113 |
+
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
|
| 114 |
+
else:
|
| 115 |
+
tar_start = start + i
|
| 116 |
+
tar_end = start + i + target_length - 1
|
| 117 |
+
#seq_in_ref = target_seq
|
| 118 |
+
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
|
| 119 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
| 120 |
+
#targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
|
| 121 |
+
return targets
|
| 122 |
+
|
| 123 |
+
def format_prediction_output(targets, model_path):
|
| 124 |
+
# Loading weights for the model
|
| 125 |
+
Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
|
| 126 |
+
Crispr_BiLSTM.load_weights(model_path)
|
| 127 |
+
|
| 128 |
+
formatted_data = []
|
| 129 |
+
for target in targets:
|
| 130 |
+
# Predict
|
| 131 |
+
encoded_seq = get_seqcode(target[0])
|
| 132 |
+
prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
|
| 133 |
+
if prediction > 100:
|
| 134 |
+
prediction = 100
|
| 135 |
+
|
| 136 |
+
# Format output
|
| 137 |
+
gRNA = target[1]
|
| 138 |
+
chr = target[2]
|
| 139 |
+
start = target[3]
|
| 140 |
+
end = target[4]
|
| 141 |
+
strand = target[5]
|
| 142 |
+
transcript_id = target[6]
|
| 143 |
+
exon_id = target[7]
|
| 144 |
+
#seq_in_ref = target[8]
|
| 145 |
+
#formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction])
|
| 146 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
|
| 147 |
+
|
| 148 |
+
return formatted_data
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def process_gene(gene_symbol, model_path):
|
| 152 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
| 153 |
+
results = []
|
| 154 |
+
all_exons = [] # To accumulate all exons
|
| 155 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
| 156 |
+
|
| 157 |
+
if transcripts:
|
| 158 |
+
for transcript in transcripts:
|
| 159 |
+
Exons = transcript['Exon']
|
| 160 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
| 161 |
+
transcript_id = transcript['id']
|
| 162 |
+
|
| 163 |
+
for Exon in Exons:
|
| 164 |
+
exon_id = Exon['id']
|
| 165 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
| 166 |
+
if gene_sequence:
|
| 167 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
| 168 |
+
chr = Exon['seq_region_name']
|
| 169 |
+
start = Exon['start']
|
| 170 |
+
end = Exon['end']
|
| 171 |
+
strand = Exon['strand']
|
| 172 |
+
|
| 173 |
+
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
|
| 174 |
+
if targets:
|
| 175 |
+
# Predict on-target efficiency for each gRNA site
|
| 176 |
+
formatted_data = format_prediction_output(targets, model_path)
|
| 177 |
+
results.extend(formatted_data) # Flatten the results
|
| 178 |
+
else:
|
| 179 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
| 180 |
+
else:
|
| 181 |
+
print("Failed to retrieve transcripts.")
|
| 182 |
+
|
| 183 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
| 184 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
| 185 |
+
|
| 186 |
+
# Return the sorted output, combined gene sequences, and all exons
|
| 187 |
+
return sorted_results, all_gene_sequences, all_exons
|
| 188 |
+
|
cas12lstmvcf.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from keras import regularizers
|
| 3 |
+
from keras.layers import Input, Dense, Dropout, Activation, Conv1D
|
| 4 |
+
from keras.layers import GlobalAveragePooling1D, AveragePooling1D
|
| 5 |
+
from keras.layers import Bidirectional, LSTM
|
| 6 |
+
from keras import Model
|
| 7 |
+
from keras.metrics import MeanSquaredError
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import requests
|
| 13 |
+
from functools import reduce
|
| 14 |
+
from operator import add
|
| 15 |
+
import tabulate
|
| 16 |
+
from difflib import SequenceMatcher
|
| 17 |
+
|
| 18 |
+
import cyvcf2
|
| 19 |
+
import parasail
|
| 20 |
+
|
| 21 |
+
import re
|
| 22 |
+
|
| 23 |
+
ntmap = {'A': (1, 0, 0, 0),
|
| 24 |
+
'C': (0, 1, 0, 0),
|
| 25 |
+
'G': (0, 0, 1, 0),
|
| 26 |
+
'T': (0, 0, 0, 1)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def get_seqcode(seq):
|
| 30 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
| 31 |
+
|
| 32 |
+
def BiLSTM_model(input_shape):
|
| 33 |
+
input = Input(shape=input_shape)
|
| 34 |
+
|
| 35 |
+
conv1 = Conv1D(128, 5, activation="relu")(input)
|
| 36 |
+
pool1 = AveragePooling1D(2)(conv1)
|
| 37 |
+
drop1 = Dropout(0.1)(pool1)
|
| 38 |
+
|
| 39 |
+
conv2 = Conv1D(128, 5, activation="relu")(drop1)
|
| 40 |
+
pool2 = AveragePooling1D(2)(conv2)
|
| 41 |
+
drop2 = Dropout(0.1)(pool2)
|
| 42 |
+
|
| 43 |
+
lstm1 = Bidirectional(LSTM(128,
|
| 44 |
+
dropout=0.1,
|
| 45 |
+
activation='tanh',
|
| 46 |
+
return_sequences=True,
|
| 47 |
+
kernel_regularizer=regularizers.l2(1e-4)))(drop2)
|
| 48 |
+
avgpool = GlobalAveragePooling1D()(lstm1)
|
| 49 |
+
|
| 50 |
+
dense1 = Dense(128,
|
| 51 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 52 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 53 |
+
activation="relu")(avgpool)
|
| 54 |
+
drop3 = Dropout(0.1)(dense1)
|
| 55 |
+
|
| 56 |
+
dense2 = Dense(32,
|
| 57 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 58 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 59 |
+
activation="relu")(drop3)
|
| 60 |
+
drop4 = Dropout(0.1)(dense2)
|
| 61 |
+
|
| 62 |
+
dense3 = Dense(32,
|
| 63 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 64 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 65 |
+
activation="relu")(drop4)
|
| 66 |
+
drop5 = Dropout(0.1)(dense3)
|
| 67 |
+
|
| 68 |
+
output = Dense(1, activation="linear")(drop5)
|
| 69 |
+
|
| 70 |
+
model = Model(inputs=[input], outputs=[output])
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
| 74 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
| 75 |
+
response = requests.get(url)
|
| 76 |
+
if response.status_code == 200:
|
| 77 |
+
gene_data = response.json()
|
| 78 |
+
if 'Transcript' in gene_data:
|
| 79 |
+
return gene_data['Transcript']
|
| 80 |
+
else:
|
| 81 |
+
print("No transcripts found for gene:", gene_symbol)
|
| 82 |
+
return None
|
| 83 |
+
else:
|
| 84 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
def fetch_ensembl_sequence(transcript_id):
|
| 88 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
| 89 |
+
response = requests.get(url)
|
| 90 |
+
if response.status_code == 200:
|
| 91 |
+
sequence_data = response.json()
|
| 92 |
+
if 'seq' in sequence_data:
|
| 93 |
+
return sequence_data['seq']
|
| 94 |
+
else:
|
| 95 |
+
print("No sequence found for transcript:", transcript_id)
|
| 96 |
+
return None
|
| 97 |
+
else:
|
| 98 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
def apply_mutation(ref_sequence, offset, ref, alt):
|
| 102 |
+
"""
|
| 103 |
+
Apply a single mutation to the sequence.
|
| 104 |
+
"""
|
| 105 |
+
if len(ref) == len(alt) and alt != "*": # SNP
|
| 106 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
|
| 107 |
+
|
| 108 |
+
elif len(ref) < len(alt): # Insertion
|
| 109 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
|
| 110 |
+
|
| 111 |
+
elif len(ref) == len(alt) and alt == "*": # Deletion
|
| 112 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
|
| 113 |
+
|
| 114 |
+
elif len(ref) > len(alt) and alt != "*": # Deletion
|
| 115 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
|
| 116 |
+
|
| 117 |
+
elif len(ref) > len(alt) and alt == "*": # Deletion
|
| 118 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
|
| 119 |
+
|
| 120 |
+
return mutated_seq
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def construct_combinations(sequence, mutations):
|
| 124 |
+
"""
|
| 125 |
+
Construct all combinations of mutations.
|
| 126 |
+
mutations is a list of tuples (position, ref, [alts])
|
| 127 |
+
"""
|
| 128 |
+
if not mutations:
|
| 129 |
+
return [sequence]
|
| 130 |
+
|
| 131 |
+
# Take the first mutation and recursively construct combinations for the rest
|
| 132 |
+
first_mutation = mutations[0]
|
| 133 |
+
rest_mutations = mutations[1:]
|
| 134 |
+
offset, ref, alts = first_mutation
|
| 135 |
+
|
| 136 |
+
sequences = []
|
| 137 |
+
for alt in alts:
|
| 138 |
+
mutated_sequence = apply_mutation(sequence, offset, ref, alt)
|
| 139 |
+
sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
|
| 140 |
+
|
| 141 |
+
return sequences
|
| 142 |
+
|
| 143 |
+
def needleman_wunsch_alignment(query_seq, ref_seq):
|
| 144 |
+
"""
|
| 145 |
+
Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
|
| 146 |
+
Use this position to represent the position of target sequence with mutations
|
| 147 |
+
"""
|
| 148 |
+
# Needleman-Wunsch alignment
|
| 149 |
+
alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
|
| 150 |
+
|
| 151 |
+
# extract CIGAR object
|
| 152 |
+
cigar = alignment.cigar
|
| 153 |
+
cigar_string = cigar.decode.decode("utf-8")
|
| 154 |
+
|
| 155 |
+
# record ref_pos
|
| 156 |
+
ref_pos = 0
|
| 157 |
+
|
| 158 |
+
matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
|
| 159 |
+
max_num_before_equal = 0
|
| 160 |
+
max_equal_index = -1
|
| 161 |
+
total_before_max_equal = 0
|
| 162 |
+
|
| 163 |
+
for i, (num_str, op) in enumerate(matches):
|
| 164 |
+
num = int(num_str)
|
| 165 |
+
if op == '=':
|
| 166 |
+
if num > max_num_before_equal:
|
| 167 |
+
max_num_before_equal = num
|
| 168 |
+
max_equal_index = i
|
| 169 |
+
total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
|
| 170 |
+
|
| 171 |
+
ref_pos = total_before_max_equal
|
| 172 |
+
|
| 173 |
+
return ref_pos
|
| 174 |
+
|
| 175 |
+
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
|
| 176 |
+
exon_id, gene_symbol, vcf_reader, pam="TTTN", target_length=34):
|
| 177 |
+
# initialization
|
| 178 |
+
mutated_sequences = [ref_sequence]
|
| 179 |
+
|
| 180 |
+
# find mutations within interested region
|
| 181 |
+
mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
|
| 182 |
+
if mutations:
|
| 183 |
+
# find mutations
|
| 184 |
+
mutation_list = []
|
| 185 |
+
for mutation in mutations:
|
| 186 |
+
offset = mutation.POS - start
|
| 187 |
+
ref = mutation.REF
|
| 188 |
+
alts = mutation.ALT[:-1]
|
| 189 |
+
mutation_list.append((offset, ref, alts))
|
| 190 |
+
|
| 191 |
+
# replace reference sequence of mutation
|
| 192 |
+
mutated_sequences = construct_combinations(ref_sequence, mutation_list)
|
| 193 |
+
|
| 194 |
+
# find gRNA in ref_sequence or all mutated_sequences
|
| 195 |
+
targets = []
|
| 196 |
+
for seq in mutated_sequences:
|
| 197 |
+
len_sequence = len(seq)
|
| 198 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
| 199 |
+
for i in range(len_sequence - target_length + 1):
|
| 200 |
+
target_seq = seq[i:i + target_length]
|
| 201 |
+
if target_seq[4:7] == 'TTT':
|
| 202 |
+
pos = ref_sequence.find(target_seq)
|
| 203 |
+
if pos != -1:
|
| 204 |
+
is_mut = False
|
| 205 |
+
if strand == -1:
|
| 206 |
+
tar_start = end - pos - target_length + 1
|
| 207 |
+
else:
|
| 208 |
+
tar_start = start + pos
|
| 209 |
+
else:
|
| 210 |
+
is_mut = True
|
| 211 |
+
nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
|
| 212 |
+
if strand == -1:
|
| 213 |
+
tar_start = str(end - nw_pos - target_length + 1) + '*'
|
| 214 |
+
else:
|
| 215 |
+
tar_start = str(start + nw_pos) + '*'
|
| 216 |
+
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
|
| 217 |
+
targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
|
| 218 |
+
|
| 219 |
+
# filter duplicated targets
|
| 220 |
+
unique_targets_set = set(tuple(element) for element in targets)
|
| 221 |
+
unique_targets = [list(element) for element in unique_targets_set]
|
| 222 |
+
|
| 223 |
+
return unique_targets
|
| 224 |
+
|
| 225 |
+
def format_prediction_output_with_mutation(targets, model_path):
|
| 226 |
+
Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
|
| 227 |
+
Crispr_BiLSTM.load_weights(model_path)
|
| 228 |
+
|
| 229 |
+
formatted_data = []
|
| 230 |
+
for target in targets:
|
| 231 |
+
# Predict
|
| 232 |
+
encoded_seq = get_seqcode(target[0])
|
| 233 |
+
prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
|
| 234 |
+
if prediction > 100:
|
| 235 |
+
prediction = 100
|
| 236 |
+
|
| 237 |
+
# Format output
|
| 238 |
+
gRNA = target[1]
|
| 239 |
+
exon_chr = target[2]
|
| 240 |
+
strand = target[3]
|
| 241 |
+
tar_start = target[4]
|
| 242 |
+
transcript_id = target[5]
|
| 243 |
+
exon_id = target[6]
|
| 244 |
+
gene_symbol = target[7]
|
| 245 |
+
is_mut = target[8]
|
| 246 |
+
formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id, exon_id, target[0], gRNA, prediction, is_mut])
|
| 247 |
+
|
| 248 |
+
return formatted_data
|
| 249 |
+
|
| 250 |
+
def process_gene(gene_symbol, vcf_reader, model_path):
|
| 251 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
| 252 |
+
results = []
|
| 253 |
+
all_exons = [] # To accumulate all exons
|
| 254 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
| 255 |
+
|
| 256 |
+
if transcripts:
|
| 257 |
+
for transcript in transcripts:
|
| 258 |
+
Exons = transcript['Exon']
|
| 259 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
| 260 |
+
transcript_id = transcript['id']
|
| 261 |
+
|
| 262 |
+
for Exon in Exons:
|
| 263 |
+
exon_id = Exon['id']
|
| 264 |
+
gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
|
| 265 |
+
if gene_sequence:
|
| 266 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
| 267 |
+
exon_chr = Exon['seq_region_name']
|
| 268 |
+
start = Exon['start']
|
| 269 |
+
end = Exon['end']
|
| 270 |
+
strand = Exon['strand']
|
| 271 |
+
|
| 272 |
+
targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand, transcript_id, exon_id, gene_symbol, vcf_reader)
|
| 273 |
+
if targets:
|
| 274 |
+
# Predict on-target efficiency for each gRNA site
|
| 275 |
+
formatted_data = format_prediction_output_with_mutation(targets, model_path)
|
| 276 |
+
results.extend(formatted_data) # Flatten the results
|
| 277 |
+
else:
|
| 278 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
| 279 |
+
else:
|
| 280 |
+
print("Failed to retrieve transcripts.")
|
| 281 |
+
|
| 282 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
| 283 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
| 284 |
+
|
| 285 |
+
# Return the sorted output, combined gene sequences, and all exons
|
| 286 |
+
return sorted_results, all_gene_sequences, all_exons
|
| 287 |
+
|
cas9att.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from operator import add
|
| 6 |
+
from functools import reduce
|
| 7 |
+
import random
|
| 8 |
+
import tabulate
|
| 9 |
+
|
| 10 |
+
from keras import Model
|
| 11 |
+
from keras import regularizers
|
| 12 |
+
from keras.optimizers import Adam
|
| 13 |
+
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
|
| 14 |
+
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
|
| 15 |
+
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
|
| 16 |
+
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
|
| 17 |
+
from keras.models import load_model
|
| 18 |
+
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
| 19 |
+
from Bio import SeqIO
|
| 20 |
+
from Bio.SeqRecord import SeqRecord
|
| 21 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
| 22 |
+
from Bio.Seq import Seq
|
| 23 |
+
|
| 24 |
+
import cyvcf2
|
| 25 |
+
import parasail
|
| 26 |
+
|
| 27 |
+
import re
|
| 28 |
+
|
| 29 |
+
ntmap = {'A': (1, 0, 0, 0),
|
| 30 |
+
'C': (0, 1, 0, 0),
|
| 31 |
+
'G': (0, 0, 1, 0),
|
| 32 |
+
'T': (0, 0, 0, 1)
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def get_seqcode(seq):
|
| 36 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
| 37 |
+
|
| 38 |
+
class PositionalEncoding(Layer):
|
| 39 |
+
def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
|
| 40 |
+
super(PositionalEncoding, self).__init__()
|
| 41 |
+
self.sequence_len = sequence_len
|
| 42 |
+
self.embedding_dim = embedding_dim
|
| 43 |
+
|
| 44 |
+
def call(self, x):
|
| 45 |
+
|
| 46 |
+
position_embedding = np.array([
|
| 47 |
+
[pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
|
| 48 |
+
for pos in range(self.sequence_len)])
|
| 49 |
+
|
| 50 |
+
position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
|
| 51 |
+
position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
|
| 52 |
+
position_embedding = tf.cast(position_embedding, dtype=tf.float32)
|
| 53 |
+
|
| 54 |
+
return position_embedding+x
|
| 55 |
+
|
| 56 |
+
def get_config(self):
|
| 57 |
+
config = super().get_config().copy()
|
| 58 |
+
config.update({
|
| 59 |
+
'sequence_len' : self.sequence_len,
|
| 60 |
+
'embedding_dim' : self.embedding_dim,
|
| 61 |
+
})
|
| 62 |
+
return config
|
| 63 |
+
|
| 64 |
+
def MultiHeadAttention_model(input_shape):
|
| 65 |
+
input = Input(shape=input_shape)
|
| 66 |
+
|
| 67 |
+
conv1 = Conv1D(256, 3, activation="relu")(input)
|
| 68 |
+
pool1 = AveragePooling1D(2)(conv1)
|
| 69 |
+
drop1 = Dropout(0.4)(pool1)
|
| 70 |
+
|
| 71 |
+
conv2 = Conv1D(256, 3, activation="relu")(drop1)
|
| 72 |
+
pool2 = AveragePooling1D(2)(conv2)
|
| 73 |
+
drop2 = Dropout(0.4)(pool2)
|
| 74 |
+
|
| 75 |
+
lstm = Bidirectional(LSTM(128,
|
| 76 |
+
dropout=0.5,
|
| 77 |
+
activation='tanh',
|
| 78 |
+
return_sequences=True,
|
| 79 |
+
kernel_regularizer=regularizers.l2(0.01)))(drop2)
|
| 80 |
+
|
| 81 |
+
pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
|
| 82 |
+
atten = MultiHeadAttention(num_heads=2,
|
| 83 |
+
key_dim=64,
|
| 84 |
+
dropout=0.2,
|
| 85 |
+
kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
|
| 86 |
+
|
| 87 |
+
flat = Flatten()(atten)
|
| 88 |
+
|
| 89 |
+
dense1 = Dense(512,
|
| 90 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 91 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 92 |
+
activation="relu")(flat)
|
| 93 |
+
drop3 = Dropout(0.1)(dense1)
|
| 94 |
+
|
| 95 |
+
dense2 = Dense(128,
|
| 96 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 97 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 98 |
+
activation="relu")(drop3)
|
| 99 |
+
drop4 = Dropout(0.1)(dense2)
|
| 100 |
+
|
| 101 |
+
dense3 = Dense(256,
|
| 102 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 103 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 104 |
+
activation="relu")(drop4)
|
| 105 |
+
drop5 = Dropout(0.1)(dense3)
|
| 106 |
+
|
| 107 |
+
output = Dense(1, activation="linear")(drop5)
|
| 108 |
+
|
| 109 |
+
model = Model(inputs=[input], outputs=[output])
|
| 110 |
+
return model
|
| 111 |
+
|
| 112 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
| 113 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
| 114 |
+
response = requests.get(url)
|
| 115 |
+
if response.status_code == 200:
|
| 116 |
+
gene_data = response.json()
|
| 117 |
+
if 'Transcript' in gene_data:
|
| 118 |
+
return gene_data['Transcript']
|
| 119 |
+
else:
|
| 120 |
+
print("No transcripts found for gene:", gene_symbol)
|
| 121 |
+
return None
|
| 122 |
+
else:
|
| 123 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
def fetch_ensembl_sequence(transcript_id):
|
| 127 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
| 128 |
+
response = requests.get(url)
|
| 129 |
+
if response.status_code == 200:
|
| 130 |
+
sequence_data = response.json()
|
| 131 |
+
if 'seq' in sequence_data:
|
| 132 |
+
return sequence_data['seq']
|
| 133 |
+
else:
|
| 134 |
+
print("No sequence found for transcript:", transcript_id)
|
| 135 |
+
return None
|
| 136 |
+
else:
|
| 137 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20):
|
| 141 |
+
targets = []
|
| 142 |
+
len_sequence = len(sequence)
|
| 143 |
+
#complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
| 144 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
| 145 |
+
|
| 146 |
+
for i in range(len_sequence - len(pam) + 1):
|
| 147 |
+
if sequence[i + 1:i + 3] == pam[1:]:
|
| 148 |
+
if i >= target_length:
|
| 149 |
+
target_seq = sequence[i - target_length:i + 3]
|
| 150 |
+
if strand == -1:
|
| 151 |
+
tar_start = end - (i + 2)
|
| 152 |
+
tar_end = end - (i - target_length)
|
| 153 |
+
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
|
| 154 |
+
else:
|
| 155 |
+
tar_start = start + i - target_length
|
| 156 |
+
tar_end = start + i + 3 - 1
|
| 157 |
+
#seq_in_ref = target_seq
|
| 158 |
+
gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
|
| 159 |
+
#targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
|
| 160 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
| 161 |
+
|
| 162 |
+
return targets
|
| 163 |
+
|
| 164 |
+
# Function to predict on-target efficiency and format output
|
| 165 |
+
def format_prediction_output(targets, model_path):
|
| 166 |
+
model = MultiHeadAttention_model(input_shape=(23, 4))
|
| 167 |
+
model.load_weights(model_path)
|
| 168 |
+
|
| 169 |
+
formatted_data = []
|
| 170 |
+
|
| 171 |
+
for target in targets:
|
| 172 |
+
# Encode the gRNA sequence
|
| 173 |
+
encoded_seq = get_seqcode(target[0])
|
| 174 |
+
|
| 175 |
+
# Predict on-target efficiency using the model
|
| 176 |
+
prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
|
| 177 |
+
if prediction > 100:
|
| 178 |
+
prediction = 100
|
| 179 |
+
|
| 180 |
+
# Format output
|
| 181 |
+
gRNA = target[1]
|
| 182 |
+
chr = target[2]
|
| 183 |
+
start = target[3]
|
| 184 |
+
end = target[4]
|
| 185 |
+
strand = target[5]
|
| 186 |
+
transcript_id = target[6]
|
| 187 |
+
exon_id = target[7]
|
| 188 |
+
#seq_in_ref = target[8]
|
| 189 |
+
#formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])
|
| 190 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
|
| 191 |
+
|
| 192 |
+
return formatted_data
|
| 193 |
+
|
| 194 |
+
def process_gene(gene_symbol, model_path):
|
| 195 |
+
# Fetch transcripts for the given gene symbol
|
| 196 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
| 197 |
+
results = []
|
| 198 |
+
all_exons = [] # To accumulate all exons
|
| 199 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
| 200 |
+
|
| 201 |
+
if transcripts:
|
| 202 |
+
for transcript in transcripts:
|
| 203 |
+
Exons = transcript['Exon']
|
| 204 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
| 205 |
+
transcript_id = transcript['id']
|
| 206 |
+
|
| 207 |
+
for exon in Exons:
|
| 208 |
+
exon_id = exon['id']
|
| 209 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
| 210 |
+
if gene_sequence:
|
| 211 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
| 212 |
+
start = exon['start']
|
| 213 |
+
end = exon['end']
|
| 214 |
+
strand = exon['strand']
|
| 215 |
+
chr = exon['seq_region_name']
|
| 216 |
+
# Find potential CRISPR targets within the exon
|
| 217 |
+
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
|
| 218 |
+
if targets:
|
| 219 |
+
# Format the prediction output for the targets found
|
| 220 |
+
formatted_data = format_prediction_output(targets, model_path)
|
| 221 |
+
results.extend(formatted_data) # Append results
|
| 222 |
+
else:
|
| 223 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
| 224 |
+
else:
|
| 225 |
+
print("Failed to retrieve transcripts.")
|
| 226 |
+
|
| 227 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
| 228 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
| 229 |
+
|
| 230 |
+
# Return the sorted output, combined gene sequences, and all exons
|
| 231 |
+
return sorted_results, all_gene_sequences, all_exons
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def create_genbank_features(data):
|
| 235 |
+
features = []
|
| 236 |
+
|
| 237 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
| 238 |
+
if isinstance(data, pd.DataFrame):
|
| 239 |
+
formatted_data = data.values.tolist()
|
| 240 |
+
elif isinstance(data, list):
|
| 241 |
+
formatted_data = data
|
| 242 |
+
else:
|
| 243 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
| 244 |
+
|
| 245 |
+
for row in formatted_data:
|
| 246 |
+
try:
|
| 247 |
+
start = int(row[1])
|
| 248 |
+
end = int(row[2])
|
| 249 |
+
except ValueError as e:
|
| 250 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
strand = 1 if row[3] == '+' else -1
|
| 254 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
| 255 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
| 256 |
+
'label': row[7], # Use gRNA as the label
|
| 257 |
+
'note': f"Prediction: {row[8]}" # Include the prediction score
|
| 258 |
+
})
|
| 259 |
+
features.append(feature)
|
| 260 |
+
|
| 261 |
+
return features
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
| 265 |
+
# Ensure gene_sequence is a string before creating Seq object
|
| 266 |
+
if not isinstance(gene_sequence, str):
|
| 267 |
+
gene_sequence = str(gene_sequence)
|
| 268 |
+
|
| 269 |
+
features = create_genbank_features(df)
|
| 270 |
+
|
| 271 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
| 272 |
+
seq_obj = Seq(gene_sequence)
|
| 273 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
| 274 |
+
description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
|
| 275 |
+
record.annotations["molecule_type"] = "DNA"
|
| 276 |
+
SeqIO.write(record, output_path, "genbank")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def create_bed_file_from_df(df, output_path):
|
| 280 |
+
with open(output_path, 'w') as bed_file:
|
| 281 |
+
for index, row in df.iterrows():
|
| 282 |
+
chrom = row["Chr"]
|
| 283 |
+
start = int(row["Start Pos"])
|
| 284 |
+
end = int(row["End Pos"])
|
| 285 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
| 286 |
+
gRNA = row["gRNA"]
|
| 287 |
+
score = str(row["Prediction"])
|
| 288 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
| 289 |
+
transcript_id = row["Transcript"]
|
| 290 |
+
|
| 291 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
| 292 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def create_csv_from_df(df, output_path):
|
| 296 |
+
df.to_csv(output_path, index=False)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
cas9attvcf.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from operator import add
|
| 6 |
+
from functools import reduce
|
| 7 |
+
import random
|
| 8 |
+
import tabulate
|
| 9 |
+
|
| 10 |
+
from keras import Model
|
| 11 |
+
from keras import regularizers
|
| 12 |
+
from keras.optimizers import Adam
|
| 13 |
+
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
|
| 14 |
+
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
|
| 15 |
+
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
|
| 16 |
+
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
|
| 17 |
+
from keras.models import load_model
|
| 18 |
+
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
| 19 |
+
from Bio import SeqIO
|
| 20 |
+
from Bio.SeqRecord import SeqRecord
|
| 21 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
| 22 |
+
from Bio.Seq import Seq
|
| 23 |
+
|
| 24 |
+
import cyvcf2
|
| 25 |
+
import parasail
|
| 26 |
+
|
| 27 |
+
import re
|
| 28 |
+
|
| 29 |
+
ntmap = {'A': (1, 0, 0, 0),
|
| 30 |
+
'C': (0, 1, 0, 0),
|
| 31 |
+
'G': (0, 0, 1, 0),
|
| 32 |
+
'T': (0, 0, 0, 1)
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def get_seqcode(seq):
|
| 36 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
| 37 |
+
|
| 38 |
+
class PositionalEncoding(Layer):
|
| 39 |
+
def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
|
| 40 |
+
super(PositionalEncoding, self).__init__()
|
| 41 |
+
self.sequence_len = sequence_len
|
| 42 |
+
self.embedding_dim = embedding_dim
|
| 43 |
+
|
| 44 |
+
def call(self, x):
|
| 45 |
+
|
| 46 |
+
position_embedding = np.array([
|
| 47 |
+
[pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
|
| 48 |
+
for pos in range(self.sequence_len)])
|
| 49 |
+
|
| 50 |
+
position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
|
| 51 |
+
position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
|
| 52 |
+
position_embedding = tf.cast(position_embedding, dtype=tf.float32)
|
| 53 |
+
|
| 54 |
+
return position_embedding+x
|
| 55 |
+
|
| 56 |
+
def get_config(self):
|
| 57 |
+
config = super().get_config().copy()
|
| 58 |
+
config.update({
|
| 59 |
+
'sequence_len' : self.sequence_len,
|
| 60 |
+
'embedding_dim' : self.embedding_dim,
|
| 61 |
+
})
|
| 62 |
+
return config
|
| 63 |
+
|
| 64 |
+
def MultiHeadAttention_model(input_shape):
|
| 65 |
+
input = Input(shape=input_shape)
|
| 66 |
+
|
| 67 |
+
conv1 = Conv1D(256, 3, activation="relu")(input)
|
| 68 |
+
pool1 = AveragePooling1D(2)(conv1)
|
| 69 |
+
drop1 = Dropout(0.4)(pool1)
|
| 70 |
+
|
| 71 |
+
conv2 = Conv1D(256, 3, activation="relu")(drop1)
|
| 72 |
+
pool2 = AveragePooling1D(2)(conv2)
|
| 73 |
+
drop2 = Dropout(0.4)(pool2)
|
| 74 |
+
|
| 75 |
+
lstm = Bidirectional(LSTM(128,
|
| 76 |
+
dropout=0.5,
|
| 77 |
+
activation='tanh',
|
| 78 |
+
return_sequences=True,
|
| 79 |
+
kernel_regularizer=regularizers.l2(0.01)))(drop2)
|
| 80 |
+
|
| 81 |
+
pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
|
| 82 |
+
atten = MultiHeadAttention(num_heads=2,
|
| 83 |
+
key_dim=64,
|
| 84 |
+
dropout=0.2,
|
| 85 |
+
kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
|
| 86 |
+
|
| 87 |
+
flat = Flatten()(atten)
|
| 88 |
+
|
| 89 |
+
dense1 = Dense(512,
|
| 90 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 91 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 92 |
+
activation="relu")(flat)
|
| 93 |
+
drop3 = Dropout(0.1)(dense1)
|
| 94 |
+
|
| 95 |
+
dense2 = Dense(128,
|
| 96 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 97 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 98 |
+
activation="relu")(drop3)
|
| 99 |
+
drop4 = Dropout(0.1)(dense2)
|
| 100 |
+
|
| 101 |
+
dense3 = Dense(256,
|
| 102 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
| 103 |
+
bias_regularizer=regularizers.l2(1e-4),
|
| 104 |
+
activation="relu")(drop4)
|
| 105 |
+
drop5 = Dropout(0.1)(dense3)
|
| 106 |
+
|
| 107 |
+
output = Dense(1, activation="linear")(drop5)
|
| 108 |
+
|
| 109 |
+
model = Model(inputs=[input], outputs=[output])
|
| 110 |
+
return model
|
| 111 |
+
|
| 112 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
| 113 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
| 114 |
+
response = requests.get(url)
|
| 115 |
+
if response.status_code == 200:
|
| 116 |
+
gene_data = response.json()
|
| 117 |
+
if 'Transcript' in gene_data:
|
| 118 |
+
return gene_data['Transcript']
|
| 119 |
+
else:
|
| 120 |
+
print("No transcripts found for gene:", gene_symbol)
|
| 121 |
+
return None
|
| 122 |
+
else:
|
| 123 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
def fetch_ensembl_sequence(transcript_id):
|
| 127 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
| 128 |
+
response = requests.get(url)
|
| 129 |
+
if response.status_code == 200:
|
| 130 |
+
sequence_data = response.json()
|
| 131 |
+
if 'seq' in sequence_data:
|
| 132 |
+
return sequence_data['seq']
|
| 133 |
+
else:
|
| 134 |
+
print("No sequence found for transcript:", transcript_id)
|
| 135 |
+
return None
|
| 136 |
+
else:
|
| 137 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
def apply_mutation(ref_sequence, offset, ref, alt):
|
| 141 |
+
"""
|
| 142 |
+
Apply a single mutation to the sequence.
|
| 143 |
+
"""
|
| 144 |
+
if len(ref) == len(alt) and alt != "*": # SNP
|
| 145 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
|
| 146 |
+
|
| 147 |
+
elif len(ref) < len(alt): # Insertion
|
| 148 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
|
| 149 |
+
|
| 150 |
+
elif len(ref) == len(alt) and alt == "*": # Deletion
|
| 151 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
|
| 152 |
+
|
| 153 |
+
elif len(ref) > len(alt) and alt != "*": # Deletion
|
| 154 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
|
| 155 |
+
|
| 156 |
+
elif len(ref) > len(alt) and alt == "*": # Deletion
|
| 157 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
|
| 158 |
+
|
| 159 |
+
return mutated_seq
|
| 160 |
+
|
| 161 |
+
def construct_combinations(sequence, mutations):
|
| 162 |
+
"""
|
| 163 |
+
Construct all combinations of mutations.
|
| 164 |
+
mutations is a list of tuples (position, ref, [alts])
|
| 165 |
+
"""
|
| 166 |
+
if not mutations:
|
| 167 |
+
return [sequence]
|
| 168 |
+
|
| 169 |
+
# Take the first mutation and recursively construct combinations for the rest
|
| 170 |
+
first_mutation = mutations[0]
|
| 171 |
+
rest_mutations = mutations[1:]
|
| 172 |
+
offset, ref, alts = first_mutation
|
| 173 |
+
|
| 174 |
+
sequences = []
|
| 175 |
+
for alt in alts:
|
| 176 |
+
mutated_sequence = apply_mutation(sequence, offset, ref, alt)
|
| 177 |
+
sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
|
| 178 |
+
|
| 179 |
+
return sequences
|
| 180 |
+
|
| 181 |
+
def needleman_wunsch_alignment(query_seq, ref_seq):
|
| 182 |
+
"""
|
| 183 |
+
Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
|
| 184 |
+
Use this position to represent the position of target sequence with mutations
|
| 185 |
+
"""
|
| 186 |
+
# Needleman-Wunsch alignment
|
| 187 |
+
alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
|
| 188 |
+
|
| 189 |
+
# extract CIGAR object
|
| 190 |
+
cigar = alignment.cigar
|
| 191 |
+
cigar_string = cigar.decode.decode("utf-8")
|
| 192 |
+
|
| 193 |
+
# record ref_pos
|
| 194 |
+
ref_pos = 0
|
| 195 |
+
|
| 196 |
+
matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
|
| 197 |
+
max_num_before_equal = 0
|
| 198 |
+
max_equal_index = -1
|
| 199 |
+
total_before_max_equal = 0
|
| 200 |
+
|
| 201 |
+
for i, (num_str, op) in enumerate(matches):
|
| 202 |
+
num = int(num_str)
|
| 203 |
+
if op == '=':
|
| 204 |
+
if num > max_num_before_equal:
|
| 205 |
+
max_num_before_equal = num
|
| 206 |
+
max_equal_index = i
|
| 207 |
+
total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
|
| 208 |
+
|
| 209 |
+
ref_pos = total_before_max_equal
|
| 210 |
+
|
| 211 |
+
return ref_pos
|
| 212 |
+
|
| 213 |
+
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
|
| 214 |
+
exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20):
|
| 215 |
+
# initialization
|
| 216 |
+
mutated_sequences = [ref_sequence]
|
| 217 |
+
|
| 218 |
+
# find mutations within interested region
|
| 219 |
+
mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
|
| 220 |
+
if mutations:
|
| 221 |
+
# find mutations
|
| 222 |
+
mutation_list = []
|
| 223 |
+
for mutation in mutations:
|
| 224 |
+
offset = mutation.POS - start
|
| 225 |
+
ref = mutation.REF
|
| 226 |
+
alts = mutation.ALT[:-1]
|
| 227 |
+
mutation_list.append((offset, ref, alts))
|
| 228 |
+
|
| 229 |
+
# replace reference sequence of mutation
|
| 230 |
+
mutated_sequences = construct_combinations(ref_sequence, mutation_list)
|
| 231 |
+
|
| 232 |
+
# find gRNA in ref_sequence or all mutated_sequences
|
| 233 |
+
targets = []
|
| 234 |
+
for seq in mutated_sequences:
|
| 235 |
+
len_sequence = len(seq)
|
| 236 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
| 237 |
+
for i in range(len_sequence - len(pam) + 1):
|
| 238 |
+
if seq[i + 1:i + 3] == pam[1:]:
|
| 239 |
+
if i >= target_length:
|
| 240 |
+
target_seq = seq[i - target_length:i + 3]
|
| 241 |
+
pos = ref_sequence.find(target_seq)
|
| 242 |
+
if pos != -1:
|
| 243 |
+
is_mut = False
|
| 244 |
+
if strand == -1:
|
| 245 |
+
tar_start = end - pos - target_length - 2
|
| 246 |
+
else:
|
| 247 |
+
tar_start = start + pos
|
| 248 |
+
else:
|
| 249 |
+
is_mut = True
|
| 250 |
+
nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
|
| 251 |
+
if strand == -1:
|
| 252 |
+
tar_start = str(end - nw_pos - target_length - 2) + '*'
|
| 253 |
+
else:
|
| 254 |
+
tar_start = str(start + nw_pos) + '*'
|
| 255 |
+
gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])
|
| 256 |
+
targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
|
| 257 |
+
|
| 258 |
+
# filter duplicated targets
|
| 259 |
+
unique_targets_set = set(tuple(element) for element in targets)
|
| 260 |
+
unique_targets = [list(element) for element in unique_targets_set]
|
| 261 |
+
|
| 262 |
+
return unique_targets
|
| 263 |
+
|
| 264 |
+
def format_prediction_output_with_mutation(targets, model_path):
|
| 265 |
+
model = MultiHeadAttention_model(input_shape=(23, 4))
|
| 266 |
+
model.load_weights(model_path)
|
| 267 |
+
|
| 268 |
+
formatted_data = []
|
| 269 |
+
|
| 270 |
+
for target in targets:
|
| 271 |
+
# Encode the gRNA sequence
|
| 272 |
+
encoded_seq = get_seqcode(target[0])
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Predict on-target efficiency using the model
|
| 276 |
+
prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
|
| 277 |
+
if prediction > 100:
|
| 278 |
+
prediction = 100
|
| 279 |
+
|
| 280 |
+
# Format output
|
| 281 |
+
gRNA = target[1]
|
| 282 |
+
exon_chr = target[2]
|
| 283 |
+
strand = target[3]
|
| 284 |
+
tar_start = target[4]
|
| 285 |
+
transcript_id = target[5]
|
| 286 |
+
exon_id = target[6]
|
| 287 |
+
gene_symbol = target[7]
|
| 288 |
+
is_mut = target[8]
|
| 289 |
+
formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,
|
| 290 |
+
exon_id, target[0], gRNA, prediction, is_mut])
|
| 291 |
+
|
| 292 |
+
return formatted_data
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def process_gene(gene_symbol, vcf_reader, model_path):
|
| 296 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
| 297 |
+
results = []
|
| 298 |
+
all_exons = [] # To accumulate all exons
|
| 299 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
| 300 |
+
|
| 301 |
+
if transcripts:
|
| 302 |
+
for transcript in transcripts:
|
| 303 |
+
Exons = transcript['Exon']
|
| 304 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
| 305 |
+
transcript_id = transcript['id']
|
| 306 |
+
|
| 307 |
+
for Exon in Exons:
|
| 308 |
+
exon_id = Exon['id']
|
| 309 |
+
gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
|
| 310 |
+
if gene_sequence:
|
| 311 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
| 312 |
+
exon_chr = Exon['seq_region_name']
|
| 313 |
+
start = Exon['start']
|
| 314 |
+
end = Exon['end']
|
| 315 |
+
strand = Exon['strand']
|
| 316 |
+
|
| 317 |
+
targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,
|
| 318 |
+
transcript_id, exon_id, gene_symbol, vcf_reader)
|
| 319 |
+
if targets:
|
| 320 |
+
# Predict on-target efficiency for each gRNA site including mutations
|
| 321 |
+
formatted_data = format_prediction_output_with_mutation(targets, model_path)
|
| 322 |
+
results.extend(formatted_data)
|
| 323 |
+
else:
|
| 324 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
| 325 |
+
else:
|
| 326 |
+
print("Failed to retrieve transcripts.")
|
| 327 |
+
|
| 328 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
| 329 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
| 330 |
+
|
| 331 |
+
# Return the sorted output, combined gene sequences, and all exons
|
| 332 |
+
return sorted_results, all_gene_sequences, all_exons
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def create_genbank_features(data):
|
| 336 |
+
features = []
|
| 337 |
+
|
| 338 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
| 339 |
+
if isinstance(data, pd.DataFrame):
|
| 340 |
+
formatted_data = data.values.tolist()
|
| 341 |
+
elif isinstance(data, list):
|
| 342 |
+
formatted_data = data
|
| 343 |
+
else:
|
| 344 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
| 345 |
+
|
| 346 |
+
for row in formatted_data:
|
| 347 |
+
try:
|
| 348 |
+
start = int(row[1])
|
| 349 |
+
end = int(row[2])
|
| 350 |
+
except ValueError as e:
|
| 351 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
| 352 |
+
continue
|
| 353 |
+
|
| 354 |
+
strand = 1 if row[3] == '+' else -1
|
| 355 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
| 356 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
| 357 |
+
'label': row[7], # Use gRNA as the label
|
| 358 |
+
'note': f"Prediction: {row[8]}" # Include the prediction score
|
| 359 |
+
})
|
| 360 |
+
features.append(feature)
|
| 361 |
+
|
| 362 |
+
return features
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
| 366 |
+
# Ensure gene_sequence is a string before creating Seq object
|
| 367 |
+
if not isinstance(gene_sequence, str):
|
| 368 |
+
gene_sequence = str(gene_sequence)
|
| 369 |
+
|
| 370 |
+
features = create_genbank_features(df)
|
| 371 |
+
|
| 372 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
| 373 |
+
seq_obj = Seq(gene_sequence)
|
| 374 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
| 375 |
+
description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
|
| 376 |
+
record.annotations["molecule_type"] = "DNA"
|
| 377 |
+
SeqIO.write(record, output_path, "genbank")
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def create_bed_file_from_df(df, output_path):
|
| 381 |
+
with open(output_path, 'w') as bed_file:
|
| 382 |
+
for index, row in df.iterrows():
|
| 383 |
+
chrom = row["Chr"]
|
| 384 |
+
start = int(row["Start Pos"])
|
| 385 |
+
end = int(row["End Pos"])
|
| 386 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
| 387 |
+
gRNA = row["gRNA"]
|
| 388 |
+
score = str(row["Prediction"])
|
| 389 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
| 390 |
+
transcript_id = row["Transcript"]
|
| 391 |
+
|
| 392 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
| 393 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def create_csv_from_df(df, output_path):
|
| 397 |
+
df.to_csv(output_path, index=False)
|
cas9on.py
CHANGED
|
@@ -8,9 +8,7 @@ from Bio import SeqIO
|
|
| 8 |
from Bio.SeqRecord import SeqRecord
|
| 9 |
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
| 10 |
from Bio.Seq import Seq
|
| 11 |
-
|
| 12 |
-
import random
|
| 13 |
-
import pyBigWig
|
| 14 |
|
| 15 |
# configure GPUs
|
| 16 |
for gpu in tf.config.list_physical_devices('GPU'):
|
|
|
|
| 8 |
from Bio.SeqRecord import SeqRecord
|
| 9 |
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
| 10 |
from Bio.Seq import Seq
|
| 11 |
+
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# configure GPUs
|
| 14 |
for gpu in tf.config.list_physical_devices('GPU'):
|
requirements.txt
CHANGED
|
@@ -4,5 +4,8 @@ pandas==1.5.2
|
|
| 4 |
tensorflow==2.11.0
|
| 5 |
tensorflow-probability==0.19.0
|
| 6 |
plotly==5.18.0
|
|
|
|
|
|
|
|
|
|
| 7 |
gtracks
|
| 8 |
pyGenomeTracks
|
|
|
|
| 4 |
tensorflow==2.11.0
|
| 5 |
tensorflow-probability==0.19.0
|
| 6 |
plotly==5.18.0
|
| 7 |
+
tabulate
|
| 8 |
+
cyvcf2
|
| 9 |
+
parasail
|
| 10 |
gtracks
|
| 11 |
pyGenomeTracks
|