ludocomito commited on
Commit
ba0d0ab
ยท
1 Parent(s): cb61c3f

correct attention patterns

Browse files
Files changed (1) hide show
  1. README.md +178 -42
README.md CHANGED
@@ -1,70 +1,206 @@
1
  ---
2
  title: Stack Viz
3
- emoji: ๐ŸŒ
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # STACK Model Visualization
13
 
14
- An interactive Gradio-based visualization of the STACK (Structured and Contextualized Knowledge) model architecture and inference capabilities.
15
 
16
- ## Features
17
 
18
- ### Architecture View
19
- - **Visual Grid Representation**: 5x5 grid showing cells and gene modules
20
- - **Interactive Steps**:
21
- - **Intra-cellular**: Visualize gene dependencies within cells (row-wise attention)
22
- - **Inter-cellular**: Show population context across cells (column-wise attention)
23
- - **Pre-training**: Demonstrate masked reconstruction for model training
24
 
25
- ### Inference View
26
- - **Context Configuration**: Select cell type and condition for the prompt
27
- - **Target Configuration**: Choose query cell type (always healthy initial state)
28
- - **Zero-shot Prediction**: Run predictions using context from one cell type to predict another
29
- - **Visual Results**: See predicted cell states with animations
30
 
31
- ## Installation
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ```bash
34
- pip install -r requirements.txt
35
  ```
36
 
37
- ## Usage
38
 
39
- Run the application:
40
 
41
- ```bash
42
- python app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ```
44
 
45
- The app will launch in your browser at `http://localhost:7860`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- ## How It Works
 
 
48
 
49
- 1. **Architecture Tab**: Explore how STACK learns dependencies:
50
- - Click "Intra-cellular" to see within-cell attention patterns
51
- - Click "Inter-cellular" to see across-cell attention patterns
52
- - Click "Pre-training" to see masked reconstruction in action
53
 
54
- 2. **Inference Tab**: Test zero-shot predictions:
55
- - Configure the context (prompt) with a cell type and condition
56
- - Configure the target (query) with a different cell type
57
- - Click "Run Prediction" to see the predicted cell state
58
- - The model uses patterns from the prompt cells to predict query cell behavior
59
 
60
- ## Cell Types
61
 
62
- - ๐Ÿ”ต **T-Cell**: T lymphocytes
63
- - ๐Ÿ”ท **B-Cell**: B lymphocytes
64
- - ๐ŸŸข **Macro**: Macrophages
65
 
66
- ## Conditions
67
 
68
- - **Healthy**: Normal cell state
69
- - **Drug A**: Drug-treated state (indicated by red marker)
70
- - **Viral**: Virus-infected state (indicated by dashed border)
 
1
  ---
2
  title: Stack Viz
3
+ emoji: ๐Ÿงฌ
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # STACK Model Visualization ๐Ÿงฌ
13
 
14
+ An interactive web-based visualization tool for understanding **STACK** (Single-cell Transcriptomic Analysis with Contextual Knowledge), a transformer-based model designed for single-cell RNA sequencing analysis with in-context learning capabilities.
15
 
16
+ ## ๐Ÿ“– What is STACK?
17
 
18
+ STACK is a deep learning model that learns from gene expression patterns in single-cell data to make **zero-shot predictions** across different cell types and experimental conditions. The key innovation is its ability to:
 
 
 
 
 
19
 
20
+ - **Learn contextual patterns**: Understand how gene expression changes under different conditions (e.g., drug treatments, viral infections)
21
+ - **Transfer knowledge**: Apply patterns learned from one cell type to predict behavior in a completely different cell type
22
+ - **Perform in-context learning**: Use "prompt" cells (examples) to guide predictions on "query" cells (targets), similar to how large language models work
 
 
23
 
24
+ ### How It Works
25
 
