astirn commited on
Commit
f311bf4
·
1 Parent(s): 89ffb34

fasta directory optimizations

Browse files
Files changed (1) hide show
  1. tiger.py +69 -48
tiger.py CHANGED
@@ -14,7 +14,9 @@ NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
14
  NUM_TOP_GUIDES = 10
15
  NUM_MISMATCHES = 3
16
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
17
- BATCH_SIZE = 500
 
 
18
 
19
  # configure GPUs
20
  for gpu in tf.config.list_physical_devices('GPU'):
@@ -42,7 +44,7 @@ def load_transcripts(fasta_files):
42
  # set index
43
  transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
44
  transcripts.set_index('id', inplace=True)
45
- assert not transcripts.index.has_duplicates
46
 
47
  return transcripts
48
 
@@ -72,7 +74,7 @@ def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):
72
  sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
73
 
74
  # one-hot encode
75
- sequence = tf.one_hot(sequence, depth=4)
76
 
77
  return sequence
78
 
@@ -103,7 +105,7 @@ def predict_on_target(transcript_seq: str, model: tf.keras.Model):
103
  target_seq, guide_seq, model_inputs = process_data(transcript_seq)
104
 
105
  # get predictions
106
- normalized_lfc = model.predict_step(model_inputs)
107
  predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
108
  predictions = predictions.sort_values('Normalized LFC')
109
 
@@ -118,7 +120,6 @@ def find_off_targets(top_guides: pd.DataFrame):
118
  # one-hot encode guides to form a filter
119
  guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False)
120
  guide_filter = tf.transpose(guide_filter, [1, 2, 0])
121
- guide_filter = tf.cast(guide_filter, tf.float16)
122
 
123
  # loop over transcripts in batches
124
  i = 0
@@ -126,43 +127,48 @@ def find_off_targets(top_guides: pd.DataFrame):
126
  off_targets = pd.DataFrame()
127
  while i < len(reference_transcripts):
128
  # select batch
129
- df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE, len(reference_transcripts))]
130
- i += BATCH_SIZE
131
 
132
- # find and log off-targets
133
  transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
134
- transcripts = tf.cast(transcripts, guide_filter.dtype)
135
  num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
136
  loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
137
- off_targets = pd.concat([off_targets, pd.DataFrame({
138
- 'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
139
- 'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
140
- 'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
141
- 'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
142
- 'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
143
- 'Midpoint': loc_off_targets[:, 1],
144
- })])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # progress update
147
  print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
148
  print('')
149
 
150
- # trim transcripts to targets
151
- dict_off_targets = off_targets.to_dict('records')
152
- for row in dict_off_targets:
153
- start_location = row['Midpoint'] - (GUIDE_LEN // 2)
154
- if start_location < CONTEXT_5P:
155
- row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P]
156
- row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target']
157
- elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']):
158
- row['Target'] = row['Target'][start_location - CONTEXT_5P:]
159
- row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target']))
160
- else:
161
- row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
162
- if row['Mismatches'] == 0 and 'N' not in row['Target']:
163
- assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN-CONTEXT_3P]])[0]
164
- off_targets = pd.DataFrame(dict_off_targets)
165
-
166
  return off_targets
167
 
168
 
@@ -175,7 +181,7 @@ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
175
  tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
176
  tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
177
  ], axis=-1)
178
- off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE, verbose=False)
179
 
180
  return off_targets.sort_values('Normalized LFC')
181
 
@@ -190,12 +196,17 @@ def tiger_exhibit(transcripts: pd.DataFrame):
190
  exit()
191
 
192
  # find top guides for each transcript
 
193
  on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
194
- for index, row in transcripts.iterrows():
195
  df = predict_on_target(row['seq'], model=tiger)
196
  df['On-target ID'] = index
197
  on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
198
 
 
 
 
 
199
  # predict off-target effects for top guides
200
  off_targets = find_off_targets(on_target_predictions)
201
  off_target_predictions = predict_off_target(off_targets, model=tiger)
@@ -220,16 +231,26 @@ if __name__ == '__main__':
220
  df_on_target.to_csv('on_target.csv')
221
  df_off_target.to_csv('off_target.csv')
222
 
223
- # # directory of fasta files
224
- # elif args.dir_in is not None and os.path.exists(args.fasta_path):
225
- # transcripts = pd.DataFrame()
226
- # for fasta in os.listdir(args.fasta_path):
227
- # df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(fasta, 'fasta')], columns=['id', 'seq'])
228
- #
229
- # try:
230
- # for tran in SeqIO.parse(os.path.join(in_path, f), 'fasta'):
231
- # on_targets, off_targets = tiger_exhibit(str(tran.seq))
232
- # on_targets.to_csv(os.path.join(out_path, tran.id + '-top-guides.csv'))
233
- # off_targets.to_csv(os.path.join(out_path, tran.id + '-off-targets.csv'))
234
- # except Exception:
235
- # warnings.warn(f)
 
 
 
 
 
 
 
 
 
 
 
