re-type commited on
Commit
a20d1ea
·
verified ·
1 Parent(s): c59941c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -51
app.py CHANGED
@@ -1,84 +1,175 @@
 
1
  import gradio as gr
2
  import torch
3
  import pickle
4
  import subprocess
5
  import pandas as pd
6
- from predictor import Predictor
7
  from tensorflow.keras.models import load_model
8
  from ml_simplified_tree import maximum_likelihood
 
 
 
 
 
9
 
10
  # --------- Load Models ---------
11
- boundary_model = Predictor("best_boundary_aware_model.pth")
12
- keras_model = load_model("best_model.keras")
 
 
 
 
 
 
 
 
13
 
14
- with open("kmer_to_index.pkl", "rb") as f:
15
- kmer_to_index = pickle.load(f)
 
 
 
 
 
 
 
 
16
 
17
  # --------- Utilities ---------
18
  def predict_with_keras(sequence):
19
- kmers = [sequence[i:i+6] for i in range(len(sequence)-5)]
20
- indices = [kmer_to_index.get(kmer, 0) for kmer in kmers]
21
- input_arr = torch.tensor([indices])
22
- prediction = keras_model.predict(input_arr)[0]
23
- return "".join(str(round(p, 3)) for p in prediction)
 
 
 
 
24
 
25
  def save_to_fasta(name, sequence, path):
26
- with open(path, "w") as f:
27
- f.write(f">{name}\n{sequence}\n")
 
 
 
 
 
 
28
 
29
  def save_to_csv(sequence, path):
30
- df = pd.DataFrame({"Sequence": [sequence]})
31
- df.to_csv(path, index=False)
 
 
 
 
 
 
32
 
33
  def run_mafft_and_iqtree(fasta_file="f_gene_sequences_aligned.fasta"):
34
  try:
35
  subprocess.run(["mafft", "--auto", fasta_file], check=True)
36
- subprocess.run(["iqtree", "-s", "f_gene_sequences.phy.treefile", "-m", "GTR"], check=True)
37
- return "MAFFT and IQTree executed successfully."
38
- except Exception as e:
39
- return f"Error running alignment/tree: {e}"
 
 
 
 
 
40
 
41
  def run_full_pipeline(dna_input):
42
- # 1. Boundary-Aware Prediction
43
- step1_out = boundary_model.predict(dna_input)
 
 
 
 
44
 
45
- # 2. Keras Prediction
46
- step2_out = predict_with_keras(step1_out)
 
47
 
48
- # 3. Save intermediate files
49
- save_to_fasta("Predicted_Seq", step2_out, "f_gene_sequences_aligned.fasta")
50
- save_to_csv(step2_out, "f gene clean dataset.csv")
51
 
52
- # 4. Run MAFFT + IQTree
53
- mafft_status = run_mafft_and_iqtree()
54
 
55
- # 5. Run ML tree
56
- try:
57
- ml_output = maximum_likelihood("f gene clean dataset.csv")
58
- except Exception as e:
59
- ml_output = f"ML Tree Error: {e}"
 
 
 
 
 
 
 
 
60
 
