ludocomito commited on
Commit
cb61c3f
·
1 Parent(s): 4e714af
Files changed (3) hide show
  1. README.md +59 -1
  2. app.py +551 -0
  3. requirements.txt +1 -0
README.md CHANGED
@@ -9,4 +9,62 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # STACK Model Visualization
13
+
14
+ An interactive Gradio-based visualization of the STACK (Structured and Contextualized Knowledge) model architecture and inference capabilities.
15
+
16
+ ## Features
17
+
18
+ ### Architecture View
19
+ - **Visual Grid Representation**: 5x5 grid showing cells and gene modules
20
+ - **Interactive Steps**:
21
+ - **Intra-cellular**: Visualize gene dependencies within cells (row-wise attention)
22
+ - **Inter-cellular**: Show population context across cells (column-wise attention)
23
+ - **Pre-training**: Demonstrate masked reconstruction for model training
24
+
25
+ ### Inference View
26
+ - **Context Configuration**: Select cell type and condition for the prompt
27
+ - **Target Configuration**: Choose query cell type (always healthy initial state)
28
+ - **Zero-shot Prediction**: Run predictions using context from one cell type to predict another
29
+ - **Visual Results**: See predicted cell states with animations
30
+
31
+ ## Installation
32
+
33
+ ```bash
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ## Usage
38
+
39
+ Run the application:
40
+
41
+ ```bash
42
+ python app.py
43
+ ```
44
+
45
+ The app will launch in your browser at `http://localhost:7860`
46
+
47
+ ## How It Works
48
+
49
+ 1. **Architecture Tab**: Explore how STACK learns dependencies:
50
+ - Click "Intra-cellular" to see within-cell attention patterns
51
+ - Click "Inter-cellular" to see across-cell attention patterns
52
+ - Click "Pre-training" to see masked reconstruction in action
53
+
54
+ 2. **Inference Tab**: Test zero-shot predictions:
55
+ - Configure the context (prompt) with a cell type and condition
56
+ - Configure the target (query) with a different cell type
57
+ - Click "Run Prediction" to see the predicted cell state
58
+ - The model uses patterns from the prompt cells to predict query cell behavior
59
+
60
+ ## Cell Types
61
+
62
+ - 🔵 **T-Cell**: T lymphocytes
63
+ - 🔷 **B-Cell**: B lymphocytes
64
+ - 🟢 **Macro**: Macrophages
65
+
66
+ ## Conditions
67
+
68
+ - **Healthy**: Normal cell state
69
+ - **Drug A**: Drug-treated state (indicated by red marker)
70
+ - **Viral**: Virus-infected state (indicated by dashed border)
app.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import time
4
+
5
+ # --- Constants ---
6
+
7
+ CELL_TYPES = [
8
+ {"id": "t-cell", "name": "T-Cell", "color": "#3b82f6", "bg": "#dbeafe"},
9
+ {"id": "b-cell", "name": "B-Cell", "color": "#6366f1", "bg": "#e0e7ff"},
10
+ {"id": "macro", "name": "Macrophage", "color": "#f97316", "bg": "#ffedd5"},
11
+ ]
12
+
13
+ CONDITIONS = [
14
+ {"id": "healthy", "name": "Healthy", "color": "#3b82f6", "bg": "#dbeafe"},
15
+ {"id": "drug-a", "name": "Drug A", "color": "#ef4444", "bg": "#fee2e2"},
16
+ {"id": "viral", "name": "Viral Infection", "color": "#f97316", "bg": "#ffedd5"},
17
+ ]
18
+
19
+ # --- Helper Functions for Cell Visuals ---
20
+
21
+ def make_cell_svg(color, is_prompt=False, size=36):
22
+ """Create an SVG cell representation like in the figure"""
23
+ # For prompt cells (treated/different condition) - show irregular shape
24
+ if is_prompt:
25
+ return f'''
26
+ <svg width="{size}" height="{size}" viewBox="0 0 40 40">
27
+ <ellipse cx="20" cy="20" rx="16" ry="14" fill="{color}" opacity="0.3"/>
28
+ <ellipse cx="20" cy="20" rx="12" ry="10" fill="{color}" opacity="0.5"/>
29
+ <circle cx="20" cy="20" r="6" fill="{color}"/>
30
+ <circle cx="18" cy="18" r="2" fill="white" opacity="0.6"/>
31
+ </svg>
32
+ '''
33
+ else:
34
+ # Query cells - regular circular shape
35
+ return f'''
36
+ <svg width="{size}" height="{size}" viewBox="0 0 40 40">
37
+ <circle cx="20" cy="20" r="16" fill="{color}" opacity="0.2" stroke="{color}" stroke-width="1.5" stroke-dasharray="3,2"/>
38
+ <circle cx="20" cy="20" r="10" fill="{color}" opacity="0.4"/>
39
+ <circle cx="20" cy="20" r="5" fill="{color}"/>
40
+ <circle cx="18" cy="18" r="1.5" fill="white" opacity="0.6"/>
41
+ </svg>
42
+ '''
43
+
44
+ def generate_cell_array(cell_type, condition, count=6, is_prompt=False):
45
+ """Generate an array of cells in a horizontal layout"""
46
+ cell = next(c for c in CELL_TYPES if c["id"] == cell_type)
47
+ cond = next(c for c in CONDITIONS if c["id"] == condition)
48
+
49
+ # Use condition color for prompt cells, cell type color for query cells
50
+ color = cond["color"] if is_prompt else cell["color"]
51
+ bg_color = cond["bg"] if is_prompt else cell["bg"]
52
+
53
+ cells_html = ""
54
+ for i in range(count):
55
+ cells_html += f'''
56
+ <div style="display: flex; align-items: center; justify-content: center;">
57
+ {make_cell_svg(color, is_prompt, 36)}
58
+ </div>
59
+ '''
60
+
61
+ # Add ellipsis
62
+ cells_html += '''
63
+ <div style="display: flex; align-items: center; justify-content: center; color: #94a3b8; font-weight: bold; letter-spacing: 2px;">
64
+ ···
65
+ </div>
66
+ '''
67
+
68
+ return cells_html
69
+
70
+ def generate_inference_display(prompt_cell, prompt_cond, query_cell, num_prompt=3, num_query=5):
71
+ """Generate the full inference visualization with stacked arrays"""
72
+ prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell)
73
+ prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond)
74
+ query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell)
75
+
76
+ prompt_cells = generate_cell_array(prompt_cell, prompt_cond, num_prompt, is_prompt=True)
77
+ query_cells = generate_cell_array(query_cell, "healthy", num_query, is_prompt=False)
78
+
79
+ html = f'''
80
+ <div style="background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 50%, #eff6ff 100%); padding: 30px; border-radius: 16px; font-family: system-ui, -apple-system, sans-serif;">
81
+
82
+ <!-- Title -->
83
+ <div style="text-align: center; margin-bottom: 24px;">
84
+ <div style="font-size: 11px; font-weight: 600; color: #64748b; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 4px;">In-context Learning</div>
85
+ <div style="font-size: 13px; color: #94a3b8;">Gene expression counts → Predicted states</div>
86
+ </div>
87
+
88
+ <!-- Main Container -->
89
+ <div style="display: flex; align-items: center; gap: 20px; justify-content: center;">
90
+
91
+ <!-- Input Arrays Container -->
92
+ <div style="display: flex; flex-direction: column; gap: 8px;">
93
+
94
+ <!-- Prompt Array -->
95
+ <div style="display: flex; align-items: center; gap: 12px;">
96
+ <div style="background: {prompt_cond_data["bg"]}; border: 2px solid {prompt_cond_data["color"]}40; border-radius: 12px; padding: 10px 16px; display: flex; gap: 6px; align-items: center;">
97
+ {prompt_cells}
98
+ </div>
99
+ </div>
100
+
101
+ <!-- Query Array -->
102
+ <div style="display: flex; align-items: center; gap: 12px;">
103
+ <div style="background: {query_cell_data["bg"]}; border: 2px solid {query_cell_data["color"]}40; border-radius: 12px; padding: 10px 16px; display: flex; gap: 6px; align-items: center;">
104
+ {query_cells}
105
+ </div>
106
+ </div>
107
+
108
+ </div>
109
+
110
+ <!-- Arrow -->
111
+ <div style="display: flex; flex-direction: column; align-items: center; gap: 4px;">
112
+ <svg width="40" height="24" viewBox="0 0 40 24">
113
+ <defs>
114
+ <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
115
+ <polygon points="0 0, 10 3.5, 0 7" fill="#94a3b8"/>
116
+ </marker>
117
+ </defs>
118
+ <line x1="0" y1="12" x2="30" y2="12" stroke="#94a3b8" stroke-width="2" marker-end="url(#arrowhead)"/>
119
+ </svg>
120
+ <div style="font-size: 9px; color: #94a3b8; text-transform: uppercase; letter-spacing: 0.05em;">STACK</div>
121
+ </div>
122
+
123
+ <!-- Gene Module Matrix Preview -->
124
+ <div style="display: flex; flex-direction: column; align-items: center; gap: 8px;">
125
+ <div style="font-size: 9px; color: #64748b; text-transform: uppercase; letter-spacing: 0.05em;">Gene Modules × Cells</div>
126
+ <div style="display: grid; grid-template-columns: repeat(5, 1fr); gap: 3px; padding: 8px; background: white; border-radius: 8px; border: 1px solid #e2e8f0;">
127
+ {generate_mini_matrix(prompt_cond_data["color"], query_cell_data["color"])}
128
+ </div>
129
+ </div>
130
+
131
+ </div>
132
+
133
+ <!-- Labels -->
134
+ <div style="display: flex; justify-content: center; gap: 40px; margin-top: 20px;">
135
+ <div style="display: flex; align-items: center; gap: 8px;">
136
+ <div style="width: 12px; height: 12px; background: {prompt_cond_data["color"]}; border-radius: 50%; opacity: 0.7;"></div>
137
+ <span style="font-size: 11px; color: #475569; font-weight: 500;">Prompt: {prompt_cell_data["name"]} + {prompt_cond_data["name"]}</span>
138
+ </div>
139
+ <div style="display: flex; align-items: center; gap: 8px;">
140
+ <div style="width: 12px; height: 12px; background: {query_cell_data["color"]}; border-radius: 50%; opacity: 0.7;"></div>
141
+ <span style="font-size: 11px; color: #475569; font-weight: 500;">Query: {query_cell_data["name"]} (Healthy)</span>
142
+ </div>
143
+ </div>
144
+
145
+ </div>
146
+ '''
147
+ return html
148
+
149
+ def generate_mini_matrix(prompt_color, query_color):
150
+ """Generate a small matrix visualization showing gene modules × cells"""
151
+ cells = []
152
+ colors = [prompt_color, prompt_color, prompt_color, query_color, query_color]
153
+
154
+ for row in range(4):
155
+ for col in range(5):
156
+ opacity = 0.2 + random.random() * 0.6
157
+ color = colors[col]
158
+ cells.append(f'<div style="width: 10px; height: 10px; background: {color}; opacity: {opacity:.1f}; border-radius: 2px;"></div>')
159
+
160
+ return '\n'.join(cells)
161
+
162
+ def generate_grid_html(step, masked_indices=None):
163
+ """Generate the 5x5 grid HTML for architecture view"""
164
+ if masked_indices is None:
165
+ masked_indices = []
166
+
167
+ # Color palette for the matrix
168
+ colors = ["#f97316", "#3b82f6", "#06b6d4", "#1e3a5f", "#1e3a5f"]
169
+
170
+ grid_html = '''
171
+ <div style="display: flex; flex-direction: column; align-items: center; background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 100%); padding: 30px; border-radius: 16px;">
172
+
173
+ <div style="display: flex; align-items: stretch;">
174
+
175
+ <!-- Y-axis label (Genes) - Left side -->
176
+ <div style="display: flex; align-items: center; justify-content: center; padding-right: 12px;">
177
+ <div style="writing-mode: vertical-rl; text-orientation: mixed; transform: rotate(180deg); font-size: 11px; font-weight: 600; color: #64748b; text-transform: uppercase; letter-spacing: 0.1em;">
178
+ Genes
179
+ </div>
180
+ </div>
181
+
182
+ <!-- Grid Container with top label -->
183
+ <div style="display: flex; flex-direction: column; align-items: center;">
184
+
185
+ <!-- X-axis label (Cells) - Top, centered -->
186
+ <div style="font-size: 11px; font-weight: 600; color: #64748b; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 12px; text-align: center;">
187
+ Cells
188
+ </div>
189
+
190
+ <!-- Main Grid -->
191
+ <div style="display: grid; grid-template-columns: repeat(5, 44px); gap: 6px; background: white; padding: 16px; border-radius: 12px; border: 1px solid #e2e8f0; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.05);">
192
+ '''
193
+
194
+ for i in range(25):
195
+ row_idx = i // 5
196
+ col_idx = i % 5
197
+ is_masked = i in masked_indices
198
+ is_row_active = step == "intra" and row_idx == 2
199
+ is_col_active = step == "inter" and col_idx == 2
200
+
201
+ # Determine cell color based on column
202
+ base_color = colors[col_idx]
203
+
204
+ if is_masked:
205
+ bg_color = "#e2e8f0"
206
+ content = '<div style="font-size: 14px;">🔄</div>'
207
+ else:
208
+ # Vary opacity based on position for visual interest
209
+ opacity = 0.3 + (row_idx * 0.12) + (col_idx * 0.08)
210
+ bg_color = base_color
211
+ content = ''
212
+
213
+ ring_style = ""
214
+ if is_row_active:
215
+ ring_style = "box-shadow: 0 0 0 3px #60a5fa; z-index: 10;"
216
+ elif is_col_active:
217
+ ring_style = "box-shadow: 0 0 0 3px #34d399; z-index: 10;"
218
+
219
+ cell_opacity = "0.2" if is_masked else f"{0.3 + row_idx * 0.15}"
220
+
221
+ grid_html += f'''
222
+ <div style="width: 44px; height: 44px; background: {bg_color}; opacity: {cell_opacity}; border-radius: 6px; display: flex; align-items: center; justify-content: center; position: relative; transition: all 0.2s; {ring_style}">
223
+ {content}
224
+ </div>
225
+ '''
226
+
227
+ grid_html += '''
228
+ </div>
229
+ </div>
230
+ </div>
231
+ </div>
232
+ '''
233
+
234
+ return grid_html
235
+
236
+ def get_step_label(step):
237
+ """Get the label for the current step"""
238
+ labels = {
239
+ "idle": '<div style="text-align: center; padding: 12px;"><span style="color: #94a3b8; font-size: 12px;">Select a learning step to visualize attention patterns</span></div>',
240
+ "intra": '<div style="text-align: center; background: #eff6ff; color: #1e40af; padding: 12px 24px; border-radius: 24px; font-size: 13px; font-weight: 600; display: inline-block;">→ Intra-cellular: Learning gene dependencies within each cell</div>',
241
+ "inter": '<div style="text-align: center; background: #ecfdf5; color: #047857; padding: 12px 24px; border-radius: 24px; font-size: 13px; font-weight: 600; display: inline-block;">↓ Inter-cellular: Learning context across cell population</div>',
242
+ "masking": '<div style="text-align: center; background: #f1f5f9; color: #334155; padding: 12px 24px; border-radius: 24px; font-size: 13px; font-weight: 600; display: inline-block;">🔄 Pre-training: Masked gene expression reconstruction</div>',
243
+ }
244
+ return f'<div style="display: flex; justify-content: center; margin-top: 16px;">{labels.get(step, labels["idle"])}</div>'
245
+
246
+ # --- Architecture View Functions ---
247
+
248
+ def update_architecture_view(step):
249
+ """Update the architecture view based on selected step"""
250
+ masked_indices = []
251
+ if step == "masking":
252
+ # Mask two consecutive full rows (genes across all cells)
253
+ start_row = random.randint(0, 3) # 0-3 so we can have 2 consecutive rows
254
+ masked_indices = list(range(start_row * 5, (start_row + 2) * 5))
255
+
256
+ grid_html = generate_grid_html(step, masked_indices)
257
+ label_html = get_step_label(step)
258
+
259
+ return grid_html, label_html
260
+
261
+ # --- Inference View Functions ---
262
+
263
+ def update_inference_display(prompt_cell_name, prompt_cond_name, query_cell_name):
264
+ """Update the inference visualization when selections change"""
265
+ prompt_cell = next(c["id"] for c in CELL_TYPES if c["name"] == prompt_cell_name)
266
+ prompt_cond = next(c["id"] for c in CONDITIONS if c["name"] == prompt_cond_name)
267
+ query_cell = next(c["id"] for c in CELL_TYPES if c["name"] == query_cell_name)
268
+
269
+ return generate_inference_display(prompt_cell, prompt_cond, query_cell), prompt_cell, prompt_cond, query_cell
270
+
271
+ def run_inference(prompt_cell, prompt_cond, query_cell):
272
+ """Run the inference prediction"""
273
+ prompt_cond_data = next(c for c in CONDITIONS if c["id"] == prompt_cond)
274
+ query_cell_data = next(c for c in CELL_TYPES if c["id"] == query_cell)
275
+ prompt_cell_data = next(c for c in CELL_TYPES if c["id"] == prompt_cell)
276
+
277
+ # Processing state
278
+ processing_html = f'''
279
+ <div style="background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 50%, #eff6ff 100%); padding: 40px; border-radius: 16px; text-align: center;">
280
+ <div style="font-size: 40px; animation: spin 1s linear infinite;">🔄</div>
281
+ <div style="margin-top: 16px; font-size: 12px; font-weight: 600; color: #6366f1; text-transform: uppercase; letter-spacing: 0.1em;">
282
+ Processing gene expression context...
283
+ </div>
284
+ <div style="margin-top: 8px; font-size: 11px; color: #94a3b8;">
285
+ Learning from {prompt_cell_data["name"]} patterns under {prompt_cond_data["name"]}
286
+ </div>
287
+ </div>
288
+ '''
289
+ yield processing_html, gr.update(visible=False), gr.update(visible=True)
290
+
291
+ time.sleep(0.5)
292
+
293
+ # Result state - show predicted cells with the new condition applied
294
+ predicted_cells = generate_cell_array(query_cell, prompt_cond, 5, is_prompt=True)
295
+
296
+ result_html = f'''
297
+ <div style="background: linear-gradient(135deg, #fdf2f8 0%, #f8fafc 50%, #eff6ff 100%); padding: 30px; border-radius: 16px;">
298
+
299
+ <!-- Header -->
300
+ <div style="text-align: center; margin-bottom: 24px;">
301
+ <div style="font-family: monospace; font-size: 10px; color: #64748b; letter-spacing: 0.15em; margin-bottom: 8px; text-transform: uppercase;">PREDICTION COMPLETE</div>
302
+ <div style="font-size: 14px; font-weight: 600; color: #1e293b;">
303
+ {query_cell_data["name"]} gene expression under {prompt_cond_data["name"]}
304
+ </div>
305
+ </div>
306
+
307
+ <!-- Predicted Cells -->
308
+ <div style="display: flex; justify-content: center; margin-bottom: 24px;">
309
+ <div style="background: {prompt_cond_data["bg"]}20; border: 2px solid {prompt_cond_data["color"]}60; border-radius: 12px; padding: 16px 24px; display: flex; gap: 8px; align-items: center;">
310
+ {predicted_cells}
311
+ </div>
312
+ </div>
313
+
314
+ <!-- Output columns visualization -->
315
+ <div style="display: flex; justify-content: center; gap: 12px; margin-bottom: 20px;">
316
+ <div style="text-align: center;">
317
+ <div style="background: rgba(99, 102, 241, 0.2); border: 1px solid rgba(99, 102, 241, 0.4); border-radius: 8px; padding: 8px 12px; display: flex; flex-direction: column; gap: 4px;">
318
+ {generate_output_column(prompt_cond_data["color"])}
319
+ </div>
320
+ <div style="font-size: 10px; color: #64748b; margin-top: 4px;">c₁</div>
321
+ </div>
322
+ <div style="text-align: center;">
323
+ <div style="background: rgba(99, 102, 241, 0.2); border: 1px solid rgba(99, 102, 241, 0.4); border-radius: 8px; padding: 8px 12px; display: flex; flex-direction: column; gap: 4px;">
324
+ {generate_output_column(prompt_cond_data["color"])}
325
+ </div>
326
+ <div style="font-size: 10px; color: #64748b; margin-top: 4px;">c₂</div>
327
+ </div>
328
+ <div style="display: flex; align-items: center; color: #64748b; font-size: 14px; letter-spacing: 3px;">···</div>
329
+ <div style="text-align: center;">
330
+ <div style="background: rgba(99, 102, 241, 0.2); border: 1px solid rgba(99, 102, 241, 0.4); border-radius: 8px; padding: 8px 12px; display: flex; flex-direction: column; gap: 4px;">
331
+ {generate_output_column(prompt_cond_data["color"])}
332
+ </div>
333
+ <div style="font-size: 10px; color: #64748b; margin-top: 4px;">cₙ</div>
334
+ </div>
335
+ </div>
336
+
337
+ <!-- Description -->
338
+ <div style="text-align: center; font-size: 11px; color: #64748b; max-width: 300px; margin: 0 auto; line-height: 1.5;">
339
+ Zero-shot prediction of gene expression counts using in-context learning from {prompt_cell_data["name"]} response to {prompt_cond_data["name"]}.
340
+ </div>
341
+
342
+ </div>
343
+ '''
344
+
345
+ yield result_html, gr.update(visible=True), gr.update(visible=False)
346
+
347
+ def generate_output_column(color):
348
+ """Generate a vertical column of gene expression values"""
349
+ cells = []
350
+ for _ in range(5):
351
+ opacity = 0.2 + random.random() * 0.6
352
+ cells.append(f'<div style="width: 16px; height: 16px; background: {color}; opacity: {opacity:.1f}; border-radius: 3px;"></div>')
353
+ return '\n'.join(cells)
354
+
355
+ def reset_inference(prompt_cell, prompt_cond, query_cell):
356
+ """Reset inference view to initial state"""
357
+ return generate_inference_display(prompt_cell, prompt_cond, query_cell), gr.update(visible=False), gr.update(visible=True)
358
+
359
+ # --- Main Gradio App ---
360
+
361
+ def create_app():
362
+ with gr.Blocks(title="STACK Model Visualization") as app:
363
+
364
+ # Header
365
+ gr.HTML('''
366
+ <div style="background: linear-gradient(90deg, #4f46e5 0%, #7c3aed 100%); padding: 20px 24px; display: flex; justify-content: space-between; align-items: center; border-radius: 12px 12px 0 0;">
367
+ <div style="display: flex; align-items: center; gap: 12px;">
368
+ <div style="background: white; padding: 10px; border-radius: 8px; display: flex; align-items: center; justify-content: center;">
369
+ <span style="font-size: 24px;">🧬</span>
370
+ </div>
371
+ <div>
372
+ <div style="font-weight: bold; color: white; font-size: 18px;">STACK</div>
373
+ <div style="font-size: 11px; color: rgba(255,255,255,0.7);">Single-cell Transcriptomic Analysis with Contextual Knowledge</div>
374
+ </div>
375
+ </div>
376
+ </div>
377
+ ''')
378
+
379
+ with gr.Tabs() as tabs:
380
+
381
+ # --- ARCHITECTURE TAB ---
382
+ with gr.Tab("🏗️ Architecture"):
383
+ with gr.Row():
384
+ with gr.Column(scale=2):
385
+ grid_display = gr.HTML(generate_grid_html("idle"))
386
+ label_display = gr.HTML(get_step_label("idle"))
387
+
388
+ with gr.Column(scale=1):
389
+ gr.HTML('''
390
+ <div style="padding: 16px; background: #f8fafc; border-radius: 12px; border: 1px solid #e2e8f0;">
391
+ <h3 style="margin: 0 0 16px 0; font-size: 14px; font-weight: 600; color: #1e293b;">Learning Process</h3>
392
+ <p style="font-size: 11px; color: #64748b; margin: 0 0 16px 0; line-height: 1.5;">
393
+ STACK learns from gene expression matrices where rows are gene modules and columns are cells.
394
+ </p>
395
+ </div>
396
+ ''')
397
+
398
+ intra_btn = gr.Button("→ 1. Intra-cellular Attention", size="lg", variant="secondary")
399
+ inter_btn = gr.Button("↓ 2. Inter-cellular Attention", size="lg", variant="secondary")
400
+ masking_btn = gr.Button("🔄 3. Masked Pre-training", size="lg", variant="secondary")
401
+
402
+ gr.HTML('''
403
+ <div style="margin-top: 16px; padding: 12px; background: #fffbeb; border-radius: 8px; border: 1px solid #fde68a; font-size: 11px; color: #92400e; line-height: 1.5;">
404
+ 💡 <strong style="color: #92400e;">Key insight:</strong> By learning gene dependencies across the entire cell population, STACK can transfer knowledge from one cell type to another.
405
+ </div>
406
+ ''')
407
+
408
+ # State for architecture view
409
+ arch_step = gr.State("idle")
410
+
411
+ def set_step(step_name):
412
+ grid, label = update_architecture_view(step_name)
413
+ return grid, label, step_name
414
+
415
+ intra_btn.click(lambda: set_step("intra"), outputs=[grid_display, label_display, arch_step])
416
+ inter_btn.click(lambda: set_step("inter"), outputs=[grid_display, label_display, arch_step])
417
+ masking_btn.click(lambda: set_step("masking"), outputs=[grid_display, label_display, arch_step])
418
+
419
+ # --- INFERENCE TAB ---
420
+ with gr.Tab("🔮 Inference"):
421
+
422
+ # Controls Row
423
+ with gr.Row():
424
+ with gr.Column(scale=1):
425
+ gr.HTML('''
426
+ <div style="font-size: 12px; font-weight: 600; color: #dc2626; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 8px;">
427
+ 🔴 Prompt Cells (Known Response)
428
+ </div>
429
+ ''')
430
+ prompt_cell_radio = gr.Radio(
431
+ choices=[c["name"] for c in CELL_TYPES],
432
+ value=CELL_TYPES[0]["name"],
433
+ label="Cell Type",
434
+ container=True
435
+ )
436
+ prompt_cond_dropdown = gr.Dropdown(
437
+ choices=[c["name"] for c in CONDITIONS if c["id"] != "healthy"],
438
+ value=CONDITIONS[1]["name"],
439
+ label="Condition/Treatment",
440
+ container=True
441
+ )
442
+
443
+ with gr.Column(scale=1):
444
+ gr.HTML('''
445
+ <div style="font-size: 12px; font-weight: 600; color: #2563eb; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 8px;">
446
+ 🔵 Query Cells (To Predict)
447
+ </div>
448
+ ''')
449
+ query_cell_radio = gr.Radio(
450
+ choices=[c["name"] for c in CELL_TYPES],
451
+ value=CELL_TYPES[1]["name"],
452
+ label="Cell Type",
453
+ container=True
454
+ )
455
+ gr.HTML('''
456
+ <div style="padding: 10px 14px; background: #f1f5f9; border-radius: 8px; font-size: 12px; color: #64748b; margin-top: 8px;">
457
+ <strong>Initial state:</strong> Healthy (baseline gene expression)
458
+ </div>
459
+ ''')
460
+
461
+ # Visualization Display
462
+ inference_display = gr.HTML(
463
+ generate_inference_display(CELL_TYPES[0]["id"], CONDITIONS[1]["id"], CELL_TYPES[1]["id"])
464
+ )
465
+
466
+ # Action Buttons
467
+ with gr.Row():
468
+ run_btn = gr.Button("▶️ Run Zero-Shot Prediction", variant="primary", size="lg")
469
+ reset_btn = gr.Button("↩️ Reset", size="lg", visible=False)
470
+
471
+ # States for inference
472
+ prompt_cell_state = gr.State(CELL_TYPES[0]["id"])
473
+ prompt_cond_state = gr.State(CONDITIONS[1]["id"])
474
+ query_cell_state = gr.State(CELL_TYPES[1]["id"])
475
+
476
+ # Update display when selections change
477
+ prompt_cell_radio.change(
478
+ update_inference_display,
479
+ inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio],
480
+ outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state]
481
+ )
482
+
483
+ prompt_cond_dropdown.change(
484
+ update_inference_display,
485
+ inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio],
486
+ outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state]
487
+ )
488
+
489
+ query_cell_radio.change(
490
+ update_inference_display,
491
+ inputs=[prompt_cell_radio, prompt_cond_dropdown, query_cell_radio],
492
+ outputs=[inference_display, prompt_cell_state, prompt_cond_state, query_cell_state]
493
+ )
494
+
495
+ # Run prediction
496
+ run_btn.click(
497
+ run_inference,
498
+ inputs=[prompt_cell_state, prompt_cond_state, query_cell_state],
499
+ outputs=[inference_display, reset_btn, run_btn]
500
+ )
501
+
502
+ # Reset
503
+ reset_btn.click(
504
+ reset_inference,
505
+ inputs=[prompt_cell_state, prompt_cond_state, query_cell_state],
506
+ outputs=[inference_display, reset_btn, run_btn]
507
+ )
508
+
509
+ return app
510
+
511
+ def custom_css():
512
+ return """
513
+ @keyframes spin {
514
+ from { transform: rotate(0deg); }
515
+ to { transform: rotate(360deg); }
516
+ }
517
+
518
+ .gradio-container {
519
+ max-width: 1000px !important;
520
+ margin: auto;
521
+ font-family: system-ui, -apple-system, sans-serif;
522
+ }
523
+
524
+ button {
525
+ border-radius: 10px !important;
526
+ font-size: 13px !important;
527
+ font-weight: 600 !important;
528
+ transition: all 0.2s !important;
529
+ }
530
+
531
+ .tabs button {
532
+ font-size: 13px !important;
533
+ font-weight: 600 !important;
534
+ padding: 12px 20px !important;
535
+ }
536
+
537
+ .tabs button[aria-selected="true"] {
538
+ background: linear-gradient(90deg, #4f46e5 0%, #7c3aed 100%) !important;
539
+ color: white !important;
540
+ }
541
+
542
+ input[type="radio"] + label {
543
+ font-size: 13px !important;
544
+ }
545
+ """
546
+
547
+ # --- Launch App ---
548
+
549
+ if __name__ == "__main__":
550
+ app = create_app()
551
+ app.launch(css=custom_css())
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio>=4.0.0