re-type commited on
Commit
51f81df
·
verified ·
1 Parent(s): ac86461

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -81
app.py CHANGED
@@ -1,134 +1,213 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- import json
5
- from typing import Optional, List, Dict, Tuple
6
- import logging
7
  import os
8
  import traceback
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
- # Debug info
15
- print("=== Debug Info ===")
16
  print(f"Working directory: {os.getcwd()}")
17
- print(f"Files: {os.listdir('.')}")
18
  print(f"PyTorch version: {torch.__version__}")
19
- print("==================")
20
 
21
- # Initialize predictor
22
  predictor = None
23
- initialization_error = None
 
24
 
25
- try:
26
- from predictor import GenePredictor
27
- model_path = 'best_boundary_aware_model.pth'
28
 
29
- if os.path.exists(model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  predictor = GenePredictor(model_path=model_path)
31
- print("✅ Model loaded successfully")
32
- else:
33
- initialization_error = f"Model file {model_path} not found"
34
- print(f"❌ {initialization_error}")
35
-
36
- except Exception as e:
37
- initialization_error = str(e)
38
- print(f"❌ Error: {e}")
 
 
 
 
 
 
39
 
40
  def predict_genes(sequence):
41
- """Simple prediction function"""
42
  try:
43
- if not predictor:
44
- return f"❌ Model not loaded: {initialization_error}"
 
45
 
 
46
  if not sequence or not sequence.strip():
47
- return " Please enter a DNA sequence"
48
 
49
- sequence = sequence.strip().upper()
50
 
51
- # Validate sequence
52
- valid_chars = set('ACTGN')
53
- if not set(sequence).issubset(valid_chars):
54
- return "❌ Invalid characters. Use only A, C, T, G, N"
 
55
 
 
56
  if len(sequence) < 3:
57
- return " Sequence too short (minimum 3 nucleotides)"
58
 
59
  if len(sequence) > 10000:
60
- return " Sequence too long (maximum 10,000 nucleotides)"
 
 
61
 
62
  # Make prediction
63
  predictions, probs_dict, confidence = predictor.predict(sequence)
64
  regions = predictor.extract_gene_regions(predictions, sequence)
65
 
66
- # Format output
67
  if not regions:
68
- return f"🔍 No gene regions detected (Confidence: {confidence:.3f})"
 
 
 
 
 
 
69
 
70
- result = f"🧬 Found {len(regions)} gene region(s) (Confidence: {confidence:.3f})\n\n"
71
 
72
  for i, region in enumerate(regions, 1):
73
- result += f"**Gene {i}:**\n"
74
- result += f" • Position: {region['start']} - {region['end']}\n"
75
- result += f" • Length: {region['length']} bp\n"
76
 
77
- # Safe sequence preview
78
  seq = region.get('sequence', '')
79
  if seq:
80
- preview = seq[:60] + ('...' if len(seq) > 60 else '')
81
- result += f" Preview: {preview}\n\n"
82
- else:
83
- result += f" • Preview: Not available\n\n"
 
 
 
84
 
85
  return result
86
 
87
  except Exception as e:
 
88
  print(f"Prediction error: {e}")
89
- return f"❌ Prediction failed: {str(e)}"
 
 
 
 
 
90
 
91
- def load_example():
92
- """Load example DNA sequence"""
93
- return "ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGACGCGTACAGGAAACACAGAAAAAAGCCCGCACCTGACAGTGCGGGCTTTTTTTTTCGACCAAAGGTAACGAGGTAACAACCATGCGAGTGTTGAAGTTCGGCGGTACATCAGTGGCAAATGCAGAACGTTTTCTGCGGGTTGCCGATATTCTGGAAAGCAATGCCAGGCAGGGGCAGGTGGCCACCGTCCTCTCTGCCCCCGCCAAAATCACCAACCACCTGGTGGCGATGATTGAAAAAACCATTAGCGGCCAGGATGCTTTACCCAATATCAGCGATGCCGAACGTATTTTTGCCGAACTTTTGACGGGACTCGCCGCCGCCCAGCCGGGGTTCCCGCTGGCGCAATTGAAAACTTTCGTCGATCAGGAATTTGCCCAAATAAAACATGTCCTGCATGGCATTAGTTTGTTGGGGCAGTGCCCGGATAGCATCAACGCTGCGCTGATTTGCCGTGGCGAGAAAATGTCGATCGCCATTATGGCCGGCGTATTAGAAGCGCGCGGTCACAACGTTACTGTTATCGATCCGGTCGAAAAACTGCTGGCAGTGGGGCATTACCTCGAATCTACCGTCGATATTGCTGAGTCCACCCGCCGTATTGCGGCAAGCCGCATTCCGGCTGATCACATGGTGCTGATGGCAGGTTTCACCGCCGGTAATGAAAAAGGCGAACTGGTGGTGCTTGGACGCAACGGTTCCGACTACTCTGCTGCGGTGCTGGCTGCCTGTTTACGCGCCGATTGTTGCGAGATTTGGACGGACGTTGACGGGGTCTATACCTGCGACCCGCGTCAGGTGCCCGATGCGAGGTTGTTGAAGTCGATGTCCTACCAGGAAGCGATGGAGCTTTCCTACTTCGGCGCTAAAGTTCTTCACCCCCGCACCATTACCCCCATCGCCCAGTTCCAGATCCCTTGCCTGATTAAAAATACCGGAAATCCTCAAGCACCAGGTACGCTCATTGGTGCCAGCCGTGATGAAGACGAATTACCGGTCAAGGGCATTTCCAATCTGAATAACATGGCAATGTTCAGCGTTTCCGGCCCGGGGATGAAAGGGATGGTCGGCATGGCGGCGCGCGTCTTTGCAGCGATGTCACGCGCCCGTATTTCCGTGGTGCTGATTACGCAATCATCTTCCGAATACAGCATCAGTTTCTGCGTTCCACAAAGCGACTGTGTGCGAGCTGAACGGGCAATGCAGGAAGAGTTCTACCTGGAACTGAAAGAAGGCTTACTGGAGCCGCTGGCAGTGACGGAACGGCTGGCCATTATCTCGGTGGTAGG"
94
 
95
- # Status message
96
- if predictor:
97
- status_msg = "✅ Model loaded successfully!"
98
- status_color = "green"
99
  else:
100
- status_msg = f"❌ Model loading failed: {initialization_error}"
101
- status_color = "red"
 
 
102
 
103
- # Try gr.Interface first (simpler approach)
104
- demo = gr.Interface(
105
- fn=predict_genes,
106
- inputs=gr.Textbox(
107
- label="DNA Sequence",
108
- placeholder="Enter DNA sequence (A, C, T, G, N)...",
109
- lines=6
110
- ),
111
- outputs=gr.Textbox(
112
- label="Prediction Results",
113
- lines=15
114
- ),
115
- title="🧬 Gene Prediction Tool",
116
- description=f"""
117
- Predict gene regions in DNA sequences using deep learning.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- **Status:** {status_msg}
 
 
 
120
 
121
- **Instructions:**
122
- 1. Enter a DNA sequence using only A, C, T, G, N characters
123
- 2. Click Submit
124
- 3. View predicted gene regions with positions and confidence scores
125
- """,
126
- examples=[
127
- ["ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGACGCGTACAGGAAACACAGAAAAAAGCCCGCACCTGACAGTGCGGGCTTTTTTTTTCGACCAAAGGTAACGAGGTAACAACCATGCGAGTGTTGAAGTTCGGCGGTACATCAGTGGCAAATGCAGAACGTTTTCTGCGGGTTGCCGATATTCTGGAAAGCAATGCCAGGCAGGGGCAGGTGGCCACCGTCCTCTCTGCCCCCGCCAAAATCACCAACCACCTGGTGGCGATGATTGAAAAAACCATTAGCGGCCAGGATGCTTTACCCAATATCAGCGATGCCGAACGTATTTTTGCCGAACTTTTGACGGGACTCGCCGCCGCCCAGCCGGGGTTCCCGCTGGCGCAATTGAAAACTTTCGTCGATCAGGAATTTGCCCAAATAAAACATGTCCTGCATGGCATTAGTTTGTTGGGGCAGTGCCCGGATAGCATCAACGCTGCGCTGATTTGCCGTGGCGAGAAAATGTCGATCGCCATTATGGCCGGCGTATTAGAAGCGCGCGGTCACAACGTTACTGTTATCGATCCGGTCGAAAAACTGCTGGCAGTGGGGCATTACCTCGAATCTACCGTCGATATTGCTGAGTCCACCCGCCGTATTGCGGCAAGCCGCATTCCGGCTGATCACATGGTGCTGATGGCAGGTTTCACCGCCGGTAATGAAAAAGGCGAACTGGTGGTGCTTGGACGCAACGGTTCCGACTACTCTGCTGCGGTGCTGGCTGCCTGTTTACGCGCCGATTGTTGCGAGATTTGGACGGACGTTGACGGGGTCTATACCTGCGACCCGCGTCAGGTGCCCGATGCGAGGTTGTTGAAGTCGATGTCCTACCAGGAAGCGATGGAGCTTTCCTACTTCGGCGCTAAAGTTCTTCACCCCCGCACCATTACCCCCATCGCCCAGTTCCAGATCCCTTGCCTGATTAAAAATACCGGAAATCCTCAAGCACCAGGTACGCTCATTGGTGCCAGCCGTGATGAAGACGAATTACCGGTCAAGGGCATTTCCAATCTGAATAACATGGCAATGTTCAGCGTTTCCGGCCCGGGGATGAAAGGGATGGTCGGCATGGCGGCGCGCGTCTTTGCAGCGATGTCACGCGCCCGTATTTCCGTGGTGCTGATTACGCAATCATCTTCCGAATACAGCATCAGTTTCTGCGTTCCACAAAGCGACTGTGTGCGAGCTGAACGGGCAATGCAGGAAGAGTTCTACCTGGAACTGAAAGAAGGCTTACTGGAGCCGCTGGCAGTGACGGAACGGCTGGCCATTATCTCGGTGGTAGG"]
128
- ],
129
- allow_flagging="never"
130
- )
131
 
132
- # Launch the app
133
  if __name__ == "__main__":
134
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
 
 
4
  import os
5
  import traceback
6
+ import logging
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ print("=== Gene Prediction App Starting ===")
 
13
  print(f"Working directory: {os.getcwd()}")
14
+ print(f"Available files: {os.listdir('.')}")
15
  print(f"PyTorch version: {torch.__version__}")
16
+ print(f"Gradio version: {gr.__version__}")
17
 
18
+ # Global variables
19
  predictor = None
20
+ model_loaded = False
21
+ error_message = ""
22
 
23
+ def initialize_model():
24
+ """Initialize the model with proper error handling"""
25
+ global predictor, model_loaded, error_message
26
 
27
+ try:
28
+ print("Attempting to import predictor...")
29
+ from predictor import GenePredictor
30
+ print("✅ Predictor imported successfully")
31
+
32
+ model_path = 'best_boundary_aware_model.pth'
33
+ print(f"Looking for model file: {model_path}")
34
+
35
+ if not os.path.exists(model_path):
36
+ error_message = f"❌ Model file '{model_path}' not found in directory"
37
+ print(error_message)
38
+ print(f"Available files: {[f for f in os.listdir('.') if f.endswith('.pth')]}")
39
+ return False
40
+
41
+ print(f"Model file found. File size: {os.path.getsize(model_path)} bytes")
42
+
43
  predictor = GenePredictor(model_path=model_path)
44
+ model_loaded = True
45
+ print("✅ Model initialized successfully")
46
+ return True
47
+
48
+ except ImportError as e:
49
+ error_message = f"❌ Failed to import predictor: {str(e)}"
50
+ print(error_message)
51
+ return False
52
+ except Exception as e:
53
+ error_message = f"❌ Model initialization failed: {str(e)}"
54
+ print(error_message)
55
+ print("Full traceback:")
56
+ traceback.print_exc()
57
+ return False
58
 
59
  def predict_genes(sequence):
60
+ """Gene prediction function with comprehensive error handling"""
61
  try:
62
+ # Check if model is loaded
63
+ if not model_loaded or predictor is None:
64
+ return f"🚫 **Model Error**\n\n{error_message}\n\nPlease check the logs for more details."
65
 
66
+ # Input validation
67
  if not sequence or not sequence.strip():
68
+ return "⚠️ **Input Error**\n\nPlease enter a DNA sequence."
69
 
70
+ sequence = sequence.strip().upper().replace(' ', '').replace('\n', '').replace('\t', '')
71
 
72
+ # Character validation
73
+ valid_chars = set('ATCGN')
74
+ invalid_chars = set(sequence) - valid_chars
75
+ if invalid_chars:
76
+ return f"⚠️ **Invalid Characters**\n\nFound invalid characters: {', '.join(sorted(invalid_chars))}\n\nPlease use only: A, T, C, G, N"
77
 
78
+ # Length validation
79
  if len(sequence) < 3:
80
+ return f"⚠️ **Sequence Too Short**\n\nMinimum length: 3 nucleotides\nYour sequence: {len(sequence)} nucleotides"
81
 
82
  if len(sequence) > 10000:
83
+ return f"⚠️ **Sequence Too Long**\n\nMaximum length: 10,000 nucleotides\nYour sequence: {len(sequence)} nucleotides"
84
+
85
+ print(f"Processing sequence of length: {len(sequence)}")
86
 
87
  # Make prediction
88
  predictions, probs_dict, confidence = predictor.predict(sequence)
89
  regions = predictor.extract_gene_regions(predictions, sequence)
90
 
91
+ # Format results
92
  if not regions:
93
+ return f"🔍 **No Gene Regions Detected**\n\nSequence length: {len(sequence)} bp\nConfidence: {confidence:.3f}\n\nThe model did not detect any gene regions in this sequence."
94
+
95
+ result = f"🧬 **Gene Prediction Results**\n\n"
96
+ result += f"📊 **Summary:**\n"
97
+ result += f"�� Found: {len(regions)} gene region(s)\n"
98
+ result += f"• Sequence length: {len(sequence)} bp\n"
99
+ result += f"• Overall confidence: {confidence:.3f}\n\n"
100
 
101
+ result += f"📍 **Detected Regions:**\n\n"
102
 
103
  for i, region in enumerate(regions, 1):
104
+ result += f"**Region {i}:**\n"
105
+ result += f"• Position: {region['start']:,} - {region['end']:,}\n"
106
+ result += f"• Length: {region['length']:,} bp\n"
107
 
108
+ # Sequence preview
109
  seq = region.get('sequence', '')
110
  if seq:
111
+ if len(seq) <= 100:
112
+ result += f"• Sequence: `{seq}`\n"
113
+ else:
114
+ preview = seq[:50] + '...' + seq[-50:]
115
+ result += f"• Preview: `{preview}`\n"
116
+
117
+ result += "\n"
118
 
119
  return result
120
 
121
  except Exception as e:
122
+ error_msg = f"🚫 **Prediction Error**\n\nAn error occurred during prediction:\n```\n{str(e)}\n```"
123
  print(f"Prediction error: {e}")
124
+ traceback.print_exc()
125
+ return error_msg
126
+
127
+ # Initialize model on startup
128
+ print("Initializing model...")
129
+ model_status = initialize_model()
130
 
131
+ # Create interface
132
+ print("Creating Gradio interface...")
 
133
 
134
+ # Determine status message and color
135
+ if model_loaded:
136
+ status_html = '<div style="padding: 10px; background-color: #d4edda; border: 1px solid #c3e6cb; border-radius: 5px; color: #155724;"><strong>✅ Model Status:</strong> Ready for predictions!</div>'
 
137
  else:
138
+ status_html = f'<div style="padding: 10px; background-color: #f8d7da; border: 1px solid #f5c6cb; border-radius: 5px; color: #721c24;"><strong>❌ Model Status:</strong> {error_message}</div>'
139
+
140
+ # Example sequence
141
+ example_sequence = "ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGACGCGTACAGGAAACACAGAAAAAAGCCCGCACCTGACAGTGCGGGCTTTTTTTTTCGACCAAAGGTAACGAGGTAACAACCATGCGAGTGTTGAAGTTCGGCGGTACATCAGTGGCAAATGCAGAACGTTTTCTGCG"
142
 
143
+ with gr.Blocks(title="🧬 Gene Prediction Tool", theme=gr.themes.Soft()) as demo:
144
+ gr.HTML("<h1 style='text-align: center; color: #2E8B57;'>🧬 Gene Prediction Tool</h1>")
145
+
146
+ gr.HTML(status_html)
147
+
148
+ gr.Markdown("""
149
+ ### Instructions:
150
+ 1. **Enter a DNA sequence** using only A, T, C, G, N characters
151
+ 2. **Click Submit** to analyze the sequence
152
+ 3. **View results** showing predicted gene regions with positions and confidence scores
153
+
154
+ **Sequence Requirements:**
155
+ - Only A, T, C, G, N characters allowed
156
+ - Minimum length: 3 nucleotides
157
+ - Maximum length: 10,000 nucleotides
158
+ """)
159
+
160
+ with gr.Row():
161
+ with gr.Column(scale=2):
162
+ sequence_input = gr.Textbox(
163
+ label="DNA Sequence",
164
+ placeholder="Enter DNA sequence (A, T, C, G, N)...",
165
+ lines=8,
166
+ max_lines=15
167
+ )
168
+
169
+ with gr.Row():
170
+ submit_btn = gr.Button("🔬 Analyze Sequence", variant="primary", size="lg")
171
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
172
+ example_btn = gr.Button("📝 Load Example", variant="secondary")
173
+
174
+ with gr.Column(scale=3):
175
+ output = gr.Textbox(
176
+ label="Prediction Results",
177
+ lines=20,
178
+ max_lines=30,
179
+ show_copy_button=True
180
+ )
181
+
182
+ # Event handlers
183
+ submit_btn.click(
184
+ fn=predict_genes,
185
+ inputs=sequence_input,
186
+ outputs=output
187
+ )
188
+
189
+ clear_btn.click(
190
+ fn=lambda: ("", ""),
191
+ outputs=[sequence_input, output]
192
+ )
193
 
194
+ example_btn.click(
195
+ fn=lambda: example_sequence,
196
+ outputs=sequence_input
197
+ )
198
 
199
+ # Also allow Enter key to submit
200
+ sequence_input.submit(
201
+ fn=predict_genes,
202
+ inputs=sequence_input,
203
+ outputs=output
204
+ )
 
 
 
 
205
 
 
206
  if __name__ == "__main__":
207
+ print("Launching Gradio app...")
208
+ demo.launch(
209
+ server_name="0.0.0.0",
210
+ server_port=7860,
211
+ show_error=True,
212
+ show_api=False
213
+ )