Spaces:
Sleeping
Sleeping
File size: 10,668 Bytes
404d784 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | #!/usr/bin/env python3
"""
Test script for ColiFormer Streamlit GUI
This script tests the core functionality of the GUI without running the full Streamlit application.
"""
import sys
import os
import traceback
from pathlib import Path
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
def test_imports():
"""Test if all required imports work"""
print("Testing imports...")
try:
import streamlit as st
print(f" OK: Streamlit: {st.__version__}")
except ImportError as e:
print(f" FAIL: Streamlit: {e}")
return False
try:
import torch
device = "GPU" if torch.cuda.is_available() else "CPU"
print(f" OK: PyTorch: {torch.__version__} ({device})")
except ImportError as e:
print(f" FAIL: PyTorch: {e}")
return False
try:
import plotly
print(f" OK: Plotly: {plotly.__version__}")
except ImportError as e:
print(f" FAIL: Plotly: {e}")
return False
try:
from CodonTransformer.CodonPrediction import predict_dna_sequence
print(" OK: CodonTransformer.CodonPrediction")
except ImportError as e:
print(f" FAIL: CodonTransformer.CodonPrediction: {e}")
return False
try:
from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI
print(" OK: CodonTransformer.CodonEvaluation")
except ImportError as e:
print(f" FAIL: CodonTransformer.CodonEvaluation: {e}")
return False
return True
def test_protein_validation():
"""Test protein sequence validation"""
print("\nTesting protein sequence validation...")
try:
# Import the validation function
from app import validate_protein_sequence
# Test cases
test_cases = [
("MKTVRQERLK", True, "Valid short sequence"),
("", False, "Empty sequence"),
("MKTVRQERLKX", False, "Invalid character X"),
("MK", False, "Too short"),
("M" * 501, False, "Too long"),
("mktvrqerlk", True, "Lowercase (should work)"),
("MKTVRQERLK*", True, "With stop codon"),
("MKTVRQERLK_", True, "With underscore stop"),
]
for seq, expected_valid, description in test_cases:
is_valid, message = validate_protein_sequence(seq)
status = "OK" if is_valid == expected_valid else "FAIL"
print(f" {status} {description}: {message}")
return True
except Exception as e:
print(f" FAIL: Error in validation test: {e}")
traceback.print_exc()
return False
def test_metrics_calculation():
"""Test metrics calculation"""
print("\nTesting metrics calculation...")
try:
from app import calculate_input_metrics
test_protein = "MKTVRQERLK"
organism = "Escherichia coli general"
metrics = calculate_input_metrics(test_protein, organism)
# Check if all expected metrics are present
expected_keys = ['length', 'gc_content', 'baseline_dna', 'cai', 'tai']
for key in expected_keys:
if key in metrics:
print(f" OK: {key}: {metrics[key]}")
else:
print(f" FAIL: Missing metric: {key}")
return False
# Validate metric values
if metrics['length'] == len(test_protein):
print(" OK: Length calculation correct")
else:
print(" FAIL: Length calculation incorrect")
return False
if 0 <= metrics['gc_content'] <= 100:
print(" OK: GC content in valid range")
else:
print(" FAIL: GC content out of range")
return False
return True
except Exception as e:
print(f" FAIL: Error in metrics calculation: {e}")
traceback.print_exc()
return False
def test_visualization_functions():
"""Test visualization functions"""
print("\nTesting visualization functions...")
try:
from app import create_gc_content_plot, create_metrics_comparison_chart
# Test GC content plot
test_dna = "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC"
fig = create_gc_content_plot(test_dna)
print(" OK: GC content plot created")
# Test metrics comparison chart
before_metrics = {'gc_content': 50.0, 'cai': 0.5, 'tai': 0.3}
after_metrics = {'gc_content': 52.0, 'cai': 0.6, 'tai': 0.4}
fig = create_metrics_comparison_chart(before_metrics, after_metrics)
print(" OK: Metrics comparison chart created")
return True
except Exception as e:
print(f" FAIL: Error in visualization test: {e}")
traceback.print_exc()
return False
def test_codon_evaluation():
"""Test CodonEvaluation functions directly"""
print("\nTesting CodonEvaluation functions...")
try:
from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI, get_ecoli_tai_weights
# Test GC content calculation
test_dna = "ATGGCGAAAGCG"
gc_content = get_GC_content(test_dna)
print(f" OK: GC content calculation: {gc_content:.1f}%")
# Test tAI calculation
try:
tai_weights = get_ecoli_tai_weights()
tai_value = calculate_tAI(test_dna, tai_weights)
print(f" OK: tAI calculation: {tai_value:.3f}")
except Exception as e:
print(f" NOTE: tAI calculation (may need scipy): {e}")
return True
except Exception as e:
print(f" FAIL: Error in CodonEvaluation test: {e}")
traceback.print_exc()
return False
def test_model_loading():
"""Test model loading functionality"""
print("\nTesting model loading (mock)...")
try:
import torch
from transformers import AutoTokenizer
from CodonTransformer.CodonPrediction import load_model
# Test tokenizer loading (this is fast)
print(" Testing tokenizer loading...")
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
print(" OK: Tokenizer loaded successfully")
# Test load_model function
print(" Testing load_model function...")
from transformers import BigBirdForMaskedLM
print(" OK: Model class available: BigBirdForMaskedLM")
# Check if fine-tuned model exists
import os
model_path = "models/alm-enhanced-training/balanced_alm_finetune.ckpt"
if os.path.exists(model_path):
print(f" OK: Fine-tuned model found: {model_path}")
else:
print(f" NOTE: Fine-tuned model not found at: {model_path}")
# Note: We won't actually load the full model here as it's ~2GB
print(" NOTE: Full model loading skipped in test (too large)")
return True
except Exception as e:
print(f" FAIL: Error in model loading test: {e}")
traceback.print_exc()
return False
def test_file_structure():
"""Test if all required files exist"""
print("\nTesting file structure...")
gui_dir = Path(__file__).parent
parent_dir = gui_dir.parent
required_files = [
"app.py",
"run_gui.py",
"requirements.txt",
"README.md"
]
all_present = True
for file_name in required_files:
file_path = gui_dir / file_name
if file_path.exists():
print(f" OK: {file_name}")
else:
print(f" FAIL: {file_name} missing")
all_present = False
# Check for model checkpoint
model_path = parent_dir / "models" / "alm-enhanced-training" / "balanced_alm_finetune.ckpt"
if model_path.exists():
print(" OK: Fine-tuned model checkpoint found")
else:
print(" NOTE: Fine-tuned model checkpoint not found")
return all_present
def test_post_processing():
"""Test post-processing functionality"""
print("\nTesting post-processing features...")
try:
from app import POST_PROCESSING_AVAILABLE, DNACHISEL_AVAILABLE
if POST_PROCESSING_AVAILABLE:
print(" OK: Post-processing module available")
if DNACHISEL_AVAILABLE:
print(" OK: DNAChisel available")
else:
print(" NOTE: DNAChisel not available")
else:
print(" NOTE: Post-processing module not available")
return True
except Exception as e:
print(f" FAIL: Error in post-processing test: {e}")
return False
def main():
"""Run all tests"""
print("ENCOT GUI Test Suite")
print("=" * 50)
tests = [
("File Structure", test_file_structure),
("Imports", test_imports),
("Protein Validation", test_protein_validation),
("Metrics Calculation", test_metrics_calculation),
("Visualization Functions", test_visualization_functions),
("CodonEvaluation Functions", test_codon_evaluation),
("Model Loading", test_model_loading),
("Post-Processing", test_post_processing),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
try:
result = test_func()
if result:
passed += 1
print(f"OK: {test_name}: PASSED")
else:
print(f"FAIL: {test_name}: FAILED")
except Exception as e:
print(f"FAIL: {test_name}: ERROR - {e}")
print("\n" + "=" * 50)
print(f"Test Results: {passed}/{total} tests passed")
if passed == total:
print("All tests passed. The GUI should work correctly.")
print("\nTo run the GUI:")
print(" python run_gui.py")
print(" or")
print(" cd streamlit_gui && streamlit run app.py --server.address=0.0.0.0")
else:
print("Some tests failed. Please check the issues above.")
print("\nNotes:")
print(" • Fine-tuned model integration")
print(" • Enhanced constrained beam search")
print(" • Post-processing with DNAChisel")
print(" • Advanced sequence analysis")
print(" • Improved parameter controls")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)
|