astirn commited on
Commit
6009ef4
·
2 Parent(s): 8c452e8 9f169cd

Merge branch 'main' of https://huggingface.co/spaces/Knowles-Lab/tiger into main

Browse files
Files changed (2) hide show
  1. app.py +62 -13
  2. tiger.py +10 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import pandas as pd
2
  import streamlit as st
3
- from tiger import tiger_exhibit, TARGET_LEN, NUCLEOTIDE_TOKENS
 
4
 
5
 
6
  @st.cache
@@ -8,23 +9,71 @@ def convert_df(df):
8
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
9
  return df.to_csv().encode('utf-8')
10
 
11
-
12
  # title and instructions
13
  st.title('TIGER Cas13 Efficacy Prediction')
14
- st.session_state['userInput'] = ''
15
- st.session_state['userInput'] = st.text_input(
 
 
 
 
 
16
  label='Enter a target transcript:',
17
- # value='ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC',
18
  placeholder='Upper or lower case')
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # input is too short
21
- if len(st.session_state['userInput']) < TARGET_LEN:
22
- transcript_len = len(st.session_state['userInput'])
23
- st.write('Transcript length ({:d}) must be at least {:d} bases.'.format(transcript_len, TARGET_LEN))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # valid input
26
- elif all([True if nt.upper() in NUCLEOTIDE_TOKENS.keys() else False for nt in st.session_state['userInput']]):
27
- on_target, off_target = tiger_exhibit(pd.DataFrame(dict(id=['ManualEntry'], seq=[st.session_state['userInput']])))
 
28
  st.write('On-target predictions: ', on_target)
29
  st.download_button(label='Download', data=convert_df(on_target), file_name='on_target.csv', mime='text/csv')
30
  if len(off_target) > 0:
@@ -34,5 +83,5 @@ elif all([True if nt.upper() in NUCLEOTIDE_TOKENS.keys() else False for nt in st
34
  st.write('We did not find any off-target effects!')
35
 
36
  # invalid input
37
- else:
38
- st.write('Nucleotides other than ACGT detected!')
 
1
  import pandas as pd
2
  import streamlit as st
3
+ import os, shutil
4
+ from tiger import tiger_exhibit, load_transcripts, TARGET_LEN, NUCLEOTIDE_TOKENS
5
 
6
 
7
  @st.cache
 
9
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
10
  return df.to_csv().encode('utf-8')
11
 
 
12
  # title and instructions
13
  st.title('TIGER Cas13 Efficacy Prediction')
14
+ st.session_state["fasta_seq"] = ""
15
+ st.session_state["text_seq"] = ""
16
+ status_bar, status_text = None, None
17
+
18
+ # UserInput Form from text input
19
+ text_form = st.form("text")
20
+ text_input = text_form.text_input(
21
  label='Enter a target transcript:',
22
+ #value='ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC',
23
  placeholder='Upper or lower case')
24
+ if text_input:
25
+ # input is too short
26
+ if len(text_input) < TARGET_LEN:
27
+ transcript_len = len(text_input)
28
+ text_form.write('Transcript length ({:d}) must be at least {:d} bases.'.format(transcript_len, TARGET_LEN))
29
+ else:
30
+ st.session_state["text_seq"] = text_input
31
+ text_calc = text_form.form_submit_button(label="calculate")
32
+ #status bar
33
+ status_text_textform = text_form.empty()
34
+ status_bar_textform = text_form.progress(0)
35
+
36
 
37
+ # UserInput Form from file
38
+ fasta_form = st.form("fasta")
39
+ fasta = fasta_form.file_uploader(label="upload fasta file")
40
+ if fasta:
41
+ if os.path.exists("temp"):
42
+ shutil.rmtree("temp")
43
+ os.makedirs("temp")
44
+ fname = fasta.name
45
+ st.write(fname)
46
+ fpath = os.path.join("temp", fname)
47
+ with open(fpath, "w") as f:
48
+ f.write(fasta.getvalue().decode("utf-8"))
49
+ transcript_tbl = load_transcripts([fpath])
50
+ fasta_form.text("fasta file contents")
51
+ fasta_form.write(transcript_tbl)
52
+ seq = transcript_tbl['seq'][0]
53
+ st.session_state["fasta_seq"] = seq
54
+ fasta_calc = fasta_form.form_submit_button(label="calculate")
55
+ status_text_fastaform = fasta_form.empty()
56
+ status_bar_fastaform = fasta_form.progress(0)
57
+ #st.write(text_calc)
58
+ #st.write(fasta_calc)
59
+
60
+ #Calculation
61
+ if text_calc:
62
+ src_seq = st.session_state["text_seq"]
63
+ status_text = status_text_textform
64
+ status_bar= status_bar_textform
65
+ elif fasta_calc:
66
+ src_seq = st.session_state["fasta_seq"]
67
+ status_text = status_text_fastaform
68
+ status_bar= status_bar_fastaform
69
+ else:
70
+ src_seq = ""
71
+ #st.write(src_seq)
72
 
73
  # valid input
74
+ if src_seq and all([True if nt.upper() in NUCLEOTIDE_TOKENS.keys() else False for nt in src_seq]):
75
+ on_target, off_target = tiger_exhibit(pd.DataFrame(dict(id=['ManualEntry'], seq=[src_seq])),
76
+ status_bar, status_text)
77
  st.write('On-target predictions: ', on_target)
78
  st.download_button(label='Download', data=convert_df(on_target), file_name='on_target.csv', mime='text/csv')
79
  if len(off_target) > 0:
 
83
  st.write('We did not find any off-target effects!')
84
 
85
  # invalid input
86
+ #else:
87
+ # st.write('Nucleotides other than ACGT detected!')
tiger.py CHANGED
@@ -24,7 +24,6 @@ for gpu in tf.config.list_physical_devices('GPU'):
24
  if len(tf.config.list_physical_devices('GPU')) > 0:
25
  tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
26
 
27
-
28
  def load_transcripts(fasta_files):
29
 
30
  # load all transcripts from fasta files into a DataFrame
@@ -95,7 +94,7 @@ def process_data(transcript_seq: str):
95
  tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
96
  tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
97
  ], axis=-1)
