import gradio as gr
import random
import time
# --- Constants ---
CELL_TYPES = [
{"id": "t-cell", "name": "T-Cell", "color": "#3b82f6", "bg": "#dbeafe"},
{"id": "b-cell", "name": "B-Cell", "color": "#6366f1", "bg": "#e0e7ff"},
{"id": "macro", "name": "Macrophage", "color": "#f97316", "bg": "#ffedd5"},
]
CONDITIONS = [
{"id": "healthy", "name": "Healthy", "color": "#3b82f6", "bg": "#dbeafe"},
{"id": "drug-a", "name": "Drug A", "color": "#ef4444", "bg": "#fee2e2"},
{"id": "viral", "name": "Viral Infection", "color": "#f97316", "bg": "#ffedd5"},
]
# --- Helper Functions for Cell Visuals ---
def make_cell_svg(color, is_prompt=False, size=36):
"""Create an SVG cell representation like in the figure"""
# For prompt cells (treated/different condition) - show irregular shape
if is_prompt:
return f'''
'''
else:
# Query cells - regular circular shape
return f'''
'''
def generate_cell_array(cell_type, condition, count=6, is_prompt=False):
"""Generate an array of cells in a horizontal layout"""
cell = next(c for c in CELL_TYPES if c["id"] == cell_type)
cond = next(c for c in CONDITIONS if c["id"] == condition)
# Use condition color for prompt cells, cell type color for query cells
color = cond["color"] if is_prompt else cell["color"]
bg_color = cond["bg"] if is_prompt else cell["bg"]
cells_html = ""
for i in range(count):
cells_html += f'''
{make_cell_svg(color, is_prompt, 36)}
'''
# Add ellipsis
cells_html += '''
···
'''
return cells_html
def generate_inference_display(prompt_cell, prompt_cond, query_cell, num_prompt=3, num_query=5):
"""Generate the full inference visualization with stacked arrays"""
prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell)
prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond)
query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell)
prompt_cells = generate_cell_array(prompt_cell, prompt_cond, num_prompt, is_prompt=True)
query_cells = generate_cell_array(query_cell, "healthy", num_query, is_prompt=False)
html = f'''
'''
return html
def generate_mini_matrix(prompt_color, query_color):
"""Generate a small matrix visualization showing gene modules × cells"""
cells = []
colors = [prompt_color, prompt_color, prompt_color, query_color, query_color]
for row in range(4):
for col in range(5):
opacity = 0.2 + random.random() * 0.6
color = colors[col]
cells.append(f'')
return '\n'.join(cells)
def generate_grid_html(step, masked_indices=None):
"""Generate the 5x5 grid HTML for architecture view"""
if masked_indices is None:
masked_indices = []
# Color palette for the matrix
colors = ["#f97316", "#3b82f6", "#06b6d4", "#1e3a5f", "#1e3a5f"]
grid_html = '''
Genes
Cells
'''
for i in range(25):
row_idx = i // 5
col_idx = i % 5
is_masked = i in masked_indices
is_col_active = step == "intra" and col_idx == 2 # Intra-cellular highlights columns
is_row_active = step == "inter" and row_idx == 2 # Inter-cellular highlights rows
# Determine cell color based on column
base_color = colors[col_idx]
if is_masked:
bg_color = "#e2e8f0"
content = '
'
# --- Architecture View Functions ---
def update_architecture_view(step):
"""Update the architecture view based on selected step"""
masked_indices = []
if step == "masking":
# Mask two consecutive full rows (genes across all cells)
start_row = random.randint(0, 3) # 0-3 so we can have 2 consecutive rows
masked_indices = list(range(start_row * 5, (start_row + 2) * 5))
grid_html = generate_grid_html(step, masked_indices)
label_html = get_step_label(step)
return grid_html, label_html
# --- Inference View Functions ---
def update_inference_display(prompt_cell_name, prompt_cond_name, query_cell_name):
"""Update the inference visualization when selections change"""
prompt_cell = next(c["id"] for c in CELL_TYPES if c["name"] == prompt_cell_name)
prompt_cond = next(c["id"] for c in CONDITIONS if c["name"] == prompt_cond_name)
query_cell = next(c["id"] for c in CELL_TYPES if c["name"] == query_cell_name)
return generate_inference_display(prompt_cell, prompt_cond, query_cell), prompt_cell, prompt_cond, query_cell
def run_inference(prompt_cell, prompt_cond, query_cell):
"""Run the inference prediction"""
prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond)
query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell)
prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell)
# Processing state
processing_html = f'''
🔄
Processing gene expression context...
Learning from {prompt_cell_data["name"]} patterns under {prompt_cond_data["name"]}
'''
yield processing_html, gr.update(visible=False), gr.update(visible=True)
time.sleep(0.5)
# Result state - show predicted gene counts per query cell
predicted_cells = generate_cell_array(query_cell, prompt_cond, 5, is_prompt=True)
result_html = f'''
PREDICTION COMPLETE
Predicted gene expression counts for {query_cell_data["name"]}
Zero-shot prediction: Using in-context learning from
{prompt_cell_data["name"]} response,
STACK predicts gene counts for each {query_cell_data["name"]}
under the same perturbation.
'''
yield result_html, gr.update(visible=True), gr.update(visible=False)
def generate_output_column(color):
"""Generate a vertical column of gene expression counts showing explicit values"""
gene_names = ["g₁", "g₂", "g₃", "g₄", "g₅"]
cells = []
for i, gene in enumerate(gene_names):
# Generate pseudo gene count value
count = random.randint(10, 500)
bar_width = min(count / 500 * 40, 40) # Scale to max 40px
opacity = 0.4 + (count / 500) * 0.5
cells.append(f'''
{gene}
{count}
''')
return '\n'.join(cells)
def reset_inference(prompt_cell, prompt_cond, query_cell):
"""Reset inference view to initial state"""
return generate_inference_display(prompt_cell, prompt_cond, query_cell), gr.update(visible=False), gr.update(visible=True)
# --- Main Gradio App ---
def create_app():
with gr.Blocks(title="STACK Model Visualization") as app:
# Header
gr.HTML('''
🧬
STACK
''')
with gr.Tabs() as tabs:
# --- ARCHITECTURE TAB ---
with gr.Tab("🏗️ Architecture"):
with gr.Row():
with gr.Column(scale=2):
grid_display = gr.HTML(generate_grid_html("idle"))
label_display = gr.HTML(get_step_label("idle"))
with gr.Column(scale=1):
gr.HTML('''
Learning Process
STACK learns from gene expression matrices where rows are gene modules and columns are cells.
💡 Key insight: By learning gene dependencies across the entire cell population, STACK can transfer knowledge from one cell type to another.
''')
# State for architecture view
arch_step = gr.State("idle")
def set_step(step_name):
grid, label = update_architecture_view(step_name)
return grid, label, step_name
intra_btn.click(lambda: set_step("intra"), outputs=[grid_display, label_display, arch_step])
inter_btn.click(lambda: set_step("inter"), outputs=[grid_display, label_display, arch_step])
masking_btn.click(lambda: set_step("masking"), outputs=[grid_display, label_display, arch_step])
# --- INFERENCE TAB ---
with gr.Tab("🔮 Inference"):
# Controls Row
with gr.Row():
with gr.Column(scale=1):
gr.HTML('''
🔴 Prompt Cells (Known Response)
''')
prompt_cell_radio = gr.Radio(
choices=[c["name"] for c in CELL_TYPES],
value=CELL_TYPES[0]["name"],
label="Cell Type",
container=True
)
prompt_cond_dropdown = gr.Dropdown(
choices=[c["name"] for c in CONDITIONS if c["id"] != "healthy"],
value=CONDITIONS[1]["name"],
label="Condition/Treatment",
container=True
)
with gr.Column(scale=1):
gr.HTML('''
🔵 Query Cells (To Predict)
''')
query_cell_radio = gr.Radio(
choices=[c["name"] for c in CELL_TYPES],
value=CELL_TYPES[1]["name"],
label="Cell Type",
container=True
)
gr.HTML('''