Ym420 commited on
Commit
173c2d0
·
verified ·
1 Parent(s): ec4cd41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -77
app.py CHANGED
@@ -2,72 +2,80 @@ import gradio as gr
2
  import joblib
3
  from huggingface_hub import hf_hub_download
4
  import numpy as np
5
- import pandas as pd # Needed for DataFrame input to model
6
 
7
- # --- Download ensemble model from HF repo (single ensemble) ---
 
 
 
 
 
 
 
 
 
 
8
  repo_id = "Ym420/terminator-ensemble-classification"
9
  ensemble_path = hf_hub_download(repo_id=repo_id, filename="ensemble.pkl")
10
- ensemble = joblib.load(ensemble_path) # Load exactly as in Colab
11
 
12
  # --- Bendability dictionary ---
13
  bend_dict = {
14
- "AAA": -0.274, "AAC": -0.205, "AAG": -0.081, "AAT": -0.280,
15
- "ACA": -0.006, "ACC": -0.032, "ACG": -0.033, "ACT": -0.183,
16
- "AGA": 0.027, "AGC": 0.017, "AGG": -0.057, "AGT": -0.183,
17
- "ATA": 0.182, "ATC": -0.110, "ATG": 0.134, "ATT": -0.280,
18
- "CAA": 0.015, "CAC": 0.040, "CAG": 0.175, "CAT": 0.134,
19
- "CCA": -0.246, "CCC": -0.012, "CCG": -0.136, "CCT": -0.057,
20
- "CGA": -0.003, "CGC": -0.077, "CGG": -0.136, "CGT": -0.033,
21
- "CTA": 0.090, "CTC": 0.031, "CTG": 0.175, "CTT": -0.081,
22
- "GAA": -0.037, "GAC": -0.013, "GAG": 0.031, "GAT": -0.110,
23
- "GCA": 0.076, "GCC": 0.107, "GCG": -0.077, "GCT": 0.017,
24
- "GGA": 0.013, "GGC": 0.107, "GGG": -0.012, "GGT": -0.032,
25
- "GTA": 0.025, "GTC": -0.013, "GTG": 0.040, "GTT": -0.205,
26
- "TAA": 0.068, "TAC": 0.025, "TAG": 0.090, "TAT": 0.182,
27
- "TCA": 0.194, "TCC": 0.013, "TCG": -0.003, "TCT": 0.027,
28
- "TGA": 0.194, "TGC": 0.076, "TGG": -0.246, "TGT": -0.006,
29
- "TTA": 0.068, "TTC": -0.037, "TTG": 0.015, "TTT": -0.274
30
  }
31
 
32
- # --- Feature functions (match training exactly) ---
33
  def gc_content(seq):
34
  seq = seq.upper()
35
  return (seq.count("G") + seq.count("C")) / len(seq) if len(seq) > 0 else 0
36
 
37
  def cpg_ratio(seq):
38
  seq = seq.upper()
39
- length = len(seq)
40
- if length == 0: return 0
41
  g = seq.count("G")
42
  c = seq.count("C")
43
  cg = seq.count("CG")
44
- expected = (g * c) / length
45
  return cg / expected if expected > 0 else 0
46
 
47
  def deltaG_stem_loop(seq):
48
  seq = seq.upper()
49
- rna = seq.replace("T", "U")
50
- nn = {
51
- "AA": -0.9, "AU": -1.1, "UA": -1.3, "CA": -0.9,
52
- "CU": -2.1, "GA": -1.3, "GU": -1.1, "UU": -0.9,
53
- "AC": -0.9, "AG": -1.3, "UG": -1.5, "UC": -1.5,
54
- "CC": -1.7, "CG": -2.4, "GC": -3.4, "GG": -1.5
55
- }
56
  def rc(s):
57
- comp = str.maketrans("ATCG", "TAGC")
58
  return s.translate(comp)[::-1]
59
  deltaG = 0.0
60
  for i in range(len(seq)):
61
- for j in range(i + 4, len(seq)):
62
  left = rna[i:j]
