Spaces:
Sleeping
Sleeping
| """ | |
| CRISPR Array Detection - HuggingFace Spaces App | |
| """ | |
| import os | |
| import html | |
| import logging | |
| import tempfile | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| os.environ.setdefault("MPLCONFIGDIR", os.path.join(tempfile.gettempdir(), "matplotlib")) | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import TwoSlopeNorm, ListedColormap | |
| from matplotlib.collections import LineCollection | |
| import umap | |
| from sklearn.cluster import AgglomerativeClustering | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import plotly.express as px | |
| from inference import predict_sequence, embed_sequence, WINDOW_SIZE | |
| from inference.model_loader import get_model, warmup_model, get_gpu_status | |
| from inference.tokenizer import validate_sequence, strip_fasta_header | |
| from inference.inference import detect_crispr_regions | |
| logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO")) | |
| logger = logging.getLogger(__name__) | |
| MAX_SEQUENCE_LENGTH = int(os.environ.get("MAX_SEQUENCE_LENGTH", "50000")) | |
| MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_BYTES", str(2 * 1024 * 1024))) | |
| MAX_SEQUENCE_VIEWER_LENGTH = int(os.environ.get("MAX_SEQUENCE_VIEWER_LENGTH", "20000")) | |
| QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "8")) | |
| DEFAULT_STRIDE = int(os.environ.get("DEFAULT_STRIDE", "500")) | |
| DEFAULT_THRESHOLD = float(os.environ.get("DEFAULT_THRESHOLD", "0.3")) | |
| # Custom CSS - Minimal monochrome design with Geist fonts | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap'); | |
| @font-face { | |
| font-family: 'Geist Mono'; | |
| src: url('https://cdn.jsdelivr.net/npm/geist@1.2.0/dist/fonts/geist-mono/GeistMono-Regular.woff2') format('woff2'); | |
| font-weight: 400; | |
| } | |
| @font-face { | |
| font-family: 'Geist Mono'; | |
| src: url('https://cdn.jsdelivr.net/npm/geist@1.2.0/dist/fonts/geist-mono/GeistMono-Medium.woff2') format('woff2'); | |
| font-weight: 500; | |
| } | |
| * { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, system-ui, sans-serif !important; | |
| } | |
| code, pre, .code, textarea, .prose code { | |
| font-family: 'Geist Mono', 'SF Mono', Consolas, monospace !important; | |
| } | |
| h1 { | |
| font-weight: 500 !important; | |
| letter-spacing: -0.02em !important; | |
| } | |
| h2, h3, h4 { | |
| font-weight: 500 !important; | |
| color: #18181b !important; | |
| } | |
| .gradio-container { | |
| max-width: 100% !important; | |
| background: #fafafa !important; | |
| } | |
| .gr-button-primary { | |
| background: #18181b !important; | |
| border: none !important; | |
| } | |
| .gr-button-primary:hover { | |
| background: #27272a !important; | |
| } | |
| .gr-button-secondary { | |
| background: #fff !important; | |
| border: 1px solid #e4e4e7 !important; | |
| color: #18181b !important; | |
| } | |
| .gr-panel { | |
| border: 1px solid #e4e4e7 !important; | |
| background: #fff !important; | |
| } | |
| /* Minimal table styling */ | |
| table { | |
| border-collapse: collapse !important; | |
| } | |
| th, td { | |
| border-bottom: 1px solid #e4e4e7 !important; | |
| padding: 8px 12px !important; | |
| } | |
| th { | |
| font-weight: 500 !important; | |
| text-transform: uppercase !important; | |
| font-size: 11px !important; | |
| letter-spacing: 0.05em !important; | |
| color: #71717a !important; | |
| } | |
| /* Slider styling */ | |
| input[type="range"] { | |
| accent-color: #18181b !important; | |
| } | |
| /* Tab styling */ | |
| .tab-nav button { | |
| font-weight: 400 !important; | |
| color: #52525b !important; | |
| } | |
| .tab-nav button.selected { | |
| color: #18181b !important; | |
| border-bottom: 2px solid #18181b !important; | |
| } | |
| """ | |
| # Real example sequences from training data | |
| CRISPR_EXAMPLE = """TCCCCATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTCTGTTTACTTCCCTCTATATCTTTTTTTGTTCGGTCATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTAAAATCACACTCACAGCCAATACAAGCGGGGGGGGAAATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTTGCAGTAGGGCAGACTGGCAGTTTTCGGGTAATGATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACATTCATACGAATAATCATTTCCGAAAGACTCCTTTTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACAGGTCATGAGCATTCAAAACGTTCTCCCCGTTCAATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTAGCCTGGACCAAATAATGTACGAACCTCTCCATCTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACATGAATTATATAACAGGGATTAAAATTTTTCTTATTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTAAATTTGAGCAAATACTAAAAAAATGAGACAAAAAGATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTCCGGCAATGAATTGATAGGACTTAAAATAATTGTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACACCAAGAAAATGAAAGAAATTTTCTTTGGAGAAACATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTGAGCCATCGACGGTCTCCGGAAGTAAAACCCCAAAATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACACATTCAAGTCGCTGCCTACCGTTGAAACATGGAAATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACCTTAATGGAAAGGCACGTAATACAAACGCGGGTAAATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTATCACGTTGAA""" | |
| NON_CRISPR_EXAMPLE = """TTCGTTCATTTTTCTGGTTTGACCAATAGCATTTAAAGCCGCCCCACATAAATCATTTGGCCCGGAAACCTTTTGGTAAGATAAAATAGTGCCATCACTAGAAAGCTCAATTCTAATATCACAGACCTTACCAAAAAAACGACTATCACCACGAAAATATGGCTTCACAGATTTATAGATAATACCTGCATACCTCGAACCCGCCTTACCGCCATCGCCTGCGCCCGCTGAAGCACCTGAACCCTGGCTGCCTTGTTTATTTATGTTACTACCTTGCAAAGCACTGCCGCCGCCCACATCACCGCCATTAAAAAAATCATCTAATGCATTTTGATCAGCTCGACGTTGCGCATCCGCTTTCGCTTTGGCTTCCGCATCTGCCTTAGCCTTGGCGTCAGCTTTCGCTTTAGCATCGGCCTTAGCTTTTGCTTCTGCGTCAGCTTTCGCTTTAGCATCCGCTTTGGCTTTTGCATCAGCTTCCGCCTTCGCTTTAGCCTGCGCCTCTGCTTTTGCCTTAGCTTCCGCCTCTGCTTTGGCTTTTGCCTCTTGTTTTGCTTTCTCGTCTGCCTGCACTTTGGCTTTTGCTTCTTCTTCCGCTTGTTTTGCAGCATCTGCTAAACGTTTTGCCTCAGCTTCAGCTTTGAGCTTAGCGGCTTCAGCAGCCTGTTTGGCTTTCGCCTCTTCTGCCTGTTTTTGCTTTTCCAAGGCTTCTAAGCGAGCTTTTTCTTCCTGTTGCTTTTTCTGTTCAGCCAAGAACCTTTGTCTTTCCAGCTCTTTTTGCTGTTCCAGTTCTTTCTGACGGGCAATTTCCTGTTGACGTTGTTGCTCTTTTAACACTTCCCGTTGTTTTTCTTCTTCCCGTTTCTGCTCTTCACGTTTTTGGTCTTCAAGGGCTTGTTTCTTTTGTCTGTCCGCCTGACCTTTTTTCTGTTGTTGAATTCGCCCCCATTCCTGCGCTGCCGAGCCCGTATCAACCATCACGGCGCCGATAACTTCACCG""" | |
| # Flanked CRISPR example: upstream (500bp) + CRISPR array (10 repeats) + downstream (500bp) | |
| # This shows nice visualization with low score on flanks and high score in the middle | |
| FLANKED_CRISPR_EXAMPLE = """ATGCGATCGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATTCCCCATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTCTGTTTACTTCCCTCTATATCTTTTTTTGTTCGGTCATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTAAAATCACACTCACAGCCAATACAAGCGGGGGGGGAAATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTTGCAGTAGGGCAGACTGGCAGTTTTCGGGTAATGATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACATTCATACGAATAATCATTTCCGAAAGACTCCTTTTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACAGGTCATGAGCATTCAAAACGTTCTCCCCGTTCAATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTAGCCTGGACCAAATAATGTACGAACCTCTCCATCTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACATGAATTATATAACAGGGATTAAAATTTTTCTTATTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTAAATTTGAGCAAATACTAAAAAAATGAGACAAAAAGATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTCCGGCAATGAATTGATAGGACTTAAAATAATTGTATTCGAGAGCAAGATCCACTAAAACAAGGATTGAAACTATCACGTTGAACGATCGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGAT""" | |
| # E. coli K-12 MG1655 CRISPR I-E example (based on real genomic region) | |
| # Contains the characteristic 29bp repeat: CGGTTTATCCCCGCTGGCGCGGGGAACTC | |
| # Structure: ~400bp upstream (cas genes) + CRISPR array (8 repeats + 7 spacers) + ~400bp downstream | |
| ECOLI_CRISPR_EXAMPLE = """ATGGATGAACGAAATCGTCAGGTGCTGGAACAACGCCTGCGCCAGCATATCGATGCGCTGGAAGCGCGCAGCAATGATGTCACCTGCCAGACGCTGGAACTGCTGCGCGATGGCGACGTACTGGATGCCGTGCTGGCGGATGCCCGCAAAGAGCTGGACGCACACCGCTTCCTGCTGGAAGACGGCTACACCACGCTGCAACAGATCGCCAACCTGCCGGGCGTGACCTCGATGCTGGACGACGGCGACATCCACCTGCACTGCGTGCTCGGCGTGCCGCAGCGCCGTGGCGAACATATCGAACAGTTCGCCCGCGAGCATTACCAGAATCCGCTGCAAACGCTGCGCGAGTGACGGTTTATCCCCGCTGGCGCGGGGAACTCGAAAGCTACGTTGATATTGCGCTATCTCATCGACGGTTTATCCCCGCTGGCGCGGGGAACTCTGCAGAACTCGAGGGATGAAACGGTCTTGCGGTTTATCCCCGCTGGCGCGGGGAACTCAATGAAGAAATGCTTCGATTTCGTAGCCGTTCGGTTTATCCCCGCTGGCGCGGGGAACTCGTTGTCTGGATGGATCGATCAATCTCATACAACGGTTTATCCCCGCTGGCGCGGGGAACTCCAGAACGATTCGCCACGGTCTGTTGATTAACCGGTTTATCCCCGCTGGCGCGGGGAACTCTGAAGTTGATGATGATTCCGATCAGCACCACGGTTTATCCCCGCTGGCGCGGGGAACTCATGATCTTGCAGGCGCGCCAGCACTTCAGCCATCGGTTTATCCCCGCTGGCGCGGGGAACTCGCGATGGCGATTTCATTACTGATGCGGCGTGAGCGTGGTGCAACATCCGCGCCCGCTGACGCGTTTTTTTGTATCCGGATAGCGTCAGCCGATGGCTGAAGCGGCGAGCAAGCTCTGAAGCGCAGCGCAATCGCGCCCTGATGGCGATGGCGCGTAATGATTTCACCGACGATATCGACATCGATATCGTCCAGGCTGCGCAGGATCAGGGCGATACGCAAACGCCCGCCTTCGCCAGCGATAATGCTGCCGCCACCCAGCAGCGCGCCCCAGAACACGGCGGCGAGGATGACGATGAAGCCGAAACGCCACAGCAGGCTGCCACAGCC""" | |
| # Longer examples for State-Dynamic Plot (upstream + CRISPR array + downstream) | |
| # Structure: ~600bp upstream | CRISPR array (25 repeats + 24 spacers) | ~600bp downstream | |
| # Total: ~3000 bp - ideal for seeing alternating patterns in State-Dynamic Plot | |
| EMBEDDING_CRISPR_EXAMPLE = """GACAGGTACAAGAAGGAGTATGCATCAATGTGGTCGTGTGGAACAAACGCCACTGGAGACTGGGTTAACCATTCGCTCCAGCGTCATGAAAGTCACTGTTAGGGCGACCTTCGATTCGGATGTGACATTTCATTACATTACGCTCAGGACTGCGAACGAAAGATTAAGAATGCTTAACCCGGTACCTAACCCATCTGATTTTTACACACTCTCCTTGGACTGGGAGGTATAAGGAATAGGCGGTAGACGCCTACTTAACTTTCATGGTGATCGTAAAGCGGAGCCTTACCATGCGGCAATTGTGAACTTTTAAATTCGATTTTTAGCTTTTCTATTATCCTAAACTTCGCTGTATATCACGCGGCGCGATGGGGCAGCCTGCCCCCACTGTGCGACCGGCCACTTAAGGCTTGAAAACTACGAGCAGATTACATGAATCTGTGTTGGGTGTGCCAGTGGCACCCGAAGGACGCACTGGTTCACTTTCGGGAACACGCACAGACGAGACACACTCTTCAAGTCGTGTTAAAAGGAGTAGGATTAACGTCGAGGATTGATTCCCGCTTATGTGCGTCTGCCGCTTATACGCATAATCTGCATGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACACCCAGCCATATTGGCGTTCTGCCAAATCGGAACCGGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACCCAATCAAATATTTACTACATTTACCGACCGCGCTCGTTTTAGAGCTATGCTGTTTTGAATGGTCACAAAACAGCAATCTTCGTAAATGCTAAAGGATCGGGGCACGAGTTTTAGTTCTATGCTGTGTTGAATGGTCCCAAAACCAGCTAGCTCCCTCAGCTCACCTACACCCGACCGTGGTTTTAGAGTAATGCCGTTTTGAATGGTCCCAAAACCATCTCAGTCCAGTTGTGTGAAATAGCTGGACTGGTGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACGTGTAAGGGTCGCGCTCTGCAACCAGCGGTTACGCCGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACAGAGTCTGATCTCTTAGGAACCCGGCGATGCCTGGCGTTTTAGATCTATGCTGTTTTGAATGGTCCCAAAACGTCCTCGGGTGTCCTCCTTTGGCCGTGCGGTCCTAAGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACGTATCTTATTAGTCACGTCCGGTAGCTCGGGACCGAGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACGCGGAATTACGAGAGGGACGAAGAGTCGCACTGCTGGGTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACACAATTAGGTTAAGCGTAACGTTATATGGTTATTGCGTTTTAGGGCTATGCTGTTTTGAATGGTCCCAAAACGTATGGCTCTATTATCAGATGTCGCCGCATCTTCCGGTTTTAGAGCTATGCTGTTTTGAATGGACCCAAAATCGGGCCAGAGCTATGTTAAAAGTCCCCGTAGTGTTAGTTCTAGAGCTATGCTGTTTTGAACGGTCCCAAAACTATGGTACTCTTCTACTCCTCGGAGTGAAGGGCAACGTTTTAGGGCTATGCTGTTTTGAATGGTCCCAAAACTCGTCCTTTACTACTTCGCGACTCAGGGGGTCGCCGGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACGACCGGATCCTATGCCTGCAGCAAGACATTGGGCCCGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACCAGGCACAGGGTGCACCACAATTGCGCTCAATCCGAGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACCGCGCCTTGATTTTTATAGTTGCGCCCGTAGCTCTCATTTTAGAGCTATGCTGTATTCATTGGTCCCAAAACCCGAGGACAAGAGTTCAACGACTATTATAGAGCGGAGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACATGCGTTTAACAATGGGCGAGGCCGATGCGTGAGGTGTTTTAGAGCTATGCTGTTTTGTATGGTCCCAAAACGAGATACCATTGTGCCCGCACGTATTTACCTCGAAGGTTTTAGAGCTGTGGTGTTTTGAATGGTCCCAAAACAGCTACCTGGCCAATGAACCGTACCAAGTGATCAACGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACCCCCTCCAGGGCTCACCGTATGACGCTCGCCAGATCGTTTTAGAGCTATGCTGTTTTGAATGGTCCCAAAACAGATAAGTTCAGTTTTTCTCACAATTTGATGTTAGAGTTTTAGAGGTATGCTGTTTTGAATGGTCCCTAAACAGCTGGCTAAGCGCGCGCGCCAAAGTAACGTGCAAAAAGCTGGATCTGCCAATCTCAGAAGCTATGTAGCCTTCGGGTAAGAAAACGCAGGCGTTGGTCGGTTAACGGCAGGTGCAACCCATTGTTGCATCGTAGGCACCGTCGCTTGCCCTCGTGGCACTGTAGTCGATGAAGGATTCATCGGCTTAGCTGTTCTCTGTCCGTCAGCGGCCAGGATAGGTCGTTCAGGTTCGCGCGACTCGGTTTCCGTTAAGTTGCAGTCGTATCCAGGTAATGATACCCATTGACCGGCCTACCAGGTCTGCGGGAGCTCTGCGGGGGTGTGCCGGACGAAGTGTTCTCTGCATATTGTTTCTAGCGGGTTAAATGTAATTCCATCCATACGGTCGACACCTACCTTAGGTCCAATCGGGATAAGATAATCATATAACAGAATACAAGGGCTGAGTATTGCTACCGCTAAGACGGCTGCGAGTGTGACACCCACGCATATAAGTGGGCACGTTGTGCGAGAATCTGTTTTGGATTCAGCCATGCAGAGACCCGTGAAAGGCGCCCTACCGCGACGACAACCAGACGGTTATAATTGGGCAACTGTTA""" | |
| # Random genomic sequence (no CRISPR) - for comparison in State-Dynamic Plot | |
| EMBEDDING_RANDOM_EXAMPLE = """ATGCGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCT""" | |
| def _count_fasta_records(text: str) -> int: | |
| return sum(1 for line in text.splitlines() if line.strip().startswith(">")) | |
| def normalize_sequence_input(sequence: str) -> tuple[bool, str, str]: | |
| """Clean and validate a single-sequence FASTA/raw DNA input.""" | |
| if sequence is None: | |
| return False, "", "Sequence is empty" | |
| text = str(sequence).strip() | |
| if not text: | |
| return False, "", "Sequence is empty" | |
| if _count_fasta_records(text) > 1: | |
| return False, "", "Multi-FASTA input is not supported. Please submit one sequence at a time." | |
| cleaned = strip_fasta_header(text) | |
| is_valid, error = validate_sequence(cleaned) | |
| if not is_valid: | |
| return False, cleaned, error | |
| if len(cleaned) > MAX_SEQUENCE_LENGTH: | |
| return ( | |
| False, | |
| cleaned, | |
| f"Sequence too long: {len(cleaned):,} bp > {MAX_SEQUENCE_LENGTH:,} bp limit", | |
| ) | |
| return True, cleaned, "" | |
| def validate_stride(stride) -> tuple[bool, int, str]: | |
| if isinstance(stride, bool): | |
| return False, 0, "Stride must be an integer between 50 and 500 bp" | |
| try: | |
| if isinstance(stride, float) and not stride.is_integer(): | |
| raise ValueError | |
| stride = int(stride) | |
| except (TypeError, ValueError): | |
| return False, 0, "Stride must be an integer between 50 and 500 bp" | |
| if not 50 <= stride <= 500: | |
| return False, stride, "Stride must be between 50 and 500 bp" | |
| return True, stride, "" | |
| def validate_threshold(threshold) -> tuple[bool, float, str]: | |
| try: | |
| threshold = float(threshold) | |
| except (TypeError, ValueError): | |
| return False, 0.0, "Threshold must be a number between 0 and 1" | |
| if not 0.0 <= threshold <= 1.0: | |
| return False, threshold, "Threshold must be between 0 and 1" | |
| return True, threshold, "" | |
| def validate_min_length(min_length) -> tuple[bool, int, str]: | |
| try: | |
| if isinstance(min_length, float) and not min_length.is_integer(): | |
| raise ValueError | |
| min_length = int(min_length) | |
| except (TypeError, ValueError): | |
| return False, 0, "Minimum region length must be an integer" | |
| if min_length < 1: | |
| return False, min_length, "Minimum region length must be at least 1 bp" | |
| return True, min_length, "" | |
| def prediction_error_outputs(message: str): | |
| return None, f"**Error**: {message}", [], None, None, None, None, None, "" | |
| def embedding_error_outputs(message: str): | |
| return None, f"**Error**: {message}", None, None | |
| def make_output_dir(prefix: str) -> str: | |
| return tempfile.mkdtemp(prefix=f"{prefix}_") | |
| def symmetric_activation_norm(values) -> TwoSlopeNorm: | |
| values = np.asarray(values, dtype=float) | |
| finite = values[np.isfinite(values)] | |
| if finite.size == 0: | |
| vmax = 1.0 | |
| else: | |
| vmax = max(abs(float(np.nanmin(finite))), abs(float(np.nanmax(finite)))) | |
| if vmax <= 0: | |
| vmax = 1.0 | |
| return TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax) | |
| def create_prediction_plot(positions, probabilities, threshold=0.3, regions=None): | |
| """Create a matplotlib figure showing the prediction curve (for PNG/PDF export).""" | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| # Plot probability curve | |
| ax.fill_between(positions, probabilities, alpha=0.3, color='blue') | |
| ax.plot(positions, probabilities, color='blue', linewidth=0.5) | |
| # Add threshold line | |
| ax.axhline(y=threshold, color='red', linestyle='--', alpha=0.7, label=f'Threshold ({threshold})') | |
| # Highlight regions above threshold | |
| above_threshold = np.array(probabilities) >= threshold | |
| if any(above_threshold): | |
| ax.fill_between(positions, probabilities, where=above_threshold, | |
| alpha=0.5, color='red', label='Predicted CRISPR') | |
| ax.set_xlabel('Position (bp)') | |
| ax.set_ylabel('CRISPR Probability') | |
| ax.set_title('CRISPR Array Detection Score') | |
| ax.set_ylim(0, 1) | |
| ax.set_xlim(min(positions) if positions else 1, max(positions) if positions else 1000) | |
| ax.legend(loc='upper right') | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def create_interactive_prediction_plot(positions, probabilities, threshold=0.3, regions=None): | |
| """Create an interactive Plotly figure showing the prediction curve with minimap.""" | |
| fig = go.Figure() | |
| min_pos = min(positions) if positions else 1 | |
| max_pos = max(positions) if positions else 1000 | |
| # Main probability curve with fill - monochrome | |
| fig.add_trace(go.Scatter( | |
| x=positions, | |
| y=probabilities, | |
| mode='lines', | |
| name='Score', | |
| line=dict(color='#18181b', width=1.5), | |
| fill='tozeroy', | |
| fillcolor='rgba(24, 24, 27, 0.08)', | |
| hovertemplate='Position: %{x:,} bp<br>Score: %{y:.3f}<extra></extra>' | |
| )) | |
| # Add threshold line - dashed gray | |
| fig.add_hline( | |
| y=threshold, | |
| line_dash="dash", | |
| line_color="#71717a", | |
| annotation_text=f"threshold={threshold}", | |
| annotation_position="top right", | |
| annotation_font_size=10, | |
| annotation_font_color="#71717a" | |
| ) | |
| # Highlight detected CRISPR regions - subtle gray | |
| if regions: | |
| for r in regions: | |
| fig.add_vrect( | |
| x0=r['start'], x1=r['end'], | |
| fillcolor="rgba(24, 24, 27, 0.06)", | |
| layer="below", | |
| line_width=1, | |
| line_color="rgba(24, 24, 27, 0.2)", | |
| annotation_text=f"#{r['region_id']}", | |
| annotation_position="top left", | |
| annotation_font_size=9, | |
| annotation_font_color="#52525b" | |
| ) | |
| fig.update_layout( | |
| title=None, | |
| xaxis=dict( | |
| title=dict(text='Position (bp)', font=dict(size=11, color='#52525b')), | |
| range=[min_pos, max_pos], | |
| gridcolor='#f4f4f5', | |
| showgrid=True, | |
| zeroline=False, | |
| linecolor='#e4e4e7', | |
| tickfont=dict(size=10, color='#71717a'), | |
| rangeslider=dict( | |
| visible=True, | |
| thickness=0.06, | |
| bgcolor='#fafafa', | |
| bordercolor='#e4e4e7', | |
| borderwidth=1 | |
| ), | |
| ), | |
| yaxis=dict( | |
| title=dict(text='Score', font=dict(size=11, color='#52525b')), | |
| range=[0, 1.05], | |
| gridcolor='#f4f4f5', | |
| showgrid=True, | |
| zeroline=False, | |
| linecolor='#e4e4e7', | |
| tickfont=dict(size=10, color='#71717a'), | |
| tickformat='.1f' | |
| ), | |
| hovermode='x unified', | |
| showlegend=False, | |
| height=420, | |
| plot_bgcolor='#fafafa', | |
| paper_bgcolor='#fafafa', | |
| margin=dict(t=50, b=60, l=50, r=20), | |
| font=dict(family='Inter, system-ui, sans-serif') | |
| ) | |
| return fig | |
| def create_embedding_heatmap(embedding, title="Sequence Embedding", cols=30): | |
| """Create a heatmap visualization of the embedding vector.""" | |
| embedding = np.array(embedding) | |
| n_dims = len(embedding) | |
| # Calculate grid dimensions | |
| rows = int(np.ceil(n_dims / cols)) | |
| # Pad embedding to fill grid | |
| padded_size = rows * cols | |
| padded = np.full(padded_size, np.nan) | |
| padded[:n_dims] = embedding | |
| # Reshape to 2D grid | |
| grid = padded.reshape(rows, cols) | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(14, max(3, rows * 0.25))) | |
| # Use diverging colormap centered at 0; constant embeddings need a non-zero span. | |
| norm = symmetric_activation_norm(embedding) | |
| im = ax.imshow(grid, cmap='RdBu_r', norm=norm, aspect='auto') | |
| # Add colorbar | |
| cbar = plt.colorbar(im, ax=ax, shrink=0.8, pad=0.02) | |
| cbar.set_label('Activation', fontsize=10) | |
| # Labels | |
| ax.set_xlabel(f'Dimension (columns of {cols})', fontsize=10) | |
| ax.set_ylabel('Row', fontsize=10) | |
| ax.set_title(f'{title} ({n_dims} dimensions)', fontsize=12, fontweight='bold') | |
| # Add dimension markers | |
| ax.set_xticks(np.arange(0, cols, 5)) | |
| ax.set_xticklabels([str(i) for i in range(0, cols, 5)], fontsize=8) | |
| ax.set_yticks(np.arange(rows)) | |
| ax.set_yticklabels([f'{i*cols}-{min((i+1)*cols-1, n_dims-1)}' for i in range(rows)], fontsize=8) | |
| plt.tight_layout() | |
| return fig | |
| def create_trajectory_heatmap(embeddings, title="Embedding Trajectory"): | |
| """Create a heatmap showing how embeddings change across windows.""" | |
| embeddings = np.array(embeddings) | |
| n_windows, n_dims = embeddings.shape | |
| # Subsample dimensions if too many | |
| if n_dims > 100: | |
| step = n_dims // 100 | |
| embeddings = embeddings[:, ::step] | |
| n_dims = embeddings.shape[1] | |
| dim_label = f'Dimension (subsampled, every {step}th)' | |
| else: | |
| dim_label = 'Dimension' | |
| fig, ax = plt.subplots(figsize=(14, max(4, n_windows * 0.3))) | |
| # Use diverging colormap; constant embeddings need a non-zero span. | |
| norm = symmetric_activation_norm(embeddings) | |
| im = ax.imshow(embeddings, cmap='RdBu_r', norm=norm, aspect='auto') | |
| cbar = plt.colorbar(im, ax=ax, shrink=0.8, pad=0.02) | |
| cbar.set_label('Activation', fontsize=10) | |
| ax.set_xlabel(dim_label, fontsize=10) | |
| ax.set_ylabel('Window', fontsize=10) | |
| ax.set_title(f'{title} ({n_windows} windows)', fontsize=12, fontweight='bold') | |
| plt.tight_layout() | |
| return fig | |
| def create_state_dynamic_plot(embeddings, n_clusters=8, stride=100): | |
| """ | |
| Create State-Dynamic Plot showing embedding trajectory in 2D with clustering. | |
| Similar to Figure 3 from the DFG SPP 2141 report - visualizes how different | |
| sequence regions (repeats, spacers, etc.) cluster in embedding space. | |
| """ | |
| embeddings = np.array(embeddings) | |
| n_windows, n_dims = embeddings.shape | |
| if n_windows < 5: | |
| # Not enough points for meaningful visualization | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| ax.text(0.5, 0.5, "Need longer sequence for State-Dynamic Plot\n(minimum ~1500 bp)", | |
| ha='center', va='center', fontsize=14) | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| ax.axis('off') | |
| return fig | |
| # Apply UMAP for dimensionality reduction | |
| n_neighbors = min(15, n_windows - 1) | |
| reducer = umap.UMAP( | |
| n_components=2, | |
| n_neighbors=n_neighbors, | |
| min_dist=0.1, | |
| metric='euclidean', | |
| random_state=42 | |
| ) | |
| embedding_2d = reducer.fit_transform(embeddings) | |
| # Apply clustering | |
| n_clusters = min(n_clusters, n_windows) | |
| clustering = AgglomerativeClustering(n_clusters=n_clusters) | |
| cluster_labels = clustering.fit_predict(embeddings) | |
| # Create figure with two subplots | |
| fig, axes = plt.subplots(1, 2, figsize=(16, 7)) | |
| # Define colors for clusters | |
| colors = plt.cm.tab10(np.linspace(0, 1, n_clusters)) | |
| cluster_cmap = ListedColormap(colors) | |
| # === Left plot: Colored by cluster === | |
| ax1 = axes[0] | |
| # Draw trajectory lines (connecting sequential windows) | |
| points = embedding_2d.reshape(-1, 1, 2) | |
| segments = np.concatenate([points[:-1], points[1:]], axis=1) | |
| # Color segments by cluster of starting point | |
| segment_colors = [colors[cluster_labels[i]] for i in range(len(segments))] | |
| lc = LineCollection(segments, colors=segment_colors, alpha=0.3, linewidths=1) | |
| ax1.add_collection(lc) | |
| # Scatter plot colored by cluster | |
| scatter1 = ax1.scatter( | |
| embedding_2d[:, 0], embedding_2d[:, 1], | |
| c=cluster_labels, cmap=cluster_cmap, | |
| s=60, alpha=0.8, edgecolors='white', linewidths=0.5 | |
| ) | |
| # Mark start and end | |
| ax1.scatter(embedding_2d[0, 0], embedding_2d[0, 1], | |
| c='green', s=200, marker='^', edgecolors='black', | |
| linewidths=2, label='Start', zorder=10) | |
| ax1.scatter(embedding_2d[-1, 0], embedding_2d[-1, 1], | |
| c='red', s=200, marker='s', edgecolors='black', | |
| linewidths=2, label='End', zorder=10) | |
| ax1.set_xlabel('UMAP 1', fontsize=11) | |
| ax1.set_ylabel('UMAP 2', fontsize=11) | |
| ax1.set_title('State-Dynamic Plot (by cluster)', fontsize=12, fontweight='bold') | |
| ax1.legend(loc='upper right') | |
| # Add colorbar for clusters | |
| cbar1 = plt.colorbar(scatter1, ax=ax1, shrink=0.8) | |
| cbar1.set_label('Cluster', fontsize=10) | |
| cbar1.set_ticks(np.arange(n_clusters)) | |
| # === Right plot: Colored by position === | |
| ax2 = axes[1] | |
| # Draw trajectory lines colored by position | |
| positions = np.arange(n_windows) | |
| norm = plt.Normalize(0, n_windows - 1) | |
| segment_colors_pos = plt.cm.viridis(norm(positions[:-1])) | |
| lc2 = LineCollection(segments, colors=segment_colors_pos, alpha=0.4, linewidths=1.5) | |
| ax2.add_collection(lc2) | |
| # Scatter plot colored by position | |
| scatter2 = ax2.scatter( | |
| embedding_2d[:, 0], embedding_2d[:, 1], | |
| c=positions, cmap='viridis', | |
| s=60, alpha=0.8, edgecolors='white', linewidths=0.5 | |
| ) | |
| # Mark start and end | |
| ax2.scatter(embedding_2d[0, 0], embedding_2d[0, 1], | |
| c='green', s=200, marker='^', edgecolors='black', | |
| linewidths=2, label='Start (5\')', zorder=10) | |
| ax2.scatter(embedding_2d[-1, 0], embedding_2d[-1, 1], | |
| c='red', s=200, marker='s', edgecolors='black', | |
| linewidths=2, label='End (3\')', zorder=10) | |
| ax2.set_xlabel('UMAP 1', fontsize=11) | |
| ax2.set_ylabel('UMAP 2', fontsize=11) | |
| ax2.set_title('State-Dynamic Plot (by position)', fontsize=12, fontweight='bold') | |
| ax2.legend(loc='upper right') | |
| # Add colorbar for position | |
| cbar2 = plt.colorbar(scatter2, ax=ax2, shrink=0.8) | |
| cbar2.set_label(f'Window position (×{stride} bp)', fontsize=10) | |
| plt.tight_layout() | |
| return fig | |
| def create_sequence_cluster_map(cluster_labels, stride=100, window_size=1000): | |
| """ | |
| Create a linear map showing cluster assignments along the sequence. | |
| Like a chromosome ideogram colored by activation cluster. | |
| """ | |
| n_windows = len(cluster_labels) | |
| n_clusters = len(np.unique(cluster_labels)) | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(14, 3)) | |
| # Define colors | |
| colors = plt.cm.tab10(np.linspace(0, 1, max(n_clusters, 10))) | |
| # Draw colored blocks for each window | |
| for i, cluster in enumerate(cluster_labels): | |
| start_pos = i * stride | |
| end_pos = start_pos + window_size | |
| ax.axvspan(start_pos, end_pos, alpha=0.7, color=colors[cluster], | |
| linewidth=0) | |
| # Add cluster legend | |
| handles = [plt.Rectangle((0,0), 1, 1, color=colors[i], alpha=0.7) | |
| for i in range(n_clusters)] | |
| ax.legend(handles, [f'Cluster {i}' for i in range(n_clusters)], | |
| loc='upper right', ncol=min(n_clusters, 5), fontsize=8) | |
| ax.set_xlim(0, (n_windows - 1) * stride + window_size) | |
| ax.set_ylim(0, 1) | |
| ax.set_xlabel('Position (bp)', fontsize=11) | |
| ax.set_ylabel('') | |
| ax.set_yticks([]) | |
| ax.set_title('Sequence colored by embedding cluster', fontsize=12, fontweight='bold') | |
| # Add position markers | |
| seq_len = (n_windows - 1) * stride + window_size | |
| for pos in range(0, int(seq_len), 500): | |
| ax.axvline(pos, color='black', alpha=0.2, linewidth=0.5) | |
| plt.tight_layout() | |
| return fig | |
| def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=False): | |
| """ | |
| Create interactive Plotly State-Dynamic Plot with 2D or 3D UMAP - monochrome style. | |
| """ | |
| embeddings = np.array(embeddings) | |
| n_windows, n_dims = embeddings.shape | |
| if n_windows < 5: | |
| fig = go.Figure() | |
| fig.add_annotation(text="Need longer sequence (minimum ~1500 bp)", | |
| xref="paper", yref="paper", x=0.5, y=0.5, | |
| showarrow=False, font=dict(size=14, color='#71717a')) | |
| fig.update_layout(plot_bgcolor='#fafafa', paper_bgcolor='#fafafa') | |
| return fig | |
| # UMAP reduction | |
| n_components = 3 if use_3d else 2 | |
| n_neighbors = min(15, n_windows - 1) | |
| reducer = umap.UMAP( | |
| n_components=n_components, | |
| n_neighbors=n_neighbors, | |
| min_dist=0.1, | |
| metric='euclidean', | |
| random_state=42 | |
| ) | |
| embedding_reduced = reducer.fit_transform(embeddings) | |
| # Clustering | |
| n_clusters = min(n_clusters, n_windows) | |
| clustering = AgglomerativeClustering(n_clusters=n_clusters) | |
| cluster_labels = clustering.fit_predict(embeddings) | |
| # Create position info for hover | |
| positions = np.arange(n_windows) * stride | |
| hover_text = [f"Window {i}<br>Position: {pos}-{pos+1000} bp<br>Cluster: {c}" | |
| for i, (pos, c) in enumerate(zip(positions, cluster_labels))] | |
| # Colorful palette for clusters | |
| colors = px.colors.qualitative.Set1[:n_clusters] | |
| if use_3d: | |
| fig = go.Figure() | |
| # Trajectory line | |
| fig.add_trace(go.Scatter3d( | |
| x=embedding_reduced[:, 0], | |
| y=embedding_reduced[:, 1], | |
| z=embedding_reduced[:, 2], | |
| mode='lines', | |
| line=dict(color='rgba(100,100,100,0.3)', width=2), | |
| name='Trajectory', | |
| hoverinfo='skip' | |
| )) | |
| # Points - colorful by cluster | |
| fig.add_trace(go.Scatter3d( | |
| x=embedding_reduced[:, 0], | |
| y=embedding_reduced[:, 1], | |
| z=embedding_reduced[:, 2], | |
| mode='markers', | |
| marker=dict( | |
| size=5, | |
| color=cluster_labels, | |
| colorscale='Set1', | |
| opacity=0.85, | |
| line=dict(width=0.5, color='white') | |
| ), | |
| text=hover_text, | |
| hovertemplate='%{text}<extra></extra>', | |
| name='Windows' | |
| )) | |
| # Start marker - green | |
| fig.add_trace(go.Scatter3d( | |
| x=[embedding_reduced[0, 0]], | |
| y=[embedding_reduced[0, 1]], | |
| z=[embedding_reduced[0, 2]], | |
| mode='markers', | |
| marker=dict(size=10, color='green', symbol='diamond'), | |
| name="5' start" | |
| )) | |
| # End marker - red | |
| fig.add_trace(go.Scatter3d( | |
| x=[embedding_reduced[-1, 0]], | |
| y=[embedding_reduced[-1, 1]], | |
| z=[embedding_reduced[-1, 2]], | |
| mode='markers', | |
| marker=dict(size=10, color='red', symbol='square'), | |
| name="3' end" | |
| )) | |
| fig.update_layout( | |
| title=None, | |
| scene=dict( | |
| xaxis=dict(title='UMAP 1', gridcolor='#e4e4e7', backgroundcolor='#fafafa'), | |
| yaxis=dict(title='UMAP 2', gridcolor='#e4e4e7', backgroundcolor='#fafafa'), | |
| zaxis=dict(title='UMAP 3', gridcolor='#e4e4e7', backgroundcolor='#fafafa'), | |
| ), | |
| height=550, | |
| showlegend=True, | |
| legend=dict(font=dict(size=10), bgcolor='rgba(250,250,250,0.9)'), | |
| plot_bgcolor='#fafafa', | |
| paper_bgcolor='#fafafa', | |
| font=dict(family='Inter, system-ui, sans-serif', color='#52525b') | |
| ) | |
| else: | |
| # 2D Plot with subplots | |
| fig = make_subplots( | |
| rows=2, cols=2, | |
| specs=[[{"type": "scatter"}, {"type": "scatter"}], | |
| [{"type": "scatter", "colspan": 2}, None]], | |
| subplot_titles=('by cluster', 'by position', 'sequence map'), | |
| row_heights=[0.6, 0.4], | |
| vertical_spacing=0.12 | |
| ) | |
| # Left plot: by cluster | |
| fig.add_trace(go.Scatter( | |
| x=embedding_reduced[:, 0], | |
| y=embedding_reduced[:, 1], | |
| mode='lines', | |
| line=dict(color='rgba(113,113,122,0.15)', width=1), | |
| hoverinfo='skip', | |
| showlegend=False | |
| ), row=1, col=1) | |
| for c in range(n_clusters): | |
| mask = cluster_labels == c | |
| fig.add_trace(go.Scatter( | |
| x=embedding_reduced[mask, 0], | |
| y=embedding_reduced[mask, 1], | |
| mode='markers', | |
| marker=dict(size=7, color=colors[c], opacity=0.8, | |
| line=dict(width=0.5, color='white')), | |
| text=[hover_text[i] for i in np.where(mask)[0]], | |
| hovertemplate='%{text}<extra></extra>', | |
| name=f'{c}', | |
| legendgroup=f'c{c}' | |
| ), row=1, col=1) | |
| # Start/End markers | |
| fig.add_trace(go.Scatter( | |
| x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]], | |
| mode='markers', marker=dict(size=12, color='green', symbol='triangle-up', | |
| line=dict(width=1, color='black')), | |
| name="5'", showlegend=True | |
| ), row=1, col=1) | |
| fig.add_trace(go.Scatter( | |
| x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]], | |
| mode='markers', marker=dict(size=12, color='red', symbol='square', | |
| line=dict(width=1, color='black')), | |
| name="3'", showlegend=True | |
| ), row=1, col=1) | |
| # Right plot: by position - viridis gradient | |
| fig.add_trace(go.Scatter( | |
| x=embedding_reduced[:, 0], | |
| y=embedding_reduced[:, 1], | |
| mode='lines+markers', | |
| line=dict(color='rgba(100,100,100,0.3)', width=1), | |
| marker=dict(size=7, color=np.arange(n_windows), colorscale='Viridis', | |
| showscale=True, colorbar=dict(title=dict(text='window', font=dict(size=10)), | |
| x=1.02, tickfont=dict(size=9))), | |
| text=hover_text, | |
| hovertemplate='%{text}<extra></extra>', | |
| showlegend=False | |
| ), row=1, col=2) | |
| fig.add_trace(go.Scatter( | |
| x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]], | |
| mode='markers', marker=dict(size=12, color='green', symbol='triangle-up', | |
| line=dict(width=1, color='black')), | |
| showlegend=False | |
| ), row=1, col=2) | |
| fig.add_trace(go.Scatter( | |
| x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]], | |
| mode='markers', marker=dict(size=12, color='red', symbol='square', | |
| line=dict(width=1, color='black')), | |
| showlegend=False | |
| ), row=1, col=2) | |
| # Bottom: sequence map - colorful blocks | |
| window_size = 1000 | |
| for i, (cluster, pos) in enumerate(zip(cluster_labels, positions)): | |
| fig.add_trace(go.Scatter( | |
| x=[pos, pos + window_size, pos + window_size, pos, pos], | |
| y=[0, 0, 1, 1, 0], | |
| fill='toself', | |
| fillcolor=colors[cluster], | |
| line=dict(width=0), | |
| hoverinfo='text', | |
| text=f'Position {pos}-{pos+window_size} bp<br>Cluster {cluster}', | |
| showlegend=False | |
| ), row=2, col=1) | |
| fig.update_xaxes(title_text='UMAP 1', row=1, col=1, gridcolor='#f4f4f5', | |
| tickfont=dict(size=9, color='#71717a')) | |
| fig.update_yaxes(title_text='UMAP 2', row=1, col=1, gridcolor='#f4f4f5', | |
| tickfont=dict(size=9, color='#71717a')) | |
| fig.update_xaxes(title_text='UMAP 1', row=1, col=2, gridcolor='#f4f4f5', | |
| tickfont=dict(size=9, color='#71717a')) | |
| fig.update_yaxes(title_text='UMAP 2', row=1, col=2, gridcolor='#f4f4f5', | |
| tickfont=dict(size=9, color='#71717a')) | |
| fig.update_xaxes(title_text='position (bp)', row=2, col=1, gridcolor='#f4f4f5', | |
| tickfont=dict(size=9, color='#71717a')) | |
| fig.update_yaxes(visible=False, row=2, col=1) | |
| fig.update_layout( | |
| title=None, | |
| height=650, | |
| showlegend=True, | |
| legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1, | |
| font=dict(size=9), bgcolor='rgba(250,250,250,0.9)'), | |
| plot_bgcolor='#fafafa', | |
| paper_bgcolor='#fafafa', | |
| font=dict(family='Inter, system-ui, sans-serif', color='#52525b', size=11), | |
| margin=dict(t=40, b=40) | |
| ) | |
| # Style subplot titles | |
| for annotation in fig['layout']['annotations']: | |
| annotation['font'] = dict(size=11, color='#52525b') | |
| return fig | |
| def parse_fasta_file(file_path): | |
| """Parse a FASTA file and return the sequence.""" | |
| if file_path is None: | |
| return None | |
| size = os.path.getsize(file_path) | |
| if size > MAX_UPLOAD_BYTES: | |
| raise gr.Error(f"Uploaded file is too large ({size:,} bytes > {MAX_UPLOAD_BYTES:,} byte limit).") | |
| with open(file_path, 'r', encoding='utf-8', errors='replace') as f: | |
| content = f.read() | |
| is_valid, cleaned, error = normalize_sequence_input(content) | |
| if not is_valid: | |
| raise gr.Error(error) | |
| return cleaned | |
| def create_gff3_export(regions, sequence_length, sequence_id="input_sequence", output_dir=None): | |
| """Create GFF3 format annotation file for detected CRISPR regions.""" | |
| output_dir = output_dir or make_output_dir("crispr_export") | |
| gff_path = os.path.join(output_dir, "crispr_regions.gff3") | |
| with open(gff_path, 'w') as f: | |
| # GFF3 header | |
| f.write("##gff-version 3\n") | |
| f.write(f"##sequence-region {sequence_id} 1 {sequence_length}\n") | |
| for r in regions: | |
| # GFF3 format: seqid source type start end score strand phase attributes | |
| attributes = f"ID=CRISPR_{r['region_id']};Name=CRISPR_array_{r['region_id']};score={r['mean_score']:.3f}" | |
| f.write(f"{sequence_id}\tCRISPR-BERT\tCRISPR_array\t{r['start']}\t{r['end']}\t{r['mean_score']:.3f}\t.\t.\t{attributes}\n") | |
| return gff_path | |
| def create_sequence_viewer_html(sequence, positions, probabilities, threshold=0.3, chunk_size=100): | |
| """Create an HTML visualization of the sequence with grayscale intensity scores.""" | |
| seq_len = len(sequence) | |
| if seq_len > MAX_SEQUENCE_VIEWER_LENGTH: | |
| return ( | |
| '<div style="background: #fafafa; padding: 16px; border: 1px solid #e4e4e7;">' | |
| f'Sequence viewer disabled for sequences longer than {MAX_SEQUENCE_VIEWER_LENGTH:,} bp ' | |
| f'(current sequence: {seq_len:,} bp). Use the plot and downloads for full results.' | |
| '</div>' | |
| ) | |
| per_base_scores = np.asarray(probabilities, dtype=float) | |
| if len(per_base_scores) != seq_len: | |
| per_base_scores = np.resize(per_base_scores, seq_len) | |
| # Generate HTML - monochrome style | |
| html_parts = ['<div style="font-family: \'Geist Mono\', \'SF Mono\', Consolas, monospace; font-size: 11px; line-height: 1.9; background: #fafafa; padding: 16px; border: 1px solid #e4e4e7; max-height: 400px; overflow-y: auto;">'] | |
| html_parts.append('<div style="margin-bottom: 12px; font-family: Inter, system-ui, sans-serif; font-size: 11px; color: #71717a;">') | |
| html_parts.append('<span style="background: linear-gradient(to right, #fafafa, #18181b); padding: 3px 24px; border: 1px solid #e4e4e7; display: inline-block;">low → high</span>') | |
| html_parts.append(f'<span style="margin-left: 12px;">threshold: {threshold}</span>') | |
| html_parts.append('</div>') | |
| # Process sequence in chunks | |
| for chunk_start in range(0, seq_len, chunk_size): | |
| chunk_end = min(chunk_start + chunk_size, seq_len) | |
| chunk_seq = sequence[chunk_start:chunk_end] | |
| chunk_scores = per_base_scores[chunk_start:chunk_end] | |
| # Position marker | |
| html_parts.append(f'<div><span style="color: #a1a1aa; width: 55px; display: inline-block; font-size: 10px;">{chunk_start+1:,}</span>') | |
| for i, (base, score) in enumerate(zip(chunk_seq, chunk_scores)): | |
| # Grayscale intensity based on score | |
| intensity = int(255 - score * 200) # Higher score = darker | |
| color = f'rgb({intensity},{intensity},{intensity})' | |
| bg_intensity = int(250 - score * 40) | |
| bg_color = f'rgb({bg_intensity},{bg_intensity},{bg_intensity})' | |
| font_weight = '600' if score >= threshold else '400' | |
| safe_base = html.escape(base) | |
| html_parts.append(f'<span style="color: {color}; background-color: {bg_color}; font-weight: {font_weight};" title="pos {chunk_start + i + 1}: {score:.3f}">{safe_base}</span>') | |
| html_parts.append('</div>') | |
| html_parts.append('</div>') | |
| return ''.join(html_parts) | |
| def predict(sequence: str, stride: int = DEFAULT_STRIDE, threshold: float = DEFAULT_THRESHOLD): | |
| """Predict CRISPR array probability for each position.""" | |
| import csv | |
| import time | |
| start_time = time.time() | |
| is_valid, sequence, error = normalize_sequence_input(sequence) | |
| if not is_valid: | |
| return prediction_error_outputs(error) | |
| is_valid, stride, error = validate_stride(stride) | |
| if not is_valid: | |
| return prediction_error_outputs(error) | |
| is_valid, threshold, error = validate_threshold(threshold) | |
| if not is_valid: | |
| return prediction_error_outputs(error) | |
| result = predict_sequence(sequence, stride=stride, aggregation="mean") | |
| # Reuse the prediction result so the model only runs once per analysis. | |
| regions = detect_crispr_regions( | |
| sequence, | |
| threshold=threshold, | |
| min_length=100, | |
| stride=stride, | |
| prediction_result=result, | |
| ) | |
| # User-facing coordinates are 1-based. Core inference stays 0-based. | |
| display_positions = [pos + 1 for pos in result.positions] | |
| # Use matplotlib figure for display AND export. | |
| # Plotly + Gradio 6.x + heavy CUSTOM_CSS was freezing the browser after inference; | |
| # matplotlib is the boring-and-reliable fallback to rule that out. | |
| output_dir = make_output_dir("crispr_prediction") | |
| fig = create_prediction_plot(display_positions, result.probabilities, threshold, regions) | |
| png_path, pdf_path = save_figure_to_file(fig, "crispr_prediction", output_dir) | |
| # Create CSV with prediction data | |
| csv_path = os.path.join(output_dir, "crispr_predictions.csv") | |
| with open(csv_path, 'w', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow(['position_1based', 'probability', 'above_threshold']) | |
| for pos, prob in zip(result.positions, result.probabilities): | |
| writer.writerow([pos + 1, f"{prob:.4f}", prob >= threshold]) | |
| # Create GFF3 export | |
| gff_path = create_gff3_export(regions, result.sequence_length, output_dir=output_dir) if regions else None | |
| # Sequence viewer disabled as a freeze-diagnostic: 1000+ inline-styled spans | |
| # seem to block the browser after render. Re-enable once the root cause is known. | |
| seq_viewer_html = "" | |
| elapsed_time = time.time() - start_time | |
| # Create summary text file | |
| summary_path = os.path.join(output_dir, "crispr_summary.txt") | |
| summary_text = f"""CRISPR Array Detection Summary | |
| ============================== | |
| Sequence length: {result.sequence_length:,} bp | |
| Windows processed: {result.num_windows} | |
| Stride: {stride} bp | |
| Threshold: {threshold} | |
| Inference time: {elapsed_time:.2f} seconds | |
| Overall score: {result.overall_score:.4f} | |
| Max score: {max(result.probabilities):.4f} | |
| Min score: {min(result.probabilities):.4f} | |
| Detected CRISPR Regions: {len(regions)} | |
| """ | |
| if regions: | |
| summary_text += "\nRegion Details:\n" | |
| for r in regions: | |
| summary_text += f" Region {r['region_id']}: {r['start']:,}-{r['end']:,} bp ({r['length']} bp), mean score: {r['mean_score']:.3f}\n" | |
| with open(summary_path, 'w') as f: | |
| f.write(summary_text) | |
| # Markdown summary for display | |
| summary = f"""## Results | |
| | Metric | Value | | |
| |--------|-------| | |
| | Sequence length | {result.sequence_length:,} bp | | |
| | Windows processed | {result.num_windows} | | |
| | Overall score | {result.overall_score:.4f} | | |
| | Max score | {max(result.probabilities):.4f} | | |
| | Regions detected | {len(regions)} | | |
| | Inference time | {elapsed_time:.2f}s | | |
| """ | |
| if regions: | |
| summary += "### Detected CRISPR Regions\n\n" | |
| for r in regions: | |
| summary += f"- **Region {r['region_id']}**: positions {r['start']:,}-{r['end']:,} ({r['length']} bp), score: {r['mean_score']:.3f}\n" | |
| return fig, summary, regions, png_path, pdf_path, csv_path, summary_path, gff_path, seq_viewer_html | |
| def detect(sequence: str, threshold: float = 0.3, min_length: int = 160): | |
| """Detect CRISPR array regions.""" | |
| is_valid, sequence, error = normalize_sequence_input(sequence) | |
| if not is_valid: | |
| return [], f"**Error**: {error}" | |
| is_valid, threshold, error = validate_threshold(threshold) | |
| if not is_valid: | |
| return [], f"**Error**: {error}" | |
| is_valid, min_length, error = validate_min_length(min_length) | |
| if not is_valid: | |
| return [], f"**Error**: {error}" | |
| regions = detect_crispr_regions( | |
| sequence, | |
| threshold=threshold, | |
| min_length=min_length, | |
| stride=DEFAULT_STRIDE | |
| ) | |
| if not regions: | |
| return [], "**No CRISPR arrays detected** above the specified threshold." | |
| summary = f"## Detected {len(regions)} CRISPR region(s)\n\n" | |
| for r in regions: | |
| summary += f"- **Region {r['region_id']}**: positions {r['start']:,}-{r['end']:,} ({r['length']} bp), score: {r['mean_score']:.3f}\n" | |
| return regions, summary | |
| def save_figure_to_file(fig, prefix="plot", output_dir=None): | |
| """Save matplotlib figure to temporary files for download.""" | |
| output_dir = output_dir or make_output_dir(prefix) | |
| # Save PNG | |
| png_path = os.path.join(output_dir, f"{prefix}.png") | |
| fig.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white') | |
| # Save PDF | |
| pdf_path = os.path.join(output_dir, f"{prefix}.pdf") | |
| fig.savefig(pdf_path, bbox_inches='tight', facecolor='white') | |
| return png_path, pdf_path | |
| def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False): | |
| """Extract hidden state embedding and visualize as heatmap.""" | |
| allowed_modes = {"state-dynamics", "mean", "max", "trajectory", "cls"} | |
| if mode not in allowed_modes: | |
| return embedding_error_outputs( | |
| "Mode must be one of: state-dynamics, mean, max, trajectory, cls" | |
| ) | |
| is_valid, sequence, error = normalize_sequence_input(sequence) | |
| if not is_valid: | |
| return embedding_error_outputs(error) | |
| result = embed_sequence(sequence, mode="trajectory" if mode == "state-dynamics" else mode) | |
| png_path, pdf_path = None, None | |
| output_dir = make_output_dir("crispr_embedding") | |
| if mode == "trajectory": | |
| # Create trajectory heatmap (windows x dimensions) | |
| fig = create_trajectory_heatmap( | |
| result.embeddings, | |
| title="Embedding Trajectory Across Sequence" | |
| ) | |
| png_path, pdf_path = save_figure_to_file(fig, "trajectory_embedding", output_dir) | |
| summary = f"""## Trajectory Embedding | |
| | Property | Value | | |
| |----------|-------| | |
| | Sequence length | {result.sequence_length:,} bp | | |
| | Windows | {result.num_windows} | | |
| | Embedding dim | {result.embedding_dim} | | |
| Each row shows the embedding for one sliding window position. | |
| Blue = negative activation, Red = positive activation. | |
| """ | |
| elif mode == "state-dynamics": | |
| # Create interactive State-Dynamic Plot using Plotly | |
| embeddings = np.array(result.embeddings) | |
| n_windows = embeddings.shape[0] | |
| n_clusters = min(8, max(3, n_windows // 3)) | |
| # Use the interactive Plotly version | |
| fig = create_interactive_state_plot(embeddings, n_clusters=n_clusters, stride=100, use_3d=use_3d) | |
| # For downloads, create a static matplotlib version | |
| static_fig = create_state_dynamic_plot(embeddings, n_clusters=n_clusters, stride=100) | |
| png_path, pdf_path = save_figure_to_file(static_fig, "state_dynamic_plot", output_dir) | |
| plt.close(static_fig) | |
| dim_text = "3D" if use_3d else "2D" | |
| summary = f"""## Interactive State-Dynamic Plot ({dim_text}) | |
| | Property | Value | | |
| |----------|-------| | |
| | Sequence length | {result.sequence_length:,} bp | | |
| | Windows analyzed | {result.num_windows} | | |
| | Clusters identified | {n_clusters} | | |
| | Visualization | {dim_text} UMAP | | |
| **Interactive controls:** | |
| - **Hover** over points to see window position and cluster | |
| - **Zoom** by scrolling or selecting region | |
| - **Pan** by dragging | |
| - **{"Rotate" if use_3d else "Double-click"}** to {"rotate 3D view" if use_3d else "reset zoom"} | |
| - **Download**: Use buttons below for PNG/PDF, or camera icon in plot toolbar | |
| **Interpretation:** | |
| - Points colored by cluster - similar activation patterns group together | |
| - Trajectory shows path through embedding space along the sequence | |
| - Alternating colors in CRISPR arrays indicate repeating structural elements (repeats vs spacers) | |
| """ | |
| else: | |
| # Create single embedding heatmap | |
| fig = create_embedding_heatmap( | |
| result.embedding, | |
| title=f"Sequence Embedding ({result.method})" | |
| ) | |
| png_path, pdf_path = save_figure_to_file(fig, f"embedding_{mode}", output_dir) | |
| summary = f"""## Embedding Extracted | |
| | Property | Value | | |
| |----------|-------| | |
| | Sequence length | {result.sequence_length:,} bp | | |
| | Pooling method | {result.method} | | |
| | Embedding dim | {result.embedding_dim} | | |
| Each cell represents one dimension of the {result.embedding_dim}-dimensional embedding. | |
| Blue = negative activation, Red = positive activation. | |
| """ | |
| return fig, summary, png_path, pdf_path | |
| # Build interface | |
| with gr.Blocks( | |
| title="CRISPR Array Detection", | |
| theme=gr.themes.Base( | |
| primary_hue=gr.themes.colors.zinc, | |
| secondary_hue=gr.themes.colors.zinc, | |
| neutral_hue=gr.themes.colors.zinc, | |
| font=gr.themes.GoogleFont("Inter"), | |
| font_mono=gr.themes.GoogleFont("Geist Mono"), | |
| ), | |
| css=CUSTOM_CSS, | |
| delete_cache=(3600, 86400), | |
| ) as demo: | |
| gr.Markdown(""" | |
| # crispr-detect | |
| BERT-based CRISPR array detection. 24-layer transformer (430M params) trained on metagenomic sequences. | |
| Sliding window analysis with per-position probability scores. Export to GFF3/CSV. | |
| """) | |
| with gr.Tab("Prediction"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| seq_input = gr.Textbox( | |
| label="sequence", | |
| placeholder="Paste DNA sequence (FASTA format accepted)...", | |
| lines=6, | |
| value=CRISPR_EXAMPLE, | |
| info="min 1000 bp" | |
| ) | |
| file_upload = gr.File( | |
| label="upload fasta", | |
| file_types=[".fasta", ".fa", ".fna", ".txt"], | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| stride_input = gr.Slider( | |
| minimum=50, maximum=500, value=DEFAULT_STRIDE, step=50, | |
| label="stride", | |
| info="500 = fast on CPU; lower = higher resolution but slower" | |
| ) | |
| threshold_input = gr.Slider( | |
| minimum=0.1, maximum=0.9, value=0.3, step=0.05, | |
| label="threshold", | |
| info="lower = sensitive, higher = specific" | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button("run", variant="primary", size="lg") | |
| gr.Markdown("*examples:*") | |
| with gr.Row(): | |
| gr.Button("flanked", size="sm").click( | |
| lambda: FLANKED_CRISPR_EXAMPLE, outputs=seq_input | |
| ) | |
| gr.Button("e.coli", size="sm").click( | |
| lambda: ECOLI_CRISPR_EXAMPLE, outputs=seq_input | |
| ) | |
| with gr.Row(): | |
| gr.Button("crispr", size="sm").click( | |
| lambda: CRISPR_EXAMPLE, outputs=seq_input | |
| ) | |
| gr.Button("control", size="sm").click( | |
| lambda: NON_CRISPR_EXAMPLE, outputs=seq_input | |
| ) | |
| result_summary = gr.Markdown() | |
| with gr.Accordion("export", open=False) as download_accordion: | |
| with gr.Row(): | |
| pred_download_png = gr.File(label="png", interactive=False) | |
| pred_download_pdf = gr.File(label="pdf", interactive=False) | |
| with gr.Row(): | |
| pred_download_csv = gr.File(label="csv", interactive=False) | |
| pred_download_gff = gr.File(label="gff3", interactive=False) | |
| with gr.Row(): | |
| pred_download_summary = gr.File(label="summary", interactive=False) | |
| with gr.Column(scale=2): | |
| plot_output = gr.Plot(label="prediction") | |
| seq_viewer_html = gr.HTML(visible=False) | |
| regions_output = gr.JSON(label="Detected Regions", visible=False) | |
| # Handle file upload - load content into textbox | |
| def load_file_to_textbox(file_path): | |
| if file_path: | |
| return parse_fasta_file(file_path) | |
| return gr.update() | |
| file_upload.change( | |
| load_file_to_textbox, | |
| inputs=[file_upload], | |
| outputs=[seq_input] | |
| ) | |
| def predict_and_show_downloads(*args): | |
| try: | |
| return predict(*args) | |
| except Exception as exc: | |
| logger.exception("Prediction failed") | |
| return prediction_error_outputs(f"Analysis failed: {exc}") | |
| predict_btn.click( | |
| predict_and_show_downloads, | |
| inputs=[seq_input, stride_input, threshold_input], | |
| outputs=[plot_output, result_summary, regions_output, pred_download_png, pred_download_pdf, | |
| pred_download_csv, pred_download_summary, pred_download_gff, seq_viewer_html], | |
| api_name="predict", | |
| concurrency_limit=1, | |
| ) | |
| with gr.Tab("Embeddings"): | |
| gr.Markdown(""" | |
| ### embeddings | |
| 768-dim hidden states from transformer layer 21. UMAP projection + agglomerative clustering. | |
| Repeats cluster together, spacers form distinct groups. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| embed_seq = gr.Textbox( | |
| label="sequence", | |
| placeholder="Paste DNA sequence...", | |
| lines=6, | |
| value=EMBEDDING_CRISPR_EXAMPLE, | |
| info="min ~2000 bp for clustering" | |
| ) | |
| embed_mode = gr.Radio( | |
| choices=["state-dynamics", "mean", "max", "trajectory"], | |
| value="state-dynamics", | |
| label="mode", | |
| info="" | |
| ) | |
| use_3d = gr.Checkbox( | |
| label="3D", | |
| value=False, | |
| info="", | |
| visible=True | |
| ) | |
| with gr.Row(): | |
| embed_btn = gr.Button("extract", variant="primary") | |
| with gr.Row(): | |
| gr.Button("crispr 3kb", size="sm").click( | |
| lambda: EMBEDDING_CRISPR_EXAMPLE, outputs=embed_seq | |
| ) | |
| gr.Button("control 3kb", size="sm").click( | |
| lambda: EMBEDDING_RANDOM_EXAMPLE, outputs=embed_seq | |
| ) | |
| gr.Markdown("*example: 600bp upstream | 25 repeats + 24 spacers | 600bp downstream*") | |
| embed_summary = gr.Markdown() | |
| with gr.Accordion("export", open=False) as embed_download_accordion: | |
| with gr.Row(): | |
| download_png = gr.File(label="png", interactive=False) | |
| download_pdf = gr.File(label="pdf", interactive=False) | |
| with gr.Column(scale=2): | |
| embed_plot = gr.Plot(label="embedding") | |
| # Show/hide 3D checkbox based on mode | |
| embed_mode.change( | |
| lambda m: gr.update(visible=(m == "state-dynamics")), | |
| inputs=[embed_mode], | |
| outputs=[use_3d] | |
| ) | |
| def embed_and_show_downloads(*args): | |
| try: | |
| return get_embedding(*args) | |
| except Exception as exc: | |
| logger.exception("Embedding failed") | |
| return embedding_error_outputs(f"Embedding failed: {exc}") | |
| embed_btn.click( | |
| embed_and_show_downloads, | |
| inputs=[embed_seq, embed_mode, use_3d], | |
| outputs=[embed_plot, embed_summary, download_png, download_pdf], | |
| api_name="get_embedding", | |
| concurrency_limit=1, | |
| ) | |
| with gr.Tab("API"): | |
| gr.Markdown(""" | |
| ### api | |
| ```python | |
| from gradio_client import Client | |
| client = Client("genomenet/crispr-array-detection") | |
| # predict | |
| result = client.predict( | |
| sequence="ATGC...", | |
| stride=500, | |
| threshold=0.3, | |
| api_name="/predict" | |
| ) | |
| # embeddings | |
| result = client.predict( | |
| sequence="ATGC...", | |
| mode="state-dynamics", | |
| use_3d=False, | |
| api_name="/get_embedding" | |
| ) | |
| ``` | |
| **output formats**: CSV (scores), GFF3 (annotations), PNG/PDF (figures) | |
| **local**: | |
| ```bash | |
| git clone https://huggingface.co/spaces/genomenet/crispr-array-detection | |
| pip install -r requirements.txt && python app.py | |
| ``` | |
| """) | |
| with gr.Tab("About"): | |
| gr.Markdown(""" | |
| ### about | |
| | | | | |
| |---|---| | |
| | architecture | BERT, 24 layers, 768 hidden, 12 heads, 430M params | | |
| | training | metagenomic contigs, microbial genomes, CRISPRCasdb | | |
| | window | 1000 bp | | |
| | embedding | layer 21 (768-dim) | | |
| **parameters** | |
| | param | default | range | | |
| |-------|---------|-------| | |
| | stride | 500 bp | 50-500 | | |
| | threshold | 0.3 | 0.1-0.9 | | |
| **citation** | |
| Mu, Z. (2024). Deep Learning-Based CRISPR Array Detection. Master's Thesis, HZI. | |
| **acknowledgements** | |
| DFG SPP 2141 (MC 172) / BMBF de.NBI GenomeNet / HZI BIFO | |
| """) | |
| if __name__ == "__main__": | |
| print("Loading model...") | |
| model = get_model() | |
| warmup_model(model) | |
| print(f"Model ready! GPU: {get_gpu_status()}") | |
| demo.queue(max_size=QUEUE_MAX_SIZE, default_concurrency_limit=1) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| max_threads=4, | |
| show_error=True, | |
| ) | |