26
+ STACK operates on gene expression matrices where:
27
+ - **Rows** represent gene modules (groups of related genes)
28
+ - **Columns** represent individual cells
29
+
30
+ The model uses a dual-attention mechanism:
31
+ 1. **Intra-cellular attention**: Learns dependencies between genes within each cell (row-wise)
32
+ 2. **Inter-cellular attention**: Learns patterns across the cell population (column-wise)
33
+
34
+ This allows STACK to capture both individual cell biology and population-level context, enabling powerful zero-shot predictions without requiring retraining for new cell types or conditions.
35
+
36
+ ---
37
+
38
+ ## ๐Ÿš€ Quick Setup Guide
39
+
40
+ ### Prerequisites
41
+ - Python 3.8 or higher
42
+ - pip package manager
43
+
44
+ ### Installation
45
+
46
+ 1. **Clone or download this repository**:
47
+ ```bash
48
+ cd stack-viz
49
+ ```
50
+
51
+ 2. **Install dependencies**:
52
+ ```bash
53
+ pip install -r requirements.txt
54
+ ```
55
+
56
+ This will install:
57
+ - `gradio>=4.0.0` - Web interface framework
58
+
59
+ 3. **Run the application**:
60
+ ```bash
61
+ python app.py
62
+ ```
63
+
64
+ 4. **Access the interface**:
65
+ - Open your browser and navigate to `http://localhost:7860`
66
+ - The Gradio interface will automatically launch
67
+
68
+ ### Alternative: Run with auto-reload
69
+ For development purposes, you can run with auto-reload:
70
  ```bash
71
+ gradio app.py
72
  ```
73
 
74
+ ---
75
 
76
+ ## ๐ŸŽฏ Using the Visualization
77
 
78
+ ### Architecture Tab ๐Ÿ—๏ธ
79
+ Explore how STACK learns from gene expression data:
80
+
81
+ - **Intra-cellular Attention** (โ†’): Click to highlight a row, showing how the model learns gene dependencies within individual cells
82
+ - **Inter-cellular Attention** (โ†“): Click to highlight a column, showing how the model learns context from the cell population
83
+ - **Masked Pre-training** (๐Ÿ”„): Click to see masked gene reconstruction - the model learns by predicting masked gene expression values (two consecutive rows are masked)
84
+
85
+ ### Inference Tab ๐Ÿ”ฎ
86
+ Test zero-shot predictions:
87
+
88
+ 1. **Configure Prompt Cells** ๐Ÿ”ด (Known Response):
89
+ - Select a cell type (T-Cell, B-Cell, or Macrophage)
90
+ - Select a condition (Drug A or Viral Infection)
91
+ - These cells provide the "context" for prediction
92
+
93
+ 2. **Configure Query Cells** ๐Ÿ”ต (To Predict):
94
+ - Select a different cell type
95
+ - Always starts in healthy/baseline state
96
+ - The model will predict how these cells would respond to the condition
97
+
98
+ 3. **Run Prediction**:
99
+ - Click "Run Zero-Shot Prediction"
100
+ - Watch the model process the context
101
+ - View predicted gene expression patterns for your query cells
102
+
103
+ **Example Use Case**:
104
+ *"If I know how T-Cells respond to Drug A, how would B-Cells respond to the same drug?"*
105
+
106
+ ---
107
+
108
+ ## ๐Ÿ“ Code Organization
109
+
110
+ ### Project Structure
111
+
112
+ ```
113
+ stack-viz/
114
+ โ”œโ”€โ”€ app.py # Main application file
115
+ โ”œโ”€โ”€ requirements.txt # Python dependencies
116
+ โ”œโ”€โ”€ README.md # This file
117
+ โ””โ”€โ”€ .git/ # Git repository
118
  ```
119
 
