epaps commited on
Commit
de93bc1
·
1 Parent(s): 9c3903c
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM mambaorg/micromamba:1.5.8
2
+ SHELL ["/bin/bash", "-lc"]
3
+
4
+ USER root
5
+ RUN mkdir -p /var/lib/apt/lists/partial && \
6
+ apt-get update && \
7
+ apt-get install -y --no-install-recommends git cmake make g++ && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ WORKDIR /work
11
+
12
+ # Your conda env file is in binder/environment.yml inside the repo
13
+ COPY binder/environment.yml /tmp/environment.yml
14
+ RUN micromamba env create -f /tmp/environment.yml && micromamba clean -a -y
15
+
16
+ # Build StochasticCIL
17
+ RUN git clone https://github.com/epapoutsellis/StochasticCIL.git && \
18
+ cd StochasticCIL && \
19
+ git fetch --all --tags && \
20
+ git checkout svrg && \
21
+ git config user.email "hf@local" && \
22
+ git config user.name "HF Build" && \
23
+ (git tag -a v1.0 -m "Version 1.0" || true) && \
24
+ mkdir -p build && cd build && \
25
+ PREFIX="$(micromamba run -n ssp python -c 'import sys, os; print(os.path.dirname(os.path.dirname(sys.executable)))')" && \
26
+ cmake .. -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
27
+ -DCONDA_BUILD=OFF \
28
+ -DCMAKE_BUILD_TYPE=Release \
29
+ -DLIBRARY_LIB="${PREFIX}/lib" \
30
+ -DLIBRARY_INC="${PREFIX}" \
31
+ -DCMAKE_INSTALL_PREFIX="${PREFIX}" \
32
+ -DPython_EXECUTABLE="${PREFIX}/bin/python" && \
33
+ make -j"$(nproc)" && \
34
+ make install
35
+
36
+ COPY start.sh /start.sh
37
+ RUN chmod +x /start.sh
38
+
39
+ EXPOSE 7860
40
+ CMD ["/start.sh"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Vaggelis Papoutsellis
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
PnP_BM3D_VR_Skip.ipynb ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "e570205c-8d15-409a-a80a-2deabe8093db",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from cil.optimisation.functions import LeastSquares, SVRGFunction, SGFunction\n",
11
+ "from cil.framework import AcquisitionGeometry\n",
12
+ "from cil.plugins.astra import ProjectionOperator\n",
13
+ "from cil.plugins.astra import FBP\n",
14
+ "from cil.optimisation.utilities import MetricsDiagnostics, RandomSampling\n",
15
+ "from cil.optimisation.algorithms import FISTA, ISTA\n",
16
+ "\n",
17
+ "from ProxSkip import ProxSkip\n",
18
+ "from utils import StoppingCriterionTime, BM3DFunction\n",
19
+ "from skimage.metrics import peak_signal_noise_ratio as PSNR\n",
20
+ "from skimage.metrics import structural_similarity as SSIM\n",
21
+ "from xdesign import Foam, discrete_phantom\n",
22
+ "\n",
23
+ "import numpy as np\n",
24
+ "import matplotlib.pyplot as plt\n",
25
+ "plt.rcParams[\"image.cmap\"] = \"inferno\""
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "id": "95eac1e1-ad82-4977-9576-1b8cc99f8d57",
31
+ "metadata": {},
32
+ "source": [
33
+ "### PnP Variance Reduced ProxSkip on a simulated dataset"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "id": "3e1dd614-7f15-4894-b13d-a1b32c82c969",
39
+ "metadata": {},
40
+ "source": [
41
+ "In this notebook, we play the following game. We compare algorithms by **time-to-quality**: that is, we measure how quickly each method achieves high reconstruction quality—quantified by **PSNR** and **SSIM**—with respect to the ground truth. Here, the ground truth is available, and we impose a fixed computational budget. The winner is the algorithm that attains the highest PSNR/SSIM within this time limit."
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "id": "d25fd695-7af3-4992-bf6e-0ce6fe299870",
47
+ "metadata": {},
48
+ "source": [
49
+ "### Generate Foam Phantom"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "7154574f-40e1-4b60-828b-ddc4b1be8d69",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "N = 347 # image size\n",
60
+ "density = 0.025 # attenuation density of the image\n",
61
+ "np.random.seed(1234)\n",
62
+ "gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10)\n",
63
+ "gt = gt / np.max(gt) * density\n",
64
+ "gt = np.pad(gt, 5)\n",
65
+ "gt[gt < 0] = 0"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "id": "7963e21e-24c7-46e0-9cda-3c9c6aaf9678",
71
+ "metadata": {},
72
+ "source": [
73
+ "### Generate Acquisition Data with simulated noise"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "226ded3c-a540-49d0-b0a3-3c560de87c96",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "num_angles = 400 \n",
84
+ "\n",
85
+ "angles = np.linspace(0, 180, num_angles, endpoint=False, dtype=np.float32)\n",
86
+ "ag2D = AcquisitionGeometry.create_Parallel2D()\\\n",
87
+ " .set_panel(num_pixels=N)\\\n",
88
+ " .set_angles(angles=angles)\n",
89
+ "ig2D = ag2D.get_ImageGeometry()\n",
90
+ "x_cil = ig2D.allocate()\n",
91
+ "x_cil.fill(gt)\n",
92
+ "A = ProjectionOperator(ig2D, ag2D, device=\"cpu\")\n",
93
+ "\n",
94
+ "# noiseless sinogram\n",
95
+ "noiseless = A.direct(x_cil)\n",
96
+ "\n",
97
+ "## simulated noise\n",
98
+ "max_intensity = 2000\n",
99
+ "expected_counts = max_intensity * np.exp(-noiseless.array)\n",
100
+ "noisy_counts = np.random.poisson(expected_counts).astype(np.float32)\n",
101
+ "noisy_counts[noisy_counts == 0] = 1 # deal with 0s\n",
102
+ "y_np = -np.log(noisy_counts / max_intensity)\n",
103
+ "noisy = ag2D.allocate()\n",
104
+ "noisy.fill(y_np)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "faa9bb7a-9454-4c6e-802e-80c991cc0da3",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "import numpy as np\n",
115
+ "import matplotlib.pyplot as plt\n",
116
+ "\n",
117
+ "fig, axs = plt.subplots(1, 3, figsize=(12, 10), constrained_layout=True)\n",
118
+ "\n",
119
+ "axs[0].imshow(gt)\n",
120
+ "axs[0].set_title(\"Ground Truth\")\n",
121
+ "\n",
122
+ "axs[1].imshow(noiseless.array)\n",
123
+ "axs[1].set_title(\"Noiseless Sinogram\")\n",
124
+ "\n",
125
+ "axs[2].imshow(noisy.array)\n",
126
+ "axs[2].set_title(\"Noisy Sinogram\")\n",
127
+ "\n",
128
+ "plt.show()"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "id": "b138b305-26db-4c0e-a40f-687cff75d883",
134
+ "metadata": {},
135
+ "source": [
136
+ "### FBP reconstruction"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "id": "78e70056-aea9-4780-89e2-52351f51b9ac",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "fbp = FBP(ig2D, ag2D, device=\"cpu\")(noisy)\n",
147
+ "fbp.array[fbp.array<0] = 0\n",
148
+ "plt.imshow(fbp.array)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "id": "498ad7aa-540b-426b-8810-991a72312e5c",
154
+ "metadata": {},
155
+ "source": [
156
+ "### Family of considered algorithms (ProxSkip template)\n",
157
+ "\n",
158
+ "**Parameters:** $\\gamma>0$, probability $p\\in(0,1]$, data subsets $N$\n",
159
+ "**Initialize:** $x_0, h_0 \\in \\mathbb{R}^n$\n",
160
+ "\n",
161
+ "For $k=0,1,\\dots,K-1$:\n",
162
+ "\n",
163
+ "1. Compute $G_k$ (an unbiased estimator of $\\nabla f(x_k)$).\n",
164
+ "2. $$\\hat x_{k+1} = x_k - \\gamma\\big(G_k(x_k) - h_k\\big).$$\n",
165
+ "3. Sample $\\theta_k\\sim\\mathrm{Bernoulli}(p)$, $\\theta_k\\in{0,1}$.\n",
166
+ "4. If $\\theta_k=1$:\n",
167
+ " $$x_{k+1}=\\mathrm{prox}_{\\frac{\\gamma}{p}g}\\left(\\hat x_{k+1}-\\frac{\\gamma}{p}h_k\\right),$$\n",
168
+ " else:\n",
169
+ " $$x_{k+1}=\\hat x_{k+1}.$$\n",
170
+ "5. Update:\n",
171
+ " $$h_{k+1}=h_k+\\frac{p}{\\gamma}\\big(x_{k+1}-\\hat x_{k+1}\\big).$$\n",
172
+ "\n",
173
+ "---\n",
174
+ "\n",
175
+ "| $p=1$ | $0<p<1$ | $G_k$ |\n",
176
+ "| --------- | ------------- | -------------------------------------------------------------------- |\n",
177
+ "| ISTA | ProxSkip | $\\nabla f(x_k)$ |\n",
178
+ "| ProxSGD | ProxSGDSkip | $N\\nabla f_{i_k}(x_k)$ |\n",
179
+ "| ProxSAGA | ProxSAGASkip | $N(\\nabla f_{i_k}(x_k)-v_k^{,i_k})+\\bar v_k$ |\n",
180
+ "| ProxSVRG | ProxSVRGSkip | $N(\\nabla f_{i_k}(x_k)-\\nabla f_{i_k}(\\tilde x))+\\nabla f(\\tilde x)$ |\n",
181
+ "| ProxLSVRG | ProxLSVRGSkip | as above, updated with $p=1/N$ |\n",
182
+ "\n",
183
+ "*Note:* FISTA is ISTA with an acceleration step (Beck & Teboulle).\n"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "id": "bfd1a9e7-f591-4391-8c4b-71228b3c49e1",
189
+ "metadata": {},
190
+ "source": [
191
+ "### Define Stopping Criteria and Metrics (PSNR/SSIM)"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "25b4287b-06a1-4d96-baa5-db52f3b8e4b0",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "gt_cil = ig2D.allocate()\n",
202
+ "gt_cil.fill(gt)\n",
203
+ "\n",
204
+ "SSIM_ = lambda x, y: SSIM(x, y, data_range=x.max()-x.min())\n",
205
+ "PSNR_ = lambda x, y: PSNR(x, y, data_range=x.max()-x.min())\n",
206
+ "\n",
207
+ "time_stop = 3*60 # 3 mins\n",
208
+ "max_iteration = 100_000\n",
209
+ "\n",
210
+ "cb_metrics = MetricsDiagnostics(reference_image=gt_cil, \n",
211
+ " metrics_dict={\"psnr\":PSNR_,\"ssim\":SSIM_}) \n"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "id": "1e03af7b-468a-43c1-9344-e66dcea5ce58",
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "### Avoid computing objectives, not needed for this demo\n",
222
+ "def update_objective(self):\n",
223
+ " return 0.\n",
224
+ "\n",
225
+ "ISTA.update_objective = update_objective\n",
226
+ "ProxSkip.update_objective = update_objective\n",
227
+ "FISTA.update_objective = update_objective"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "markdown",
232
+ "id": "1f22e499-849f-47b5-b959-b2b8254af0e6",
233
+ "metadata": {},
234
+ "source": [
235
+ "### Run deterministic algorithms: No data splitting, no skipping"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "id": "e67a015c-1e41-40fe-84c2-383e0558d4f3",
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "sigma_no_skip = 0.0002 \n",
246
+ "F = LeastSquares(A=A, b=noisy, c=0.5)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "id": "070838d2-505b-4265-846b-cee8a2a96450",
252
+ "metadata": {},
253
+ "source": [
254
+ "### FISTA"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "02a9d83c-3a9a-4935-8dd0-c6c6ff3dea2f",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "initial = ig2D.allocate()\n",
265
+ "step_size = 1./F.L\n",
266
+ "G = BM3DFunction(sigma_no_skip)\n",
267
+ "cb_time = StoppingCriterionTime(time_stop)\n",
268
+ "fista = FISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
269
+ " update_objective_interval = 1,\n",
270
+ " max_iteration = max_iteration) \n",
271
+ "fista.run(verbose=0, callback=[cb_metrics, cb_time])"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "id": "9a7833e5-095e-4524-b7f6-d3f4d5be473f",
277
+ "metadata": {},
278
+ "source": [
279
+ "### ISTA"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": null,
285
+ "id": "f1b02973-d57c-4f69-b6d1-4af18fa6b160",
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "G = BM3DFunction(sigma_no_skip)\n",
290
+ "cb_time = StoppingCriterionTime(time_stop)\n",
291
+ "ista = ISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
292
+ " update_objective_interval = 1,\n",
293
+ " max_iteration = max_iteration) \n",
294
+ "ista.run(verbose=0, callback=[cb_metrics, cb_time])"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "markdown",
299
+ "id": "57bd7320-c2cb-4e4c-a3f3-f382a0769d6a",
300
+ "metadata": {},
301
+ "source": [
302
+ "### Run ProxSkip: Skip the regulariser"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "id": "1a9b36d3-dd8a-43df-999c-acb414c3ba51",
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "prob = 0.05\n",
313
+ "sigma_skip = sigma_no_skip/np.sqrt(prob)"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": null,
319
+ "id": "cd6d1c7b-77c8-44e4-a71c-8169a097567e",
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "G = BM3DFunction(sigma_skip)\n",
324
+ "\n",
325
+ "cb_time = StoppingCriterionTime(time_stop)\n",
326
+ "proxskip = ProxSkip(initial = [initial, initial], f = F, step_size = step_size, g=G, \n",
327
+ " update_objective_interval = 1, prob=prob, seed = 42, \n",
328
+ " max_iteration = max_iteration) \n",
329
+ "proxskip.run(verbose=0, callback=[cb_metrics, cb_time])"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "id": "84f96bea-e56d-4fc7-a8b9-7b288430f079",
335
+ "metadata": {},
336
+ "source": [
337
+ "### Run stochastic algorithms: Data splitting, no skipping"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "id": "750978d9-af9e-4366-aeaf-810155d22f1e",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "def list_of_functions(data):\n",
348
+ " \n",
349
+ " list_funcs = []\n",
350
+ " ig = data[0].geometry.get_ImageGeometry()\n",
351
+ " \n",
352
+ " for d in data:\n",
353
+ " ageom_subset = d.geometry \n",
354
+ " Ai = ProjectionOperator(ig, ageom_subset, device = 'cpu') \n",
355
+ " fi = LeastSquares(Ai, b = d, c = 0.5)\n",
356
+ " list_funcs.append(fi) \n",
357
+ " \n",
358
+ " return list_funcs"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "id": "ce7998a1-9c98-49fe-b35a-89be65498fc3",
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": [
368
+ "# number of subsets\n",
369
+ "nsub = 50\n",
370
+ "data_split, method = noisy.split_to_subsets(nsub, method= \"ordered\", info=True)\n",
371
+ "\n",
372
+ "# list of fis for finite sum \n",
373
+ "list_func = list_of_functions(data_split) "
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "markdown",
378
+ "id": "5fd9e598-bafd-4bdf-b9b0-33f1646ce191",
379
+ "metadata": {},
380
+ "source": [
381
+ "### ProxSVRG and ProxSGD"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": null,
387
+ "id": "aeb2e2d8-6054-4759-bd84-d5cc7d2bbc08",
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "selection = RandomSampling(len(list_func), nsub, seed=42)\n",
392
+ "Fsvrg = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
393
+ "Fsvrg.initial = initial\n",
394
+ "step_size = 1./(Fsvrg.L)\n",
395
+ "\n",
396
+ "G = BM3DFunction(sigma_no_skip)\n",
397
+ "cb_time = StoppingCriterionTime(time_stop)\n",
398
+ "prox_svrg = ISTA(initial = initial, f = Fsvrg, step_size = step_size, g=G, \n",
399
+ " update_objective_interval = 1, \n",
400
+ " max_iteration = max_iteration) \n",
401
+ "prox_svrg.run(verbose=0, callback=[cb_metrics, cb_time]) \n"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "id": "b1565f4c-162c-41bd-919a-bfd61a4a57cb",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "selection = RandomSampling(len(list_func), nsub, seed=42)\n",
412
+ "Fsgd = SGFunction(list_func, selection = selection) \n",
413
+ "step_size = 1./(Fsgd.L)\n",
414
+ "\n",
415
+ "G = BM3DFunction(sigma_no_skip)\n",
416
+ "cb_time = StoppingCriterionTime(time_stop)\n",
417
+ "prox_sgd = ISTA(initial = initial, f = Fsgd, step_size = step_size, g=G, \n",
418
+ " update_objective_interval = 1, \n",
419
+ " max_iteration = max_iteration) \n",
420
+ "prox_sgd.run(verbose=0, callback=[cb_metrics, cb_time]) \n"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "markdown",
425
+ "id": "e35518e7-70d8-4c7d-955b-f062c133371e",
426
+ "metadata": {},
427
+ "source": [
428
+ "### Run stochastic algorithms: Data splitting and skipping"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "code",
433
+ "execution_count": null,
434
+ "id": "94d2a7e7-b0d3-4418-8d9d-fa35f506c5f3",
435
+ "metadata": {},
436
+ "outputs": [],
437
+ "source": [
438
+ "selection = RandomSampling(len(list_func), nsub, seed=42)\n",
439
+ "Fsvrg_skip = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
440
+ "Fsvrg_skip.initial = initial\n",
441
+ "step_size = 1./(Fsvrg_skip.L)\n",
442
+ "\n",
443
+ "G = BM3DFunction(sigma_skip)\n",
444
+ "cb_time = StoppingCriterionTime(time_stop)\n",
445
+ "prox_svrg_skip = ProxSkip(initial = [initial, initial], f = Fsvrg_skip, step_size = step_size, g=G, \n",
446
+ " update_objective_interval = 1, prob=prob, seed=42,\n",
447
+ " max_iteration = max_iteration) \n",
448
+ "prox_svrg_skip.run(verbose=0, callback=[cb_metrics, cb_time]) "
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": null,
454
+ "id": "fbfc107d-a6f7-45b8-8974-c1e766148bc0",
455
+ "metadata": {},
456
+ "outputs": [],
457
+ "source": [
458
+ "selection = RandomSampling(len(list_func), nsub, seed=42)\n",
459
+ "Fsgd_skip = SGFunction(list_func, selection = selection)\n",
460
+ "step_size = 1./(Fsgd_skip.L)\n",
461
+ "\n",
462
+ "G = BM3DFunction(sigma_skip)\n",
463
+ "cb_time = StoppingCriterionTime(time_stop)\n",
464
+ "prox_sgd_skip = ProxSkip(initial = [initial, initial], f = Fsgd_skip, step_size = step_size, g=G, \n",
465
+ " update_objective_interval = 1, prob=prob, seed=42,\n",
466
+ " max_iteration = max_iteration) \n",
467
+ "prox_sgd_skip.run(verbose=0, callback=[cb_metrics, cb_time]) "
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "markdown",
472
+ "id": "48181800-6930-454c-bc6c-e4fb9d4e2d42",
473
+ "metadata": {},
474
+ "source": [
475
+ "### Plot PSNR/SSIM progress\n",
476
+ "- with respect to iteration\n",
477
+ "- with respect to time"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "id": "2a76c9c9-f023-419b-9f7c-fdc840c92301",
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "t_ista = np.cumsum(ista.timing)\n",
488
+ "t_fista = np.cumsum(fista.timing)\n",
489
+ "t_proxskip = np.cumsum(proxskip.timing)\n",
490
+ "t_prox_svrg = np.cumsum(prox_svrg.timing)\n",
491
+ "t_prox_svrg_skip = np.cumsum(prox_svrg_skip.timing)\n",
492
+ "t_prox_sgd_skip = np.cumsum(prox_sgd_skip.timing)"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": null,
498
+ "id": "ca6e8583-5487-414e-ba5d-3573ae290595",
499
+ "metadata": {},
500
+ "outputs": [],
501
+ "source": [
502
+ "import numpy as np\n",
503
+ "import matplotlib.pyplot as plt\n",
504
+ "plt.rcParams['lines.linewidth'] = 5\n",
505
+ "plt.rcParams['lines.markersize'] = 20\n",
506
+ "plt.rcParams['font.size'] = 40\n",
507
+ "\n",
508
+ "fig, ax = plt.subplots(2, 2, figsize=(35, 25), constrained_layout=True)\n",
509
+ "\n",
510
+ "ax[0,0].plot(ista.psnr[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
511
+ "ax[0,0].plot(fista.psnr[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
512
+ "ax[0,0].plot(proxskip.psnr[:-1], label=f\"ProxSkip, K={proxskip.iteration}, #BM3D={np.sum(proxskip.thetas)}\")\n",
513
+ "ax[0,0].plot(prox_svrg.psnr[:-1], label=f\"ProxSVRG, K={prox_svrg.iteration}, #BM3D={prox_svrg.iteration}\")\n",
514
+ "ax[0,0].plot(prox_sgd_skip.psnr[:-1],label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
515
+ "ax[0,0].plot(prox_svrg_skip.psnr[:-1],label=f\"ProxSVRGSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\")\n",
516
+ "ax[0,0].grid(which=\"major\")\n",
517
+ "ax[0,0].set_xlabel(\"Iteration\")\n",
518
+ "ax[0,0].set_ylabel(\"PSNR\")\n",
519
+ "\n",
520
+ "ax[0,1].plot(ista.ssim[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
521
+ "ax[0,1].plot(fista.ssim[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
522
+ "ax[0,1].plot(proxskip.ssim[:-1], label=f\"ProxSkip, p={prob}, #prox={np.sum(proxskip.thetas)}\")\n",
523
+ "ax[0,1].plot(prox_svrg.ssim[:-1], label=\"ProxSVRG\")\n",
524
+ "ax[0,1].semilogy(prox_sgd_skip.ssim[:-1],label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, N={nsub}, p={prob}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
525
+ "ax[0,1].plot(prox_svrg_skip.ssim[:-1],label=f\"ProxSVRGSkip, p={prob}, #prox={np.sum(prox_svrg_skip.thetas)}\")\n",
526
+ "ax[0,1].grid(which=\"major\")\n",
527
+ "ax[0,1].set_xlabel(\"Iteration\")\n",
528
+ "ax[0,1].set_ylabel(\"SSIM\")\n",
529
+ "\n",
530
+ "ax[1,0].plot(t_ista, ista.psnr[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
531
+ "ax[1,0].plot(t_fista, fista.psnr[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
532
+ "ax[1,0].plot(t_proxskip, proxskip.psnr[:-1], label=f\"ProxSkip, K={proxskip.iteration}, #BM3D={np.sum(proxskip.thetas)}\")\n",
533
+ "ax[1,0].plot(t_prox_svrg, prox_svrg.psnr[:-1], label=f\"ProxSVRG, K={prox_svrg.iteration}, #BM3D={prox_svrg.iteration}\")\n",
534
+ "ax[1,0].plot(t_prox_sgd_skip, prox_sgd_skip.psnr[:-1],\n",
535
+ " label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
536
+ "ax[1,0].plot(t_prox_svrg_skip, prox_svrg_skip.psnr[:-1],\n",
537
+ " label=f\"ProxSVRGSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\")\n",
538
+ "ax[1,0].grid(which=\"major\")\n",
539
+ "ax[1,0].set_xlabel(\"Time (sec)\")\n",
540
+ "ax[1,0].set_ylabel(\"PSNR\")\n",
541
+ "\n",
542
+ "ax[1,1].plot(t_ista, ista.ssim[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
543
+ "ax[1,1].plot(t_fista, fista.ssim[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
544
+ "ax[1,1].plot(t_proxskip, proxskip.ssim[:-1], label=f\"ProxSkip, p={prob}, #prox={np.sum(proxskip.thetas)}\")\n",
545
+ "ax[1,1].plot(t_prox_svrg, prox_svrg.ssim[:-1], label=\"ProxSVRG\")\n",
546
+ "ax[1,1].semilogy(t_prox_sgd_skip, prox_sgd_skip.ssim[:-1],\n",
547
+ " label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, N={nsub}, p={prob}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
548
+ "ax[1,1].plot(t_prox_svrg_skip, prox_svrg_skip.ssim[:-1],\n",
549
+ " label=f\"ProxSVRGSkip, p={prob}, #prox={np.sum(prox_svrg_skip.thetas)}\")\n",
550
+ "ax[1,1].grid(which=\"major\")\n",
551
+ "ax[1,1].set_xlabel(\"Time (sec)\")\n",
552
+ "ax[1,1].set_ylabel(\"SSIM\")\n",
553
+ "\n",
554
+ "handles, labels = ax[0,1].get_legend_handles_labels()\n",
555
+ "fig.legend(handles, labels, loc=\"upper center\", bbox_to_anchor=(0.53, 1.14),\n",
556
+ " ncols=2, frameon=True)\n",
557
+ "\n",
558
+ "\n",
559
+ "plt.show()\n"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": null,
565
+ "id": "af5e1f06-ec65-4944-a84a-1681462e321e",
566
+ "metadata": {},
567
+ "outputs": [],
568
+ "source": [
569
+ "import numpy as np\n",
570
+ "import matplotlib.pyplot as plt\n",
571
+ "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
572
+ "from matplotlib.patches import Rectangle\n",
573
+ "\n",
574
+ "imgs = [\n",
575
+ " x_cil.array,\n",
576
+ " fbp.array, \n",
577
+ " ista.solution.array,\n",
578
+ " fista.solution.array,\n",
579
+ " proxskip.solution.array,\n",
580
+ " prox_svrg.solution.array,\n",
581
+ " prox_sgd_skip.solution.array,\n",
582
+ " prox_svrg_skip.solution.array,\n",
583
+ "]\n",
584
+ "\n",
585
+ "labels = [\n",
586
+ " \"Ground-Truth\",\n",
587
+ " f\"FBP\",\n",
588
+ " f\"ISTA\\nPSNR/SSIM = {ista.psnr[-1]:.2f}/{ista.ssim[-1]:.3f}\",\n",
589
+ " f\"FISTA\\nPSNR/SSIM = {fista.psnr[-1]:.2f}/{fista.ssim[-1]:.3f}\",\n",
590
+ " f\"ProxSkip\\nPSNR/SSIM = {proxskip.psnr[-1]:.2f}/{proxskip.ssim[-1]:.3f}\",\n",
591
+ " f\"ProxSVRG\\nPSNR/SSIM = {prox_svrg.psnr[-1]:.2f}/{prox_svrg.ssim[-1]:.3f}\",\n",
592
+ " f\"ProxSGDSkip\\nPSNR/SSIM = {prox_sgd_skip.psnr[-1]:.2f}/{prox_sgd_skip.ssim[-1]:.3f}\",\n",
593
+ " f\"ProxSVRGSkip\\nPSNR/SSIM = {prox_svrg_skip.psnr[-1]:.2f}/{prox_svrg_skip.ssim[-1]:.3f}\",\n",
594
+ "]\n",
595
+ "\n",
596
+ "vmax = max(np.max(im) for im in imgs)\n",
597
+ "vmin = 0.01\n",
598
+ "vmax = 0.03\n",
599
+ "\n",
600
+ "h, w = imgs[0].shape[:2]\n",
601
+ "img_aspect = w / h\n",
602
+ "nrows, ncols = 2, 4\n",
603
+ "\n",
604
+ "fig_h = 15\n",
605
+ "fig_w = fig_h * (ncols / nrows) * img_aspect\n",
606
+ "\n",
607
+ "fig = plt.figure(figsize=(fig_w, fig_h))\n",
608
+ "gs = fig.add_gridspec(nrows, ncols, wspace=0, hspace=0)\n",
609
+ "axes = np.array([fig.add_subplot(gs[i, j]) for i in range(nrows) for j in range(ncols)])\n",
610
+ "fig.subplots_adjust(left=0, right=1, bottom=0, top=1)\n",
611
+ "\n",
612
+ "H, W = imgs[0].shape[:2]\n",
613
+ "roi_size = 80\n",
614
+ "cx, cy = 75, 120 \n",
615
+ "\n",
616
+ "x0 = cx - roi_size // 2\n",
617
+ "x1 = x0 + roi_size\n",
618
+ "y0 = cy - roi_size // 2\n",
619
+ "y1 = y0 + roi_size\n",
620
+ "\n",
621
+ "x0 = max(0, min(x0, W - roi_size)); x1 = x0 + roi_size\n",
622
+ "y0 = max(0, min(y0, H - roi_size)); y1 = y0 + roi_size\n",
623
+ "\n",
624
+ "inset_size = \"32%\"\n",
625
+ "inset_borderpad = 0.6\n",
626
+ "\n",
627
+ "for ax, im, txt in zip(axes, imgs, labels):\n",
628
+ " ax.imshow(im, cmap=\"inferno\", vmin=vmin, vmax=vmax)\n",
629
+ " ax.axis(\"off\")\n",
630
+ "\n",
631
+ " ax.text(\n",
632
+ " 0.02, 0.98, txt,\n",
633
+ " transform=ax.transAxes,\n",
634
+ " ha=\"left\", va=\"top\",\n",
635
+ " color=\"white\", fontsize=32,\n",
636
+ " bbox=dict(boxstyle=\"round,pad=0.3\", facecolor=\"black\", alpha=0.25, edgecolor=\"none\")\n",
637
+ " )\n",
638
+ "\n",
639
+ " rect = Rectangle((x0, y0), roi_size, roi_size,\n",
640
+ " linewidth=2.5, edgecolor=\"yellow\", facecolor=\"none\")\n",
641
+ " ax.add_patch(rect)\n",
642
+ "\n",
643
+ " axins = inset_axes(ax, width=inset_size, height=inset_size,\n",
644
+ " loc=\"lower right\", borderpad=inset_borderpad)\n",
645
+ " axins.imshow(im, cmap=\"inferno\", vmin=vmin, vmax=vmax)\n",
646
+ " axins.set_xlim(x0, x1)\n",
647
+ " axins.set_ylim(y1, y0) \n",
648
+ " axins.set_xticks([])\n",
649
+ " axins.set_yticks([])\n",
650
+ " for spine in axins.spines.values():\n",
651
+ " spine.set_linewidth(2)\n",
652
+ " spine.set_edgecolor(\"yellow\")\n",
653
+ "\n",
654
+ "plt.show()\n"
655
+ ]
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "execution_count": null,
660
+ "id": "fa9861c4-aeaa-46db-92b7-3198ca2eaab3",
661
+ "metadata": {},
662
+ "outputs": [],
663
+ "source": []
664
+ }
665
+ ],
666
+ "metadata": {
667
+ "kernelspec": {
668
+ "display_name": "Python [conda env:ssp]",
669
+ "language": "python",
670
+ "name": "conda-env-ssp-py"
671
+ },
672
+ "language_info": {
673
+ "codemirror_mode": {
674
+ "name": "ipython",
675
+ "version": 3
676
+ },
677
+ "file_extension": ".py",
678
+ "mimetype": "text/x-python",
679
+ "name": "python",
680
+ "nbconvert_exporter": "python",
681
+ "pygments_lexer": "ipython3",
682
+ "version": "3.12.12"
683
+ }
684
+ },
685
+ "nbformat": 4,
686
+ "nbformat_minor": 5
687
+ }
ProxSkip.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cil.optimisation.algorithms import Algorithm
2
+ import numpy as np
3
+ import logging
4
+
5
+
6
+ class ProxSkip(Algorithm):
7
+
8
+
9
+ r"""Proximal Skip (ProxSkip) algorithm, see "ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!†"
10
+
11
+ Parameters
12
+ ----------
13
+
14
+ initial : DataContainer
15
+ Initial point for the ProxSkip algorithm.
16
+ f : Function
17
+ A smooth function with Lipschitz continuous gradient.
18
+ g : Function
19
+ A convex function with a "simple" proximal.
20
+ prob : positive :obj:`float`
21
+ Probability to skip the proximal step. If :code:`prob=1`, proximal step is used in every iteration.
22
+ step_size : positive :obj:`float`
23
+ Step size for the ProxSkip algorithm and is equal to 1./L where L is the Lipschitz constant for the gradient of f.
24
+
25
+ """
26
+
27
+
28
+ def __init__(self, initial, f, g, step_size, prob, seed=None, **kwargs):
29
+ """ Set up of the algorithm
30
+ """
31
+
32
+ super(ProxSkip, self).__init__(**kwargs)
33
+
34
+ self.f = f # smooth function
35
+ self.g = g # proximable
36
+ self.step_size = step_size
37
+ self.prob = prob
38
+ self.rng = np.random.default_rng(seed)
39
+ self.set_up(initial, f, g, step_size, prob, **kwargs)
40
+ self.thetas = []
41
+ self.prox_iterates = []
42
+
43
+
44
+ def set_up(self, initial, f, g, step_size, prob, **kwargs):
45
+
46
+ logging.info("{} setting up".format(self.__class__.__name__, ))
47
+
48
+ self.initial = initial[0]
49
+
50
+ self.x = initial[0].copy()
51
+ self.xhat_new = initial[0].copy()
52
+ self.x_new = initial[0].copy()
53
+ self.ht = initial[1].copy() #self.f.gradient(initial)
54
+ # self.ht = self.f.gradient(initial)
55
+
56
+ self.configured = True
57
+
58
+ # count proximal and non proximal steps
59
+ self.use_prox = 0
60
+ # self.no_use_prox = 0
61
+
62
+ logging.info("{} configured".format(self.__class__.__name__, ))
63
+
64
+
65
+ def update(self):
66
+ r""" Performs a single iteration of the ProxSkip algorithm
67
+ """
68
+
69
+ self.f.gradient(self.x, out=self.xhat_new)
70
+ self.xhat_new -= self.ht
71
+ self.x.sapyb(1., self.xhat_new, -self.step_size, out=self.xhat_new)
72
+
73
+ theta = self.rng.random() < self.prob
74
+ if self.iteration==0:
75
+ theta = 1
76
+ self.thetas.append(theta)
77
+
78
+ if theta==1:
79
+ # print("here")
80
+ # Proximal step is used
81
+ self.g.proximal(self.xhat_new - (self.step_size/self.prob)*self.ht, self.step_size/self.prob, out=self.x_new)
82
+ self.ht.sapyb(1., (self.x_new - self.xhat_new), (self.prob/self.step_size), out=self.ht)
83
+ self.use_prox+=1
84
+ # self.prox_iterates.append(self.x.copy())
85
+ else:
86
+ # Proximal step is skipped
87
+ # print("here1")
88
+ self.x_new.fill(self.xhat_new)
89
+
90
+ def _update_previous_solution(self):
91
+ """ Swaps the references to current and previous solution based on the :func:`~Algorithm.update_previous_solution` of the base class :class:`Algorithm`.
92
+ """
93
+ tmp = self.x_new
94
+ self.x = self.x_new
95
+ self.x = tmp
96
+
97
+ def get_output(self):
98
+ " Returns the current solution. "
99
+ return self.x
100
+
101
+
102
+ def update_objective(self):
103
+
104
+ """ Updates the objective
105
+
106
+ .. math:: f(x) + g(x)
107
+
108
+ """
109
+
110
+ fun_g = self.g(self.x)
111
+ fun_f = self.f(self.x)
112
+ p1 = fun_f + fun_g
113
+ self.loss.append( p1 )
114
+
115
+ def proximal_evaluations(self):
116
+
117
+ prox_evals = []
118
+
119
+ for i in self.iterations:
120
+ if self.rng.random() < self.prob:
121
+ prox_evals.append(1)
122
+ else:
123
+ prox_evals.append(0)
124
+ return prox_evals
README.md CHANGED
@@ -1,10 +1,74 @@
1
  ---
2
- title: Split Skip And Play
3
- emoji: 🌍
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: docker
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Split-Skip-and-Play
 
 
 
3
  sdk: docker
4
+ app_port: 7860
5
  ---
6
 
7
+
8
+ # Split-Skip-and-Play
9
+
10
+ Code to reproduce results for paper ["SPLIT, SKIP AND PLAY: VARIANCE-REDUCED PROXSKIP FOR TOMOGRAPHY
11
+ RECONSTRUCTION IS EXTREMELY FAST"](https://arxiv.org/abs/2602.09527) by Evangelos Papoutsellis, Zeljko Kereta, Kostas Papafitsoros. This is an extension of the paper ["Why do we regularise in every iteration for imaging inverse problems?"](https://arxiv.org/abs/2411.00688).
12
+
13
+ ![](PnP_BM3D_1min_reconstruction.gif)
14
+
15
+ [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/epapoutsellis/Split-Skip-and-Play/HEAD)
16
+
17
+ ### Abstract
18
+ Many modern iterative solvers for large-scale tomographic reconstruction incur two major computational costs per iteration: expensive forward/adjoint projections to update the data fidelity term and costly proximal computations for the regulariser, often done via inner iterations. This paper studies for the first time the application of methods that couple randomised skipping of the proximal with variance-reduced subset-based optimisation of data-fit term, to simultaneously reduce both costs in challenging tomographic reconstruction tasks. We provide a series of experiments using both synthetic and real data, demonstrating striking speed-ups of the order 5x--20x compared to the non-skipped counterparts which have been so far the standard approach for efficiently solving these problems. Our work lays the groundwork for broader adoption of these methods in inverse problems.
19
+
20
+
21
+ ### Installation
22
+
23
+ We use the [Core Imaging Library (CIL)](https://github.com/TomographicImaging/CIL) with some additional [new functionalities](#Appendix). Code and installation tested on macOS (Apple M2 Pro), Linux, and Windows 10.
24
+
25
+ ```
26
+ conda create --name ssp -c conda-forge python=3.12 "numpy<2.0" cmake scipy six cython numba pillow jupyterlab scikit-learn dask "zarr<3" pywavelets astra-toolbox tqdm nb_conda_kernels
27
+
28
+ conda activate ssp
29
+ pip install bm3d xdesign scikit-image
30
+ pip install "setuptools<82" --upgrade
31
+
32
+ git clone https://github.com/epapoutsellis/StochasticCIL.git
33
+ cd StochasticCIL
34
+ git checkout svrg
35
+ git tag -a v1.0 -m "Version 1.0"
36
+ mkdir build
37
+ cd build
38
+ cmake ../ -DCONDA_BUILD=OFF -DCMAKE_BUILD_TYPE="Release" -DLIBRARY_LIB=$CONDA_PREFIX/lib -DLIBRARY_INC=$CONDA_PREFIX -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
39
+ make install
40
+ ```
41
+ For windows: `cmake ../ -DCONDA_BUILD=OFF`, `cmake --build . --target install`
42
+
43
+ ### Citation
44
+
45
+ ```
46
+ @article{2602.09527,
47
+ Author = {Evangelos Papoutsellis and Zeljko Kereta and Kostas Papafitsoros},
48
+ Title = {Split, Skip and Play: Variance-Reduced ProxSkip for Tomography Reconstruction is Extremely Fast},
49
+ Year = {2026},
50
+ Eprint = {arXiv:2602.09527},
51
+ }
52
+ ```
53
+
54
+ ### New CIL Functionalities
55
+
56
+ 1) Splitting methods for Acquisition Data compatible with CIL and [SIRF](https://github.com/SyneRBI/SIRF).
57
+ 2) Sampling methods for CIL Stochastic Functions (used also by SPDHG).
58
+ 3) ApproximateGradientSumFunction (Base class for CIL Stochastic Functions)
59
+ 4) SGFunction
60
+ 5) SAGFunction
61
+ 6) SAGAFunction
62
+ 7) SVRGFunction
63
+ 8) LSVRGFunction
64
+ 9) PGA (Proximal Gradient Algorithm), base class for GD, ISTA, FISTA.  Designed for fixed and/or adaptive Preconditioners and step sizes.
65
+ 10) Preconditioner (base class). An instance of Preconditioner can be passed to GD, ISTA, FISTA.
66
+ 11) StepSizeMethod (base class for step size search), including standard Armijo/Backtracking.
67
+ 12) Callback utilities, including standard metrics (compared to a reference) and statistics of iterates. Any function (CIL), or callables from other libraries can be used.
68
+ 13) PD3O (Primal Dual Three Operator Splitting Algorithm) which can be combined with a Stochastic CIL function.
69
+ 14) SIRF Priors classes can be used in the CIL Optimisation Framework for free. In addition, SIRF ObjectiveFunction Classes can be used. This allows more flexibility for SIRF users to have control and monitor a CIL algorithm.
70
+
71
+ ### Acknowledgements
72
+
73
+ E. Papoutsellis acknowledges funding through the Innovate UK Analysis for Innovators (A4i) program: "Denoising of chemical imaging and tomography data" under which the experiments were initially conducted, CCPi (EPSRC grant EP/T026677/1), CCP SyneRBI (EPSRC grant EP/T026693/1).
74
+
TV_Tomography_Reconstruction_VR_Skip.ipynb ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "e82fbdab-f56e-471c-a314-63c97d08d197",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from cil.framework import AcquisitionGeometry\n",
11
+ "from cil.optimisation.algorithms import ISTA, FISTA\n",
12
+ "from cil.optimisation.functions import SVRGFunction, LSVRGFunction, LeastSquares\n",
13
+ "from cil.optimisation.utilities import RandomSampling, MetricsDiagnostics\n",
14
+ "from cil.plugins.astra import FBP, ProjectionOperator\n",
15
+ "from TotalVariation import TotalVariationNew\n",
16
+ "from ProxSkip import ProxSkip\n",
17
+ "from utils import create_circular_mask, StoppingCriterion\n",
18
+ "\n",
19
+ "import numpy as np\n",
20
+ "import zarr\n",
21
+ "import matplotlib.pyplot as plt\n",
22
+ "\n",
23
+ "plt.rcParams['lines.linewidth'] = 5\n",
24
+ "plt.rcParams['lines.markersize'] = 20\n",
25
+ "plt.rcParams['font.size'] = 40\n",
26
+ "plt.rcParams[\"image.cmap\"] = \"inferno\""
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "id": "e5e68083-d6ff-4cc8-9f4e-6e7fbcd38d43",
32
+ "metadata": {},
33
+ "source": [
34
+ "### Total Variation Reconstruction on a real dataset\n",
35
+ "\n",
36
+ "In this notebook, we play the following game. We start from a high-accuracy reference solution of\n",
37
+ "\n",
38
+ "$$ \n",
39
+ "\\begin{equation}\n",
40
+ "\\min_{x\\in\\mathbb{R}^n} \\; \\frac{1}{2}\\|Ax - b\\|_{2}^{2}\n",
41
+ "+ \\alpha \\mathrm{TV}(x)\n",
42
+ "+ \\mathbb{I}_{\\{x \\ge 0\\}}(x),\n",
43
+ "\\end{equation}\n",
44
+ "$$\n",
45
+ "\n",
46
+ "and we fix a target accuracy threshold $\\varepsilon>0$. The goal is to compare algorithms by **time-to-accuracy**: namely, determine which algorithm reaches $\\frac{\\|x_k - x^*\\|_2}{\\|x^*\\|_2}$ below $\\varepsilon$ in the **shortest runtime**.\n"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "1d21a834-36d9-4a3c-b036-c9fdf5c3dfbe",
52
+ "metadata": {},
53
+ "source": [
54
+ "### Load real dataset"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "ee55e26e-95ac-48d0-91fa-8ffbe38825fc",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "sino = zarr.load(\"data/NiPd_spent_01_microct_rings_removed_2D.zarr\")\n",
65
+ "\n",
66
+ "print(f\"sinogram shape is {sino.shape}\")\n",
67
+ "\n",
68
+ "_, horizontal = sino.shape\n",
69
+ "\n",
70
+ "angles_list = np.linspace(0, np.pi, 800)[::2]\n",
71
+ "ag2D = AcquisitionGeometry.create_Parallel2D().\\\n",
72
+ " set_panel(horizontal).\\\n",
73
+ " set_angles(angles_list, angle_unit=\"radian\").\\\n",
74
+ " set_labels(['angle','horizontal'])\n",
75
+ "ig2D = ag2D.get_ImageGeometry()\n",
76
+ "\n",
77
+ "data2D = ag2D.allocate()\n",
78
+ "data2D.fill(sino)"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "id": "41b24ec4-f012-4fa1-a2de-95b162228de1",
84
+ "metadata": {},
85
+ "source": [
86
+ "### FBP reconstruction"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "f350f391-d445-432e-a772-24810e3a6266",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "fbp = FBP(ig2D, ag2D, device=\"cpu\")(data2D)\n",
97
+ "fbp.array[fbp.array<0] = 0\n",
98
+ "plt.imshow(fbp.array)"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "id": "b48aee65-6bae-4520-ae24-35d14555adbf",
104
+ "metadata": {},
105
+ "source": [
106
+ "### Load high accuracy solution "
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "79240538-ca62-4f0b-b1e6-df38b8d48f4a",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "alpha = 0.1"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "id": "7f251eeb-3298-46ab-a8f3-bb8782d9ffd9",
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "pdhg_optimal_info = zarr.open_group(\"data/pdhg_optimal_finden_tv_alpha_{}_explicit_precond_maxiterations_200000.zarr\".format(alpha))\n",
127
+ "\n",
128
+ "pdhg_optimal_np = pdhg_optimal_info[\"solution\"][:]\n",
129
+ "pdhg_optimal_cil = ig2D.allocate()\n",
130
+ "pdhg_optimal_cil.fill(pdhg_optimal_np)"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "id": "44e0cb98-199d-46b1-9250-d876446a36f1",
136
+ "metadata": {},
137
+ "source": [
138
+ "### Family of considered algorithms (ProxSkip template)\n",
139
+ "\n",
140
+ "**Parameters:** $\\gamma>0$, probability $p\\in(0,1]$, data subsets $N$\n",
141
+ "**Initialize:** $x_0, h_0 \\in \\mathbb{R}^n$\n",
142
+ "\n",
143
+ "For $k=0,1,\\dots,K-1$:\n",
144
+ "\n",
145
+ "1. Compute $G_k$ (an unbiased estimator of $\\nabla f(x_k)$).\n",
146
+ "2. $$\\hat x_{k+1} = x_k - \\gamma\\big(G_k(x_k) - h_k\\big).$$\n",
147
+ "3. Sample $\\theta_k\\sim\\mathrm{Bernoulli}(p)$, $\\theta_k\\in{0,1}$.\n",
148
+ "4. If $\\theta_k=1$:\n",
149
+ " $$x_{k+1}=\\mathrm{prox}_{\\frac{\\gamma}{p}g}\\left(\\hat x_{k+1}-\\frac{\\gamma}{p}h_k\\right),$$\n",
150
+ " else:\n",
151
+ " $$x_{k+1}=\\hat x_{k+1}.$$\n",
152
+ "5. Update:\n",
153
+ " $$h_{k+1}=h_k+\\frac{p}{\\gamma}\\big(x_{k+1}-\\hat x_{k+1}\\big).$$\n",
154
+ "\n",
155
+ "---\n",
156
+ "\n",
157
+ "| $p=1$ | $0<p<1$ | $G_k$ |\n",
158
+ "| --------- | ------------- | -------------------------------------------------------------------- |\n",
159
+ "| ISTA | ProxSkip | $\\nabla f(x_k)$ |\n",
160
+ "| ProxSGD | ProxSGDSkip | $N\\nabla f_{i_k}(x_k)$ |\n",
161
+ "| ProxSAGA | ProxSAGASkip | $N(\\nabla f_{i_k}(x_k)-v_k^{,i_k})+\\bar v_k$ |\n",
162
+ "| ProxSVRG | ProxSVRGSkip | $N(\\nabla f_{i_k}(x_k)-\\nabla f_{i_k}(\\tilde x))+\\nabla f(\\tilde x)$ |\n",
163
+ "| ProxLSVRG | ProxLSVRGSkip | as above, updated with $p=1/N$ |\n",
164
+ "\n",
165
+ "*Note:* FISTA is ISTA with an acceleration step (Beck & Teboulle).\n"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "id": "6079afd7-9f2b-43bf-9a9a-59a589c52b8e",
171
+ "metadata": {},
172
+ "source": [
173
+ "### Define Stopping Criteria and Metrics (PSNR/SSIM)"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "id": "baa0f235-7dab-430a-87c6-4b6d5a72f31a",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "h, w = pdhg_optimal_cil.shape\n",
184
+ "mask = create_circular_mask(h, w, center=(170,165), radius=150)\n",
185
+ "\n",
186
+ "def NRSE(x, y, **kwargs):\n",
187
+ " return np.sqrt(np.sum(np.abs(x - y*mask.ravel())**2))/np.sqrt(np.sum(x**2))\n",
188
+ "\n",
189
+ "cb_metrics = MetricsDiagnostics(reference_image=pdhg_optimal_cil*mask, \n",
190
+ " metrics_dict={\"rse\":NRSE}) \n",
191
+ "\n",
192
+ "epsilon = 0.99e-5\n",
193
+ "inner_its = 10\n",
194
+ "epochs = 100\n",
195
+ "max_iteration = 5000\n",
196
+ "seed_skip = 42\n",
197
+ "seed_sampling = 42"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "id": "f4146831-c967-407d-a1f0-aa0ff7dca979",
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "### Avoid computing objectives, not needed for this demo\n",
208
+ "def update_objective(self):\n",
209
+ " return 0.\n",
210
+ "\n",
211
+ "ISTA.update_objective = update_objective\n",
212
+ "ProxSkip.update_objective = update_objective\n",
213
+ "FISTA.update_objective = update_objective"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "markdown",
218
+ "id": "743c77b3-49a0-46c7-bf43-b322f78b1753",
219
+ "metadata": {},
220
+ "source": [
221
+ "### Run deterministic algorithms: No data splitting, no skipping"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "86570298-273c-4ab5-a04b-96e70df402ac",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "K = ProjectionOperator(ig2D, ag2D, device=\"cpu\")\n",
232
+ "F = LeastSquares(A=K, b=data2D, c=0.5)\n",
233
+ "G = alpha * TotalVariationNew(max_iteration = inner_its, tolerance=None, correlation='Space',\n",
234
+ " backend='c', lower=0, upper=np.infty, isotropic=True, \n",
235
+ " split=False, info=False, strong_convexity_constant=0,\n",
236
+ " warm_start=True)\n",
237
+ "initial = ig2D.allocate()"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "id": "71eca541-0265-4af8-bbbf-c7b791bbe362",
243
+ "metadata": {},
244
+ "source": [
245
+ "### FISTA"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "id": "f5a3e483-7bdc-45da-b35b-c8147f1c6ce2",
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "step_size = 1./F.L\n",
256
+ "cb_error = StoppingCriterion(epsilon=epsilon) \n",
257
+ "fista = FISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
258
+ " update_objective_interval = 1,\n",
259
+ " max_iteration = max_iteration) \n",
260
+ "fista.run(verbose=0, callback=[cb_metrics, cb_error])"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "id": "47e90863-2b50-4d61-a9b3-e0918f3bf316",
266
+ "metadata": {},
267
+ "source": [
268
+ "### ISTA"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": null,
274
+ "id": "f11fc38d-cca0-4fad-bd5c-86f3e3402e41",
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "step_size = 1.99/F.L\n",
279
+ "cb_error = StoppingCriterion(epsilon=epsilon) \n",
280
+ "ista = ISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
281
+ " update_objective_interval = 1,\n",
282
+ " max_iteration = max_iteration) \n",
283
+ "ista.run(verbose=0, callback=[cb_metrics, cb_error])"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "id": "9dd6d3f8-7418-44b4-9a6a-f831f116cced",
289
+ "metadata": {},
290
+ "source": [
291
+ "### Run ProxSkip: Skip the regulariser"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": null,
297
+ "id": "72f458c9-6f9c-4a2d-8877-19a9dc5a638f",
298
+ "metadata": {},
299
+ "outputs": [],
300
+ "source": [
301
+ "prob = 0.1"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": null,
307
+ "id": "805fd84c-9f20-45ba-8188-957f044cee89",
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "step_size = 1.99/F.L\n",
312
+ "cb_error = StoppingCriterion(epsilon=epsilon) \n",
313
+ "proxskip = ProxSkip(initial = [initial, initial], f = F, step_size = step_size, g=G, \n",
314
+ " update_objective_interval = 1, prob=prob, seed = seed_skip,\n",
315
+ " max_iteration = max_iteration) \n",
316
+ "proxskip.run(verbose=0, callback=[cb_metrics, cb_error])\n"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "markdown",
321
+ "id": "f6d2c3a3-ec14-4b9c-ad87-76e62f1161fe",
322
+ "metadata": {},
323
+ "source": [
324
+ "### Run stochastic algorithms: Data splitting, no skipping"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "id": "8fbc2e91-2a5d-4f50-9f60-c2ced3b41038",
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "def list_of_functions(data):\n",
335
+ " \n",
336
+ " list_funcs = []\n",
337
+ " ig = data[0].geometry.get_ImageGeometry()\n",
338
+ " \n",
339
+ " for d in data:\n",
340
+ " ageom_subset = d.geometry \n",
341
+ " Ai = ProjectionOperator(ig, ageom_subset, device = 'cpu') \n",
342
+ " fi = LeastSquares(Ai, b = d, c = 0.5)\n",
343
+ " list_funcs.append(fi) \n",
344
+ " \n",
345
+ " return list_funcs"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": null,
351
+ "id": "2364fecd-aab6-4653-8ba2-cf855a9ebda9",
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "nsub = 200\n",
356
+ "data_split, method = data2D.split_to_subsets(nsub, method= \"ordered\", info=True)\n",
357
+ "list_func = list_of_functions(data_split) "
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": null,
363
+ "id": "a25d5716-cba8-4c78-9327-043d351ed2b1",
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
368
+ "Fsvrg = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
369
+ "Fsvrg.initial = initial\n",
370
+ "step_size = 1./(Fsvrg.L)\n",
371
+ "cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
372
+ "prox_svrg = ISTA(initial = initial, f = Fsvrg, step_size = step_size, g=G, \n",
373
+ " update_objective_interval = 1, \n",
374
+ " max_iteration = max_iteration) \n",
375
+ "prox_svrg.run(verbose=0, callback=[cb_metrics, cb_error])"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": null,
381
+ "id": "1dce5d0d-4873-40f9-8339-d35fd5eac23f",
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
386
+ "Flsvrg = LSVRGFunction(list_func, selection = selection, update_prob=1./len(list_func))\n",
387
+ "Flsvrg.initial = initial\n",
388
+ "step_size = 1./(Fsvrg.L)\n",
389
+ "cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
390
+ "prox_lsvrg = ISTA(initial = initial, f = Flsvrg, step_size = step_size, g=G, \n",
391
+ " update_objective_interval = 1, \n",
392
+ " max_iteration = max_iteration) \n",
393
+ "prox_lsvrg.run(verbose=0, callback=[cb_metrics, cb_error])"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "markdown",
398
+ "id": "59988152-5f09-4950-a770-a8b4d435ee44",
399
+ "metadata": {},
400
+ "source": [
401
+ "### Run stochastic algorithms: Data splitting and skipping"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "id": "4d74a1ae-4f52-4e04-a4a4-03b4488f5fef",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
412
+ "Fsvrg_skip = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
413
+ "Fsvrg_skip.initial = initial\n",
414
+ "step_size = 1./(Fsvrg_skip.L)\n",
415
+ "cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
416
+ "prox_svrg_skip = ProxSkip(initial = [initial, initial], f = Fsvrg_skip, step_size = step_size, g=G, \n",
417
+ " update_objective_interval = 1, prob = prob,\n",
418
+ " max_iteration = max_iteration) \n",
419
+ "prox_svrg_skip.run(verbose=0, callback=[cb_metrics, cb_error])"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": null,
425
+ "id": "2244d1db-1338-46ca-8355-bfcc03ce7e56",
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": [
429
+ "selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
430
+ "Flsvrg_skip = LSVRGFunction(list_func, selection = selection, update_prob=1./len(list_func))\n",
431
+ "Flsvrg_skip.initial = initial\n",
432
+ "step_size = 1./(Flsvrg_skip.L)\n",
433
+ "cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
434
+ "prox_lsvrg_skip = ProxSkip(initial = [initial, initial], f = Flsvrg_skip, step_size = step_size, g=G, \n",
435
+ " update_objective_interval = 1, prob = prob,\n",
436
+ " max_iteration = max_iteration) \n",
437
+ "prox_lsvrg_skip.run(verbose=0, callback=[cb_metrics, cb_error])"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "id": "f0d1a611-d16c-432e-a6e5-56cae3b7e640",
443
+ "metadata": {},
444
+ "source": [
445
+ "### Plot PSNR/SSIM progress\n",
446
+ "- with respect to iteration\n",
447
+ "- with respect to time"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "id": "ee4475aa-9018-427e-82c4-63d0f487e899",
454
+ "metadata": {},
455
+ "outputs": [],
456
+ "source": [
457
+ "def find_iter_to_error(algo, error):\n",
458
+ " rse = np.asarray(algo.rse)\n",
459
+ " idx = np.where(rse < error)[0]\n",
460
+ " return np.nan if idx.size == 0 else int(idx[0])\n",
461
+ "\n",
462
+ "def time_to_iter(algo, it):\n",
463
+ " \"\"\"Time up to and including iteration it.\"\"\"\n",
464
+ " if it is None or (isinstance(it, float) and np.isnan(it)):\n",
465
+ " return np.nan\n",
466
+ " it = int(it)\n",
467
+ " return float(np.sum(np.asarray(algo.timing)[:it+1]))\n",
468
+ "\n",
469
+ "def time_to_error(algo, error):\n",
470
+ " it = find_iter_to_error(algo, error)\n",
471
+ " return time_to_iter(algo, it)"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": null,
477
+ "id": "a74d852a-744f-46ee-927b-2152c130bd6d",
478
+ "metadata": {},
479
+ "outputs": [],
480
+ "source": [
481
+ "t_ista = time_to_error(ista, epsilon)\n",
482
+ "t_proxskip = time_to_error(proxskip, epsilon)\n",
483
+ "t_prox_svrg =time_to_error(prox_svrg, epsilon)\n",
484
+ "t_prox_lsvrg = time_to_error(prox_lsvrg, epsilon)\n",
485
+ "t_prox_svrg_skip = time_to_error(prox_svrg_skip, epsilon)\n",
486
+ "t_prox_lsvrg_skip = time_to_error(prox_lsvrg_skip, epsilon)"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "id": "f28efd88-80de-494b-a991-207df0444ad8",
493
+ "metadata": {},
494
+ "outputs": [],
495
+ "source": [
496
+ "fig, ax = plt.subplots(2,1,figsize=(30, 25))\n",
497
+ "\n",
498
+ "fig.subplots_adjust(top=0.80)\n",
499
+ "\n",
500
+ "ax[0].semilogy(prox_svrg.rse[:-1],label=f\"ProxSVRG (N=200)\")\n",
501
+ "ax[0].semilogy(prox_lsvrg.rse[:-1],label=f\"ProxLSVRG (N=200)\")\n",
502
+ "ax[0].semilogy(proxskip.rse[:-1],label=f\"ProxSkip (p=0.1)\")\n",
503
+ "ax[0].semilogy(prox_svrg_skip.rse[:-1],label=f\"ProxSVRGSkip (N=200, p=0.05)\")\n",
504
+ "ax[0].semilogy(prox_lsvrg_skip.rse[:-1],label=f\"ProxLSVRGSkip (N=200, p=0.05)\")\n",
505
+ "ax[0].semilogy(ista.rse[:-1],label=f\"ISTA\")\n",
506
+ "ax[0].semilogy(fista.rse[:-1],label=f\"FISTA\")\n",
507
+ "ax[0].set_xlabel(\"Iteration\")\n",
508
+ "ax[0].set_ylabel(r\"$\\frac{\\|x_{k} - x^{*}\\|_{2}}{\\|x^{*}\\|_{2}}$\")\n",
509
+ "ax[0].grid(True, which=\"major\")\n",
510
+ "\n",
511
+ "ax[0].text(\n",
512
+ " 0.25, 0.9, \"Proximal-TV with 10 iterations\",\n",
513
+ " transform=ax[0].transAxes,\n",
514
+ " ha=\"left\", va=\"top\",\n",
515
+ " fontsize=45,\n",
516
+ " bbox=dict(boxstyle=\"round,pad=0.35\", facecolor=\"white\", edgecolor=\"none\")\n",
517
+ ")\n",
518
+ "ax[0].legend(ncols=2, loc=\"lower center\", bbox_to_anchor=(0.5, 1.00), frameon=True)\n",
519
+ "\n",
520
+ "ax[1].semilogy(np.cumsum(prox_svrg.timing), prox_svrg.rse[:-1],label=f\"ProxSVRG (N=200)\")\n",
521
+ "ax[1].semilogy(np.cumsum(prox_lsvrg.timing), prox_lsvrg.rse[:-1],label=f\"ProxLSVRG (N=200)\")\n",
522
+ "ax[1].semilogy(np.cumsum(proxskip.timing), proxskip.rse[:-1],label=f\"ProxSkip (p=0.1)\")\n",
523
+ "ax[1].semilogy(np.cumsum(prox_svrg_skip.timing), prox_svrg_skip.rse[:-1],label=f\"ProxSVRGSkip (N=200, p=0.05)\")\n",
524
+ "ax[1].semilogy(np.cumsum(prox_lsvrg_skip.timing), prox_lsvrg_skip.rse[:-1],label=f\"ProxLSVRGSkip (N=200, p=0.05)\")\n",
525
+ "ax[1].semilogy(np.cumsum(ista.timing), ista.rse[:-1],label=f\"ISTA\")\n",
526
+ "ax[1].semilogy(np.cumsum(fista.timing), fista.rse[:-1],label=f\"FISTA\")\n",
527
+ "ax[1].set_xlabel(\"Time (sec)\")\n",
528
+ "ax[1].set_ylabel(r\"$\\frac{\\|x_{k} - x^{*}\\|_{2}}{\\|x^{*}\\|_{2}}$\")\n",
529
+ "ax[1].grid(True, which=\"major\")\n",
530
+ "\n",
531
+ "plt.show()"
532
+ ]
533
+ }
534
+ ],
535
+ "metadata": {
536
+ "kernelspec": {
537
+ "display_name": "Python [conda env:ssp]",
538
+ "language": "python",
539
+ "name": "conda-env-ssp-py"
540
+ },
541
+ "language_info": {
542
+ "codemirror_mode": {
543
+ "name": "ipython",
544
+ "version": 3
545
+ },
546
+ "file_extension": ".py",
547
+ "mimetype": "text/x-python",
548
+ "name": "python",
549
+ "nbconvert_exporter": "python",
550
+ "pygments_lexer": "ipython3",
551
+ "version": "3.12.12"
552
+ }
553
+ },
554
+ "nbformat": 4,
555
+ "nbformat_minor": 5
556
+ }
TotalVariation.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2020 United Kingdom Research and Innovation
3
+ # Copyright 2020 The University of Manchester
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # Authors:
18
+ # CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
19
+ # Claire Delplancke (University of Bath)
20
+
21
+
22
+ from cil.optimisation.functions import Function, IndicatorBox, MixedL21Norm, MixedL11Norm
23
+ from cil.optimisation.operators import GradientOperator
24
+ import numpy as np
25
+ from numbers import Number
26
+ import warnings
27
+ import logging
28
+
29
+
30
+ class TotalVariationNew(Function):
31
+
32
+ r""" Total variation Function
33
+
34
+ .. math:: \mathrm{TV}(u) := \|\nabla u\|_{2,1} = \sum \|\nabla u\|_{2},\, (\mbox{isotropic})
35
+
36
+ .. math:: \mathrm{TV}(u) := \|\nabla u\|_{1,1} = \sum \|\nabla u\|_{1}\, (\mbox{anisotropic})
37
+
38
+ Notes
39
+ -----
40
+
41
+ The :code:`TotalVariation` (TV) :code:`Function` acts as a composite function, i.e.,
42
+ the composition of the :class:`.MixedL21Norm` function and the :class:`.GradientOperator` operator,
43
+
44
+ .. math:: f(u) = \|u\|_{2,1}, \Rightarrow (f\circ\nabla)(u) = f(\nabla x) = \mathrm{TV}(u)
45
+
46
+ In that case, the proximal operator of TV does not have an exact solution and we use an iterative
47
+ algorithm to solve:
48
+
49
+ .. math:: \mathrm{prox}_{\tau \mathrm{TV}}(b) := \underset{u}{\mathrm{argmin}} \frac{1}{2\tau}\|u - b\|^{2} + \mathrm{TV}(u)
50
+
51
+ The algorithm used for the proximal operator of TV is the Fast Gradient Projection algorithm (or FISTA)
52
+ applied to the _dual problem_ of the above problem, see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`, :cite:`Zhu2010`.
53
+
54
+ See also "Multicontrast MRI Reconstruction with Structure-Guided Total Variation", Ehrhardt, Betcke, 2016.
55
+
56
+
57
+ Parameters
58
+ ----------
59
+
60
+ max_iteration : :obj:`int`, default = 5
61
+ Maximum number of iterations for the FGP algorithm to solve to solve the dual problem
62
+ of the Total Variation Denoising problem (ROF). If warm_start=False, this should be around 100,
63
+ or larger, with a set tolerance.
64
+ tolerance : :obj:`float`, default = None
65
+ Stopping criterion for the FGP algorithm used to to solve the dual problem
66
+ of the Total Variation Denoising problem (ROF). If the difference between iterates in the FGP algorithm is less than the tolerance
67
+ the iterations end before the max_iteration number.
68
+
69
+ .. math:: \|x^{k+1} - x^{k}\|_{2} < \mathrm{tolerance}
70
+
71
+ correlation : :obj:`str`, default = `Space`
72
+ Correlation between `Space` and/or `SpaceChannels` for the :class:`.GradientOperator`.
73
+ backend : :obj:`str`, default = `c`
74
+ Backend to compute the :class:`.GradientOperator`
75
+ lower : :obj:`'float`, default = None
76
+ A constraint is enforced using the :class:`.IndicatorBox` function, e.g., :code:`IndicatorBox(lower, upper)`.
77
+ upper : :obj:`'float`, default = None
78
+ A constraint is enforced using the :class:`.IndicatorBox` function, e.g., :code:`IndicatorBox(lower, upper)`.
79
+ isotropic : :obj:`boolean`, default = True
80
+ Use either isotropic or anisotropic definition of TV.
81
+
82
+ .. math:: |x|_{2} = \sqrt{x_{1}^{2} + x_{2}^{2}},\, (\mbox{isotropic})
83
+
84
+ .. math:: |x|_{1} = |x_{1}| + |x_{2}|\, (\mbox{anisotropic})
85
+
86
+ split : :obj:`boolean`, default = False
87
+ Splits the Gradient into spatial gradient and spectral or temporal gradient for multichannel data.
88
+
89
+ info : :obj:`boolean`, default = False
90
+ Information is printed for the stopping criterion of the FGP algorithm used to solve the dual problem
91
+ of the Total Variation Denoising problem (ROF).
92
+
93
+ strong_convexity_constant : :obj:`float`, default = 0
94
+ A strongly convex term weighted by the :code:`strong_convexity_constant` (:math:`\gamma`) parameter is added to the Total variation.
95
+ Now the :code:`TotalVariation` function is :math:`\gamma` - strongly convex and the proximal operator is
96
+
97
+ .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\tau}\|u - b\|^{2} + \mathrm{TV}(u) + \frac{\gamma}{2}\|u\|^{2} \Leftrightarrow
98
+
99
+ .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\frac{\tau}{1+\gamma\tau}}\|u - \frac{b}{1+\gamma\tau}\|^{2} + \mathrm{TV}(u)
100
+
101
+ warm_start : :obj:`boolean`, default = True
102
+ If set to true, the FGP algorithm used to solve the dual problem of the Total Variation Denoising problem (ROF) is initiated by the final value from the previous iteration and not at zero.
103
+ This allows the max_iteration value to be reduced to 5-10 iterations.
104
+
105
+
106
+ Note
107
+ ----
108
+
109
+ With warm_start set to the default, True, the TV function will keep in memory the range of the gradient of the image to be denoised, i.e. N times the dimensionality of the image. This increases the memory requirements.
110
+ However, during the evaluation of `proximal` the memory requirements will be unchanged as the same amount of memory will need to be allocated and deallocated.
111
+
112
+ Note
113
+ ----
114
+
115
+ In the case where the Total variation becomes a :math:`\gamma` - strongly convex function, i.e.,
116
+
117
+ .. math:: \mathrm{TV}(u) + \frac{\gamma}{2}\|u\|^{2}
118
+
119
+ :math:`\gamma` should be relatively small, so as the second term above will not act as an additional regulariser.
120
+ For more information, see :cite:`Rasch2020`, :cite:`CP2011`.
121
+
122
+
123
+
124
+
125
+ Examples
126
+ --------
127
+
128
+ .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + \alpha\|\nabla u\|_{2,1}
129
+
130
+ >>> alpha = 2.0
131
+ >>> TV = TotalVariation()
132
+ >>> sol = TV.proximal(b, tau = alpha)
133
+
134
+ Examples
135
+ --------
136
+
137
+ .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + \alpha\|\nabla u\|_{1,1} + \mathbb{I}_{C}(u)
138
+
139
+ where :math:`C = \{1.0\leq u\leq 2.0\}`.
140
+
141
+ >>> alpha = 2.0
142
+ >>> TV = TotalVariation(isotropic=False, lower=1.0, upper=2.0)
143
+ >>> sol = TV.proximal(b, tau = alpha)
144
+
145
+
146
+ Examples
147
+ --------
148
+
149
+ .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + (\alpha\|\nabla u\|_{2,1} + \frac{\gamma}{2}\|u\|^{2})
150
+
151
+ >>> alpha = 2.0
152
+ >>> gamma = 1e-3
153
+ >>> TV = alpha * TotalVariation(isotropic=False, strong_convexity_constant=gamma)
154
+ >>> sol = TV.proximal(b, tau = 1.0)
155
+
156
+ """
157
+
158
+ def __init__(self,
159
+ max_iteration=10,
160
+ tolerance=None,
161
+ correlation="Space",
162
+ backend="c",
163
+ lower=None,
164
+ upper=None,
165
+ isotropic=True,
166
+ split=False,
167
+ info=False,
168
+ strong_convexity_constant=0,
169
+ warm_start=True):
170
+
171
+ super(TotalVariationNew, self).__init__(L=None)
172
+
173
+ # Regularising parameter = alpha
174
+ self.regularisation_parameter = 1.
175
+
176
+ self.iterations = max_iteration
177
+
178
+ self.tolerance = tolerance
179
+
180
+ # Total variation correlation (isotropic=Default)
181
+ self.isotropic = isotropic
182
+
183
+ # correlation space or spacechannels
184
+ self.correlation = correlation
185
+ self.backend = backend
186
+
187
+ # Define orthogonal projection onto the convex set C
188
+ if lower is None:
189
+ lower = -np.inf
190
+ if upper is None:
191
+ upper = np.inf
192
+ self.lower = lower
193
+ self.upper = upper
194
+ self.projection_C = IndicatorBox(lower, upper).proximal
195
+
196
+ # Setup GradientOperator as None. This is to avoid domain argument in the __init__
197
+ self._gradient = None
198
+ self._domain = None
199
+
200
+ self.info = info
201
+ if self.info:
202
+ warnings.warn(" `info` is deprecate. Please use logging instead.")
203
+
204
+ # splitting Gradient
205
+ self.split = split
206
+
207
+ # For the warm_start functionality
208
+ self.warm_start = warm_start
209
+ self._p2 = None
210
+
211
+ # Strong convexity for TV
212
+ self.strong_convexity_constant = strong_convexity_constant
213
+
214
+ # Define Total variation norm
215
+ if self.isotropic:
216
+ self.func = MixedL21Norm()
217
+ else:
218
+ self.func = MixedL11Norm()
219
+
220
+ def _get_p2(self):
221
+ r"""The initial value for the dual in the proximal calculation - allocated to zero in the case of warm_start=False
222
+ or initialised as the last iterate seen in the proximal calculation in the case warm_start=True ."""
223
+
224
+ if self._p2 is None:
225
+ return self.gradient.range_geometry().allocate(0)
226
+ else:
227
+ return self._p2
228
+
229
+ @property
230
+ def regularisation_parameter(self):
231
+ return self._regularisation_parameter
232
+
233
+ @regularisation_parameter.setter
234
+ def regularisation_parameter(self, value):
235
+ if not isinstance(value, Number):
236
+ raise TypeError(
237
+ "regularisation_parameter: expected a number, got {}".format(type(value)))
238
+ self._regularisation_parameter = value
239
+
240
+ def __call__(self, x):
241
+ r""" Returns the value of the TotalVariation function at :code:`x` ."""
242
+
243
+ try:
244
+ self._domain = x.geometry
245
+ except:
246
+ self._domain = x
247
+
248
+ # Compute Lipschitz constant provided that domain is not None.
249
+ # Lipschitz constant dependes on the GradientOperator, which is configured only if domain is not None
250
+ if self._L is None:
251
+ self.calculate_Lipschitz()
252
+
253
+ if self.strong_convexity_constant > 0:
254
+ strongly_convex_term = (
255
+ self.strong_convexity_constant/2)*x.squared_norm()
256
+ else:
257
+ strongly_convex_term = 0
258
+
259
+ return self.regularisation_parameter * self.func(self.gradient.direct(x)) + strongly_convex_term
260
+
261
+ def proximal(self, x, tau, out=None):
262
+ r""" Returns the proximal operator of the TotalVariation function at :code:`x` ."""
263
+
264
+ if self.strong_convexity_constant > 0:
265
+
266
+ strongly_convex_factor = (1 + tau * self.strong_convexity_constant)
267
+ x /= strongly_convex_factor
268
+ tau /= strongly_convex_factor
269
+
270
+ if out is None:
271
+ solution = self._fista_on_dual_rof(x, tau)
272
+ else:
273
+ self._fista_on_dual_rof(x, tau, out=out)
274
+
275
+ if self.strong_convexity_constant > 0:
276
+ x *= strongly_convex_factor
277
+ tau *= strongly_convex_factor
278
+
279
+ if out is None:
280
+ return solution
281
+
282
+ def _fista_on_dual_rof(self, x, tau, out=None):
283
+ r""" Runs the Fast Gradient Projection (FGP) algorithm to solve the dual problem
284
+ of the Total Variation Denoising problem (ROF).
285
+
286
+ .. math: \max_{\|y\|_{\infty}<=1.} \frac{1}{2}\|\nabla^{*} y + x \|^{2} - \frac{1}{2}\|x\|^{2}
287
+
288
+ """
289
+ try:
290
+ self._domain = x.geometry
291
+ except:
292
+ self._domain = x
293
+
294
+ # Compute Lipschitz constant provided that domain is not None.
295
+ # Lipschitz constant depends on the GradientOperator, which is configured only if domain is not None
296
+ if self._L is None:
297
+ self.calculate_Lipschitz()
298
+
299
+ # initialise
300
+ t = 1
301
+
302
+ # dual variable - its content is overwritten during iterations
303
+ p1 = self.gradient.range_geometry().allocate(None)
304
+ p2 = self._get_p2()
305
+ tmp_q = p2.copy()
306
+
307
+ # multiply tau by -1 * regularisation_parameter here so it's not recomputed every iteration
308
+ # when tau is an array this is done inplace so reverted at the end
309
+ if isinstance(tau, Number):
310
+ tau_reg_neg = -self.regularisation_parameter * tau
311
+ else:
312
+ tau_reg_neg = tau
313
+ tau.multiply(-self.regularisation_parameter, out=tau_reg_neg)
314
+
315
+ should_return = False
316
+ if out is None:
317
+ should_return = True
318
+ out = self.gradient.domain_geometry().allocate(0)
319
+
320
+ for k in range(self.iterations):
321
+
322
+ t0 = t
323
+ self.gradient.adjoint(tmp_q, out=out)
324
+ out.sapyb(tau_reg_neg, x, 1.0, out=out)
325
+ self.projection_C(out, tau=None, out=out)
326
+
327
+ self.gradient.direct(out, out=p1)
328
+
329
+ multip = (-self.L)/tau_reg_neg
330
+
331
+ tmp_q.sapyb(1., p1, multip, out=tmp_q)
332
+
333
+ if self.tolerance is not None and k % 5 == 0:
334
+ p1 *= multip
335
+ error = p1.norm()
336
+ error /= tmp_q.norm()
337
+ if error < self.tolerance:
338
+ break
339
+
340
+ self.func.proximal_conjugate(tmp_q, 1.0, out=p1)
341
+
342
+ t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2
343
+ p1.subtract(p2, out=tmp_q)
344
+ tmp_q *= (t0-1)/t
345
+ tmp_q += p1
346
+
347
+ # switch p1 and p2 references
348
+ tmp = p1
349
+ p1 = p2
350
+ p2 = tmp
351
+ if self.warm_start:
352
+ self._p2 = p2
353
+
354
+ if self.info:
355
+ if self.tolerance is not None:
356
+ logging.info(
357
+ "Stop at {} iterations with tolerance {} .".format(k, error))
358
+ else:
359
+ logging.info("Stop at {} iterations.".format(k))
360
+
361
+ # return tau to its original state if it was modified
362
+ if id(tau_reg_neg) == id(tau):
363
+ tau_reg_neg.divide(-self.regularisation_parameter, out=tau)
364
+
365
+ if should_return:
366
+ return out
367
+
368
+ def convex_conjugate(self, x):
369
+ r""" Returns the value of convex conjugate of the TotalVariation function at :code:`x` ."""
370
+ return 0.0
371
+
372
+ def calculate_Lipschitz(self):
373
+ r""" Default value for the Lipschitz constant."""
374
+
375
+ # Compute the Lipschitz parameter from the operator if possible
376
+ # Leave it initialised to None otherwise
377
+ self._L = (1./self.gradient.norm())**2
378
+
379
+ @property
380
+ def gradient(self):
381
+ r""" GradientOperator is created if it is not instantiated yet. The domain of the `_gradient`,
382
+ is created in the `__call__` and `proximal` methods.
383
+
384
+ """
385
+ if self._domain is not None:
386
+ self._gradient = GradientOperator(
387
+ self._domain, correlation=self.correlation, backend=self.backend)
388
+ else:
389
+ raise ValueError(
390
+ " The domain of the TotalVariation is {}. Please use the __call__ or proximal methods first before calling gradient.".format(self._domain))
391
+
392
+ return self._gradient
393
+
394
+ def __rmul__(self, scalar):
395
+ if not isinstance(scalar, Number):
396
+ raise TypeError(
397
+ "scalar: Expected a number, got {}".format(type(scalar)))
398
+ self.regularisation_parameter *= scalar
399
+ return self
binder/.ipynb_checkpoints/environment-checkpoint.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ssp
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - python=3.12
6
+ - "numpy<2.0"
7
+ - "zarr<3"
8
+
9
+ - cmake
10
+ - scipy
11
+ - six
12
+ - cython
13
+ - numba
14
+ - pillow
15
+ - pywavelets
16
+ - astra-toolbox
17
+ - tqdm
18
+ - "setuptools<82"
19
+ - git
20
+
21
+ # pip deps
22
+ - pip
23
+ - wheel
24
+ - pip:
25
+ - bm3d
26
+ - xdesign
27
+ - scikit-image
binder/.ipynb_checkpoints/postBuild-checkpoint ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euxo pipefail
3
+
4
+ python -m ipykernel install --user --name ssp --display-name "Python (ssp)"
5
+
6
+ # Pick the active environment prefix in Binder
7
+ PREFIX="${CONDA_PREFIX:-${NB_PYTHON_PREFIX:-}}"
8
+ if [ -z "${PREFIX}" ]; then
9
+ PREFIX="$(python -c 'import sys, os; print(os.path.dirname(os.path.dirname(sys.executable)))')"
10
+ fi
11
+ echo "Using PREFIX=${PREFIX}"
12
+
13
+ if [ ! -d "StochasticCIL" ]; then
14
+ git clone https://github.com/epapoutsellis/StochasticCIL.git
15
+ fi
16
+
17
+ cd StochasticCIL
18
+ git fetch --all --tags
19
+ git checkout svrg
20
+
21
+ # Ensure annotated tag for `git describe`
22
+ git config user.email "binder@local"
23
+ git config user.name "Binder Build"
24
+
25
+ if git rev-parse -q --verify refs/tags/v1.0 >/dev/null; then
26
+ if [ "$(git cat-file -t v1.0)" != "tag" ]; then
27
+ git tag -d v1.0
28
+ git tag -a v1.0 -m "Version 1.0"
29
+ fi
30
+ else
31
+ git tag -a v1.0 -m "Version 1.0"
32
+ fi
33
+
34
+ mkdir -p build
35
+ cd build
36
+
37
+ cmake ../ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
38
+ -DCONDA_BUILD=OFF \
39
+ -DCMAKE_BUILD_TYPE=Release \
40
+ -DLIBRARY_LIB="${PREFIX}/lib" \
41
+ -DLIBRARY_INC="${PREFIX}" \
42
+ -DCMAKE_INSTALL_PREFIX="${PREFIX}" \
43
+ -DPython_EXECUTABLE="${PREFIX}/bin/python"
44
+
45
+ make -j"$(nproc)"
46
+ make install
binder/environment.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ssp
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - python=3.12
6
+ - "numpy<2.0"
7
+ - "zarr<3"
8
+ - jupyterlab
9
+ - ipykernel
10
+ - cmake
11
+ - scipy
12
+ - six
13
+ - cython
14
+ - numba
15
+ - pillow
16
+ - pywavelets
17
+ - astra-toolbox
18
+ - tqdm
19
+ - "setuptools<82"
20
+ - git
21
+
22
+ # pip deps
23
+ - pip
24
+ - wheel
25
+ - pip:
26
+ - bm3d
27
+ - xdesign
28
+ - scikit-image
binder/postBuild ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euxo pipefail
3
+
4
+ python -m ipykernel install --user --name ssp --display-name "Python (ssp)"
5
+
6
+ # Pick the active environment prefix in Binder
7
+ PREFIX="${CONDA_PREFIX:-${NB_PYTHON_PREFIX:-}}"
8
+ if [ -z "${PREFIX}" ]; then
9
+ PREFIX="$(python -c 'import sys, os; print(os.path.dirname(os.path.dirname(sys.executable)))')"
10
+ fi
11
+ echo "Using PREFIX=${PREFIX}"
12
+
13
+ if [ ! -d "StochasticCIL" ]; then
14
+ git clone https://github.com/epapoutsellis/StochasticCIL.git
15
+ fi
16
+
17
+ cd StochasticCIL
18
+ git fetch --all --tags
19
+ git checkout svrg
20
+
21
+ # Ensure annotated tag for `git describe`
22
+ git config user.email "binder@local"
23
+ git config user.name "Binder Build"
24
+
25
+ if git rev-parse -q --verify refs/tags/v1.0 >/dev/null; then
26
+ if [ "$(git cat-file -t v1.0)" != "tag" ]; then
27
+ git tag -d v1.0
28
+ git tag -a v1.0 -m "Version 1.0"
29
+ fi
30
+ else
31
+ git tag -a v1.0 -m "Version 1.0"
32
+ fi
33
+
34
+ mkdir -p build
35
+ cd build
36
+
37
+ cmake ../ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
38
+ -DCONDA_BUILD=OFF \
39
+ -DCMAKE_BUILD_TYPE=Release \
40
+ -DLIBRARY_LIB="${PREFIX}/lib" \
41
+ -DLIBRARY_INC="${PREFIX}" \
42
+ -DCMAKE_INSTALL_PREFIX="${PREFIX}" \
43
+ -DPython_EXECUTABLE="${PREFIX}/bin/python"
44
+
45
+ make -j"$(nproc)"
46
+ make install
environment.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ssp
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - python=3.12
6
+ - "numpy<2.0"
7
+ - "zarr<3"
8
+
9
+ - cmake
10
+ - scipy
11
+ - six
12
+ - cython
13
+ - numba
14
+ - pillow
15
+ - pywavelets
16
+ - astra-toolbox
17
+ - tqdm
18
+ - "setuptools<82"
19
+ - git
20
+
21
+ # pip deps
22
+ - pip
23
+ - wheel
24
+ - pip:
25
+ - bm3d
26
+ - xdesign
27
+ - scikit-image
start.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ TOKEN="${JUPYTER_TOKEN:-}"
5
+
6
+ exec micromamba run -n ssp jupyter lab \
7
+ --allow-root \
8
+ --ip=0.0.0.0 --port=7860 \
9
+ --no-browser \
10
+ --ServerApp.allow_remote_access=True \
11
+ --ServerApp.root_dir=/work \
12
+ --IdentityProvider.token="${TOKEN}"
utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cil.optimisation.utilities import AlgorithmDiagnostics
2
+ import numpy as np
3
+ from bm3d import bm3d, BM3DStages
4
+ from cil.optimisation.functions import Function
5
+
6
+ class StoppingCriterionTime(AlgorithmDiagnostics):
7
+
8
+ def __init__(self, time):
9
+
10
+ self.time = time
11
+ super(StoppingCriterionTime, self).__init__(verbose=0)
12
+
13
+ self.should_stop = False
14
+
15
+ def _should_stop(self):
16
+
17
+ return self.should_stop
18
+
19
+ def __call__(self, algo):
20
+
21
+ if algo.iteration==0:
22
+ algo.should_stop = self._should_stop
23
+
24
+ stop_crit = np.sum(algo.timing)>self.time
25
+ if stop_crit:
26
+ self.should_stop = True
27
+ print("Stop at {} time {}".format(algo.iteration, np.sum(algo.timing)))
28
+
29
+
30
+ class BM3DFunction(Function):
31
+ """
32
+ PnP 'regulariser' whose proximal applies BM3D denoising.
33
+
34
+ In PnP-ISTA/FISTA we typically use a FIXED BM3D sigma (regularization strength),
35
+ independent of the gradient step-size tau.
36
+ Optionally apply damping: (1-gamma) z + gamma * BM3D(z).
37
+ """
38
+
39
+ def __init__(self, sigma, gamma=1.0, profile="np",
40
+ stage_arg=BM3DStages.ALL_STAGES, positivity=True):
41
+ self.sigma = float(sigma) # BM3D noise parameter
42
+ self.gamma = float(gamma) # damping in (0,1]
43
+ if not (0.0 < self.gamma <= 1.0):
44
+ raise ValueError("gamma must be in (0,1].")
45
+ self.profile = profile
46
+ self.stage_arg = stage_arg
47
+ self.positivity = positivity
48
+ super().__init__()
49
+
50
+ def __call__(self, x):
51
+ return 0.0
52
+
53
+ def convex_conjugate(self, x):
54
+ return 0.0
55
+
56
+ def _denoise(self, znp: np.ndarray) -> np.ndarray:
57
+ z = np.asarray(znp, dtype=np.float32)
58
+ # BM3D expects sigma as noise std (same units as the image)
59
+ return bm3d(z, sigma_psd=self.sigma,
60
+ profile=self.profile,
61
+ stage_arg=self.stage_arg).astype(np.float32)
62
+
63
+ def proximal(self, x, tau, out=None):
64
+ z = x.array.astype(np.float32, copy=False)
65
+ d = self._denoise(z)
66
+
67
+ # damping/relaxation (recommended if you see oscillations)
68
+ u = (1.0 - self.gamma) * z + self.gamma * d
69
+
70
+ if self.positivity:
71
+ u = np.maximum(u, 0.0)
72
+
73
+ if out is None:
74
+ out = x * 0.0
75
+ out.fill(u)
76
+ return out
77
+
78
+
79
+ def create_circular_mask(h, w, center=None, radius=None):
80
+
81
+ if center is None:
82
+ center = (int(w/2), int(h/2))
83
+ if radius is None:
84
+ radius = min(center[0], center[1], w-center[0], h-center[1])
85
+
86
+ Y, X = np.ogrid[:h, :w]
87
+ dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
88
+
89
+ mask = dist_from_center <= radius
90
+ return mask
91
+
92
+
93
+ class StoppingCriterion(AlgorithmDiagnostics):
94
+ def __init__(self, epsilon, epochs=None):
95
+ self.epsilon = epsilon
96
+ self.epochs = epochs
97
+ super().__init__(verbose=0)
98
+ self.should_stop = False
99
+ self.rse_reached = False
100
+
101
+ def _should_stop(self):
102
+ return self.should_stop
103
+
104
+ def __call__(self, algo):
105
+
106
+ if algo.iteration == 0:
107
+ algo.should_stop = self._should_stop
108
+
109
+ stop_rse = (algo.rse[-1] <= self.epsilon)
110
+
111
+ stop_epochs = False
112
+ if self.epochs is not None:
113
+ try:
114
+ dp = algo.f.data_passes
115
+ dp_last = dp[-1] if hasattr(dp, "__len__") else dp
116
+ stop_epochs = (dp_last > self.epochs)
117
+ except AttributeError:
118
+ stop_epochs = False
119
+
120
+ stop = stop_rse or stop_epochs
121
+
122
+ if algo.iteration < algo.max_iteration:
123
+ if stop:
124
+ self.rse_reached = stop_rse
125
+ self.should_stop = True
126
+ print(f"Accuracy reached at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")
127
+ else:
128
+ print(f"Failed to reach accuracy. Stop at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")