separated notebooks for clarity
Browse files
notebooks/Singlecell_evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/Singlecell_inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/Spheroid_evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/Spheroid_inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/demo.ipynb
DELETED
|
@@ -1,248 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "markdown",
|
| 5 |
-
"metadata": {},
|
| 6 |
-
"source": [
|
| 7 |
-
"# S2F Model Evaluation\n",
|
| 8 |
-
"\n",
|
| 9 |
-
"This notebook shows how to evaluate a trained Shape2Force (S2F) model on your dataset.\n",
|
| 10 |
-
"\n",
|
| 11 |
-
"**Metrics computed:**\n",
|
| 12 |
-
"- **MSE** – Mean Squared Error\n",
|
| 13 |
-
"- **MS-SSIM** – Multi-Scale Structural Similarity\n",
|
| 14 |
-
"- **Pixel Correlation** – Pearson correlation between predicted and ground-truth heatmaps\n",
|
| 15 |
-
"- **Relative Magnitude Error** – WFM-style weighted relative error\n",
|
| 16 |
-
"- **Force Sum/Mean Correlation** – Correlation of total force per sample\n",
|
| 17 |
-
"\n",
|
| 18 |
-
"Run the cells below after adjusting paths and settings."
|
| 19 |
-
]
|
| 20 |
-
},
|
| 21 |
-
{
|
| 22 |
-
"cell_type": "code",
|
| 23 |
-
"execution_count": null,
|
| 24 |
-
"metadata": {},
|
| 25 |
-
"outputs": [],
|
| 26 |
-
"source": [
|
| 27 |
-
"%load_ext autoreload\n",
|
| 28 |
-
"%autoreload 2\n",
|
| 29 |
-
"\n",
|
| 30 |
-
"import warnings\n",
|
| 31 |
-
"warnings.filterwarnings('ignore')\n",
|
| 32 |
-
"import sys\n",
|
| 33 |
-
"import os\n",
|
| 34 |
-
"import cv2\n",
|
| 35 |
-
"cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR)\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"cwd = os.getcwd()\n",
|
| 38 |
-
"S2F_ROOT = cwd if os.path.exists(os.path.join(cwd, 'models')) else os.path.dirname(cwd)\n",
|
| 39 |
-
"sys.path.insert(0, S2F_ROOT)\n",
|
| 40 |
-
"\n",
|
| 41 |
-
"import torch\n",
|
| 42 |
-
"import matplotlib.pyplot as plt\n",
|
| 43 |
-
"\n",
|
| 44 |
-
"from data.cell_dataset import prepare_data, load_folder_data\n",
|
| 45 |
-
"from models.s2f_model import create_s2f_model, compute_settings_normalization\n",
|
| 46 |
-
"from utils.metrics import (\n",
|
| 47 |
-
" evaluate_metrics_on_dataset,\n",
|
| 48 |
-
" print_metrics_report,\n",
|
| 49 |
-
" gen_prediction_plots,\n",
|
| 50 |
-
" plot_predictions,\n",
|
| 51 |
-
")\n",
|
| 52 |
-
"\n",
|
| 53 |
-
"print(f\"S2F root: {S2F_ROOT}\")"
|
| 54 |
-
]
|
| 55 |
-
},
|
| 56 |
-
{
|
| 57 |
-
"cell_type": "markdown",
|
| 58 |
-
"metadata": {},
|
| 59 |
-
"source": [
|
| 60 |
-
"## Configuration"
|
| 61 |
-
]
|
| 62 |
-
},
|
| 63 |
-
{
|
| 64 |
-
"cell_type": "code",
|
| 65 |
-
"execution_count": null,
|
| 66 |
-
"metadata": {},
|
| 67 |
-
"outputs": [],
|
| 68 |
-
"source": [
|
| 69 |
-
"# --- Adjust these paths ---\n",
|
| 70 |
-
"USE_SINGLE_CELL = True # True = single-cell (with substrate), False = spheroid\n",
|
| 71 |
-
"DATA_FOLDER = os.path.join(S2F_ROOT, 'sample') # or path to your dataset\n",
|
| 72 |
-
"CHECKPOINT_PATH = os.path.join(S2F_ROOT, 'ckp', 'best_checkpoint.pth')\n",
|
| 73 |
-
"\n",
|
| 74 |
-
"IMAGE_SIZE = 1024\n",
|
| 75 |
-
"BATCH_SIZE = 2\n",
|
| 76 |
-
"THRESHOLD = 0.0 # Threshold for heatmap metrics (0 = no threshold)\n",
|
| 77 |
-
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 78 |
-
"SUBSTRATE = 'fibroblasts_PDMS' # For single-cell mode"
|
| 79 |
-
]
|
| 80 |
-
},
|
| 81 |
-
{
|
| 82 |
-
"cell_type": "markdown",
|
| 83 |
-
"metadata": {},
|
| 84 |
-
"source": [
|
| 85 |
-
"## Load Data"
|
| 86 |
-
]
|
| 87 |
-
},
|
| 88 |
-
{
|
| 89 |
-
"cell_type": "code",
|
| 90 |
-
"execution_count": null,
|
| 91 |
-
"metadata": {},
|
| 92 |
-
"outputs": [],
|
| 93 |
-
"source": [
|
| 94 |
-
"# Option A: Load from folder with train/test subfolders\n",
|
| 95 |
-
"train_folder = os.path.join(DATA_FOLDER, 'train')\n",
|
| 96 |
-
"test_folder = os.path.join(DATA_FOLDER, 'test')\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"if os.path.exists(train_folder) and os.path.exists(test_folder):\n",
|
| 99 |
-
" train_loader, val_loader = prepare_data(\n",
|
| 100 |
-
" DATA_FOLDER,\n",
|
| 101 |
-
" batch_size=BATCH_SIZE,\n",
|
| 102 |
-
" target_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
|
| 103 |
-
" use_augmentations=False,\n",
|
| 104 |
-
" train_test_sep_folder=True,\n",
|
| 105 |
-
" return_metadata=USE_SINGLE_CELL,\n",
|
| 106 |
-
" substrate=SUBSTRATE if USE_SINGLE_CELL else None,\n",
|
| 107 |
-
" )\n",
|
| 108 |
-
" print(f\"Loaded train: {len(train_loader.dataset)} samples, val: {len(val_loader.dataset)} samples\")\n",
|
| 109 |
-
"else:\n",
|
| 110 |
-
" # Option B: Load from a single folder (e.g. test only)\n",
|
| 111 |
-
" val_loader = load_folder_data(\n",
|
| 112 |
-
" DATA_FOLDER,\n",
|
| 113 |
-
" substrate=SUBSTRATE if USE_SINGLE_CELL else None,\n",
|
| 114 |
-
" img_size=IMAGE_SIZE,\n",
|
| 115 |
-
" batch_size=BATCH_SIZE,\n",
|
| 116 |
-
" return_metadata=USE_SINGLE_CELL,\n",
|
| 117 |
-
" )\n",
|
| 118 |
-
" train_loader = None\n",
|
| 119 |
-
" print(f\"Loaded {len(val_loader.dataset)} samples from {DATA_FOLDER}\")"
|
| 120 |
-
]
|
| 121 |
-
},
|
| 122 |
-
{
|
| 123 |
-
"cell_type": "markdown",
|
| 124 |
-
"metadata": {},
|
| 125 |
-
"source": [
|
| 126 |
-
"## Load Model"
|
| 127 |
-
]
|
| 128 |
-
},
|
| 129 |
-
{
|
| 130 |
-
"cell_type": "code",
|
| 131 |
-
"execution_count": null,
|
| 132 |
-
"metadata": {},
|
| 133 |
-
"outputs": [],
|
| 134 |
-
"source": [
|
| 135 |
-
"in_channels = 3 if USE_SINGLE_CELL else 1\n",
|
| 136 |
-
"generator, _ = create_s2f_model(in_channels=in_channels)\n",
|
| 137 |
-
"checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu', weights_only=False)\n",
|
| 138 |
-
"generator.load_state_dict(checkpoint.get('generator_state_dict', checkpoint), strict=True)\n",
|
| 139 |
-
"generator = generator.to(DEVICE)\n",
|
| 140 |
-
"generator.eval()\n",
|
| 141 |
-
"\n",
|
| 142 |
-
"print(f\"Loaded checkpoint from {CHECKPOINT_PATH}\")\n",
|
| 143 |
-
"if 'epoch' in checkpoint:\n",
|
| 144 |
-
" print(f\" Epoch: {checkpoint['epoch']}\")"
|
| 145 |
-
]
|
| 146 |
-
},
|
| 147 |
-
{
|
| 148 |
-
"cell_type": "markdown",
|
| 149 |
-
"metadata": {},
|
| 150 |
-
"source": [
|
| 151 |
-
"## Run Evaluation"
|
| 152 |
-
]
|
| 153 |
-
},
|
| 154 |
-
{
|
| 155 |
-
"cell_type": "code",
|
| 156 |
-
"execution_count": null,
|
| 157 |
-
"metadata": {},
|
| 158 |
-
"outputs": [],
|
| 159 |
-
"source": [
|
| 160 |
-
"config_path = os.path.join(S2F_ROOT, 'config', 'substrate_settings.json')\n",
|
| 161 |
-
"normalization_params = compute_settings_normalization(config_path=config_path) if USE_SINGLE_CELL else None\n",
|
| 162 |
-
"\n",
|
| 163 |
-
"val_results = evaluate_metrics_on_dataset(\n",
|
| 164 |
-
" generator,\n",
|
| 165 |
-
" val_loader,\n",
|
| 166 |
-
" device=DEVICE,\n",
|
| 167 |
-
" description=\"Evaluating\",\n",
|
| 168 |
-
" save_predictions=True,\n",
|
| 169 |
-
" threshold=THRESHOLD,\n",
|
| 170 |
-
" use_settings=USE_SINGLE_CELL,\n",
|
| 171 |
-
" normalization_params=normalization_params,\n",
|
| 172 |
-
" config_path=config_path,\n",
|
| 173 |
-
" substrate_override=SUBSTRATE,\n",
|
| 174 |
-
")\n",
|
| 175 |
-
"\n",
|
| 176 |
-
"report = {'validation': val_results}\n",
|
| 177 |
-
"if train_loader is not None:\n",
|
| 178 |
-
" train_results = evaluate_metrics_on_dataset(\n",
|
| 179 |
-
" generator, train_loader, device=DEVICE, description=\"Training\",\n",
|
| 180 |
-
" threshold=THRESHOLD, use_settings=USE_SINGLE_CELL,\n",
|
| 181 |
-
" normalization_params=normalization_params, config_path=config_path,\n",
|
| 182 |
-
" substrate_override=SUBSTRATE,\n",
|
| 183 |
-
" )\n",
|
| 184 |
-
" report = {'train': train_results, 'validation': val_results}\n",
|
| 185 |
-
"\n",
|
| 186 |
-
"print_metrics_report(report, threshold=THRESHOLD)"
|
| 187 |
-
]
|
| 188 |
-
},
|
| 189 |
-
{
|
| 190 |
-
"cell_type": "markdown",
|
| 191 |
-
"metadata": {},
|
| 192 |
-
"source": [
|
| 193 |
-
"## Visualize Predictions"
|
| 194 |
-
]
|
| 195 |
-
},
|
| 196 |
-
{
|
| 197 |
-
"cell_type": "code",
|
| 198 |
-
"execution_count": null,
|
| 199 |
-
"metadata": {},
|
| 200 |
-
"outputs": [],
|
| 201 |
-
"source": [
|
| 202 |
-
"# Quick preview: plot first few samples\n",
|
| 203 |
-
"plot_predictions(\n",
|
| 204 |
-
" val_loader,\n",
|
| 205 |
-
" generator,\n",
|
| 206 |
-
" n_samples=3,\n",
|
| 207 |
-
" device=DEVICE,\n",
|
| 208 |
-
" use_settings=USE_SINGLE_CELL,\n",
|
| 209 |
-
" normalization_params=normalization_params,\n",
|
| 210 |
-
" config_path=config_path,\n",
|
| 211 |
-
" substrate_override=SUBSTRATE,\n",
|
| 212 |
-
")"
|
| 213 |
-
]
|
| 214 |
-
},
|
| 215 |
-
{
|
| 216 |
-
"cell_type": "code",
|
| 217 |
-
"execution_count": null,
|
| 218 |
-
"metadata": {},
|
| 219 |
-
"outputs": [],
|
| 220 |
-
"source": [
|
| 221 |
-
"# Save individual prediction plots (sorted by MSE, best first)\n",
|
| 222 |
-
"plot_output_dir = os.path.join(S2F_ROOT, 'outputs', 'evaluation_plots')\n",
|
| 223 |
-
"if 'individual_predictions' in val_results:\n",
|
| 224 |
-
" gen_prediction_plots(\n",
|
| 225 |
-
" val_results['individual_predictions'],\n",
|
| 226 |
-
" save_dir=plot_output_dir,\n",
|
| 227 |
-
" sort_by='mse',\n",
|
| 228 |
-
" sort_order='asc',\n",
|
| 229 |
-
" threshold=THRESHOLD,\n",
|
| 230 |
-
" )\n",
|
| 231 |
-
" print(f\"Saved plots to {plot_output_dir}\")"
|
| 232 |
-
]
|
| 233 |
-
}
|
| 234 |
-
],
|
| 235 |
-
"metadata": {
|
| 236 |
-
"kernelspec": {
|
| 237 |
-
"display_name": "Python 3",
|
| 238 |
-
"language": "python",
|
| 239 |
-
"name": "python3"
|
| 240 |
-
},
|
| 241 |
-
"language_info": {
|
| 242 |
-
"name": "python",
|
| 243 |
-
"version": "3.10.0"
|
| 244 |
-
}
|
| 245 |
-
},
|
| 246 |
-
"nbformat": 4,
|
| 247 |
-
"nbformat_minor": 4
|
| 248 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|