#!/usr/bin/env python3 """ Test script for ColiFormer Streamlit GUI This script tests the core functionality of the GUI without running the full Streamlit application. """ import sys import os import traceback from pathlib import Path # Add parent directory to path for imports sys.path.append(str(Path(__file__).parent.parent)) def test_imports(): """Test if all required imports work""" print("Testing imports...") try: import streamlit as st print(f" OK: Streamlit: {st.__version__}") except ImportError as e: print(f" FAIL: Streamlit: {e}") return False try: import torch device = "GPU" if torch.cuda.is_available() else "CPU" print(f" OK: PyTorch: {torch.__version__} ({device})") except ImportError as e: print(f" FAIL: PyTorch: {e}") return False try: import plotly print(f" OK: Plotly: {plotly.__version__}") except ImportError as e: print(f" FAIL: Plotly: {e}") return False try: from CodonTransformer.CodonPrediction import predict_dna_sequence print(" OK: CodonTransformer.CodonPrediction") except ImportError as e: print(f" FAIL: CodonTransformer.CodonPrediction: {e}") return False try: from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI print(" OK: CodonTransformer.CodonEvaluation") except ImportError as e: print(f" FAIL: CodonTransformer.CodonEvaluation: {e}") return False return True def test_protein_validation(): """Test protein sequence validation""" print("\nTesting protein sequence validation...") try: # Import the validation function from app import validate_protein_sequence # Test cases test_cases = [ ("MKTVRQERLK", True, "Valid short sequence"), ("", False, "Empty sequence"), ("MKTVRQERLKX", False, "Invalid character X"), ("MK", False, "Too short"), ("M" * 501, False, "Too long"), ("mktvrqerlk", True, "Lowercase (should work)"), ("MKTVRQERLK*", True, "With stop codon"), ("MKTVRQERLK_", True, "With underscore stop"), ] for seq, expected_valid, description in test_cases: is_valid, message = validate_protein_sequence(seq) status = "OK" if is_valid == expected_valid else "FAIL" print(f" {status} {description}: {message}") return True except Exception as e: print(f" FAIL: Error in validation test: {e}") traceback.print_exc() return False def test_metrics_calculation(): """Test metrics calculation""" print("\nTesting metrics calculation...") try: from app import calculate_input_metrics test_protein = "MKTVRQERLK" organism = "Escherichia coli general" metrics = calculate_input_metrics(test_protein, organism) # Check if all expected metrics are present expected_keys = ['length', 'gc_content', 'baseline_dna', 'cai', 'tai'] for key in expected_keys: if key in metrics: print(f" OK: {key}: {metrics[key]}") else: print(f" FAIL: Missing metric: {key}") return False # Validate metric values if metrics['length'] == len(test_protein): print(" OK: Length calculation correct") else: print(" FAIL: Length calculation incorrect") return False if 0 <= metrics['gc_content'] <= 100: print(" OK: GC content in valid range") else: print(" FAIL: GC content out of range") return False return True except Exception as e: print(f" FAIL: Error in metrics calculation: {e}") traceback.print_exc() return False def test_visualization_functions(): """Test visualization functions""" print("\nTesting visualization functions...") try: from app import create_gc_content_plot, create_metrics_comparison_chart # Test GC content plot test_dna = "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC" fig = create_gc_content_plot(test_dna) print(" OK: GC content plot created") # Test metrics comparison chart before_metrics = {'gc_content': 50.0, 'cai': 0.5, 'tai': 0.3} after_metrics = {'gc_content': 52.0, 'cai': 0.6, 'tai': 0.4} fig = create_metrics_comparison_chart(before_metrics, after_metrics) print(" OK: Metrics comparison chart created") return True except Exception as e: print(f" FAIL: Error in visualization test: {e}") traceback.print_exc() return False def test_codon_evaluation(): """Test CodonEvaluation functions directly""" print("\nTesting CodonEvaluation functions...") try: from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI, get_ecoli_tai_weights # Test GC content calculation test_dna = "ATGGCGAAAGCG" gc_content = get_GC_content(test_dna) print(f" OK: GC content calculation: {gc_content:.1f}%") # Test tAI calculation try: tai_weights = get_ecoli_tai_weights() tai_value = calculate_tAI(test_dna, tai_weights) print(f" OK: tAI calculation: {tai_value:.3f}") except Exception as e: print(f" NOTE: tAI calculation (may need scipy): {e}") return True except Exception as e: print(f" FAIL: Error in CodonEvaluation test: {e}") traceback.print_exc() return False def test_model_loading(): """Test model loading functionality""" print("\nTesting model loading (mock)...") try: import torch from transformers import AutoTokenizer from CodonTransformer.CodonPrediction import load_model # Test tokenizer loading (this is fast) print(" Testing tokenizer loading...") tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") print(" OK: Tokenizer loaded successfully") # Test load_model function print(" Testing load_model function...") from transformers import BigBirdForMaskedLM print(" OK: Model class available: BigBirdForMaskedLM") # Check if fine-tuned model exists import os model_path = "models/alm-enhanced-training/balanced_alm_finetune.ckpt" if os.path.exists(model_path): print(f" OK: Fine-tuned model found: {model_path}") else: print(f" NOTE: Fine-tuned model not found at: {model_path}") # Note: We won't actually load the full model here as it's ~2GB print(" NOTE: Full model loading skipped in test (too large)") return True except Exception as e: print(f" FAIL: Error in model loading test: {e}") traceback.print_exc() return False def test_file_structure(): """Test if all required files exist""" print("\nTesting file structure...") gui_dir = Path(__file__).parent parent_dir = gui_dir.parent required_files = [ "app.py", "run_gui.py", "requirements.txt", "README.md" ] all_present = True for file_name in required_files: file_path = gui_dir / file_name if file_path.exists(): print(f" OK: {file_name}") else: print(f" FAIL: {file_name} missing") all_present = False # Check for model checkpoint model_path = parent_dir / "models" / "alm-enhanced-training" / "balanced_alm_finetune.ckpt" if model_path.exists(): print(" OK: Fine-tuned model checkpoint found") else: print(" NOTE: Fine-tuned model checkpoint not found") return all_present def test_post_processing(): """Test post-processing functionality""" print("\nTesting post-processing features...") try: from app import POST_PROCESSING_AVAILABLE, DNACHISEL_AVAILABLE if POST_PROCESSING_AVAILABLE: print(" OK: Post-processing module available") if DNACHISEL_AVAILABLE: print(" OK: DNAChisel available") else: print(" NOTE: DNAChisel not available") else: print(" NOTE: Post-processing module not available") return True except Exception as e: print(f" FAIL: Error in post-processing test: {e}") return False def main(): """Run all tests""" print("ENCOT GUI Test Suite") print("=" * 50) tests = [ ("File Structure", test_file_structure), ("Imports", test_imports), ("Protein Validation", test_protein_validation), ("Metrics Calculation", test_metrics_calculation), ("Visualization Functions", test_visualization_functions), ("CodonEvaluation Functions", test_codon_evaluation), ("Model Loading", test_model_loading), ("Post-Processing", test_post_processing), ] passed = 0 total = len(tests) for test_name, test_func in tests: try: result = test_func() if result: passed += 1 print(f"OK: {test_name}: PASSED") else: print(f"FAIL: {test_name}: FAILED") except Exception as e: print(f"FAIL: {test_name}: ERROR - {e}") print("\n" + "=" * 50) print(f"Test Results: {passed}/{total} tests passed") if passed == total: print("All tests passed. The GUI should work correctly.") print("\nTo run the GUI:") print(" python run_gui.py") print(" or") print(" cd streamlit_gui && streamlit run app.py --server.address=0.0.0.0") else: print("Some tests failed. Please check the issues above.") print("\nNotes:") print(" • Fine-tuned model integration") print(" • Enhanced constrained beam search") print(" • Post-processing with DNAChisel") print(" • Advanced sequence analysis") print(" • Improved parameter controls") return passed == total if __name__ == "__main__": success = main() sys.exit(0 if success else 1)