Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import random | |
| import time | |
| # --- Constants --- | |
| CELL_TYPES = [ | |
| {"id": "t-cell", "name": "T-Cell", "color": "#3b82f6", "bg": "#dbeafe"}, | |
| {"id": "b-cell", "name": "B-Cell", "color": "#6366f1", "bg": "#e0e7ff"}, | |
| {"id": "macro", "name": "Macrophage", "color": "#f97316", "bg": "#ffedd5"}, | |
| ] | |
| CONDITIONS = [ | |
| {"id": "healthy", "name": "Healthy", "color": "#3b82f6", "bg": "#dbeafe"}, | |
| {"id": "drug-a", "name": "Drug A", "color": "#ef4444", "bg": "#fee2e2"}, | |
| {"id": "viral", "name": "Viral Infection", "color": "#f97316", "bg": "#ffedd5"}, | |
| ] | |
| # --- Helper Functions for Cell Visuals --- | |
| def make_cell_svg(color, is_prompt=False, size=36): | |
| """Create an SVG cell representation like in the figure""" | |
| # For prompt cells (treated/different condition) - show irregular shape | |
| if is_prompt: | |
| return f''' | |
| <svg width="{size}" height="{size}" viewBox="0 0 40 40"> | |
| <ellipse cx="20" cy="20" rx="16" ry="14" fill="{color}" opacity="0.3"/> | |
| <ellipse cx="20" cy="20" rx="12" ry="10" fill="{color}" opacity="0.5"/> | |
| <circle cx="20" cy="20" r="6" fill="{color}"/> | |
| <circle cx="18" cy="18" r="2" fill="white" opacity="0.6"/> | |
| </svg> | |
| ''' | |
| else: | |
| # Query cells - regular circular shape | |
| return f''' | |
| <svg width="{size}" height="{size}" viewBox="0 0 40 40"> | |
| <circle cx="20" cy="20" r="16" fill="{color}" opacity="0.2" stroke="{color}" stroke-width="1.5" stroke-dasharray="3,2"/> | |
| <circle cx="20" cy="20" r="10" fill="{color}" opacity="0.4"/> | |
| <circle cx="20" cy="20" r="5" fill="{color}"/> | |
| <circle cx="18" cy="18" r="1.5" fill="white" opacity="0.6"/> | |
| </svg> | |
| ''' | |
| def generate_cell_array(cell_type, condition, count=6, is_prompt=False): | |
| """Generate an array of cells in a horizontal layout""" | |
| cell = next(c for c in CELL_TYPES if c["id"] == cell_type) | |
| cond = next(c for c in CONDITIONS if c["id"] == condition) | |
| # Use condition color for prompt cells, cell type color for query cells | |
| color = cond["color"] if is_prompt else cell["color"] | |
| bg_color = cond["bg"] if is_prompt else cell["bg"] | |
| cells_html = "" | |
| for i in range(count): | |
| cells_html += f''' | |
| <div style="display: flex; align-items: center; justify-content: center;"> | |
| {make_cell_svg(color, is_prompt, 36)} | |
| </div> | |
| ''' | |
| # Add ellipsis | |
| cells_html += ''' | |
| <div style="display: flex; align-items: center; justify-content: center; color: #94a3b8; font-weight: bold; letter-spacing: 2px;"> | |
| ··· | |
| </div> | |
| ''' | |
| return cells_html | |
| def generate_inference_display(prompt_cell, prompt_cond, query_cell, num_prompt=3, num_query=5): | |
| """Generate the full inference visualization with stacked arrays""" | |
| prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell) | |
| prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond) | |
| query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell) | |
| prompt_cells = generate_cell_array(prompt_cell, prompt_cond, num_prompt, is_prompt=True) | |
| query_cells = generate_cell_array(query_cell, "healthy", num_query, is_prompt=False) | |
| html = f''' | |
| <div style="background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 50%, #eff6ff 100%); padding: 30px; border-radius: 16px; font-family: system-ui, -apple-system, sans-serif;"> | |
| <!-- Title --> | |
| <div style="text-align: center; margin-bottom: 24px;"> | |
| <div style="font-size: 11px; font-weight: 600; color: #64748b; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 4px;">In-context Learning</div> | |
| <div style="font-size: 13px; color: #94a3b8;">Gene expression counts → Predicted states</div> | |
| </div> | |
| <!-- Main Container --> | |
| <div style="display: flex; align-items: center; gap: 20px; justify-content: center;"> | |
| <!-- Input Arrays Container --> | |
| <div style="display: flex; flex-direction: column; gap: 8px;"> | |
| <!-- Prompt Array --> | |
| <div style="display: flex; align-items: center; gap: 12px;"> | |
| <div style="background: {prompt_cond_data["bg"]}; border: 2px solid {prompt_cond_data["color"]}40; border-radius: 12px; padding: 10px 16px; display: flex; gap: 6px; align-items: center;"> | |
| {prompt_cells} | |
| </div> | |
| </div> | |
| <!-- Query Array --> | |
| <div style="display: flex; align-items: center; gap: 12px;"> | |
| <div style="background: {query_cell_data["bg"]}; border: 2px solid {query_cell_data["color"]}40; border-radius: 12px; padding: 10px 16px; display: flex; gap: 6px; align-items: center;"> | |
| {query_cells} | |
| </div> | |
| </div> | |
| </div> | |
| <!-- Arrow --> | |
| <div style="display: flex; flex-direction: column; align-items: center; gap: 4px;"> | |
| <svg width="40" height="24" viewBox="0 0 40 24"> | |
| <defs> | |
| <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> | |
| <polygon points="0 0, 10 3.5, 0 7" fill="#94a3b8"/> | |
| </marker> | |
| </defs> | |
| <line x1="0" y1="12" x2="30" y2="12" stroke="#94a3b8" stroke-width="2" marker-end="url(#arrowhead)"/> | |
| </svg> | |
| <div style="font-size: 9px; color: #94a3b8; text-transform: uppercase; letter-spacing: 0.05em;">STACK</div> | |
| </div> | |
| <!-- Gene Module Matrix Preview --> | |
| <div style="display: flex; flex-direction: column; align-items: center; gap: 8px;"> | |
| <div style="font-size: 9px; color: #64748b; text-transform: uppercase; letter-spacing: 0.05em;">Gene Modules × Cells</div> | |
| <div style="display: grid; grid-template-columns: repeat(5, 1fr); gap: 3px; padding: 8px; background: white; border-radius: 8px; border: 1px solid #e2e8f0;"> | |
| {generate_mini_matrix(prompt_cond_data["color"], query_cell_data["color"])} | |
| </div> | |
| </div> | |
| </div> | |
| <!-- Labels --> | |
| <div style="display: flex; justify-content: center; gap: 40px; margin-top: 20px;"> | |
| <div style="display: flex; align-items: center; gap: 8px;"> | |
| <div style="width: 12px; height: 12px; background: {prompt_cond_data["color"]}; border-radius: 50%; opacity: 0.7;"></div> | |
| <span style="font-size: 11px; color: #475569; font-weight: 500;">Prompt: {prompt_cell_data["name"]} + {prompt_cond_data["name"]}</span> | |
| </div> | |
| <div style="display: flex; align-items: center; gap: 8px;"> | |
| <div style="width: 12px; height: 12px; background: {query_cell_data["color"]}; border-radius: 50%; opacity: 0.7;"></div> | |
| <span style="font-size: 11px; color: #475569; font-weight: 500;">Query: {query_cell_data["name"]} (Healthy)</span> | |
| </div> | |
| </div> | |
| </div> | |
| ''' | |
| return html | |
| def generate_mini_matrix(prompt_color, query_color): | |
| """Generate a small matrix visualization showing gene modules × cells""" | |
| cells = [] | |
| colors = [prompt_color, prompt_color, prompt_color, query_color, query_color] | |
| for row in range(4): | |
| for col in range(5): | |
| opacity = 0.2 + random.random() * 0.6 | |
| color = colors[col] | |
| cells.append(f'<div style="width: 10px; height: 10px; background: {color}; opacity: {opacity:.1f}; border-radius: 2px;"></div>') | |
| return '\n'.join(cells) | |
| def generate_grid_html(step, masked_indices=None): | |
| """Generate the 5x5 grid HTML for architecture view""" | |
| if masked_indices is None: | |
| masked_indices = [] | |
| # Color palette for the matrix | |
| colors = ["#f97316", "#3b82f6", "#06b6d4", "#1e3a5f", "#1e3a5f"] | |
| grid_html = ''' | |
| <div style="display: flex; flex-direction: column; align-items: center; background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 100%); padding: 30px; border-radius: 16px;"> | |
| <div style="display: flex; align-items: stretch;"> | |
| <!-- Y-axis label (Genes) - Left side --> | |
| <div style="display: flex; align-items: center; justify-content: center; padding-right: 12px;"> | |
| <div style="writing-mode: vertical-rl; text-orientation: mixed; transform: rotate(180deg); font-size: 11px; font-weight: 600; color: #64748b; text-transform: uppercase; letter-spacing: 0.1em;"> | |
| Genes | |
| </div> | |
| </div> | |
| <!-- Grid Container with top label --> | |
| <div style="display: flex; flex-direction: column; align-items: center;"> | |
| <!-- X-axis label (Cells) - Top, centered --> | |
| <div style="font-size: 11px; font-weight: 600; color: #64748b; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 12px; text-align: center;"> | |
| Cells | |
| </div> | |
| <!-- Main Grid --> | |
| <div style="display: grid; grid-template-columns: repeat(5, 44px); gap: 6px; background: white; padding: 16px; border-radius: 12px; border: 1px solid #e2e8f0; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.05);"> | |
| ''' | |
| for i in range(25): | |
| row_idx = i // 5 | |
| col_idx = i % 5 | |
| is_masked = i in masked_indices | |
| is_col_active = step == "intra" and col_idx == 2 # Intra-cellular highlights columns | |
| is_row_active = step == "inter" and row_idx == 2 # Inter-cellular highlights rows | |
| # Determine cell color based on column | |
| base_color = colors[col_idx] | |
| if is_masked: | |
| bg_color = "#e2e8f0" | |
| content = '<div style="font-size: 14px;">🔄</div>' | |
| else: | |
| # Vary opacity based on position for visual interest | |
| opacity = 0.3 + (row_idx * 0.12) + (col_idx * 0.08) | |
| bg_color = base_color | |
| content = '' | |
| ring_style = "" | |
| if is_row_active: | |
| ring_style = "box-shadow: 0 0 0 3px #60a5fa; z-index: 10;" | |
| elif is_col_active: | |
| ring_style = "box-shadow: 0 0 0 3px #34d399; z-index: 10;" | |
| cell_opacity = "0.2" if is_masked else f"{0.3 + row_idx * 0.15}" | |
| grid_html += f''' | |
| <div style="width: 44px; height: 44px; background: {bg_color}; opacity: {cell_opacity}; border-radius: 6px; display: flex; align-items: center; justify-content: center; position: relative; transition: all 0.2s; {ring_style}"> | |
| {content} | |
| </div> | |
| ''' | |
| grid_html += ''' | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| ''' | |
| return grid_html | |
| def get_step_label(step): | |
| """Get the label for the current step""" | |
| labels = { | |
| "idle": '<div style="text-align: center; padding: 12px;"><span style="color: #94a3b8; font-size: 12px;">Select a learning step to visualize attention patterns</span></div>', | |
| "intra": '<div style="text-align: center; background: #eff6ff; color: #1e40af; padding: 12px 24px; border-radius: 24px; font-size: 13px; font-weight: 600; display: inline-block;">→ Intra-cellular: Learning gene dependencies within each cell</div>', | |
| "inter": '<div style="text-align: center; background: #ecfdf5; color: #047857; padding: 12px 24px; border-radius: 24px; font-size: 13px; font-weight: 600; display: inline-block;">↓ Inter-cellular: Learning context across cell population</div>', | |
| "masking": '<div style="text-align: center; background: #f1f5f9; color: #334155; padding: 12px 24px; border-radius: 24px; font-size: 13px; font-weight: 600; display: inline-block;">🔄 Pre-training: Masked gene expression reconstruction</div>', | |
| } | |
| return f'<div style="display: flex; justify-content: center; margin-top: 16px;">{labels.get(step, labels["idle"])}</div>' | |
| # --- Architecture View Functions --- | |
| def update_architecture_view(step): | |
| """Update the architecture view based on selected step""" | |
| masked_indices = [] | |
| if step == "masking": | |
| # Mask two consecutive full rows (genes across all cells) | |
| start_row = random.randint(0, 3) # 0-3 so we can have 2 consecutive rows | |
| masked_indices = list(range(start_row * 5, (start_row + 2) * 5)) | |
| grid_html = generate_grid_html(step, masked_indices) | |
| label_html = get_step_label(step) | |
| return grid_html, label_html | |
| # --- Inference View Functions --- | |
| def update_inference_display(prompt_cell_name, prompt_cond_name, query_cell_name): | |
| """Update the inference visualization when selections change""" | |
| prompt_cell = next(c["id"] for c in CELL_TYPES if c["name"] == prompt_cell_name) | |
| prompt_cond = next(c["id"] for c in CONDITIONS if c["name"] == prompt_cond_name) | |
| query_cell = next(c["id"] for c in CELL_TYPES if c["name"] == query_cell_name) | |
| return generate_inference_display(prompt_cell, prompt_cond, query_cell), prompt_cell, prompt_cond, query_cell | |
| def run_inference(prompt_cell, prompt_cond, query_cell): | |
| """Run the inference prediction""" | |
| prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond) | |
| query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell) | |
| prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell) | |
| # Processing state | |
| processing_html = f''' | |
| <div style="background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 50%, #eff6ff 100%); padding: 40px; border-radius: 16px; text-align: center;"> | |
| <div style="font-size: 40px; animation: spin 1s linear infinite;">🔄</div> | |
| <div style="margin-top: 16px; font-size: 12px; font-weight: 600; color: #6366f1; text-transform: uppercase; letter-spacing: 0.1em;"> | |
| Processing gene expression context... | |
| </div> | |
| <div style="margin-top: 8px; font-size: 11px; color: #94a3b8;"> | |
| Learning from {prompt_cell_data["name"]} patterns under {prompt_cond_data["name"]} | |
| </div> | |
| </div> | |
| ''' | |
| yield processing_html, gr.update(visible=False), gr.update(visible=True) | |
| time.sleep(0.5) | |
| # Result state - show predicted gene counts per query cell | |
| predicted_cells = generate_cell_array(query_cell, prompt_cond, 5, is_prompt=True) | |
| result_html = f''' | |
| <div style="background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 50%, #eff6ff 100%); padding: 30px; border-radius: 16px;"> | |
| <!-- Header --> | |
| <div style="text-align: center; margin-bottom: 24px;"> | |
| <div style="font-family: monospace; font-size: 10px; color: #64748b; letter-spacing: 0.15em; margin-bottom: 8px; text-transform: uppercase;">PREDICTION COMPLETE</div> | |
| <div style="font-size: 14px; font-weight: 600; color: #1e293b;"> | |
| Predicted gene expression counts for {query_cell_data["name"]} | |
| </div> | |
| <div style="font-size: 11px; color: #64748b; margin-top: 4px;"> | |
| under <span style="color: {prompt_cond_data["color"]}; font-weight: 600;">{prompt_cond_data["name"]}</span> condition | |
| </div> | |
| </div> | |
| <!-- Section label --> | |
| <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 16px; justify-content: center;"> | |
| <div style="height: 1px; width: 60px; background: linear-gradient(90deg, transparent, #cbd5e1);"></div> | |
| <div style="font-size: 10px; font-weight: 600; color: #475569; text-transform: uppercase; letter-spacing: 0.08em;">Predicted Gene Counts Per Query Cell</div> | |
| <div style="height: 1px; width: 60px; background: linear-gradient(90deg, #cbd5e1, transparent);"></div> | |
| </div> | |
| <!-- Output columns visualization - Gene counts per cell --> | |
| <div style="display: flex; justify-content: center; gap: 16px; margin-bottom: 20px;"> | |
| <div style="text-align: center;"> | |
| <div style="background: white; border: 2px solid {query_cell_data["color"]}40; border-radius: 10px; padding: 10px 14px; display: flex; flex-direction: column; gap: 2px; box-shadow: 0 2px 8px rgba(0,0,0,0.05);"> | |
| {generate_output_column(prompt_cond_data["color"])} | |
| </div> | |
| <div style="display: flex; align-items: center; justify-content: center; gap: 4px; margin-top: 6px;"> | |
| {make_cell_svg(query_cell_data["color"], False, 18)} | |
| <div style="font-size: 9px; color: #475569; font-weight: 500;">Cell 1</div> | |
| </div> | |
| </div> | |
| <div style="text-align: center;"> | |
| <div style="background: white; border: 2px solid {query_cell_data["color"]}40; border-radius: 10px; padding: 10px 14px; display: flex; flex-direction: column; gap: 2px; box-shadow: 0 2px 8px rgba(0,0,0,0.05);"> | |
| {generate_output_column(prompt_cond_data["color"])} | |
| </div> | |
| <div style="display: flex; align-items: center; justify-content: center; gap: 4px; margin-top: 6px;"> | |
| {make_cell_svg(query_cell_data["color"], False, 18)} | |
| <div style="font-size: 9px; color: #475569; font-weight: 500;">Cell 2</div> | |
| </div> | |
| </div> | |
| <div style="display: flex; align-items: center; color: #94a3b8; font-size: 16px; font-weight: bold; letter-spacing: 3px; padding-bottom: 24px;">···</div> | |
| <div style="text-align: center;"> | |
| <div style="background: white; border: 2px solid {query_cell_data["color"]}40; border-radius: 10px; padding: 10px 14px; display: flex; flex-direction: column; gap: 2px; box-shadow: 0 2px 8px rgba(0,0,0,0.05);"> | |
| {generate_output_column(prompt_cond_data["color"])} | |
| </div> | |
| <div style="display: flex; align-items: center; justify-content: center; gap: 4px; margin-top: 6px;"> | |
| {make_cell_svg(query_cell_data["color"], False, 18)} | |
| <div style="font-size: 9px; color: #475569; font-weight: 500;">Cell n</div> | |
| </div> | |
| </div> | |
| </div> | |
| <!-- Description --> | |
| <div style="text-align: center; font-size: 11px; color: #64748b; max-width: 360px; margin: 0 auto; line-height: 1.6;"> | |
| <strong style="color: #475569;">Zero-shot prediction:</strong> Using in-context learning from | |
| <span style="color: {prompt_cond_data["color"]}; font-weight: 500;">{prompt_cell_data["name"]}</span> response, | |
| STACK predicts gene counts for each <span style="color: {query_cell_data["color"]}; font-weight: 500;">{query_cell_data["name"]}</span> | |
| under the same perturbation. | |
| </div> | |
| </div> | |
| ''' | |
| yield result_html, gr.update(visible=True), gr.update(visible=False) | |
| def generate_output_column(color): | |
| """Generate a vertical column of gene expression counts showing explicit values""" | |
| gene_names = ["g₁", "g₂", "g₃", "g₄", "g₅"] | |
| cells = [] | |
| for i, gene in enumerate(gene_names): | |
| # Generate pseudo gene count value | |
| count = random.randint(10, 500) | |
| bar_width = min(count / 500 * 40, 40) # Scale to max 40px | |
| opacity = 0.4 + (count / 500) * 0.5 | |
| cells.append(f''' | |
| <div style="display: flex; align-items: center; gap: 4px; height: 18px;"> | |
| <div style="font-size: 8px; color: #64748b; width: 14px; text-align: right;">{gene}</div> | |
| <div style="width: 44px; height: 12px; background: #e2e8f0; border-radius: 2px; overflow: hidden; position: relative;"> | |
| <div style="width: {bar_width}px; height: 100%; background: {color}; opacity: {opacity:.1f}; border-radius: 2px;"></div> | |
| </div> | |
| <div style="font-size: 8px; color: #475569; font-family: monospace; width: 22px;">{count}</div> | |
| </div> | |
| ''') | |
| return '\n'.join(cells) | |
| def reset_inference(prompt_cell, prompt_cond, query_cell): | |
| """Reset inference view to initial state""" | |
| return generate_inference_display(prompt_cell, prompt_cond, query_cell), gr.update(visible=False), gr.update(visible=True) | |
| # --- Main Gradio App --- | |
| def create_app(): | |
| with gr.Blocks(title="STACK Model Visualization") as app: | |
| # Header | |
| gr.HTML(''' | |
| <div style="background: linear-gradient(90deg, #4f46e5 0%, #7c3aed 100%); padding: 20px 24px; display: flex; justify-content: space-between; align-items: center; border-radius: 12px 12px 0 0;"> | |
| <div style="display: flex; align-items: center; gap: 12px;"> | |
| <div style="background: white; padding: 10px; border-radius: 8px; display: flex; align-items: center; justify-content: center;"> | |
| <span style="font-size: 24px;">🧬</span> | |
| </div> | |
| <div> | |
| <div style="font-weight: bold; color: white; font-size: 18px;">STACK</div> | |
| </div> | |
| </div> | |
| </div> | |
| ''') | |
| with gr.Tabs() as tabs: | |
| # --- ARCHITECTURE TAB --- | |
| with gr.Tab("🏗️ Architecture"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| grid_display = gr.HTML(generate_grid_html("idle")) | |
| label_display = gr.HTML(get_step_label("idle")) | |
| with gr.Column(scale=1): | |
| gr.HTML(''' | |
| <div style="padding: 16px; background: #f8fafc; border-radius: 12px; border: 1px solid #e2e8f0;"> | |
| <h3 style="margin: 0 0 16px 0; font-size: 14px; font-weight: 600; color: #1e293b;">Learning Process</h3> | |
| <p style="font-size: 11px; color: #64748b; margin: 0 0 16px 0; line-height: 1.5;"> | |
| STACK learns from gene expression matrices where rows are gene modules and columns are cells. | |
| </p> | |
| </div> | |
| ''') | |
| intra_btn = gr.Button("→ 1. Intra-cellular Attention", size="lg", variant="secondary") | |
| inter_btn = gr.Button("↓ 2. Inter-cellular Attention", size="lg", variant="secondary") | |
| masking_btn = gr.Button("🔄 3. Masked Pre-training", size="lg", variant="secondary") | |
| gr.HTML(''' | |
| <div style="margin-top: 16px; padding: 12px; background: #fffbeb; border-radius: 8px; border: 1px solid #fde68a; font-size: 11px; color: #92400e; line-height: 1.5;"> | |
| 💡 <strong style="color: #92400e;">Key insight:</strong> By learning gene dependencies across the entire cell population, STACK can transfer knowledge from one cell type to another. | |
| </div> | |
| ''') | |
| # State for architecture view | |
| arch_step = gr.State("idle") | |
| def set_step(step_name): | |
| grid, label = update_architecture_view(step_name) | |
| return grid, label, step_name | |
| intra_btn.click(lambda: set_step("intra"), outputs=[grid_display, label_display, arch_step]) | |
| inter_btn.click(lambda: set_step("inter"), outputs=[grid_display, label_display, arch_step]) | |
| masking_btn.click(lambda: set_step("masking"), outputs=[grid_display, label_display, arch_step]) | |
| # --- INFERENCE TAB --- | |
| with gr.Tab("🔮 Inference"): | |
| # Controls Row | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML(''' | |
| <div style="font-size: 12px; font-weight: 600; color: #dc2626; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 8px;"> | |
| 🔴 Prompt Cells (Known Response) | |
| </div> | |
| ''') | |
| prompt_cell_radio = gr.Radio( | |
| choices=[c["name"] for c in CELL_TYPES], | |
| value=CELL_TYPES[0]["name"], | |
| label="Cell Type", | |
| container=True | |
| ) | |
| prompt_cond_dropdown = gr.Dropdown( | |
| choices=[c["name"] for c in CONDITIONS if c["id"] != "healthy"], | |
| value=CONDITIONS[1]["name"], | |
| label="Condition/Treatment", | |
| container=True | |
| ) | |
| with gr.Column(scale=1): | |
| gr.HTML(''' | |
| <div style="font-size: 12px; font-weight: 600; color: #2563eb; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 8px;"> | |
| 🔵 Query Cells (To Predict) | |
| </div> | |
| ''') | |
| query_cell_radio = gr.Radio( | |
| choices=[c["name"] for c in CELL_TYPES], | |
| value=CELL_TYPES[1]["name"], | |
| label="Cell Type", | |
| container=True | |
| ) | |
| gr.HTML(''' | |
| <div style="padding: 10px 14px; background: #f1f5f9; border-radius: 8px; font-size: 12px; color: #64748b; margin-top: 8px;"> | |
| <strong>Initial state:</strong> Healthy (baseline gene expression) | |
| </div> | |
| ''') | |
| # Visualization Display | |
| inference_display = gr.HTML( | |
| generate_inference_display(CELL_TYPES[0]["id"], CONDITIONS[1]["id"], CELL_TYPES[1]["id"]) | |
| ) | |
| # Action Buttons | |
| with gr.Row(): | |
| run_btn = gr.Button("▶️ Run Zero-Shot Prediction", variant="primary", size="lg") | |
| reset_btn = gr.Button("↩️ Reset", size="lg", visible=False) | |
| # States for inference | |
| prompt_cell_state = gr.State(CELL_TYPES[0]["id"]) | |
| prompt_cond_state = gr.State(CONDITIONS[1]["id"]) | |
| query_cell_state = gr.State(CELL_TYPES[1]["id"]) | |
| # Update display when selections change | |
| prompt_cell_radio.change( | |
| update_inference_display, | |
| inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio], | |
| outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state] | |
| ) | |
| prompt_cond_dropdown.change( | |
| update_inference_display, | |
| inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio], | |
| outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state] | |
| ) | |
| query_cell_radio.change( | |
| update_inference_display, | |
| inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio], | |
| outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state] | |
| ) | |
| # Run prediction | |
| run_btn.click( | |
| run_inference, | |
| inputs=[prompt_cell_state, prompt_cond_state, query_cell_state], | |
| outputs=[inference_display, reset_btn, run_btn] | |
| ) | |
| # Reset | |
| reset_btn.click( | |
| reset_inference, | |
| inputs=[prompt_cell_state, prompt_cond_state, query_cell_state], | |
| outputs=[inference_display, reset_btn, run_btn] | |
| ) | |
| return app | |
| def custom_css(): | |
| return """ | |
| @keyframes spin { | |
| from { transform: rotate(0deg); } | |
| to { transform: rotate(360deg); } | |
| } | |
| .gradio-container { | |
| max-width: 1000px !important; | |
| margin: auto; | |
| font-family: system-ui, -apple-system, sans-serif; | |
| } | |
| button { | |
| border-radius: 10px !important; | |
| font-size: 13px !important; | |
| font-weight: 600 !important; | |
| transition: all 0.2s !important; | |
| } | |
| .tabs button { | |
| font-size: 13px !important; | |
| font-weight: 600 !important; | |
| padding: 12px 20px !important; | |
| } | |
| .tabs button[aria-selected="true"] { | |
| background: linear-gradient(90deg, #4f46e5 0%, #7c3aed 100%) !important; | |
| color: white !important; | |
| } | |
| input[type="radio"] + label { | |
| font-size: 13px !important; | |
| } | |
| """ | |
| # --- Launch App --- | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.launch(css=custom_css()) | |