import gradio as gr import random import time # --- Constants --- CELL_TYPES = [ {"id": "t-cell", "name": "T-Cell", "color": "#3b82f6", "bg": "#dbeafe"}, {"id": "b-cell", "name": "B-Cell", "color": "#6366f1", "bg": "#e0e7ff"}, {"id": "macro", "name": "Macrophage", "color": "#f97316", "bg": "#ffedd5"}, ] CONDITIONS = [ {"id": "healthy", "name": "Healthy", "color": "#3b82f6", "bg": "#dbeafe"}, {"id": "drug-a", "name": "Drug A", "color": "#ef4444", "bg": "#fee2e2"}, {"id": "viral", "name": "Viral Infection", "color": "#f97316", "bg": "#ffedd5"}, ] # --- Helper Functions for Cell Visuals --- def make_cell_svg(color, is_prompt=False, size=36): """Create an SVG cell representation like in the figure""" # For prompt cells (treated/different condition) - show irregular shape if is_prompt: return f''' ''' else: # Query cells - regular circular shape return f''' ''' def generate_cell_array(cell_type, condition, count=6, is_prompt=False): """Generate an array of cells in a horizontal layout""" cell = next(c for c in CELL_TYPES if c["id"] == cell_type) cond = next(c for c in CONDITIONS if c["id"] == condition) # Use condition color for prompt cells, cell type color for query cells color = cond["color"] if is_prompt else cell["color"] bg_color = cond["bg"] if is_prompt else cell["bg"] cells_html = "" for i in range(count): cells_html += f'''
{make_cell_svg(color, is_prompt, 36)}
''' # Add ellipsis cells_html += '''
···
''' return cells_html def generate_inference_display(prompt_cell, prompt_cond, query_cell, num_prompt=3, num_query=5): """Generate the full inference visualization with stacked arrays""" prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell) prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond) query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell) prompt_cells = generate_cell_array(prompt_cell, prompt_cond, num_prompt, is_prompt=True) query_cells = generate_cell_array(query_cell, "healthy", num_query, is_prompt=False) html = f'''
In-context Learning
Gene expression counts → Predicted states
{prompt_cells}
{query_cells}
STACK
Gene Modules × Cells
{generate_mini_matrix(prompt_cond_data["color"], query_cell_data["color"])}
Prompt: {prompt_cell_data["name"]} + {prompt_cond_data["name"]}
Query: {query_cell_data["name"]} (Healthy)
''' return html def generate_mini_matrix(prompt_color, query_color): """Generate a small matrix visualization showing gene modules × cells""" cells = [] colors = [prompt_color, prompt_color, prompt_color, query_color, query_color] for row in range(4): for col in range(5): opacity = 0.2 + random.random() * 0.6 color = colors[col] cells.append(f'
') return '\n'.join(cells) def generate_grid_html(step, masked_indices=None): """Generate the 5x5 grid HTML for architecture view""" if masked_indices is None: masked_indices = [] # Color palette for the matrix colors = ["#f97316", "#3b82f6", "#06b6d4", "#1e3a5f", "#1e3a5f"] grid_html = '''
Genes
Cells
''' for i in range(25): row_idx = i // 5 col_idx = i % 5 is_masked = i in masked_indices is_col_active = step == "intra" and col_idx == 2 # Intra-cellular highlights columns is_row_active = step == "inter" and row_idx == 2 # Inter-cellular highlights rows # Determine cell color based on column base_color = colors[col_idx] if is_masked: bg_color = "#e2e8f0" content = '
🔄
' else: # Vary opacity based on position for visual interest opacity = 0.3 + (row_idx * 0.12) + (col_idx * 0.08) bg_color = base_color content = '' ring_style = "" if is_row_active: ring_style = "box-shadow: 0 0 0 3px #60a5fa; z-index: 10;" elif is_col_active: ring_style = "box-shadow: 0 0 0 3px #34d399; z-index: 10;" cell_opacity = "0.2" if is_masked else f"{0.3 + row_idx * 0.15}" grid_html += f'''
{content}
''' grid_html += '''
''' return grid_html def get_step_label(step): """Get the label for the current step""" labels = { "idle": '
Select a learning step to visualize attention patterns
', "intra": '
→ Intra-cellular: Learning gene dependencies within each cell
', "inter": '
↓ Inter-cellular: Learning context across cell population
', "masking": '
🔄 Pre-training: Masked gene expression reconstruction
', } return f'
{labels.get(step, labels["idle"])}
' # --- Architecture View Functions --- def update_architecture_view(step): """Update the architecture view based on selected step""" masked_indices = [] if step == "masking": # Mask two consecutive full rows (genes across all cells) start_row = random.randint(0, 3) # 0-3 so we can have 2 consecutive rows masked_indices = list(range(start_row * 5, (start_row + 2) * 5)) grid_html = generate_grid_html(step, masked_indices) label_html = get_step_label(step) return grid_html, label_html # --- Inference View Functions --- def update_inference_display(prompt_cell_name, prompt_cond_name, query_cell_name): """Update the inference visualization when selections change""" prompt_cell = next(c["id"] for c in CELL_TYPES if c["name"] == prompt_cell_name) prompt_cond = next(c["id"] for c in CONDITIONS if c["name"] == prompt_cond_name) query_cell = next(c["id"] for c in CELL_TYPES if c["name"] == query_cell_name) return generate_inference_display(prompt_cell, prompt_cond, query_cell), prompt_cell, prompt_cond, query_cell def run_inference(prompt_cell, prompt_cond, query_cell): """Run the inference prediction""" prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond) query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell) prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell) # Processing state processing_html = f'''
🔄
Processing gene expression context...
Learning from {prompt_cell_data["name"]} patterns under {prompt_cond_data["name"]}
''' yield processing_html, gr.update(visible=False), gr.update(visible=True) time.sleep(0.5) # Result state - show predicted gene counts per query cell predicted_cells = generate_cell_array(query_cell, prompt_cond, 5, is_prompt=True) result_html = f'''
PREDICTION COMPLETE
Predicted gene expression counts for {query_cell_data["name"]}
under {prompt_cond_data["name"]} condition
Predicted Gene Counts Per Query Cell
{generate_output_column(prompt_cond_data["color"])}
{make_cell_svg(query_cell_data["color"], False, 18)}
Cell 1
{generate_output_column(prompt_cond_data["color"])}
{make_cell_svg(query_cell_data["color"], False, 18)}
Cell 2
···
{generate_output_column(prompt_cond_data["color"])}
{make_cell_svg(query_cell_data["color"], False, 18)}
Cell n
Zero-shot prediction: Using in-context learning from {prompt_cell_data["name"]} response, STACK predicts gene counts for each {query_cell_data["name"]} under the same perturbation.
''' yield result_html, gr.update(visible=True), gr.update(visible=False) def generate_output_column(color): """Generate a vertical column of gene expression counts showing explicit values""" gene_names = ["g₁", "g₂", "g₃", "g₄", "g₅"] cells = [] for i, gene in enumerate(gene_names): # Generate pseudo gene count value count = random.randint(10, 500) bar_width = min(count / 500 * 40, 40) # Scale to max 40px opacity = 0.4 + (count / 500) * 0.5 cells.append(f'''
{gene}
{count}
''') return '\n'.join(cells) def reset_inference(prompt_cell, prompt_cond, query_cell): """Reset inference view to initial state""" return generate_inference_display(prompt_cell, prompt_cond, query_cell), gr.update(visible=False), gr.update(visible=True) # --- Main Gradio App --- def create_app(): with gr.Blocks(title="STACK Model Visualization") as app: # Header gr.HTML('''
🧬
STACK
''') with gr.Tabs() as tabs: # --- ARCHITECTURE TAB --- with gr.Tab("🏗️ Architecture"): with gr.Row(): with gr.Column(scale=2): grid_display = gr.HTML(generate_grid_html("idle")) label_display = gr.HTML(get_step_label("idle")) with gr.Column(scale=1): gr.HTML('''