63
  right = rna[j:]
64
- left_rc = rc(left).replace("T", "U")
65
  if left_rc in right:
66
  total = 0.0
67
  for k in range(len(left)-1):
68
  pair = left[k:k+2]
69
  if pair in nn: total += nn[pair]
70
- if total < deltaG or deltaG == 0.0: deltaG = total
71
  return deltaG
72
 
73
  def avg_bendability(seq):
@@ -80,83 +88,61 @@ def avg_bendability(seq):
80
 
81
  def nucleotide_frequencies(seq):
82
  seq = seq.upper()
83
- length = len(seq)
84
- if length == 0: return 0,0,0,0
85
- return seq.count("A")/length, seq.count("T")/length, seq.count("G")/length, seq.count("C")/length
86
 
87
  def purine_pyrimidine_ratio(seq):
88
  seq = seq.upper()
89
  pur = seq.count("A")+seq.count("G")
90
  pyr = seq.count("C")+seq.count("T")
91
- return pur/pyr if pyr > 0 else 0
92
 
93
- # --- Feature extraction ---
94
  def extract_features(seq):
95
  gc = gc_content(seq)
96
  cpg = cpg_ratio(seq)
97
  dg = deltaG_stem_loop(seq)
98
  bend = avg_bendability(seq)
99
- freq_a, freq_t, freq_g, freq_c = nucleotide_frequencies(seq)
100
  pur_pyr = purine_pyrimidine_ratio(seq)
101
- # Use SAME order as training
102
- return [gc, cpg, dg, bend, freq_a, freq_t, freq_g, freq_c, pur_pyr]
103
 
104
  # --- Prediction functions ---
105
  def predict_terminator(sequence: str) -> tuple[str, float]:
106
  clean_seq = "".join(sequence.split()).upper()
107
  X_new_df = pd.DataFrame([extract_features(clean_seq)], columns=[
108
- "gc_content",
109
- "cpg_ratio",
110
- "deltaG",
111
- "bendability",
112
- "freq_A",
113
- "freq_T",
114
- "freq_G",
115
- "freq_C",
116
- "purine_pyrimidine_ratio"
117
  ])
118
- y_pred_proba = ensemble.predict_proba(X_new_df)[0] # ✅ Single ensemble
119
- label = "Terminator" if y_pred_proba >= 0.5 else "Non-terminator"
120
- confidence = round(float(y_pred_proba), 4)
121
  return label, confidence
122
 
123
  def predict_terminator_table(sequence: str):
124
- label, confidence = predict_terminator(sequence)
125
- non_terminator_conf = round(1.0 - confidence, 4)
126
- return [
127
- ["Terminator", confidence],
128
- ["Non-terminator", non_terminator_conf]
129
- ]
130
 
131
  # --- Gradio UI ---
132
- custom_css = """
133
- footer, .footer {
134
- display: none !important;
135
- }
136
- """
137
-
138
  with gr.Blocks(css=custom_css, theme="default") as demo:
139
  gr.Markdown("## Intrinsic Terminator Prediction\nEnter a DNA sequence to predict terminator probability.")
140
-
141
  seq = gr.Textbox(label="Enter DNA sequence")
142
-
143
  with gr.Row():
144
  predict_btn = gr.Button("Predict", variant="primary", elem_id="predict-btn")
145
  clear_btn = gr.Button("Clear", elem_id="clear-btn")
146
-
147
  gr.HTML("""
148
  <style>
149
- #predict-btn { width: 48%; min-width: 120px; }
150
- #clear-btn { width: 48%; min-width: 100px; }
151
  </style>
152
- """)
153
-
154
  table = gr.Dataframe(headers=["Class","Confidence"], datatype=["str","number"], interactive=False)
155
-
156
  predict_btn.click(fn=predict_terminator_table, inputs=seq, outputs=table)
157
- clear_btn.click(fn=lambda: ("", []), outputs=[seq, table])
158
-
159
  gr.api(predict_terminator, api_name="predict_terminator")
160
 