14
  NUM_TOP_GUIDES = 10
15
  NUM_MISMATCHES = 3
16
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
17
+ BATCH_SIZE_COMPUTE = 500
18
+ BATCH_SIZE_SCAN = 20
19
+ BATCH_SIZE_TRANSCRIPTS = 50
20
 
21
  # configure GPUs
22
  for gpu in tf.config.list_physical_devices('GPU'):
 
44
  # set index
45
  transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
46
  transcripts.set_index('id', inplace=True)
47
+ assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected"
48
 
49
  return transcripts
50
 
 
74
  sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
75
 
76
  # one-hot encode
77
+ sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16)
78
 
79
  return sequence
80
 
 
105
  target_seq, guide_seq, model_inputs = process_data(transcript_seq)
106
 
107
  # get predictions
108
+ normalized_lfc = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
109
  predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
110
  predictions = predictions.sort_values('Normalized LFC')
111
 
 
120
  # one-hot encode guides to form a filter
121
  guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False)
122
  guide_filter = tf.transpose(guide_filter, [1, 2, 0])
 
123
 
124
  # loop over transcripts in batches
125
  i = 0
 
127
  off_targets = pd.DataFrame()
128
  while i < len(reference_transcripts):
129
  # select batch
130
+ df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))]
131
+ i += BATCH_SIZE_SCAN
132
 
133
+ # find locations of off-targets
134
  transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
 
135
  num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
136
  loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
137
+
138
+ # off-targets discovered
139
+ if len(loc_off_targets) > 0:
140
+
141
+ # log off-targets
142
+ dict_off_targets = pd.DataFrame({
143
+ 'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
144
+ 'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
145
+ 'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
146
+ 'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
147
+ 'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
148
+ 'Midpoint': loc_off_targets[:, 1],
149
+ }).to_dict('records')
150
+
151
+ # trim transcripts to targets
152
+ for row in dict_off_targets:
153
+ start_location = row['Midpoint'] - (GUIDE_LEN // 2)
154
+ if start_location < CONTEXT_5P:
155
+ row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P]
156
+ row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target']
157
+ elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']):
158
+ row['Target'] = row['Target'][start_location - CONTEXT_5P:]
159
+ row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target']))
160
+ else:
161
+ row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
162
+ if row['Mismatches'] == 0 and 'N' not in row['Target']:
163
+ assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0]
164
+
165
+ # append new off-targets
166
+ off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
167
 
168
  # progress update
169
  print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
170
  print('')
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  return off_targets
173
 
174
 
 
181
  tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
182
  tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
183
  ], axis=-1)
184
+ off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
185
 
186
  return off_targets.sort_values('Normalized LFC')
187
 
 
196
  exit()
197
 
198
  # find top guides for each transcript
199
+ print('Finding top guides for each transcript')
200
  on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
201
+ for i, (index, row) in enumerate(transcripts.iterrows()):
202
  df = predict_on_target(row['seq'], model=tiger)
203
  df['On-target ID'] = index
204
  on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
205
 
206
+ # progress update
207
+ print('\rPercent complete: {:.2f}%'.format(100 * min((i + 1) / len(transcripts), 1)), end='')
208
+ print('')
209
+
210
  # predict off-target effects for top guides
211
  off_targets = find_off_targets(on_target_predictions)
212
  off_target_predictions = predict_off_target(off_targets, model=tiger)
 
231
  df_on_target.to_csv('on_target.csv')
232
  df_off_target.to_csv('off_target.csv')
233
 
234
+ # directory of fasta files
235
+ elif args.fasta_path is not None and os.path.exists(args.fasta_path):
236
+
237
+ # load transcripts
238
+ df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)])
239
+
240
+ # process in batches
241
+ df_on_target = pd.DataFrame()
242
+ df_off_target = pd.DataFrame()
243
+ batch = 1
244
+ num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS
245
+ num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0)
246
+ for t in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS):
247
+ print('Batch {:d} of {:d}'.format(batch, num_batches))
248
+ t_stop = min(t + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
249
+ df_on_target_new, df_off_target_new = tiger_exhibit(df_transcripts[t:t_stop])
250
+ df_on_target = pd.concat([df_on_target, df_on_target_new])
251
+ df_off_target = pd.concat([df_off_target, df_off_target_new])
252
+ batch += 1
253
+
254
+ # save results
255
+ df_on_target.to_csv('on_target.csv')
256
+ df_off_target.to_csv('off_target.csv')