Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
fasta directory optimizations
Browse files
tiger.py
CHANGED
|
@@ -14,7 +14,9 @@ NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
|
|
| 14 |
NUM_TOP_GUIDES = 10
|
| 15 |
NUM_MISMATCHES = 3
|
| 16 |
REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# configure GPUs
|
| 20 |
for gpu in tf.config.list_physical_devices('GPU'):
|
|
@@ -42,7 +44,7 @@ def load_transcripts(fasta_files):
|
|
| 42 |
# set index
|
| 43 |
transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
|
| 44 |
transcripts.set_index('id', inplace=True)
|
| 45 |
-
assert not transcripts.index.has_duplicates
|
| 46 |
|
| 47 |
return transcripts
|
| 48 |
|
|
@@ -72,7 +74,7 @@ def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):
|
|
| 72 |
sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
|
| 73 |
|
| 74 |
# one-hot encode
|
| 75 |
-
sequence = tf.one_hot(sequence, depth=4)
|
| 76 |
|
| 77 |
return sequence
|
| 78 |
|
|
@@ -103,7 +105,7 @@ def predict_on_target(transcript_seq: str, model: tf.keras.Model):
|
|
| 103 |
target_seq, guide_seq, model_inputs = process_data(transcript_seq)
|
| 104 |
|
| 105 |
# get predictions
|
| 106 |
-
normalized_lfc = model.
|
| 107 |
predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
|
| 108 |
predictions = predictions.sort_values('Normalized LFC')
|
| 109 |
|
|
@@ -118,7 +120,6 @@ def find_off_targets(top_guides: pd.DataFrame):
|
|
| 118 |
# one-hot encode guides to form a filter
|
| 119 |
guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False)
|
| 120 |
guide_filter = tf.transpose(guide_filter, [1, 2, 0])
|
| 121 |
-
guide_filter = tf.cast(guide_filter, tf.float16)
|
| 122 |
|
| 123 |
# loop over transcripts in batches
|
| 124 |
i = 0
|
|
@@ -126,43 +127,48 @@ def find_off_targets(top_guides: pd.DataFrame):
|
|
| 126 |
off_targets = pd.DataFrame()
|
| 127 |
while i < len(reference_transcripts):
|
| 128 |
# select batch
|
| 129 |
-
df_batch = reference_transcripts.iloc[i:min(i +
|
| 130 |
-
i +=
|
| 131 |
|
| 132 |
-
# find
|
| 133 |
transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
|
| 134 |
-
transcripts = tf.cast(transcripts, guide_filter.dtype)
|
| 135 |
num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
|
| 136 |
loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# progress update
|
| 147 |
print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
|
| 148 |
print('')
|
| 149 |
|
| 150 |
-
# trim transcripts to targets
|
| 151 |
-
dict_off_targets = off_targets.to_dict('records')
|
| 152 |
-
for row in dict_off_targets:
|
| 153 |
-
start_location = row['Midpoint'] - (GUIDE_LEN // 2)
|
| 154 |
-
if start_location < CONTEXT_5P:
|
| 155 |
-
row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P]
|
| 156 |
-
row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target']
|
| 157 |
-
elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']):
|
| 158 |
-
row['Target'] = row['Target'][start_location - CONTEXT_5P:]
|
| 159 |
-
row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target']))
|
| 160 |
-
else:
|
| 161 |
-
row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
|
| 162 |
-
if row['Mismatches'] == 0 and 'N' not in row['Target']:
|
| 163 |
-
assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN-CONTEXT_3P]])[0]
|
| 164 |
-
off_targets = pd.DataFrame(dict_off_targets)
|
| 165 |
-
|
| 166 |
return off_targets
|
| 167 |
|
| 168 |
|
|
@@ -175,7 +181,7 @@ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
|
|
| 175 |
tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
|
| 176 |
tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
|
| 177 |
], axis=-1)
|
| 178 |
-
off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=
|
| 179 |
|
| 180 |
return off_targets.sort_values('Normalized LFC')
|
| 181 |
|
|
@@ -190,12 +196,17 @@ def tiger_exhibit(transcripts: pd.DataFrame):
|
|
| 190 |
exit()
|
| 191 |
|
| 192 |
# find top guides for each transcript
|
|
|
|
| 193 |
on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
|
| 194 |
-
for index, row in transcripts.iterrows():
|
| 195 |
df = predict_on_target(row['seq'], model=tiger)
|
| 196 |
df['On-target ID'] = index
|
| 197 |
on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
# predict off-target effects for top guides
|
| 200 |
off_targets = find_off_targets(on_target_predictions)
|
| 201 |
off_target_predictions = predict_off_target(off_targets, model=tiger)
|
|
@@ -220,16 +231,26 @@ if __name__ == '__main__':
|
|
| 220 |
df_on_target.to_csv('on_target.csv')
|
| 221 |
df_off_target.to_csv('off_target.csv')
|
| 222 |
|
| 223 |
-
#
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
NUM_TOP_GUIDES = 10
|
| 15 |
NUM_MISMATCHES = 3
|
| 16 |
REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
|
| 17 |
+
BATCH_SIZE_COMPUTE = 500
|
| 18 |
+
BATCH_SIZE_SCAN = 20
|
| 19 |
+
BATCH_SIZE_TRANSCRIPTS = 50
|
| 20 |
|
| 21 |
# configure GPUs
|
| 22 |
for gpu in tf.config.list_physical_devices('GPU'):
|
|
|
|
| 44 |
# set index
|
| 45 |
transcripts['id'] = transcripts['id'].apply(lambda s: s.split('|')[0])
|
| 46 |
transcripts.set_index('id', inplace=True)
|
| 47 |
+
assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected"
|
| 48 |
|
| 49 |
return transcripts
|
| 50 |
|
|
|
|
| 74 |
sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
|
| 75 |
|
| 76 |
# one-hot encode
|
| 77 |
+
sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16)
|
| 78 |
|
| 79 |
return sequence
|
| 80 |
|
|
|
|
| 105 |
target_seq, guide_seq, model_inputs = process_data(transcript_seq)
|
| 106 |
|
| 107 |
# get predictions
|
| 108 |
+
normalized_lfc = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
|
| 109 |
predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
|
| 110 |
predictions = predictions.sort_values('Normalized LFC')
|
| 111 |
|
|
|
|
| 120 |
# one-hot encode guides to form a filter
|
| 121 |
guide_filter = one_hot_encode_sequence(sequence_complement(top_guides['Guide']), add_context_padding=False)
|
| 122 |
guide_filter = tf.transpose(guide_filter, [1, 2, 0])
|
|
|
|
| 123 |
|
| 124 |
# loop over transcripts in batches
|
| 125 |
i = 0
|
|
|
|
| 127 |
off_targets = pd.DataFrame()
|
| 128 |
while i < len(reference_transcripts):
|
| 129 |
# select batch
|
| 130 |
+
df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))]
|
| 131 |
+
i += BATCH_SIZE_SCAN
|
| 132 |
|
| 133 |
+
# find locations of off-targets
|
| 134 |
transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
|
|
|
|
| 135 |
num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
|
| 136 |
loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
|
| 137 |
+
|
| 138 |
+
# off-targets discovered
|
| 139 |
+
if len(loc_off_targets) > 0:
|
| 140 |
+
|
| 141 |
+
# log off-targets
|
| 142 |
+
dict_off_targets = pd.DataFrame({
|
| 143 |
+
'On-target ID': top_guides.iloc[loc_off_targets[:, 2]]['On-target ID'],
|
| 144 |
+
'Guide': top_guides.iloc[loc_off_targets[:, 2]]['Guide'],
|
| 145 |
+
'Off-target ID': df_batch.index.values[loc_off_targets[:, 0]],
|
| 146 |
+
'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
|
| 147 |
+
'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
|
| 148 |
+
'Midpoint': loc_off_targets[:, 1],
|
| 149 |
+
}).to_dict('records')
|
| 150 |
+
|
| 151 |
+
# trim transcripts to targets
|
| 152 |
+
for row in dict_off_targets:
|
| 153 |
+
start_location = row['Midpoint'] - (GUIDE_LEN // 2)
|
| 154 |
+
if start_location < CONTEXT_5P:
|
| 155 |
+
row['Target'] = row['Target'][0:GUIDE_LEN + CONTEXT_3P]
|
| 156 |
+
row['Target'] = 'N' * (TARGET_LEN - len(row['Target'])) + row['Target']
|
| 157 |
+
elif start_location + GUIDE_LEN + CONTEXT_3P > len(row['Target']):
|
| 158 |
+
row['Target'] = row['Target'][start_location - CONTEXT_5P:]
|
| 159 |
+
row['Target'] = row['Target'] + 'N' * (TARGET_LEN - len(row['Target']))
|
| 160 |
+
else:
|
| 161 |
+
row['Target'] = row['Target'][start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
|
| 162 |
+
if row['Mismatches'] == 0 and 'N' not in row['Target']:
|
| 163 |
+
assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0]
|
| 164 |
+
|
| 165 |
+
# append new off-targets
|
| 166 |
+
off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
|
| 167 |
|
| 168 |
# progress update
|
| 169 |
print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
|
| 170 |
print('')
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
return off_targets
|
| 173 |
|
| 174 |
|
|
|
|
| 181 |
tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
|
| 182 |
tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
|
| 183 |
], axis=-1)
|
| 184 |
+
off_targets['Normalized LFC'] = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)
|
| 185 |
|
| 186 |
return off_targets.sort_values('Normalized LFC')
|
| 187 |
|
|
|
|
| 196 |
exit()
|
| 197 |
|
| 198 |
# find top guides for each transcript
|
| 199 |
+
print('Finding top guides for each transcript')
|
| 200 |
on_target_predictions = pd.DataFrame(columns=['On-target ID', 'Guide', 'Normalized LFC'])
|
| 201 |
+
for i, (index, row) in enumerate(transcripts.iterrows()):
|
| 202 |
df = predict_on_target(row['seq'], model=tiger)
|
| 203 |
df['On-target ID'] = index
|
| 204 |
on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
|
| 205 |
|
| 206 |
+
# progress update
|
| 207 |
+
print('\rPercent complete: {:.2f}%'.format(100 * min((i + 1) / len(transcripts), 1)), end='')
|
| 208 |
+
print('')
|
| 209 |
+
|
| 210 |
# predict off-target effects for top guides
|
| 211 |
off_targets = find_off_targets(on_target_predictions)
|
| 212 |
off_target_predictions = predict_off_target(off_targets, model=tiger)
|
|
|
|
| 231 |
df_on_target.to_csv('on_target.csv')
|
| 232 |
df_off_target.to_csv('off_target.csv')
|
| 233 |
|
| 234 |
+
# directory of fasta files
|
| 235 |
+
elif args.fasta_path is not None and os.path.exists(args.fasta_path):
|
| 236 |
+
|
| 237 |
+
# load transcripts
|
| 238 |
+
df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)])
|
| 239 |
+
|
| 240 |
+
# process in batches
|
| 241 |
+
df_on_target = pd.DataFrame()
|
| 242 |
+
df_off_target = pd.DataFrame()
|
| 243 |
+
batch = 1
|
| 244 |
+
num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS
|
| 245 |
+
num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0)
|
| 246 |
+
for t in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS):
|
| 247 |
+
print('Batch {:d} of {:d}'.format(batch, num_batches))
|
| 248 |
+
t_stop = min(t + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
|
| 249 |
+
df_on_target_new, df_off_target_new = tiger_exhibit(df_transcripts[t:t_stop])
|
| 250 |
+
df_on_target = pd.concat([df_on_target, df_on_target_new])
|
| 251 |
+
df_off_target = pd.concat([df_off_target, df_off_target_new])
|
| 252 |
+
batch += 1
|
| 253 |
+
|
| 254 |
+
# save results
|
| 255 |
+
df_on_target.to_csv('on_target.csv')
|
| 256 |
+
df_off_target.to_csv('off_target.csv')
|