98
-
99
  return target_seq, guide_seq, model_inputs
100
 
101
 
@@ -112,7 +111,7 @@ def predict_on_target(transcript_seq: str, model: tf.keras.Model):
112
  return predictions
113
 
114
 
115
- def find_off_targets(top_guides: pd.DataFrame):
116
 
117
  # load reference transcripts
118
  reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
@@ -166,6 +165,9 @@ def find_off_targets(top_guides: pd.DataFrame):
166
  off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
167
 
168
  # progress update
 
 
 
169
  print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
170
  print('')
171
 
@@ -186,7 +188,7 @@ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
186
  return off_targets.sort_values('Normalized LFC')
187
 
188
 
189
- def tiger_exhibit(transcripts: pd.DataFrame):
190
 
191
  # load model
192
  if os.path.exists('model'):
@@ -204,11 +206,14 @@ def tiger_exhibit(transcripts: pd.DataFrame):
204
  on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
205
 
206
  # progress update
 
 
 
207
  print('\rPercent complete: {:.2f}%'.format(100 * min((i + 1) / len(transcripts), 1)), end='')
208
  print('')
209
 
210
  # predict off-target effects for top guides
211
- off_targets = find_off_targets(on_target_predictions)
212
  off_target_predictions = predict_off_target(off_targets, model=tiger)
213
 
214
  # reverse guide sequences
 
24
  if len(tf.config.list_physical_devices('GPU')) > 0:
25
  tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
26
 
 
27
  def load_transcripts(fasta_files):
28
 
29
  # load all transcripts from fasta files into a DataFrame
 
94
  tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
95
  tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
96
  ], axis=-1)
97
+ print(model_inputs)
98
  return target_seq, guide_seq, model_inputs
99
 
100
 
 
111
  return predictions
112
 
113
 
114
+ def find_off_targets(top_guides: pd.DataFrame, status_bar, status_text):
115
 
116
  # load reference transcripts
117
  reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
 
165
  off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
166
 
167
  # progress update
168
+ if status_bar:
169
+ status_text.text("Scanning for off-targets Percent complete: {:.2f}%".format(int(100 * min(i / len(reference_transcripts), 1))))
170
+ status_bar.progress(int(100 * min(i / len(reference_transcripts), 1)))
171
  print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(reference_transcripts), 1)), end='')
172
  print('')
173
 
 
188
  return off_targets.sort_values('Normalized LFC')
189
 
190
 
191
+ def tiger_exhibit(transcripts: pd.DataFrame, status_bar=None, status_text=None):
192
 
193
  # load model
194
  if os.path.exists('model'):
 
206
  on_target_predictions = pd.concat([on_target_predictions, df.iloc[:NUM_TOP_GUIDES]])
207
 
208
  # progress update
209
+ if status_bar:
210
+ status_text.text("Scanning for on-targets Percent complete: {:.2f}%".format(100 * min((i + 1) / len(transcripts), 1)))
211
+ status_bar.progress(int(100 * min((i + 1) / len(transcripts), 1)))
212
  print('\rPercent complete: {:.2f}%'.format(100 * min((i + 1) / len(transcripts), 1)), end='')
213
  print('')
214
 
215
  # predict off-target effects for top guides
216
+ off_targets = find_off_targets(on_target_predictions, status_bar, status_text)
217
  off_target_predictions = predict_off_target(off_targets, model=tiger)
218
 
219
  # reverse guide sequences