Ym420 commited on
Commit
face98f
·
verified ·
1 Parent(s): 9c22afe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -49
app.py CHANGED
@@ -2,11 +2,9 @@ import gradio as gr
2
  import joblib
3
  from huggingface_hub import hf_hub_download
4
  import numpy as np
5
- import xgboost
6
- import pandas as pd
7
 
8
- # --- Download model and scaler from HF Hub model repo ---
9
- repo_id = "Ym420/terminator-classification" # public HF model repo
10
 
11
  best_model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pkl")
12
  scaler_path = hf_hub_download(repo_id=repo_id, filename="scaler.pkl")
@@ -34,21 +32,20 @@ bend_dict = {
34
  "TTA": 0.068, "TTC": -0.037, "TTG": 0.015, "TTT": -0.274
35
  }
36
 
37
- # --- Feature extraction functions ---
38
  def gc_content(seq):
39
  seq = seq.upper()
40
- if len(seq) == 0:
41
- return 0
42
- return (seq.count("G") + seq.count("C")) / len(seq)
43
 
44
  def cpg_ratio(seq):
45
  seq = seq.upper()
 
 
 
46
  g = seq.count("G")
47
  c = seq.count("C")
48
  cg = seq.count("CG")
49
- if len(seq) == 0:
50
- return 0
51
- expected = (g * c) / len(seq)
52
  return cg / expected if expected > 0 else 0
53
 
54
  def tata_box_presence(seq):
@@ -61,7 +58,7 @@ def avg_bendability(seq):
61
  tri = seq[i:i+3]
62
  if tri in bend_dict:
63
  scores.append(bend_dict[tri])
64
- return np.mean(scores) if scores else 0
65
 
66
  def nucleotide_frequencies(seq):
67
  seq = seq.upper()
@@ -77,10 +74,11 @@ def nucleotide_frequencies(seq):
77
 
78
  def purine_pyrimidine_ratio(seq):
79
  seq = seq.upper()
80
- purines = seq.count("A") + seq.count("G")
81
- pyrimidines = seq.count("C") + seq.count("T")
82
- return purines / pyrimidines if pyrimidines > 0 else 0
83
 
 
84
  def extract_features(seq):
85
  seq = seq.upper()
86
  gc = gc_content(seq)
@@ -89,35 +87,43 @@ def extract_features(seq):
89
  bend = avg_bendability(seq)
90
  freq_a, freq_t, freq_g, freq_c = nucleotide_frequencies(seq)
91
  pur_pyr = purine_pyrimidine_ratio(seq)
 
 
92
  return [gc, cpg, tata, bend, freq_a, freq_t, freq_g, freq_c, pur_pyr]
93
 
94
- # --- Prediction function ---
95
- def predict_terminator(sequence: str) -> tuple[str, float]:
96
- X_new = [extract_features(sequence)]
97
- X_scaled = scaler.transform(X_new)
 
 
 
 
 
 
 
98
  y_pred = best_model.predict(X_scaled)[0]
99
- y_pred_proba = best_model.predict_proba(X_scaled)[0, 1] if hasattr(best_model, "predict_proba") else 0.0
 
 
 
 
100
 
101
  label = "Terminator" if y_pred == 1 else "Non-terminator"
102
- confidence = round(float(y_pred_proba), 4)
103
- return label, confidence
104
-
105
- def predict_terminator_table(sequence: str):
106
- clean_seq = "".join(sequence.split()).upper()
107
- label, confidence = predict_terminator(clean_seq)
108
- non_terminator_conf = round(1.0 - confidence, 4)
109
 
 
 
 
 
110
  return [
111
  ["Terminator", confidence],
112
- ["Non-terminator", non_terminator_conf]
113
  ]
114
 
115
- # --- Gradio Interface ---
116
  custom_css = """
117
- /* Hide Gradio footer */
118
- footer, .footer {
119
- display: none !important;
120
- }
121
  """
122
 
123
  with gr.Blocks(css=custom_css) as demo:
@@ -129,22 +135,7 @@ with gr.Blocks(css=custom_css) as demo:
129
  predict_btn = gr.Button("Predict", variant="primary", elem_id="predict-btn")
130
  clear_btn = gr.Button("Clear", elem_id="clear-btn")
131
 
132
- gr.HTML(
133
- """
134
- <style>
135
- #predict-btn {
136
- width: 48%;
137
- min-width: 120px;
138
- }
139
- #clear-btn {
140
- width: 48%;
141
- min-width: 100px;
142
- }
143
- </style>
144
- """
145
- )
146
-
147
- table = gr.Dataframe(headers=["Class", "Confidence"], datatype=["str","number"], interactive=False)
148
 
