Ym420 commited on
Commit
9c22afe
·
verified ·
1 Parent(s): 8a6500f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import joblib
3
+ from huggingface_hub import hf_hub_download
4
+ import numpy as np
5
+ import xgboost
6
+ import pandas as pd
7
+
8
+ # --- Download model and scaler from HF Hub model repo ---
9
+ repo_id = "Ym420/terminator-classification" # public HF model repo
10
+
11
+ best_model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pkl")
12
+ scaler_path = hf_hub_download(repo_id=repo_id, filename="scaler.pkl")
13
+
14
+ best_model = joblib.load(best_model_path)
15
+ scaler = joblib.load(scaler_path)
16
+
17
+ # --- Bendability dictionary ---
18
+ bend_dict = {
19
+ "AAA": -0.274, "AAC": -0.205, "AAG": -0.081, "AAT": -0.280,
20
+ "ACA": -0.006, "ACC": -0.032, "ACG": -0.033, "ACT": -0.183,
21
+ "AGA": 0.027, "AGC": 0.017, "AGG": -0.057, "AGT": -0.183,
22
+ "ATA": 0.182, "ATC": -0.110, "ATG": 0.134, "ATT": -0.280,
23
+ "CAA": 0.015, "CAC": 0.040, "CAG": 0.175, "CAT": 0.134,
24
+ "CCA": -0.246, "CCC": -0.012, "CCG": -0.136, "CCT": -0.057,
25
+ "CGA": -0.003, "CGC": -0.077, "CGG": -0.136, "CGT": -0.033,
26
+ "CTA": 0.090, "CTC": 0.031, "CTG": 0.175, "CTT": -0.081,
27
+ "GAA": -0.037, "GAC": -0.013, "GAG": 0.031, "GAT": -0.110,
28
+ "GCA": 0.076, "GCC": 0.107, "GCG": -0.077, "GCT": 0.017,
29
+ "GGA": 0.013, "GGC": 0.107, "GGG": -0.012, "GGT": -0.032,
30
+ "GTA": 0.025, "GTC": -0.013, "GTG": 0.040, "GTT": -0.205,
31
+ "TAA": 0.068, "TAC": 0.025, "TAG": 0.090, "TAT": 0.182,
32
+ "TCA": 0.194, "TCC": 0.013, "TCG": -0.003, "TCT": 0.027,
33
+ "TGA": 0.194, "TGC": 0.076, "TGG": -0.246, "TGT": -0.006,
34
+ "TTA": 0.068, "TTC": -0.037, "TTG": 0.015, "TTT": -0.274
35
+ }
36
+
37
+ # --- Feature extraction functions ---
38
+ def gc_content(seq):
39
+ seq = seq.upper()
40
+ if len(seq) == 0:
41
+ return 0
42
+ return (seq.count("G") + seq.count("C")) / len(seq)
43
+
44
+ def cpg_ratio(seq):
45
+ seq = seq.upper()
46
+ g = seq.count("G")
47
+ c = seq.count("C")
48
+ cg = seq.count("CG")
49
+ if len(seq) == 0:
50
+ return 0
51
+ expected = (g * c) / len(seq)
52
+ return cg / expected if expected > 0 else 0
53
+
54
+ def tata_box_presence(seq):
55
+ return int("TATA" in seq.upper())
56
+
57
+ def avg_bendability(seq):
58
+ seq = seq.upper()
59
+ scores = []
60
+ for i in range(len(seq) - 2):
61
+ tri = seq[i:i+3]
62
+ if tri in bend_dict:
63
+ scores.append(bend_dict[tri])
64
+ return np.mean(scores) if scores else 0
65
+
66
+ def nucleotide_frequencies(seq):
67
+ seq = seq.upper()
68
+ length = len(seq)
69
+ if length == 0:
70
+ return 0, 0, 0, 0
71
+ return (
72
+ seq.count("A") / length,
73
+ seq.count("T") / length,
74
+ seq.count("G") / length,
75
+ seq.count("C") / length,
76
+ )
77
+
78
+ def purine_pyrimidine_ratio(seq):
79
+ seq = seq.upper()
80
+ purines = seq.count("A") + seq.count("G")
81
+ pyrimidines = seq.count("C") + seq.count("T")
82
+ return purines / pyrimidines if pyrimidines > 0 else 0
83
+
84
+ def extract_features(seq):
85
+ seq = seq.upper()
86
+ gc = gc_content(seq)
87
+ cpg = cpg_ratio(seq)
88
+ tata = tata_box_presence(seq)
89
+ bend = avg_bendability(seq)
90
+ freq_a, freq_t, freq_g, freq_c = nucleotide_frequencies(seq)
91
+ pur_pyr = purine_pyrimidine_ratio(seq)
92
+ return [gc, cpg, tata, bend, freq_a, freq_t, freq_g, freq_c, pur_pyr]
93
+
94
+ # --- Prediction function ---
95
+ def predict_terminator(sequence: str) -> tuple[str, float]:
96
+ X_new = [extract_features(sequence)]
97
+ X_scaled = scaler.transform(X_new)
98
+ y_pred = best_model.predict(X_scaled)[0]
99
+ y_pred_proba = best_model.predict_proba(X_scaled)[0, 1] if hasattr(best_model, "predict_proba") else 0.0
100
+
101
+ label = "Terminator" if y_pred == 1 else "Non-terminator"
102
+ confidence = round(float(y_pred_proba), 4)
103
+ return label, confidence
104
+
105
+ def predict_terminator_table(sequence: str):
106
+ clean_seq = "".join(sequence.split()).upper()
107
+ label, confidence = predict_terminator(clean_seq)
108
+ non_terminator_conf = round(1.0 - confidence, 4)
109
+
110
+ return [
111
+ ["Terminator", confidence],
112
+ ["Non-terminator", non_terminator_conf]
113
+ ]
114
+
115
+ # --- Gradio Interface ---
116
+ custom_css = """
117
+ /* Hide Gradio footer */
118
+ footer, .footer {
119
+ display: none !important;
120
+ }
121
+ """
122
+
123
+ with gr.Blocks(css=custom_css) as demo:
124
+ gr.Markdown("## Intrinsic Terminator Prediction\nEnter a DNA sequence to predict terminator probability.")
125
+
126
+ seq = gr.Textbox(label="Enter DNA sequence")
127
+
128
+ with gr.Row():
129
+ predict_btn = gr.Button("Predict", variant="primary", elem_id="predict-btn")
130
+ clear_btn = gr.Button("Clear", elem_id="clear-btn")
131
+
132
+ gr.HTML(
133
+ """
134
+ <style>
135
+ #predict-btn {
136
+ width: 48%;
137
+ min-width: 120px;
138
+ }
139
+ #clear-btn {
140
+ width: 48%;
141
+ min-width: 100px;
142
+ }
143
+ </style>
144
+ """
145
+ )
146
+
147
+ table = gr.Dataframe(headers=["Class", "Confidence"], datatype=["str","number"], interactive=False)
148
+
149
+ predict_btn.click(fn=predict_terminator_table, inputs=seq, outputs=table)
150
+ clear_btn.click(fn=lambda: ("", []), outputs=[seq, table])
151
+
152
+ gr.api(predict_terminator, api_name="predict_terminator")
153
+
154
+ if __name__ == "__main__":
155
+ demo.launch()