Learning Process

STACK learns from gene expression matrices where rows are gene modules and columns are cells.

''') intra_btn = gr.Button("→ 1. Intra-cellular Attention", size="lg", variant="secondary") inter_btn = gr.Button("↓ 2. Inter-cellular Attention", size="lg", variant="secondary") masking_btn = gr.Button("🔄 3. Masked Pre-training", size="lg", variant="secondary") gr.HTML('''
💡 Key insight: By learning gene dependencies across the entire cell population, STACK can transfer knowledge from one cell type to another.
''') # State for architecture view arch_step = gr.State("idle") def set_step(step_name): grid, label = update_architecture_view(step_name) return grid, label, step_name intra_btn.click(lambda: set_step("intra"), outputs=[grid_display, label_display, arch_step]) inter_btn.click(lambda: set_step("inter"), outputs=[grid_display, label_display, arch_step]) masking_btn.click(lambda: set_step("masking"), outputs=[grid_display, label_display, arch_step]) # --- INFERENCE TAB --- with gr.Tab("🔮 Inference"): # Controls Row with gr.Row(): with gr.Column(scale=1): gr.HTML('''
🔴 Prompt Cells (Known Response)
''') prompt_cell_radio = gr.Radio( choices=[c["name"] for c in CELL_TYPES], value=CELL_TYPES[0]["name"], label="Cell Type", container=True ) prompt_cond_dropdown = gr.Dropdown( choices=[c["name"] for c in CONDITIONS if c["id"] != "healthy"], value=CONDITIONS[1]["name"], label="Condition/Treatment", container=True ) with gr.Column(scale=1): gr.HTML('''
🔵 Query Cells (To Predict)
''') query_cell_radio = gr.Radio( choices=[c["name"] for c in CELL_TYPES], value=CELL_TYPES[1]["name"], label="Cell Type", container=True ) gr.HTML('''
Initial state: Healthy (baseline gene expression)
''') # Visualization Display inference_display = gr.HTML( generate_inference_display(CELL_TYPES[0]["id"], CONDITIONS[1]["id"], CELL_TYPES[1]["id"]) ) # Action Buttons with gr.Row(): run_btn = gr.Button("▶️ Run Zero-Shot Prediction", variant="primary", size="lg") reset_btn = gr.Button("↩️ Reset", size="lg", visible=False) # States for inference prompt_cell_state = gr.State(CELL_TYPES[0]["id"]) prompt_cond_state = gr.State(CONDITIONS[1]["id"]) query_cell_state = gr.State(CELL_TYPES[1]["id"]) # Update display when selections change prompt_cell_radio.change( update_inference_display, inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio], outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state] ) prompt_cond_dropdown.change( update_inference_display, inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio], outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state] ) query_cell_radio.change( update_inference_display, inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio], outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state] ) # Run prediction run_btn.click( run_inference, inputs=[prompt_cell_state, prompt_cond_state, query_cell_state], outputs=[inference_display, reset_btn, run_btn] ) # Reset reset_btn.click( reset_inference, inputs=[prompt_cell_state, prompt_cond_state, query_cell_state], outputs=[inference_display, reset_btn, run_btn] ) return app def custom_css(): return """ @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } .gradio-container { max-width: 1000px !important; margin: auto; font-family: system-ui, -apple-system, sans-serif; } button { border-radius: 10px !important; font-size: 13px !important; font-weight: 600 !important; transition: all 0.2s !important; } .tabs button { font-size: 13px !important; font-weight: 600 !important; padding: 12px 20px !important; } .tabs button[aria-selected="true"] { background: linear-gradient(90deg, #4f46e5 0%, #7c3aed 100%) !important; color: white !important; } input[type="radio"] + label { font-size: 13px !important; } """ # --- Launch App --- if __name__ == "__main__": app = create_app() app.launch(css=custom_css())