astirn commited on
Commit
f57c1f6
·
1 Parent(s): d78d0d1

find top guides for all transcripts and then scan off-targets simultaneously

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. tiger.py +53 -25
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ off_target.csv
2
+ on_target.csv
tiger.py CHANGED
@@ -1,7 +1,6 @@
1
  import argparse
2
  import os
3
  import gzip
4
- import numpy as np
5
  import pandas as pd
6
  import tensorflow as tf
7
  from Bio import SeqIO
@@ -15,6 +14,7 @@ NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
15
  NUM_TOP_GUIDES = 10
16
  NUM_MISMATCHES = 3
17
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
 
18
 
19
  # configure GPUs
20
  for gpu in tf.config.list_physical_devices('GPU'):
@@ -105,41 +105,42 @@ def predict_on_target(transcript_seq: str, model: tf.keras.Model):
105
  # get predictions
106
  normalized_lfc = model.predict_step(model_inputs)
107
  predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
108
- predictions = predictions.set_index('Guide').sort_values('Normalized LFC')
109
 
110
  return predictions
111
 
112
 
113
- def find_off_targets(guides, batch_size=500):
114
 
115
  # load reference transcripts
116
  reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
117
 
118
  # one-hot encode guides to form a filter
119
- guide_filter = one_hot_encode_sequence(sequence_complement(guides), add_context_padding=False)
120
  guide_filter = tf.transpose(guide_filter, [1, 2, 0])
121
  guide_filter = tf.cast(guide_filter, tf.float16)
122
 
123
  # loop over transcripts in batches
124
  i = 0
125
  print('Scanning for off-targets')
126
- df_off_targets = pd.DataFrame()
127
  while i < len(reference_transcripts):
128
  # select batch
129
- df_batch = reference_transcripts.iloc[i:min(i + batch_size, len(reference_transcripts))]
130
- i += batch_size
131
 
132
  # find and log off-targets
133
  transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
134
  transcripts = tf.cast(transcripts, guide_filter.dtype)
135
  num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
136
  loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
137
- df_off_targets = pd.concat([df_off_targets, pd.DataFrame({
138
- 'Guide': np.array(guides)[loc_off_targets[:, 2]],
139
- 'Isoform': df_batch.index.values[loc_off_targets[:, 0]],
 
 
140
  'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
141
  'Midpoint': loc_off_targets[:, 1],
142
- 'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
143
  })])
144
 
145
  # progress update
@@ -147,7 +148,7 @@ def find_off_targets(guides, batch_size=500):
147
  print('')
148
 
149
  # trim transcripts to targets
150
- dict_off_targets = df_off_targets.to_dict('records')
151
  for row in dict_off_targets:
152
  start_location = row['Midpoint'] - (GUIDE_LEN // 2)
153
  if start_location < CONTEXT_5P:
@@ -160,9 +161,9 @@ def find_off_targets(guides, batch_size=500):
160
  row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
161
  if row['Mismatches'] == 0 and 'N' not in row['Target']:
162
  assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN-CONTEXT_3P]])[0]
163
- df_off_targets = pd.DataFrame(dict_off_targets)
164
 
165
- return df_off_targets
166
 
167
 
168
  def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
@@ -174,12 +175,12 @@ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
174
  tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
175
  tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
176
  ], axis=-1)
177
- off_targets['Normalized LFC'] = model.predict_step(model_inputs)
178
 
179
- return off_targets.set_index('Guide').sort_values('Normalized LFC')
180
 
181
 
182
- def tiger_exhibit(transcript):
183
 
184
  # load model
185
  if os.path.exists('model'):
@@ -188,20 +189,47 @@ def tiger_exhibit(transcript):
188
  print('no saved model!')
189
  exit()
190
 
191
- # on-target predictions
192
- on_target_predictions = predict_on_target(transcript, model=tiger)
193
-
194
- # keep only top guides
195
- on_target_predictions = on_target_predictions.iloc[:NUM_TOP_GUIDES]
 
196
 
197
  # predict off-target effects for top guides
198
- off_targets = find_off_targets(on_target_predictions.index.values.tolist())
199
  off_target_predictions = predict_off_target(off_targets, model=tiger)
200
 
201
- return on_target_predictions, off_target_predictions
202
 
203
 
204
  if __name__ == '__main__':
205
 
 
 
 
 
 
 
206
  # simple test case
207
- print(tiger_exhibit('ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC'.lower())) # first 50 from EIF3B-003's CDS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import os
3
  import gzip
 
4
  import pandas as pd
5
  import tensorflow as tf
6
  from Bio import SeqIO
 
14
  NUM_TOP_GUIDES = 10
