tiger / tiger.py
astirn's picture
find top guides for all transcripts and then scan off-targets simultaneously
f57c1f6
raw
history blame
9.79 kB
import argparse
import os
import gzip
import pandas as pd
import tensorflow as tf
from Bio import SeqIO
GUIDE_LEN = 23
CONTEXT_5P = 3
CONTEXT_3P = 0
TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
NUM_TOP_GUIDES = 10
NUM_MISMATCHES = 3
REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
BATCH_SIZE = 500
# configure GPUs
for gpu in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, enable=True)
if len(tf.config.list_physical_devices('GPU')) > 0:
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
def load_transcripts(fasta_files):
# load all transcripts from fasta files into a DataFrame
transcripts = pd.DataFrame()
for file in fasta_files:
try:
if os.path.splitext(file)[1] == '.gz':
with gzip.open(file, 'rt') as f:
df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=['id', 'seq'])
else:
df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=['id', 'seq'])
except Exception as e:
print(e, 'while loading', file)
continue
transcripts = pd.concat([transcripts, df])
# set index
transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
transcripts.set_index('id', inplace=True)
assert not transcripts.index.has_duplicates
return transcripts
def sequence_complement(sequence: list):
return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]
def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):
# stack list of sequences into a tensor
sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)
# tokenize sequence
nucleotide_table = tf.lookup.StaticVocabularyTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
num_oov_buckets=1)
sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
row_splits=sequence.row_splits).to_tensor(255)
# add context padding if requested
if add_context_padding:
pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
# one-hot encode
sequence = tf.one_hot(sequence, depth=4)
return sequence
def process_data(transcript_seq: str):
# convert to upper case
transcript_seq = transcript_seq.upper()
# get all target sites
target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)]
# prepare guide sequences
guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])
# model inputs
model_inputs = tf.concat([
tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
], axis=-1)
return target_seq, guide_seq, model_inputs
def predict_on_target(transcript_seq: str, model: tf.keras.Model):
# parse transcript sequence
target_seq, guide_seq, model_inputs = process_data(transcript_seq)
# get predictions
normalized_lfc = model.predict_step(model_inputs)
predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
predictions = predictions.sort_values('Normalized LFC')
return predictions
def find_off_targets(top_guides: pd.DataFrame):
# load reference transcripts
reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
# one-hot encode guides to form a filter
guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False)
guide_filter = tf.transpose(guide_filter, [1, 2, 0])
guide_filter = tf.cast(guide_filter, tf.float16)
# loop over transcripts in batches
i = 0
print('Scanning for off-targets')
off_targets = pd.DataFrame()
while i < len(reference_transcripts):
# select batch
df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE, len(reference_transcripts))]
i += BATCH_SIZE
# find and log off-targets
transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
transcripts = tf.cast(transcripts, guide_filter.dtype)
num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
off_targets = pd.concat([off_targets, pd.DataFrame({
'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
'Midpoint': loc_off_targets[:, 1],
})])
# progress update
print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
print('')
# trim transcripts to targets
dict_off_targets = off_targets.to_dict('records')
for row in dict_off_targets:
start_location = row['Midpoint'] - (GUIDE_LEN // 2)
if start_location < CONTEXT_5P:
row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P]
row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target']
elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']):
row['Target'] = row['Target'][start_location - CONTEXT_5P:]
row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target']))
else:
row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
if row['Mismatches'] == 0 and 'N' not in row['Target']:
assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN-CONTEXT_3P]])[0]
off_targets = pd.DataFrame(dict_off_targets)
return off_targets
def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
if len(off_targets) == 0:
return pd.DataFrame()
# append predictions off-target predictions
model_inputs = tf.concat([
tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
], axis=-1)
off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE, verbose=False)
return off_targets.sort_values('Normalized LFC')
def tiger_exhibit(transcripts: pd.DataFrame):
# load model
if os.path.exists('model'):
tiger = tf.keras.models.load_model('model')
else:
print('no saved model!')
exit()
# find top guides for each transcript
on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
for index, row in transcripts.iterrows():
df = predict_on_target(row['seq'], model=tiger)
df['On-target ID'] = index
on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
# predict off-target effects for top guides
off_targets = find_off_targets(on_target_predictions)
off_target_predictions = predict_off_target(off_targets, model=tiger)
return on_target_predictions.reset_index(drop=True), off_target_predictions.reset_index(drop=True)
if __name__ == '__main__':
# common arguments
parser = argparse.ArgumentParser()
parser.add_argument('--fasta_path', type=str, default=None)
parser.add_argument('--simple_test', action='store_true', default=False)
args = parser.parse_args()
# simple test case
if args.simple_test:
# first 50 from EIF3B-003's CDS
simple_test = pd.DataFrame(dict(id=['user entry'], seq=['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']))
simple_test.set_index('id', inplace=True)
df_on_target, df_off_target = tiger_exhibit(simple_test)
df_on_target.to_csv('on_target.csv')
df_off_target.to_csv('off_target.csv')
# # directory of fasta files
# elif args.dir_in is not None and os.path.exists(args.fasta_path):
# transcripts = pd.DataFrame()
# for fasta in os.listdir(args.fasta_path):
# df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(fasta, 'fasta')], columns=['id', 'seq'])
#
# try:
# for tran in SeqIO.parse(os.path.join(in_path, f), 'fasta'):
# on_targets, off_targets = tiger_exhibit(str(tran.seq))
# on_targets.to_csv(os.path.join(out_path, tran.id + '-top-guides.csv'))
# off_targets.to_csv(os.path.join(out_path, tran.id + '-off-targets.csv'))
# except Exception:
# warnings.warn(f)