120
+ ### Code Architecture (`app.py`)
121
+
122
+ The application is organized into logical sections:
123
+
124
+ #### 1. **Constants & Configuration** (Lines 6-17)
125
+ - `CELL_TYPES`: Definitions for T-Cell, B-Cell, and Macrophage with colors
126
+ - `CONDITIONS`: Definitions for Healthy, Drug A, and Viral conditions
127
+ - Color schemes for visual consistency
128
+
129
+ #### 2. **Cell Visualization Helpers** (Lines 19-68)
130
+ - `make_cell_svg()`: Generates SVG representations of cells
131
+ - Prompt cells: Irregular shape (treated/perturbed state)
132
+ - Query cells: Circular shape with dashed borders (baseline state)
133
+ - `generate_cell_array()`: Creates horizontal arrays of cells for display
134
+ - Uses color coding to distinguish cell types and conditions
135
+
136
+ #### 3. **Inference Display Functions** (Lines 70-160)
137
+ - `generate_inference_display()`: Creates the main visualization showing prompt and query cells
138
+ - `generate_mini_matrix()`: Generates preview of gene expression matrix
139
+ - Combines cell arrays with arrow indicators and matrix representations
140
+
141
+ #### 4. **Architecture Grid Functions** (Lines 162-244)
142
+ - `generate_grid_html()`: Creates the 5ร—5 gene expression matrix visualization
143
+ - Rows = Genes
144
+ - Columns = Cells
145
+ - Interactive highlighting for different attention patterns
146
+ - `get_step_label()`: Returns descriptive labels for each learning step
147
+ - `update_architecture_view()`: Manages state changes between visualization modes
148
+
149
+ #### 5. **Inference Logic** (Lines 261-357)
150
+ - `update_inference_display()`: Updates visualization when user changes selections
151
+ - `run_inference()`: Simulates the prediction process with animation
152
+ - Shows "processing" state with spinner
153
+ - Generates predicted cell states
154
+ - Displays gene expression output columns
155
+ - `reset_inference()`: Resets to initial state
156
+ - `generate_output_column()`: Creates visual representation of predicted gene counts
157
+
158
+ #### 6. **Gradio Interface** (Lines 359-509)
159
+ - `create_app()`: Main application builder
160
+ - **Header**: Branding and title
161
+ - **Architecture Tab**: Grid visualization with control buttons
162
+ - **Inference Tab**: Interactive prediction interface with controls
163
+ - State management using `gr.State()`
164
+ - Event handlers for button clicks and selections
165
+
166
+ #### 7. **Styling** (Lines 511-545)
167
+ - `custom_css()`: Custom CSS for enhanced UI
168
+ - Animations (spinner for loading)
169
+ - Button styling and transitions
170
+ - Tab styling with gradients
171
+ - Responsive design
172
+
173
+ #### 8. **Application Entry Point** (Lines 547-551)
174
+ - Launches the Gradio app with custom CSS
175
+ - Default port: 7860
176
+
177
+ ### Customization Points
178
+
179
+ To modify the visualization:
180
+
181
+ - **Add new cell types**: Edit `CELL_TYPES` constant
182
+ - **Add new conditions**: Edit `CONDITIONS` constant
183
+ - **Change grid size**: Modify the grid generation in `generate_grid_html()`
184
+ - **Adjust masking logic**: Update the masking selection in `update_architecture_view()`
185
+ - **Customize styling**: Edit `custom_css()` function
186
 
187
+ ---
188
+
189
+ ## ๐Ÿ”ฌ Cell Types & Conditions
190
 
191
+ ### Available Cell Types
192
+ - ๐Ÿ”ต **T-Cell**: T lymphocytes (immune cells, blue)
193
+ - ๐Ÿ”ท **B-Cell**: B lymphocytes (antibody producers, indigo)
194
+ - ๏ฟฝ **Macrophage**: Innate immune cells (orange)
195
 
196
+ ### Experimental Conditions
197
+ - โœ… **Healthy**: Normal baseline state
198
+ - ๐Ÿ’Š **Drug A**: Pharmaceutical treatment (red indicator)
199
+ - ๐Ÿฆ  **Viral Infection**: Infectious disease state (orange indicator)
 
200
 
 
201
 
202
+ ---
 
 
203
 
204
+ ## ๐Ÿ“„ License
205
 
206
+ This project is open source and available for educational and research purposes.