File size: 10,668 Bytes
404d784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#!/usr/bin/env python3
"""

Test script for ColiFormer Streamlit GUI



This script tests the core functionality of the GUI without running the full Streamlit application.

"""

import sys
import os
import traceback
from pathlib import Path

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))

def test_imports():
    """Test if all required imports work"""
    print("Testing imports...")

    try:
        import streamlit as st
        print(f"  OK: Streamlit: {st.__version__}")
    except ImportError as e:
        print(f"  FAIL: Streamlit: {e}")
        return False

    try:
        import torch
        device = "GPU" if torch.cuda.is_available() else "CPU"
        print(f"  OK: PyTorch: {torch.__version__} ({device})")
    except ImportError as e:
        print(f"  FAIL: PyTorch: {e}")
        return False

    try:
        import plotly
        print(f"  OK: Plotly: {plotly.__version__}")
    except ImportError as e:
        print(f"  FAIL: Plotly: {e}")
        return False

    try:
        from CodonTransformer.CodonPrediction import predict_dna_sequence
        print("  OK: CodonTransformer.CodonPrediction")
    except ImportError as e:
        print(f"  FAIL: CodonTransformer.CodonPrediction: {e}")
        return False

    try:
        from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI
        print("  OK: CodonTransformer.CodonEvaluation")
    except ImportError as e:
        print(f"  FAIL: CodonTransformer.CodonEvaluation: {e}")
        return False

    return True

def test_protein_validation():
    """Test protein sequence validation"""
    print("\nTesting protein sequence validation...")

    try:
        # Import the validation function
        from app import validate_protein_sequence

        # Test cases
        test_cases = [
            ("MKTVRQERLK", True, "Valid short sequence"),
            ("", False, "Empty sequence"),
            ("MKTVRQERLKX", False, "Invalid character X"),
            ("MK", False, "Too short"),
            ("M" * 501, False, "Too long"),
            ("mktvrqerlk", True, "Lowercase (should work)"),
            ("MKTVRQERLK*", True, "With stop codon"),
            ("MKTVRQERLK_", True, "With underscore stop"),
        ]

        for seq, expected_valid, description in test_cases:
            is_valid, message = validate_protein_sequence(seq)
            status = "OK" if is_valid == expected_valid else "FAIL"
            print(f"  {status} {description}: {message}")

        return True
    except Exception as e:
        print(f"  FAIL: Error in validation test: {e}")
        traceback.print_exc()
        return False

def test_metrics_calculation():
    """Test metrics calculation"""
    print("\nTesting metrics calculation...")

    try:
        from app import calculate_input_metrics

        test_protein = "MKTVRQERLK"
        organism = "Escherichia coli general"

        metrics = calculate_input_metrics(test_protein, organism)

        # Check if all expected metrics are present
        expected_keys = ['length', 'gc_content', 'baseline_dna', 'cai', 'tai']
        for key in expected_keys:
            if key in metrics:
                print(f"  OK: {key}: {metrics[key]}")
            else:
                print(f"  FAIL: Missing metric: {key}")
                return False

        # Validate metric values
        if metrics['length'] == len(test_protein):
            print("  OK: Length calculation correct")
        else:
            print("  FAIL: Length calculation incorrect")
            return False

        if 0 <= metrics['gc_content'] <= 100:
            print("  OK: GC content in valid range")
        else:
            print("  FAIL: GC content out of range")
            return False

        return True
    except Exception as e:
        print(f"  FAIL: Error in metrics calculation: {e}")
        traceback.print_exc()
        return False

def test_visualization_functions():
    """Test visualization functions"""
    print("\nTesting visualization functions...")

    try:
        from app import create_gc_content_plot, create_metrics_comparison_chart

        # Test GC content plot
        test_dna = "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC"
        fig = create_gc_content_plot(test_dna)
        print("  OK: GC content plot created")

        # Test metrics comparison chart
        before_metrics = {'gc_content': 50.0, 'cai': 0.5, 'tai': 0.3}
        after_metrics = {'gc_content': 52.0, 'cai': 0.6, 'tai': 0.4}
        fig = create_metrics_comparison_chart(before_metrics, after_metrics)
        print("  OK: Metrics comparison chart created")

        return True
    except Exception as e:
        print(f"  FAIL: Error in visualization test: {e}")
        traceback.print_exc()
        return False