15
  NUM_MISMATCHES = 3
16
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
17
+ BATCH_SIZE = 500
18
 
19
  # configure GPUs
20
  for gpu in tf.config.list_physical_devices('GPU'):
 
105
  # get predictions
106
  normalized_lfc = model.predict_step(model_inputs)
107
  predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
108
+ predictions = predictions.sort_values('Normalized LFC')
109
 
110
  return predictions
111
 
112
 
113
+ def find_off_targets(top_guides: pd.DataFrame):
114
 
115
  # load reference transcripts
116
  reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
117
 
118
  # one-hot encode guides to form a filter
119
+ guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False)
120
  guide_filter = tf.transpose(guide_filter, [1, 2, 0])
121
  guide_filter = tf.cast(guide_filter, tf.float16)
122
 
123
  # loop over transcripts in batches
124
  i = 0
125
  print('Scanning for off-targets')
126
+ off_targets = pd.DataFrame()
127
  while i < len(reference_transcripts):
128
  # select batch
129
+ df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE, len(reference_transcripts))]
130
+ i += BATCH_SIZE
131
 
132
  # find and log off-targets
133
  transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
134
  transcripts = tf.cast(transcripts, guide_filter.dtype)
135
  num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
136
  loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
137
+ off_targets = pd.concat([off_targets, pd.DataFrame({
138
+ 'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
139
+ 'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
140
+ 'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
141
+ 'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
142
  'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
143
  'Midpoint': loc_off_targets[:, 1],
 
144
  })])
145
 
146
  # progress update
 
148
  print('')
149
 
150
  # trim transcripts to targets
151
+ dict_off_targets = off_targets.to_dict('records')
152
  for row in dict_off_targets:
153
  start_location = row['Midpoint'] - (GUIDE_LEN // 2)
154
  if start_location < CONTEXT_5P:
 
161
  row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
162
  if row['Mismatches'] == 0 and 'N' not in row['Target']:
163
  assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN-CONTEXT_3P]])[0]
164
+ off_targets = pd.DataFrame(dict_off_targets)
165
 
166
+ return off_targets
167
 
168
 
169
  def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
 
175
  tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
176
  tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
177
  ], axis=-1)
178
+ off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE, verbose=False)
179
 
180
+ return off_targets.sort_values('Normalized LFC')
181
 
182
 
183
+ def tiger_exhibit(transcripts: pd.DataFrame):
184
 
185
  # load model
186
  if os.path.exists('model'):
 
189
  print('no saved model!')
190
  exit()
191
 
192
+ # find top guides for each transcript
193
+ on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
194
+ for index, row in transcripts.iterrows():
195
+ df = predict_on_target(row['seq'], model=tiger)
196
+ df['On-target ID'] = index
197
+ on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
198
 
199
  # predict off-target effects for top guides
200
+ off_targets = find_off_targets(on_target_predictions)
201
  off_target_predictions = predict_off_target(off_targets, model=tiger)
202
 
203
+ return on_target_predictions.reset_index(drop=True), off_target_predictions.reset_index(drop=True)
204
 
205
 
206
  if __name__ == '__main__':
207
 
208
+ # common arguments
209
+ parser = argparse.ArgumentParser()
210
+ parser.add_argument('--fasta_path', type=str, default=None)
211
+ parser.add_argument('--simple_test', action='store_true', default=False)
212
+ args = parser.parse_args()
213
+
214
  # simple test case
215
+ if args.simple_test:
216
+ # first 50 from EIF3B-003's CDS
217
+ simple_test = pd.DataFrame(dict(id=['user entry'], seq=['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']))
218
+ simple_test.set_index('id', inplace=True)
219
+ df_on_target, df_off_target = tiger_exhibit(simple_test)
220
+ df_on_target.to_csv('on_target.csv')
221
+ df_off_target.to_csv('off_target.csv')
222
+
223
+ # # directory of fasta files
224
+ # elif args.dir_in is not None and os.path.exists(args.fasta_path):
225
+ # transcripts = pd.DataFrame()
226
+ # for fasta in os.listdir(args.fasta_path):
227
+ # df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(fasta, 'fasta')], columns=['id', 'seq'])
228
+ #
229
+ # try:
230
+ # for tran in SeqIO.parse(os.path.join(in_path, f), 'fasta'):
231
+ # on_targets, off_targets = tiger_exhibit(str(tran.seq))
232
+ # on_targets.to_csv(os.path.join(out_path, tran.id + '-top-guides.csv'))
233
+ # off_targets.to_csv(os.path.join(out_path, tran.id + '-off-targets.csv'))
234
+ # except Exception:
235
+ # warnings.warn(f)