61
- return {
62
- "Boundary Model Output": step1_out,
63
- "Keras Model Output": step2_out,
64
- "MAFFT + IQTree Status": mafft_status,
65
- "Maximum Likelihood Tree Output": ml_output
66
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # --------- Gradio Interface ---------
69
- gr_interface = gr.Interface(
70
- fn=run_full_pipeline,
71
- inputs=gr.Textbox(label="Input DNA Sequence"),
72
- outputs=[
73
- gr.Textbox(label="Boundary Model Output"),
74
- gr.Textbox(label="Keras Model Output"),
75
- gr.Textbox(label="MAFFT + IQTree Status"),
76
- gr.Textbox(label="ML Tree Output")
77
- ],
78
- title="Sequential Phylogenetic Inference Pipeline",
79
- description="This pipeline runs sequentially: Boundary-Aware Model → Keras Model → Tree Building"
80
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # --------- Launch ---------
83
  if __name__ == "__main__":
84
- gr_interface.launch()
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
  import gradio as gr
3
  import torch
4
  import pickle
5
  import subprocess
6
  import pandas as pd
7
+ from predictor import GenePredictor
8
  from tensorflow.keras.models import load_model
9
  from ml_simplified_tree import maximum_likelihood
10
+ import logging
11
+ import os
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
  # --------- Load Models ---------
17
+ try:
18
+ boundary_model = GenePredictor("best_boundary_aware_model.pth")
19
+ keras_model = load_model("best_model.keras")
20
+ logging.info("Models loaded successfully")
21
+ except FileNotFoundError as e:
22
+ logging.error(f"Model file not found: {e}")
23
+ raise
24
+ except Exception as e:
25
+ logging.error(f"Error loading models: {e}")
26
+ raise
27
 
28
+ try:
29
+ with open("kmer_to_index.pkl", "rb") as f:
30
+ kmer_to_index = pickle.load(f)
31
+ logging.info("kmer_to_index.pkl loaded successfully")
32
+ except FileNotFoundError:
33
+ logging.error("kmer_to_index.pkl not found")
34
+ raise
35
+ except Exception as e:
36
+ logging.error(f"Error loading kmer_to_index.pkl: {e}")
37
+ raise
38
 
39
  # --------- Utilities ---------
40
  def predict_with_keras(sequence):
41
+ try:
42
+ kmers = [sequence[i:i+6] for i in range(len(sequence)-5)]
43
+ indices = [kmer_to_index.get(kmer, 0) for kmer in kmers]
44
+ input_arr = torch.tensor([indices])
45
+ prediction = keras_model.predict(input_arr)[0]
46
+ return "".join(str(round(p, 3)) for p in prediction)
47
+ except Exception as e:
48
+ logging.error(f"Error in Keras prediction: {e}")
49
+ return f"Error in Keras prediction: {e}"
50
 
51
  def save_to_fasta(name, sequence, path):
52
+ try:
53
+ with open(path, "w") as f:
54
+ f.write(f">{name}\n{sequence}\n")
55
+ logging.info(f"FASTA file saved to {path}")
56
+ return "FASTA file saved successfully"
57
+ except Exception as e:
58
+ logging.error(f"Error saving FASTA: {e}")
59
+ return f"Error saving FASTA: {e}"
60
 
61
  def save_to_csv(sequence, path):
62
+ try:
63
+ df = pd.DataFrame({"Sequence": [sequence]})
64
+ df.to_csv(path, index=False)
65
+ logging.info(f"CSV file saved to {path}")
66
+ return "CSV file saved successfully"
67
+ except Exception as e:
68
+ logging.error(f"Error saving CSV: {e}")
69
+ return f"Error saving CSV: {e}"
70
 
71
  def run_mafft_and_iqtree(fasta_file="f_gene_sequences_aligned.fasta"):
72
  try:
73
  subprocess.run(["mafft", "--auto", fasta_file], check=True)
74
+ subprocess.run(["iqtree", "-s", fasta_file, "-m", "GTR"], check=True)
75
+ logging.info("MAFFT and IQTree executed successfully")
76
+ return "MAFFT and IQTree executed successfully"
77
+ except subprocess.CalledProcessError as e:
78
+ logging.error(f"Error running MAFFT/IQTree: {e}")
79
+ return f"Error running MAFFT/IQTree: {e}"
80
+ except FileNotFoundError:
81
+ logging.error("MAFFT or IQTree not found. Ensure they are installed.")
82
+ return "MAFFT or IQTree not found. Ensure they are installed and in PATH."
83
 
84
  def run_full_pipeline(dna_input):
85
+ try:
86
+ # 1. Boundary-Aware Prediction
87
+ predictions, probs, confidence = boundary_model.predict(dna_input)
88
+ gene_regions = boundary_model.extract_gene_regions(predictions, dna_input)
89
+ step1_out = gene_regions[0]["sequence"] if gene_regions else dna_input
90
+ logging.info(f"Boundary model output: {step1_out}")
91
 
92
+ # 2. Keras Prediction
93
+ step2_out = predict_with_keras(step1_out)
94
+ logging.info(f"Keras model output: {step2_out}")
95
 
96
+ # 3. Save intermediate files
97
+ fasta_status = save_to_fasta("Predicted_Seq", step2_out, "f_gene_sequences_aligned.fasta")
98
+ csv_status = save_to_csv(step2_out, "f gene clean dataset.csv")
99
 
100
+ # 4. Run MAFFT + IQTree
101
+ mafft_status = run_mafft_and_iqtree()
102
 
103
+ # 5. Run ML tree and ensure HTML output
104
+ html_file = "tree.html" # Expected output file from maximum_likelihood
105
+ try:
106
+ ml_output = maximum_likelihood("f gene clean dataset.csv")
107
+ if os.path.exists(html_file):
108
+ logging.info(f"HTML tree file generated: {html_file}")
109
+ else:
110
+ logging.warning(f"HTML tree file {html_file} not found")
111
+ html_file = None # Set to None if file doesn't exist
112
+ except Exception as e:
113
+ logging.error(f"ML Tree Error: {e}")
114
+ ml_output = f"ML Tree Error: {e}"
115
+ html_file = None
116
 
117
+ return {
118
+ "Boundary Model Output": step1_out,
119
+ "Keras Model Output": step2_out,
120
+ "FASTA Save Status": fasta_status,
121
+ "CSV Save Status": csv_status,
122
+ "MAFFT + IQTree Status": mafft_status,
123
+ "Maximum Likelihood Tree Output": ml_output,
124
+ "Tree HTML File": html_file # Return file path for download
125
+ }
126
+ except Exception as e:
127
+ logging.error(f"Pipeline failed: {e}")
128
+ return {
129
+ "Boundary Model Output": f"Error: {e}",
130
+ "Keras Model Output": "N/A",
131
+ "FASTA Save Status": "N/A",
132
+ "CSV Save Status": "N/A",
133
+ "MAFFT + IQTree Status": "N/A",
134
+ "Maximum Likelihood Tree Output": "N/A",
135
+ "Tree HTML File": None
136
+ }
137
 
138
  # --------- Gradio Interface ---------
139
+ with gr.Blocks() as gr_interface:
140
+ gr.Markdown("# Sequential Phylogenetic Inference Pipeline")
141
+ gr.Markdown("This pipeline runs sequentially: Boundary-Aware Model → Keras Model → Tree Building")
142
+
143
+ dna_input = gr.Textbox(label="Input DNA Sequence")
144
+ submit_button = gr.Button("Run Pipeline")
145
+
146
+ boundary_output = gr.Textbox(label="Boundary Model Output")
147
+ keras_output = gr.Textbox(label="Keras Model Output")
148
+ fasta_status = gr.Textbox(label="FASTA Save Status")
149
+ csv_status = gr.Textbox(label="CSV Save Status")
150
+ mafft_status = gr.Textbox(label="MAFFT + IQTree Status")
151
+ ml_output = gr.Textbox(label="Maximum Likelihood Tree Output")
152
+ tree_download = gr.File(label="Download Tree (HTML)")
153
+
154
+ submit_button.click(
155
+ fn=run_full_pipeline,
156
+ inputs=dna_input,
157
+ outputs=[
158
+ boundary_output,
159
+ keras_output,
160
+ fasta_status,
161
+ csv_status,
162
+ mafft_status,
163
+ ml_output,
164
+ tree_download
165
+ ]
166
+ )
167
 
168
  # --------- Launch ---------
169
  if __name__ == "__main__":
170
+ try:
171
+ gr_interface.launch(server_name="0.0.0.0", server_port=7860)
172
+ logging.info("Gradio interface launched successfully")
173
+ except Exception as e:
174
+ logging.error(f"Error launching Gradio interface: {e}")
175
+ raise