def test_codon_evaluation():
    """Test CodonEvaluation functions directly"""
    print("\nTesting CodonEvaluation functions...")

    try:
        from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI, get_ecoli_tai_weights

        # Test GC content calculation
        test_dna = "ATGGCGAAAGCG"
        gc_content = get_GC_content(test_dna)
        print(f"  OK: GC content calculation: {gc_content:.1f}%")

        # Test tAI calculation
        try:
            tai_weights = get_ecoli_tai_weights()
            tai_value = calculate_tAI(test_dna, tai_weights)
            print(f"  OK: tAI calculation: {tai_value:.3f}")
        except Exception as e:
            print(f"  NOTE: tAI calculation (may need scipy): {e}")

        return True
    except Exception as e:
        print(f"  FAIL: Error in CodonEvaluation test: {e}")
        traceback.print_exc()
        return False

def test_model_loading():
    """Test model loading functionality"""
    print("\nTesting model loading (mock)...")

    try:
        import torch
        from transformers import AutoTokenizer
        from CodonTransformer.CodonPrediction import load_model

        # Test tokenizer loading (this is fast)
        print("  Testing tokenizer loading...")
        tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
        print("  OK: Tokenizer loaded successfully")

        # Test load_model function
        print("  Testing load_model function...")
        from transformers import BigBirdForMaskedLM
        print("  OK: Model class available: BigBirdForMaskedLM")

        # Check if fine-tuned model exists
        import os
        model_path = "models/alm-enhanced-training/balanced_alm_finetune.ckpt"
        if os.path.exists(model_path):
            print(f"  OK: Fine-tuned model found: {model_path}")
        else:
            print(f"  NOTE: Fine-tuned model not found at: {model_path}")

        # Note: We won't actually load the full model here as it's ~2GB
        print("  NOTE: Full model loading skipped in test (too large)")

        return True
    except Exception as e:
        print(f"  FAIL: Error in model loading test: {e}")
        traceback.print_exc()
        return False

def test_file_structure():
    """Test if all required files exist"""
    print("\nTesting file structure...")

    gui_dir = Path(__file__).parent
    parent_dir = gui_dir.parent

    required_files = [
        "app.py",
        "run_gui.py",
        "requirements.txt",
        "README.md"
    ]

    all_present = True
    for file_name in required_files:
        file_path = gui_dir / file_name
        if file_path.exists():
            print(f"  OK: {file_name}")
        else:
            print(f"  FAIL: {file_name} missing")
            all_present = False

    # Check for model checkpoint
    model_path = parent_dir / "models" / "alm-enhanced-training" / "balanced_alm_finetune.ckpt"
    if model_path.exists():
        print("  OK: Fine-tuned model checkpoint found")
    else:
        print("  NOTE: Fine-tuned model checkpoint not found")

    return all_present

def test_post_processing():
    """Test post-processing functionality"""
    print("\nTesting post-processing features...")

    try:
        from app import POST_PROCESSING_AVAILABLE, DNACHISEL_AVAILABLE

        if POST_PROCESSING_AVAILABLE:
            print("  OK: Post-processing module available")
            if DNACHISEL_AVAILABLE:
                print("  OK: DNAChisel available")
            else:
                print("  NOTE: DNAChisel not available")
        else:
            print("  NOTE: Post-processing module not available")

        return True
    except Exception as e:
        print(f"  FAIL: Error in post-processing test: {e}")
        return False

def main():
    """Run all tests"""
    print("ENCOT GUI Test Suite")
    print("=" * 50)

    tests = [
        ("File Structure", test_file_structure),
        ("Imports", test_imports),
        ("Protein Validation", test_protein_validation),
        ("Metrics Calculation", test_metrics_calculation),
        ("Visualization Functions", test_visualization_functions),
        ("CodonEvaluation Functions", test_codon_evaluation),
        ("Model Loading", test_model_loading),
        ("Post-Processing", test_post_processing),
    ]

    passed = 0
    total = len(tests)

    for test_name, test_func in tests:
        try:
            result = test_func()
            if result:
                passed += 1
                print(f"OK: {test_name}: PASSED")
            else:
                print(f"FAIL: {test_name}: FAILED")
        except Exception as e:
            print(f"FAIL: {test_name}: ERROR - {e}")

    print("\n" + "=" * 50)
    print(f"Test Results: {passed}/{total} tests passed")

    if passed == total:
        print("All tests passed. The GUI should work correctly.")
        print("\nTo run the GUI:")
        print("  python run_gui.py")
        print("  or")
        print("  cd streamlit_gui && streamlit run app.py --server.address=0.0.0.0")
    else:
        print("Some tests failed. Please check the issues above.")

    print("\nNotes:")
    print("  • Fine-tuned model integration")
    print("  • Enhanced constrained beam search")
    print("  • Post-processing with DNAChisel")
    print("  • Advanced sequence analysis")
    print("  • Improved parameter controls")

    return passed == total

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)