161
- if __name__ == "__main__":
162
  demo.launch()
 
 
2
  import joblib
3
  from huggingface_hub import hf_hub_download
4
  import numpy as np
5
+ import pandas as pd # For DataFrame input to ensemble model
6
 
7
+ # --- Define EnsembleModel class (same as Colab) ---
8
+ class EnsembleModel:
9
+ def __init__(self, models):
10
+ self.models = models
11
+
12
+ def predict_proba(self, X):
13
+ # Average probabilities from all models
14
+ probs = [m.predict_proba(X)[:, 1] for m in self.models]
15
+ return np.mean(probs, axis=0)
16
+
17
+ # --- Download ensemble from HF repo ---
18
  repo_id = "Ym420/terminator-ensemble-classification"
19
  ensemble_path = hf_hub_download(repo_id=repo_id, filename="ensemble.pkl")
20
+ ensemble = joblib.load(ensemble_path) # Load Colab ensemble
21
 
22
  # --- Bendability dictionary ---
23
  bend_dict = {
24
+ "AAA": -0.274,"AAC": -0.205,"AAG": -0.081,"AAT": -0.280,
25
+ "ACA": -0.006,"ACC": -0.032,"ACG": -0.033,"ACT": -0.183,
26
+ "AGA": 0.027,"AGC": 0.017,"AGG": -0.057,"AGT": -0.183,
27
+ "ATA": 0.182,"ATC": -0.110,"ATG": 0.134,"ATT": -0.280,
28
+ "CAA": 0.015,"CAC": 0.040,"CAG": 0.175,"CAT": 0.134,
29
+ "CCA": -0.246,"CCC": -0.012,"CCG": -0.136,"CCT": -0.057,
30
+ "CGA": -0.003,"CGC": -0.077,"CGG": -0.136,"CGT": -0.033,
31
+ "CTA": 0.090,"CTC": 0.031,"CTG": 0.175,"CTT": -0.081,
32
+ "GAA": -0.037,"GAC": -0.013,"GAG": 0.031,"GAT": -0.110,
33
+ "GCA": 0.076,"GCC": 0.107,"GCG": -0.077,"GCT": 0.017,
34
+ "GGA": 0.013,"GGC": 0.107,"GGG": -0.012,"GGT": -0.032,
35
+ "GTA": 0.025,"GTC": -0.013,"GTG": 0.040,"GTT": -0.205,
36
+ "TAA": 0.068,"TAC": 0.025,"TAG": 0.090,"TAT": 0.182,
37
+ "TCA": 0.194,"TCC": 0.013,"TCG": -0.003,"TCT": 0.027,
38
+ "TGA": 0.194,"TGC": 0.076,"TGG": -0.246,"TGT": -0.006,
39
+ "TTA": 0.068,"TTC": -0.037,"TTG": 0.015,"TTT": -0.274
40
  }
41
 
42
+ # --- Feature functions (same as Colab) ---
43
  def gc_content(seq):
44
  seq = seq.upper()
45
  return (seq.count("G") + seq.count("C")) / len(seq) if len(seq) > 0 else 0
46
 
47
  def cpg_ratio(seq):
48
  seq = seq.upper()
49
+ l = len(seq)
50
+ if l == 0: return 0
51
  g = seq.count("G")
52
  c = seq.count("C")
53
  cg = seq.count("CG")
54
+ expected = (g * c) / l
55
  return cg / expected if expected > 0 else 0
56
 
57
  def deltaG_stem_loop(seq):
58
  seq = seq.upper()
59
+ rna = seq.replace("T","U")
60
+ nn = {"AA": -0.9,"AU": -1.1,"UA": -1.3,"CA": -0.9,
61
+ "CU": -2.1,"GA": -1.3,"GU": -1.1,"UU": -0.9,
62
+ "AC": -0.9,"AG": -1.3,"UG": -1.5,"UC": -1.5,
63
+ "CC": -1.7,"CG": -2.4,"GC": -3.4,"GG": -1.5}
 
 
64
  def rc(s):
