kaveh commited on
Commit
31a6a59
·
1 Parent(s): 4ec8735

separated notebooks for clarity

Browse files
notebooks/Singlecell_evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/Singlecell_inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/Spheroid_evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/Spheroid_inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/demo.ipynb DELETED
@@ -1,248 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# S2F Model Evaluation\n",
8
- "\n",
9
- "This notebook shows how to evaluate a trained Shape2Force (S2F) model on your dataset.\n",
10
- "\n",
11
- "**Metrics computed:**\n",
12
- "- **MSE** – Mean Squared Error\n",
13
- "- **MS-SSIM** – Multi-Scale Structural Similarity\n",
14
- "- **Pixel Correlation** – Pearson correlation between predicted and ground-truth heatmaps\n",
15
- "- **Relative Magnitude Error** – WFM-style weighted relative error\n",
16
- "- **Force Sum/Mean Correlation** – Correlation of total force per sample\n",
17
- "\n",
18
- "Run the cells below after adjusting paths and settings."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": null,
24
- "metadata": {},
25
- "outputs": [],
26
- "source": [
27
- "%load_ext autoreload\n",
28
- "%autoreload 2\n",
29
- "\n",
30
- "import warnings\n",
31
- "warnings.filterwarnings('ignore')\n",
32
- "import sys\n",
33
- "import os\n",
34
- "import cv2\n",
35
- "cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR)\n",
36
- "\n",
37
- "cwd = os.getcwd()\n",
38
- "S2F_ROOT = cwd if os.path.exists(os.path.join(cwd, 'models')) else os.path.dirname(cwd)\n",
39
- "sys.path.insert(0, S2F_ROOT)\n",
40
- "\n",
41
- "import torch\n",
42
- "import matplotlib.pyplot as plt\n",
43
- "\n",
44
- "from data.cell_dataset import prepare_data, load_folder_data\n",
45
- "from models.s2f_model import create_s2f_model, compute_settings_normalization\n",
46
- "from utils.metrics import (\n",
47
- " evaluate_metrics_on_dataset,\n",
48
- " print_metrics_report,\n",
49
- " gen_prediction_plots,\n",
50
- " plot_predictions,\n",
51
- ")\n",
52
- "\n",
53
- "print(f\"S2F root: {S2F_ROOT}\")"
54
- ]
55
- },
56
- {
57
- "cell_type": "markdown",
58
- "metadata": {},
59
- "source": [
60
- "## Configuration"
61
- ]
62
- },
63
- {
64
- "cell_type": "code",
65
- "execution_count": null,
66
- "metadata": {},
67
- "outputs": [],
68
- "source": [
69
- "# --- Adjust these paths ---\n",
70
- "USE_SINGLE_CELL = True # True = single-cell (with substrate), False = spheroid\n",
71
- "DATA_FOLDER = os.path.join(S2F_ROOT, 'sample') # or path to your dataset\n",
72
- "CHECKPOINT_PATH = os.path.join(S2F_ROOT, 'ckp', 'best_checkpoint.pth')\n",
73
- "\n",
74
- "IMAGE_SIZE = 1024\n",
75
- "BATCH_SIZE = 2\n",
76
- "THRESHOLD = 0.0 # Threshold for heatmap metrics (0 = no threshold)\n",
77
- "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
78
- "SUBSTRATE = 'fibroblasts_PDMS' # For single-cell mode"
79
- ]
80
- },
81
- {
82
- "cell_type": "markdown",
83
- "metadata": {},
84
- "source": [
85
- "## Load Data"
86
- ]
87
- },
88
- {
89
- "cell_type": "code",
90
- "execution_count": null,
91
- "metadata": {},
92
- "outputs": [],
93
- "source": [
94
- "# Option A: Load from folder with train/test subfolders\n",
95
- "train_folder = os.path.join(DATA_FOLDER, 'train')\n",
96
- "test_folder = os.path.join(DATA_FOLDER, 'test')\n",
97
- "\n",
98
- "if os.path.exists(train_folder) and os.path.exists(test_folder):\n",
99
- " train_loader, val_loader = prepare_data(\n",
100
- " DATA_FOLDER,\n",
101
- " batch_size=BATCH_SIZE,\n",
102
- " target_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
103
- " use_augmentations=False,\n",
104
- " train_test_sep_folder=True,\n",
105
- " return_metadata=USE_SINGLE_CELL,\n",
106
- " substrate=SUBSTRATE if USE_SINGLE_CELL else None,\n",
107
- " )\n",
108
- " print(f\"Loaded train: {len(train_loader.dataset)} samples, val: {len(val_loader.dataset)} samples\")\n",
109
- "else:\n",
110
- " # Option B: Load from a single folder (e.g. test only)\n",
111
- " val_loader = load_folder_data(\n",
112
- " DATA_FOLDER,\n",
113
- " substrate=SUBSTRATE if USE_SINGLE_CELL else None,\n",
114
- " img_size=IMAGE_SIZE,\n",
115
- " batch_size=BATCH_SIZE,\n",
116
- " return_metadata=USE_SINGLE_CELL,\n",
117
- " )\n",
118
- " train_loader = None\n",
119
- " print(f\"Loaded {len(val_loader.dataset)} samples from {DATA_FOLDER}\")"
120
- ]
121
- },
122
- {
123
- "cell_type": "markdown",
124
- "metadata": {},
125
- "source": [
126
- "## Load Model"
127
- ]
128
- },
129
- {
130
- "cell_type": "code",
131
- "execution_count": null,
132
- "metadata": {},
133
- "outputs": [],
134
- "source": [
135
- "in_channels = 3 if USE_SINGLE_CELL else 1\n",
136
- "generator, _ = create_s2f_model(in_channels=in_channels)\n",
137
- "checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu', weights_only=False)\n",
138
- "generator.load_state_dict(checkpoint.get('generator_state_dict', checkpoint), strict=True)\n",
139
- "generator = generator.to(DEVICE)\n",
140
- "generator.eval()\n",
141
- "\n",
142
- "print(f\"Loaded checkpoint from {CHECKPOINT_PATH}\")\n",
143
- "if 'epoch' in checkpoint:\n",
144
- " print(f\" Epoch: {checkpoint['epoch']}\")"
145
- ]
146
- },
147
- {
148
- "cell_type": "markdown",
149
- "metadata": {},
150
- "source": [
151
- "## Run Evaluation"
152
- ]
153
- },
154
- {
155
- "cell_type": "code",
156
- "execution_count": null,
157
- "metadata": {},
158
- "outputs": [],
159
- "source": [
160
- "config_path = os.path.join(S2F_ROOT, 'config', 'substrate_settings.json')\n",
161
- "normalization_params = compute_settings_normalization(config_path=config_path) if USE_SINGLE_CELL else None\n",
162
- "\n",
163
- "val_results = evaluate_metrics_on_dataset(\n",
164
- " generator,\n",
165
- " val_loader,\n",
166
- " device=DEVICE,\n",
167
- " description=\"Evaluating\",\n",
168
- " save_predictions=True,\n",
169
- " threshold=THRESHOLD,\n",
170
- " use_settings=USE_SINGLE_CELL,\n",
171
- " normalization_params=normalization_params,\n",
172
- " config_path=config_path,\n",
173
- " substrate_override=SUBSTRATE,\n",
174
- ")\n",
175
- "\n",
176
- "report = {'validation': val_results}\n",
177
- "if train_loader is not None:\n",
178
- " train_results = evaluate_metrics_on_dataset(\n",
179
- " generator, train_loader, device=DEVICE, description=\"Training\",\n",
180
- " threshold=THRESHOLD, use_settings=USE_SINGLE_CELL,\n",
181
- " normalization_params=normalization_params, config_path=config_path,\n",
182
- " substrate_override=SUBSTRATE,\n",
183
- " )\n",
184
- " report = {'train': train_results, 'validation': val_results}\n",
185
- "\n",
186
- "print_metrics_report(report, threshold=THRESHOLD)"
187
- ]
188
- },
189
- {
190
- "cell_type": "markdown",
191
- "metadata": {},
192
- "source": [
193
- "## Visualize Predictions"
194
- ]
195
- },
196
- {
197
- "cell_type": "code",
198
- "execution_count": null,
199
- "metadata": {},
200
- "outputs": [],
201
- "source": [
202
- "# Quick preview: plot first few samples\n",
203
- "plot_predictions(\n",
204
- " val_loader,\n",
205
- " generator,\n",
206
- " n_samples=3,\n",
207
- " device=DEVICE,\n",
208
- " use_settings=USE_SINGLE_CELL,\n",
209
- " normalization_params=normalization_params,\n",
210
- " config_path=config_path,\n",
211
- " substrate_override=SUBSTRATE,\n",
212
- ")"
213
- ]
214
- },
215
- {
216
- "cell_type": "code",
217
- "execution_count": null,
218
- "metadata": {},
219
- "outputs": [],
220
- "source": [
221
- "# Save individual prediction plots (sorted by MSE, best first)\n",
222
- "plot_output_dir = os.path.join(S2F_ROOT, 'outputs', 'evaluation_plots')\n",
223
- "if 'individual_predictions' in val_results:\n",
224
- " gen_prediction_plots(\n",
225
- " val_results['individual_predictions'],\n",
226
- " save_dir=plot_output_dir,\n",
227
- " sort_by='mse',\n",
228
- " sort_order='asc',\n",
229
- " threshold=THRESHOLD,\n",
230
- " )\n",
231
- " print(f\"Saved plots to {plot_output_dir}\")"
232
- ]
233
- }
234
- ],
235
- "metadata": {
236
- "kernelspec": {
237
- "display_name": "Python 3",
238
- "language": "python",
239
- "name": "python3"
240
- },
241
- "language_info": {
242
- "name": "python",
243
- "version": "3.10.0"
244
- }
245
- },
246
- "nbformat": 4,
247
- "nbformat_minor": 4
248
- }