149
  predict_btn.click(fn=predict_terminator_table, inputs=seq, outputs=table)
150
  clear_btn.click(fn=lambda: ("", []), outputs=[seq, table])
 
2
  import joblib
3
  from huggingface_hub import hf_hub_download
4
  import numpy as np
 
 
5
 
6
+ # --- Download model and scaler from your HF repo ---
7
+ repo_id = "Ym420/terminator-classification"
8
 
9
  best_model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pkl")
10
  scaler_path = hf_hub_download(repo_id=repo_id, filename="scaler.pkl")
 
32
  "TTA": 0.068, "TTC": -0.037, "TTG": 0.015, "TTT": -0.274
33
  }
34
 
35
+ # --- Feature functions (match training exactly) ---
36
  def gc_content(seq):
37
  seq = seq.upper()
38
+ return (seq.count("G") + seq.count("C")) / len(seq) if len(seq) > 0 else 0
 
 
39
 
40
  def cpg_ratio(seq):
41
  seq = seq.upper()
42
+ length = len(seq)
43
+ if length == 0:
44
+ return 0
45
  g = seq.count("G")
46
  c = seq.count("C")
47
  cg = seq.count("CG")
48
+ expected = (g * c) / length
 
 
49
  return cg / expected if expected > 0 else 0
50
 
51
  def tata_box_presence(seq):
 
58
  tri = seq[i:i+3]
59
  if tri in bend_dict:
60
  scores.append(bend_dict[tri])
61
+ return float(np.mean(scores)) if scores else 0.0
62
 
63
  def nucleotide_frequencies(seq):
64
  seq = seq.upper()
 
74
 
75
  def purine_pyrimidine_ratio(seq):
76
  seq = seq.upper()
77
+ pur = seq.count("A") + seq.count("G")
78
+ pyr = seq.count("C") + seq.count("T")
79
+ return pur / pyr if pyr > 0 else 0
80
 
81
+ # ✅ Critical — must match training order exactly
82
  def extract_features(seq):
83
  seq = seq.upper()
84
  gc = gc_content(seq)
 
87
  bend = avg_bendability(seq)
88
  freq_a, freq_t, freq_g, freq_c = nucleotide_frequencies(seq)
89
  pur_pyr = purine_pyrimidine_ratio(seq)
90
+
91
+ # SAME order as X_train
92
  return [gc, cpg, tata, bend, freq_a, freq_t, freq_g, freq_c, pur_pyr]
93
 
94
+ # --- Prediction functions ---
95
+ def predict_terminator(sequence: str):
96
+ # clean input
97
+ clean = "".join(sequence.split()).upper()
98
+ clean = "".join([b for b in clean if b in {"A","C","G","T"}])
99
+
100
+ if len(clean) < 10:
101
+ return "Sequence too short", 0.0
102
+
103
+ X_new = np.array([extract_features(clean)]) # shape (1,9)
104
+ X_scaled = scaler.transform(X_new) # apply exact training scaler
105
  y_pred = best_model.predict(X_scaled)[0]
106
+
107
+ if hasattr(best_model, "predict_proba"):
108
+ proba = float(best_model.predict_proba(X_scaled)[0][1])
109
+ else:
110
+ proba = float(y_pred)
111
 
112
  label = "Terminator" if y_pred == 1 else "Non-terminator"
113
+ return label, round(proba, 4)
 
 
 
 
 
 
114
 
115
+ def predict_terminator_table(sequence: str):
116
+ label, confidence = predict_terminator(sequence)
117
+ if label == "Sequence too short":
118
+ return [["Error", 0.0]]
119
  return [
120
  ["Terminator", confidence],
121
+ ["Non-terminator", round(1-confidence, 4)]
122
  ]
123
 
124
+ # --- Gradio UI (no changes needed) ---
125
  custom_css = """
126
+ footer, .footer { display: none !important; }
 
 
 
127
  """
128
 
129
  with gr.Blocks(css=custom_css) as demo:
 
135
  predict_btn = gr.Button("Predict", variant="primary", elem_id="predict-btn")
136
  clear_btn = gr.Button("Clear", elem_id="clear-btn")
137
 
138
+ table = gr.Dataframe(headers=["Class","Confidence"], datatype=["str","number"], interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  predict_btn.click(fn=predict_terminator_table, inputs=seq, outputs=table)
141
  clear_btn.click(fn=lambda: ("", []), outputs=[seq, table])