65
+ comp = str.maketrans("ATCG","TAGC")
66
  return s.translate(comp)[::-1]
67
  deltaG = 0.0
68
  for i in range(len(seq)):
69
+ for j in range(i+4,len(seq)):
70
  left = rna[i:j]
71
  right = rna[j:]
72
+ left_rc = rc(left).replace("T","U")
73
  if left_rc in right:
74
  total = 0.0
75
  for k in range(len(left)-1):
76
  pair = left[k:k+2]
77
  if pair in nn: total += nn[pair]
78
+ if total < deltaG or deltaG==0.0: deltaG = total
79
  return deltaG
80
 
81
  def avg_bendability(seq):
 
88
 
89
  def nucleotide_frequencies(seq):
90
  seq = seq.upper()
91
+ l = len(seq)
92
+ if l == 0: return 0,0,0,0
93
+ return seq.count("A")/l, seq.count("T")/l, seq.count("G")/l, seq.count("C")/l
94
 
95
  def purine_pyrimidine_ratio(seq):
96
  seq = seq.upper()
97
  pur = seq.count("A")+seq.count("G")
98
  pyr = seq.count("C")+seq.count("T")
99
+ return pur/pyr if pyr>0 else 0
100
 
101
+ # --- Extract features ---
102
  def extract_features(seq):
103
  gc = gc_content(seq)
104
  cpg = cpg_ratio(seq)
105
  dg = deltaG_stem_loop(seq)
106
  bend = avg_bendability(seq)
107
+ freq_a,freq_t,freq_g,freq_c = nucleotide_frequencies(seq)
108
  pur_pyr = purine_pyrimidine_ratio(seq)
109
+ return [gc, cpg, dg, bend, freq_a,freq_t,freq_g,freq_c, pur_pyr]
 
110
 
111
  # --- Prediction functions ---
112
  def predict_terminator(sequence: str) -> tuple[str, float]:
113
  clean_seq = "".join(sequence.split()).upper()
114
  X_new_df = pd.DataFrame([extract_features(clean_seq)], columns=[
115
+ "gc_content", "cpg_ratio", "deltaG", "bendability",
116
+ "freq_A","freq_T","freq_G","freq_C","purine_pyrimidine_ratio"
 
 
 
 
 
 
 
117
  ])
118
+ y_pred_proba = ensemble.predict_proba(X_new_df)[0]
119
+ label = "Terminator" if y_pred_proba>=0.5 else "Non-terminator"
120
+ confidence = round(float(y_pred_proba),4)
121
  return label, confidence
122
 
123
  def predict_terminator_table(sequence: str):
124
+ label, conf = predict_terminator(sequence)
125
+ return [["Terminator", conf], ["Non-terminator", round(1-conf,4)]]
 
 
 
 
126
 
127
  # --- Gradio UI ---
128
+ custom_css = "footer, .footer {display:none !important;}"
 
 
 
 
 
129
  with gr.Blocks(css=custom_css, theme="default") as demo:
130
  gr.Markdown("## Intrinsic Terminator Prediction\nEnter a DNA sequence to predict terminator probability.")
 
131
  seq = gr.Textbox(label="Enter DNA sequence")
 
132
  with gr.Row():
133
  predict_btn = gr.Button("Predict", variant="primary", elem_id="predict-btn")
134
  clear_btn = gr.Button("Clear", elem_id="clear-btn")
 
135
  gr.HTML("""
136
  <style>
137
+ #predict-btn { width:48%; min-width:120px; }
138
+ #clear-btn { width:48%; min-width:100px; }
139
  </style>
140
+ """)
 
141
  table = gr.Dataframe(headers=["Class","Confidence"], datatype=["str","number"], interactive=False)
 
142
  predict_btn.click(fn=predict_terminator_table, inputs=seq, outputs=table)
143
+ clear_btn.click(fn=lambda: ("",[]), outputs=[seq, table])
 
144
  gr.api(predict_terminator, api_name="predict_terminator")
145
 
146
+ if __name__=="__main__":
147
  demo.launch()
148
+