Spaces:
Running
Running
test
Browse files- .DS_Store +0 -0
- Dockerfile +40 -0
- LICENSE +21 -0
- PnP_BM3D_VR_Skip.ipynb +687 -0
- ProxSkip.py +124 -0
- README.md +70 -6
- TV_Tomography_Reconstruction_VR_Skip.ipynb +556 -0
- TotalVariation.py +399 -0
- binder/.ipynb_checkpoints/environment-checkpoint.yml +27 -0
- binder/.ipynb_checkpoints/postBuild-checkpoint +46 -0
- binder/environment.yml +28 -0
- binder/postBuild +46 -0
- environment.yml +27 -0
- start.sh +12 -0
- utils.py +128 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM mambaorg/micromamba:1.5.8
|
| 2 |
+
SHELL ["/bin/bash", "-lc"]
|
| 3 |
+
|
| 4 |
+
USER root
|
| 5 |
+
RUN mkdir -p /var/lib/apt/lists/partial && \
|
| 6 |
+
apt-get update && \
|
| 7 |
+
apt-get install -y --no-install-recommends git cmake make g++ && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
WORKDIR /work
|
| 11 |
+
|
| 12 |
+
# Your conda env file is in binder/environment.yml inside the repo
|
| 13 |
+
COPY binder/environment.yml /tmp/environment.yml
|
| 14 |
+
RUN micromamba env create -f /tmp/environment.yml && micromamba clean -a -y
|
| 15 |
+
|
| 16 |
+
# Build StochasticCIL
|
| 17 |
+
RUN git clone https://github.com/epapoutsellis/StochasticCIL.git && \
|
| 18 |
+
cd StochasticCIL && \
|
| 19 |
+
git fetch --all --tags && \
|
| 20 |
+
git checkout svrg && \
|
| 21 |
+
git config user.email "hf@local" && \
|
| 22 |
+
git config user.name "HF Build" && \
|
| 23 |
+
(git tag -a v1.0 -m "Version 1.0" || true) && \
|
| 24 |
+
mkdir -p build && cd build && \
|
| 25 |
+
PREFIX="$(micromamba run -n ssp python -c 'import sys, os; print(os.path.dirname(os.path.dirname(sys.executable)))')" && \
|
| 26 |
+
cmake .. -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
|
| 27 |
+
-DCONDA_BUILD=OFF \
|
| 28 |
+
-DCMAKE_BUILD_TYPE=Release \
|
| 29 |
+
-DLIBRARY_LIB="${PREFIX}/lib" \
|
| 30 |
+
-DLIBRARY_INC="${PREFIX}" \
|
| 31 |
+
-DCMAKE_INSTALL_PREFIX="${PREFIX}" \
|
| 32 |
+
-DPython_EXECUTABLE="${PREFIX}/bin/python" && \
|
| 33 |
+
make -j"$(nproc)" && \
|
| 34 |
+
make install
|
| 35 |
+
|
| 36 |
+
COPY start.sh /start.sh
|
| 37 |
+
RUN chmod +x /start.sh
|
| 38 |
+
|
| 39 |
+
EXPOSE 7860
|
| 40 |
+
CMD ["/start.sh"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Vaggelis Papoutsellis
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
PnP_BM3D_VR_Skip.ipynb
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "e570205c-8d15-409a-a80a-2deabe8093db",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from cil.optimisation.functions import LeastSquares, SVRGFunction, SGFunction\n",
|
| 11 |
+
"from cil.framework import AcquisitionGeometry\n",
|
| 12 |
+
"from cil.plugins.astra import ProjectionOperator\n",
|
| 13 |
+
"from cil.plugins.astra import FBP\n",
|
| 14 |
+
"from cil.optimisation.utilities import MetricsDiagnostics, RandomSampling\n",
|
| 15 |
+
"from cil.optimisation.algorithms import FISTA, ISTA\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"from ProxSkip import ProxSkip\n",
|
| 18 |
+
"from utils import StoppingCriterionTime, BM3DFunction\n",
|
| 19 |
+
"from skimage.metrics import peak_signal_noise_ratio as PSNR\n",
|
| 20 |
+
"from skimage.metrics import structural_similarity as SSIM\n",
|
| 21 |
+
"from xdesign import Foam, discrete_phantom\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"import numpy as np\n",
|
| 24 |
+
"import matplotlib.pyplot as plt\n",
|
| 25 |
+
"plt.rcParams[\"image.cmap\"] = \"inferno\""
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"id": "95eac1e1-ad82-4977-9576-1b8cc99f8d57",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"source": [
|
| 33 |
+
"### PnP Variance Reduced ProxSkip on a simulated dataset"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"id": "3e1dd614-7f15-4894-b13d-a1b32c82c969",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"In this notebook, we play the following game. We compare algorithms by **time-to-quality**: that is, we measure how quickly each method achieves high reconstruction quality—quantified by **PSNR** and **SSIM**—with respect to the ground truth. Here, the ground truth is available, and we impose a fixed computational budget. The winner is the algorithm that attains the highest PSNR/SSIM within this time limit."
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"id": "d25fd695-7af3-4992-bf6e-0ce6fe299870",
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"source": [
|
| 49 |
+
"### Generate Foam Phantom"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": null,
|
| 55 |
+
"id": "7154574f-40e1-4b60-828b-ddc4b1be8d69",
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"N = 347 # image size\n",
|
| 60 |
+
"density = 0.025 # attenuation density of the image\n",
|
| 61 |
+
"np.random.seed(1234)\n",
|
| 62 |
+
"gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10)\n",
|
| 63 |
+
"gt = gt / np.max(gt) * density\n",
|
| 64 |
+
"gt = np.pad(gt, 5)\n",
|
| 65 |
+
"gt[gt < 0] = 0"
|
| 66 |
+
]
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "markdown",
|
| 70 |
+
"id": "7963e21e-24c7-46e0-9cda-3c9c6aaf9678",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"source": [
|
| 73 |
+
"### Generate Acquisition Data with simulated noise"
|
| 74 |
+
]
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "code",
|
| 78 |
+
"execution_count": null,
|
| 79 |
+
"id": "226ded3c-a540-49d0-b0a3-3c560de87c96",
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"outputs": [],
|
| 82 |
+
"source": [
|
| 83 |
+
"num_angles = 400 \n",
|
| 84 |
+
"\n",
|
| 85 |
+
"angles = np.linspace(0, 180, num_angles, endpoint=False, dtype=np.float32)\n",
|
| 86 |
+
"ag2D = AcquisitionGeometry.create_Parallel2D()\\\n",
|
| 87 |
+
" .set_panel(num_pixels=N)\\\n",
|
| 88 |
+
" .set_angles(angles=angles)\n",
|
| 89 |
+
"ig2D = ag2D.get_ImageGeometry()\n",
|
| 90 |
+
"x_cil = ig2D.allocate()\n",
|
| 91 |
+
"x_cil.fill(gt)\n",
|
| 92 |
+
"A = ProjectionOperator(ig2D, ag2D, device=\"cpu\")\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# noiseless sinogram\n",
|
| 95 |
+
"noiseless = A.direct(x_cil)\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"## simulated noise\n",
|
| 98 |
+
"max_intensity = 2000\n",
|
| 99 |
+
"expected_counts = max_intensity * np.exp(-noiseless.array)\n",
|
| 100 |
+
"noisy_counts = np.random.poisson(expected_counts).astype(np.float32)\n",
|
| 101 |
+
"noisy_counts[noisy_counts == 0] = 1 # deal with 0s\n",
|
| 102 |
+
"y_np = -np.log(noisy_counts / max_intensity)\n",
|
| 103 |
+
"noisy = ag2D.allocate()\n",
|
| 104 |
+
"noisy.fill(y_np)"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"id": "faa9bb7a-9454-4c6e-802e-80c991cc0da3",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [],
|
| 113 |
+
"source": [
|
| 114 |
+
"import numpy as np\n",
|
| 115 |
+
"import matplotlib.pyplot as plt\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"fig, axs = plt.subplots(1, 3, figsize=(12, 10), constrained_layout=True)\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"axs[0].imshow(gt)\n",
|
| 120 |
+
"axs[0].set_title(\"Ground Truth\")\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"axs[1].imshow(noiseless.array)\n",
|
| 123 |
+
"axs[1].set_title(\"Noiseless Sinogram\")\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"axs[2].imshow(noisy.array)\n",
|
| 126 |
+
"axs[2].set_title(\"Noisy Sinogram\")\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"plt.show()"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "markdown",
|
| 133 |
+
"id": "b138b305-26db-4c0e-a40f-687cff75d883",
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"source": [
|
| 136 |
+
"### FBP reconstruction"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": null,
|
| 142 |
+
"id": "78e70056-aea9-4780-89e2-52351f51b9ac",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"fbp = FBP(ig2D, ag2D, device=\"cpu\")(noisy)\n",
|
| 147 |
+
"fbp.array[fbp.array<0] = 0\n",
|
| 148 |
+
"plt.imshow(fbp.array)"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "markdown",
|
| 153 |
+
"id": "498ad7aa-540b-426b-8810-991a72312e5c",
|
| 154 |
+
"metadata": {},
|
| 155 |
+
"source": [
|
| 156 |
+
"### Family of considered algorithms (ProxSkip template)\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"**Parameters:** $\\gamma>0$, probability $p\\in(0,1]$, data subsets $N$\n",
|
| 159 |
+
"**Initialize:** $x_0, h_0 \\in \\mathbb{R}^n$\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"For $k=0,1,\\dots,K-1$:\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"1. Compute $G_k$ (an unbiased estimator of $\\nabla f(x_k)$).\n",
|
| 164 |
+
"2. $$\\hat x_{k+1} = x_k - \\gamma\\big(G_k(x_k) - h_k\\big).$$\n",
|
| 165 |
+
"3. Sample $\\theta_k\\sim\\mathrm{Bernoulli}(p)$, $\\theta_k\\in{0,1}$.\n",
|
| 166 |
+
"4. If $\\theta_k=1$:\n",
|
| 167 |
+
" $$x_{k+1}=\\mathrm{prox}_{\\frac{\\gamma}{p}g}\\left(\\hat x_{k+1}-\\frac{\\gamma}{p}h_k\\right),$$\n",
|
| 168 |
+
" else:\n",
|
| 169 |
+
" $$x_{k+1}=\\hat x_{k+1}.$$\n",
|
| 170 |
+
"5. Update:\n",
|
| 171 |
+
" $$h_{k+1}=h_k+\\frac{p}{\\gamma}\\big(x_{k+1}-\\hat x_{k+1}\\big).$$\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"---\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"| $p=1$ | $0<p<1$ | $G_k$ |\n",
|
| 176 |
+
"| --------- | ------------- | -------------------------------------------------------------------- |\n",
|
| 177 |
+
"| ISTA | ProxSkip | $\\nabla f(x_k)$ |\n",
|
| 178 |
+
"| ProxSGD | ProxSGDSkip | $N\\nabla f_{i_k}(x_k)$ |\n",
|
| 179 |
+
"| ProxSAGA | ProxSAGASkip | $N(\\nabla f_{i_k}(x_k)-v_k^{,i_k})+\\bar v_k$ |\n",
|
| 180 |
+
"| ProxSVRG | ProxSVRGSkip | $N(\\nabla f_{i_k}(x_k)-\\nabla f_{i_k}(\\tilde x))+\\nabla f(\\tilde x)$ |\n",
|
| 181 |
+
"| ProxLSVRG | ProxLSVRGSkip | as above, updated with $p=1/N$ |\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"*Note:* FISTA is ISTA with an acceleration step (Beck & Teboulle).\n"
|
| 184 |
+
]
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"cell_type": "markdown",
|
| 188 |
+
"id": "bfd1a9e7-f591-4391-8c4b-71228b3c49e1",
|
| 189 |
+
"metadata": {},
|
| 190 |
+
"source": [
|
| 191 |
+
"### Define Stopping Criteria and Metrics (PSNR/SSIM)"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"cell_type": "code",
|
| 196 |
+
"execution_count": null,
|
| 197 |
+
"id": "25b4287b-06a1-4d96-baa5-db52f3b8e4b0",
|
| 198 |
+
"metadata": {},
|
| 199 |
+
"outputs": [],
|
| 200 |
+
"source": [
|
| 201 |
+
"gt_cil = ig2D.allocate()\n",
|
| 202 |
+
"gt_cil.fill(gt)\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"SSIM_ = lambda x, y: SSIM(x, y, data_range=x.max()-x.min())\n",
|
| 205 |
+
"PSNR_ = lambda x, y: PSNR(x, y, data_range=x.max()-x.min())\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"time_stop = 3*60 # 3 mins\n",
|
| 208 |
+
"max_iteration = 100_000\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"cb_metrics = MetricsDiagnostics(reference_image=gt_cil, \n",
|
| 211 |
+
" metrics_dict={\"psnr\":PSNR_,\"ssim\":SSIM_}) \n"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
+
"id": "1e03af7b-468a-43c1-9344-e66dcea5ce58",
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"### Avoid computing objectives, not needed for this demo\n",
|
| 222 |
+
"def update_objective(self):\n",
|
| 223 |
+
" return 0.\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"ISTA.update_objective = update_objective\n",
|
| 226 |
+
"ProxSkip.update_objective = update_objective\n",
|
| 227 |
+
"FISTA.update_objective = update_objective"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"cell_type": "markdown",
|
| 232 |
+
"id": "1f22e499-849f-47b5-b959-b2b8254af0e6",
|
| 233 |
+
"metadata": {},
|
| 234 |
+
"source": [
|
| 235 |
+
"### Run deterministic algorithms: No data splitting, no skipping"
|
| 236 |
+
]
|
| 237 |
+
},
|
| 238 |
+
{
|
| 239 |
+
"cell_type": "code",
|
| 240 |
+
"execution_count": null,
|
| 241 |
+
"id": "e67a015c-1e41-40fe-84c2-383e0558d4f3",
|
| 242 |
+
"metadata": {},
|
| 243 |
+
"outputs": [],
|
| 244 |
+
"source": [
|
| 245 |
+
"sigma_no_skip = 0.0002 \n",
|
| 246 |
+
"F = LeastSquares(A=A, b=noisy, c=0.5)"
|
| 247 |
+
]
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"cell_type": "markdown",
|
| 251 |
+
"id": "070838d2-505b-4265-846b-cee8a2a96450",
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"source": [
|
| 254 |
+
"### FISTA"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"cell_type": "code",
|
| 259 |
+
"execution_count": null,
|
| 260 |
+
"id": "02a9d83c-3a9a-4935-8dd0-c6c6ff3dea2f",
|
| 261 |
+
"metadata": {},
|
| 262 |
+
"outputs": [],
|
| 263 |
+
"source": [
|
| 264 |
+
"initial = ig2D.allocate()\n",
|
| 265 |
+
"step_size = 1./F.L\n",
|
| 266 |
+
"G = BM3DFunction(sigma_no_skip)\n",
|
| 267 |
+
"cb_time = StoppingCriterionTime(time_stop)\n",
|
| 268 |
+
"fista = FISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
|
| 269 |
+
" update_objective_interval = 1,\n",
|
| 270 |
+
" max_iteration = max_iteration) \n",
|
| 271 |
+
"fista.run(verbose=0, callback=[cb_metrics, cb_time])"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "markdown",
|
| 276 |
+
"id": "9a7833e5-095e-4524-b7f6-d3f4d5be473f",
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"source": [
|
| 279 |
+
"### ISTA"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"cell_type": "code",
|
| 284 |
+
"execution_count": null,
|
| 285 |
+
"id": "f1b02973-d57c-4f69-b6d1-4af18fa6b160",
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"outputs": [],
|
| 288 |
+
"source": [
|
| 289 |
+
"G = BM3DFunction(sigma_no_skip)\n",
|
| 290 |
+
"cb_time = StoppingCriterionTime(time_stop)\n",
|
| 291 |
+
"ista = ISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
|
| 292 |
+
" update_objective_interval = 1,\n",
|
| 293 |
+
" max_iteration = max_iteration) \n",
|
| 294 |
+
"ista.run(verbose=0, callback=[cb_metrics, cb_time])"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"cell_type": "markdown",
|
| 299 |
+
"id": "57bd7320-c2cb-4e4c-a3f3-f382a0769d6a",
|
| 300 |
+
"metadata": {},
|
| 301 |
+
"source": [
|
| 302 |
+
"### Run ProxSkip: Skip the regulariser"
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"cell_type": "code",
|
| 307 |
+
"execution_count": null,
|
| 308 |
+
"id": "1a9b36d3-dd8a-43df-999c-acb414c3ba51",
|
| 309 |
+
"metadata": {},
|
| 310 |
+
"outputs": [],
|
| 311 |
+
"source": [
|
| 312 |
+
"prob = 0.05\n",
|
| 313 |
+
"sigma_skip = sigma_no_skip/np.sqrt(prob)"
|
| 314 |
+
]
|
| 315 |
+
},
|
| 316 |
+
{
|
| 317 |
+
"cell_type": "code",
|
| 318 |
+
"execution_count": null,
|
| 319 |
+
"id": "cd6d1c7b-77c8-44e4-a71c-8169a097567e",
|
| 320 |
+
"metadata": {},
|
| 321 |
+
"outputs": [],
|
| 322 |
+
"source": [
|
| 323 |
+
"G = BM3DFunction(sigma_skip)\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"cb_time = StoppingCriterionTime(time_stop)\n",
|
| 326 |
+
"proxskip = ProxSkip(initial = [initial, initial], f = F, step_size = step_size, g=G, \n",
|
| 327 |
+
" update_objective_interval = 1, prob=prob, seed = 42, \n",
|
| 328 |
+
" max_iteration = max_iteration) \n",
|
| 329 |
+
"proxskip.run(verbose=0, callback=[cb_metrics, cb_time])"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "markdown",
|
| 334 |
+
"id": "84f96bea-e56d-4fc7-a8b9-7b288430f079",
|
| 335 |
+
"metadata": {},
|
| 336 |
+
"source": [
|
| 337 |
+
"### Run stochastic algorithms: Data splitting, no skipping"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "code",
|
| 342 |
+
"execution_count": null,
|
| 343 |
+
"id": "750978d9-af9e-4366-aeaf-810155d22f1e",
|
| 344 |
+
"metadata": {},
|
| 345 |
+
"outputs": [],
|
| 346 |
+
"source": [
|
| 347 |
+
"def list_of_functions(data):\n",
|
| 348 |
+
" \n",
|
| 349 |
+
" list_funcs = []\n",
|
| 350 |
+
" ig = data[0].geometry.get_ImageGeometry()\n",
|
| 351 |
+
" \n",
|
| 352 |
+
" for d in data:\n",
|
| 353 |
+
" ageom_subset = d.geometry \n",
|
| 354 |
+
" Ai = ProjectionOperator(ig, ageom_subset, device = 'cpu') \n",
|
| 355 |
+
" fi = LeastSquares(Ai, b = d, c = 0.5)\n",
|
| 356 |
+
" list_funcs.append(fi) \n",
|
| 357 |
+
" \n",
|
| 358 |
+
" return list_funcs"
|
| 359 |
+
]
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"cell_type": "code",
|
| 363 |
+
"execution_count": null,
|
| 364 |
+
"id": "ce7998a1-9c98-49fe-b35a-89be65498fc3",
|
| 365 |
+
"metadata": {},
|
| 366 |
+
"outputs": [],
|
| 367 |
+
"source": [
|
| 368 |
+
"# number of subsets\n",
|
| 369 |
+
"nsub = 50\n",
|
| 370 |
+
"data_split, method = noisy.split_to_subsets(nsub, method= \"ordered\", info=True)\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"# list of fis for finite sum \n",
|
| 373 |
+
"list_func = list_of_functions(data_split) "
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"cell_type": "markdown",
|
| 378 |
+
"id": "5fd9e598-bafd-4bdf-b9b0-33f1646ce191",
|
| 379 |
+
"metadata": {},
|
| 380 |
+
"source": [
|
| 381 |
+
"### ProxSVRG and ProxSGD"
|
| 382 |
+
]
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"cell_type": "code",
|
| 386 |
+
"execution_count": null,
|
| 387 |
+
"id": "aeb2e2d8-6054-4759-bd84-d5cc7d2bbc08",
|
| 388 |
+
"metadata": {},
|
| 389 |
+
"outputs": [],
|
| 390 |
+
"source": [
|
| 391 |
+
"selection = RandomSampling(len(list_func), nsub, seed=42)\n",
|
| 392 |
+
"Fsvrg = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
|
| 393 |
+
"Fsvrg.initial = initial\n",
|
| 394 |
+
"step_size = 1./(Fsvrg.L)\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"G = BM3DFunction(sigma_no_skip)\n",
|
| 397 |
+
"cb_time = StoppingCriterionTime(time_stop)\n",
|
| 398 |
+
"prox_svrg = ISTA(initial = initial, f = Fsvrg, step_size = step_size, g=G, \n",
|
| 399 |
+
" update_objective_interval = 1, \n",
|
| 400 |
+
" max_iteration = max_iteration) \n",
|
| 401 |
+
"prox_svrg.run(verbose=0, callback=[cb_metrics, cb_time]) \n"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"cell_type": "code",
|
| 406 |
+
"execution_count": null,
|
| 407 |
+
"id": "b1565f4c-162c-41bd-919a-bfd61a4a57cb",
|
| 408 |
+
"metadata": {},
|
| 409 |
+
"outputs": [],
|
| 410 |
+
"source": [
|
| 411 |
+
"selection = RandomSampling(len(list_func), nsub, seed=42)\n",
|
| 412 |
+
"Fsgd = SGFunction(list_func, selection = selection) \n",
|
| 413 |
+
"step_size = 1./(Fsgd.L)\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"G = BM3DFunction(sigma_no_skip)\n",
|
| 416 |
+
"cb_time = StoppingCriterionTime(time_stop)\n",
|
| 417 |
+
"prox_sgd = ISTA(initial = initial, f = Fsgd, step_size = step_size, g=G, \n",
|
| 418 |
+
" update_objective_interval = 1, \n",
|
| 419 |
+
" max_iteration = max_iteration) \n",
|
| 420 |
+
"prox_sgd.run(verbose=0, callback=[cb_metrics, cb_time]) \n"
|
| 421 |
+
]
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"cell_type": "markdown",
|
| 425 |
+
"id": "e35518e7-70d8-4c7d-955b-f062c133371e",
|
| 426 |
+
"metadata": {},
|
| 427 |
+
"source": [
|
| 428 |
+
"### Run stochastic algorithms: Data splitting and skipping"
|
| 429 |
+
]
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"cell_type": "code",
|
| 433 |
+
"execution_count": null,
|
| 434 |
+
"id": "94d2a7e7-b0d3-4418-8d9d-fa35f506c5f3",
|
| 435 |
+
"metadata": {},
|
| 436 |
+
"outputs": [],
|
| 437 |
+
"source": [
|
| 438 |
+
"selection = RandomSampling(len(list_func), nsub, seed=42)\n",
|
| 439 |
+
"Fsvrg_skip = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
|
| 440 |
+
"Fsvrg_skip.initial = initial\n",
|
| 441 |
+
"step_size = 1./(Fsvrg_skip.L)\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"G = BM3DFunction(sigma_skip)\n",
|
| 444 |
+
"cb_time = StoppingCriterionTime(time_stop)\n",
|
| 445 |
+
"prox_svrg_skip = ProxSkip(initial = [initial, initial], f = Fsvrg_skip, step_size = step_size, g=G, \n",
|
| 446 |
+
" update_objective_interval = 1, prob=prob, seed=42,\n",
|
| 447 |
+
" max_iteration = max_iteration) \n",
|
| 448 |
+
"prox_svrg_skip.run(verbose=0, callback=[cb_metrics, cb_time]) "
|
| 449 |
+
]
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"cell_type": "code",
|
| 453 |
+
"execution_count": null,
|
| 454 |
+
"id": "fbfc107d-a6f7-45b8-8974-c1e766148bc0",
|
| 455 |
+
"metadata": {},
|
| 456 |
+
"outputs": [],
|
| 457 |
+
"source": [
|
| 458 |
+
"selection = RandomSampling(len(list_func), nsub, seed=42)\n",
|
| 459 |
+
"Fsgd_skip = SGFunction(list_func, selection = selection)\n",
|
| 460 |
+
"step_size = 1./(Fsgd_skip.L)\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"G = BM3DFunction(sigma_skip)\n",
|
| 463 |
+
"cb_time = StoppingCriterionTime(time_stop)\n",
|
| 464 |
+
"prox_sgd_skip = ProxSkip(initial = [initial, initial], f = Fsgd_skip, step_size = step_size, g=G, \n",
|
| 465 |
+
" update_objective_interval = 1, prob=prob, seed=42,\n",
|
| 466 |
+
" max_iteration = max_iteration) \n",
|
| 467 |
+
"prox_sgd_skip.run(verbose=0, callback=[cb_metrics, cb_time]) "
|
| 468 |
+
]
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"cell_type": "markdown",
|
| 472 |
+
"id": "48181800-6930-454c-bc6c-e4fb9d4e2d42",
|
| 473 |
+
"metadata": {},
|
| 474 |
+
"source": [
|
| 475 |
+
"### Plot PSNR/SSIM progress\n",
|
| 476 |
+
"- with respect to iteration\n",
|
| 477 |
+
"- with respect to time"
|
| 478 |
+
]
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"cell_type": "code",
|
| 482 |
+
"execution_count": null,
|
| 483 |
+
"id": "2a76c9c9-f023-419b-9f7c-fdc840c92301",
|
| 484 |
+
"metadata": {},
|
| 485 |
+
"outputs": [],
|
| 486 |
+
"source": [
|
| 487 |
+
"t_ista = np.cumsum(ista.timing)\n",
|
| 488 |
+
"t_fista = np.cumsum(fista.timing)\n",
|
| 489 |
+
"t_proxskip = np.cumsum(proxskip.timing)\n",
|
| 490 |
+
"t_prox_svrg = np.cumsum(prox_svrg.timing)\n",
|
| 491 |
+
"t_prox_svrg_skip = np.cumsum(prox_svrg_skip.timing)\n",
|
| 492 |
+
"t_prox_sgd_skip = np.cumsum(prox_sgd_skip.timing)"
|
| 493 |
+
]
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"cell_type": "code",
|
| 497 |
+
"execution_count": null,
|
| 498 |
+
"id": "ca6e8583-5487-414e-ba5d-3573ae290595",
|
| 499 |
+
"metadata": {},
|
| 500 |
+
"outputs": [],
|
| 501 |
+
"source": [
|
| 502 |
+
"import numpy as np\n",
|
| 503 |
+
"import matplotlib.pyplot as plt\n",
|
| 504 |
+
"plt.rcParams['lines.linewidth'] = 5\n",
|
| 505 |
+
"plt.rcParams['lines.markersize'] = 20\n",
|
| 506 |
+
"plt.rcParams['font.size'] = 40\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"fig, ax = plt.subplots(2, 2, figsize=(35, 25), constrained_layout=True)\n",
|
| 509 |
+
"\n",
|
| 510 |
+
"ax[0,0].plot(ista.psnr[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
|
| 511 |
+
"ax[0,0].plot(fista.psnr[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
|
| 512 |
+
"ax[0,0].plot(proxskip.psnr[:-1], label=f\"ProxSkip, K={proxskip.iteration}, #BM3D={np.sum(proxskip.thetas)}\")\n",
|
| 513 |
+
"ax[0,0].plot(prox_svrg.psnr[:-1], label=f\"ProxSVRG, K={prox_svrg.iteration}, #BM3D={prox_svrg.iteration}\")\n",
|
| 514 |
+
"ax[0,0].plot(prox_sgd_skip.psnr[:-1],label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
|
| 515 |
+
"ax[0,0].plot(prox_svrg_skip.psnr[:-1],label=f\"ProxSVRGSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\")\n",
|
| 516 |
+
"ax[0,0].grid(which=\"major\")\n",
|
| 517 |
+
"ax[0,0].set_xlabel(\"Iteration\")\n",
|
| 518 |
+
"ax[0,0].set_ylabel(\"PSNR\")\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"ax[0,1].plot(ista.ssim[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
|
| 521 |
+
"ax[0,1].plot(fista.ssim[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
|
| 522 |
+
"ax[0,1].plot(proxskip.ssim[:-1], label=f\"ProxSkip, p={prob}, #prox={np.sum(proxskip.thetas)}\")\n",
|
| 523 |
+
"ax[0,1].plot(prox_svrg.ssim[:-1], label=\"ProxSVRG\")\n",
|
| 524 |
+
"ax[0,1].semilogy(prox_sgd_skip.ssim[:-1],label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, N={nsub}, p={prob}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
|
| 525 |
+
"ax[0,1].plot(prox_svrg_skip.ssim[:-1],label=f\"ProxSVRGSkip, p={prob}, #prox={np.sum(prox_svrg_skip.thetas)}\")\n",
|
| 526 |
+
"ax[0,1].grid(which=\"major\")\n",
|
| 527 |
+
"ax[0,1].set_xlabel(\"Iteration\")\n",
|
| 528 |
+
"ax[0,1].set_ylabel(\"SSIM\")\n",
|
| 529 |
+
"\n",
|
| 530 |
+
"ax[1,0].plot(t_ista, ista.psnr[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
|
| 531 |
+
"ax[1,0].plot(t_fista, fista.psnr[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
|
| 532 |
+
"ax[1,0].plot(t_proxskip, proxskip.psnr[:-1], label=f\"ProxSkip, K={proxskip.iteration}, #BM3D={np.sum(proxskip.thetas)}\")\n",
|
| 533 |
+
"ax[1,0].plot(t_prox_svrg, prox_svrg.psnr[:-1], label=f\"ProxSVRG, K={prox_svrg.iteration}, #BM3D={prox_svrg.iteration}\")\n",
|
| 534 |
+
"ax[1,0].plot(t_prox_sgd_skip, prox_sgd_skip.psnr[:-1],\n",
|
| 535 |
+
" label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
|
| 536 |
+
"ax[1,0].plot(t_prox_svrg_skip, prox_svrg_skip.psnr[:-1],\n",
|
| 537 |
+
" label=f\"ProxSVRGSkip, K={prox_svrg_skip.iteration}, #BM3D={np.sum(prox_svrg_skip.thetas)}\")\n",
|
| 538 |
+
"ax[1,0].grid(which=\"major\")\n",
|
| 539 |
+
"ax[1,0].set_xlabel(\"Time (sec)\")\n",
|
| 540 |
+
"ax[1,0].set_ylabel(\"PSNR\")\n",
|
| 541 |
+
"\n",
|
| 542 |
+
"ax[1,1].plot(t_ista, ista.ssim[:-1], label=f\"ISTA, K={ista.iteration}, #BM3D={ista.iteration}\")\n",
|
| 543 |
+
"ax[1,1].plot(t_fista, fista.ssim[:-1], label=f\"FISTA, K={fista.iteration}, #BM3D={fista.iteration}\")\n",
|
| 544 |
+
"ax[1,1].plot(t_proxskip, proxskip.ssim[:-1], label=f\"ProxSkip, p={prob}, #prox={np.sum(proxskip.thetas)}\")\n",
|
| 545 |
+
"ax[1,1].plot(t_prox_svrg, prox_svrg.ssim[:-1], label=\"ProxSVRG\")\n",
|
| 546 |
+
"ax[1,1].semilogy(t_prox_sgd_skip, prox_sgd_skip.ssim[:-1],\n",
|
| 547 |
+
" label=f\"ProxSGDSkip, K={prox_svrg_skip.iteration}, N={nsub}, p={prob}, #BM3D={np.sum(prox_svrg_skip.thetas)}\", alpha=0.65)\n",
|
| 548 |
+
"ax[1,1].plot(t_prox_svrg_skip, prox_svrg_skip.ssim[:-1],\n",
|
| 549 |
+
" label=f\"ProxSVRGSkip, p={prob}, #prox={np.sum(prox_svrg_skip.thetas)}\")\n",
|
| 550 |
+
"ax[1,1].grid(which=\"major\")\n",
|
| 551 |
+
"ax[1,1].set_xlabel(\"Time (sec)\")\n",
|
| 552 |
+
"ax[1,1].set_ylabel(\"SSIM\")\n",
|
| 553 |
+
"\n",
|
| 554 |
+
"handles, labels = ax[0,1].get_legend_handles_labels()\n",
|
| 555 |
+
"fig.legend(handles, labels, loc=\"upper center\", bbox_to_anchor=(0.53, 1.14),\n",
|
| 556 |
+
" ncols=2, frameon=True)\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"\n",
|
| 559 |
+
"plt.show()\n"
|
| 560 |
+
]
|
| 561 |
+
},
|
| 562 |
+
{
|
| 563 |
+
"cell_type": "code",
|
| 564 |
+
"execution_count": null,
|
| 565 |
+
"id": "af5e1f06-ec65-4944-a84a-1681462e321e",
|
| 566 |
+
"metadata": {},
|
| 567 |
+
"outputs": [],
|
| 568 |
+
"source": [
|
| 569 |
+
"import numpy as np\n",
|
| 570 |
+
"import matplotlib.pyplot as plt\n",
|
| 571 |
+
"from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
|
| 572 |
+
"from matplotlib.patches import Rectangle\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"imgs = [\n",
|
| 575 |
+
" x_cil.array,\n",
|
| 576 |
+
" fbp.array, \n",
|
| 577 |
+
" ista.solution.array,\n",
|
| 578 |
+
" fista.solution.array,\n",
|
| 579 |
+
" proxskip.solution.array,\n",
|
| 580 |
+
" prox_svrg.solution.array,\n",
|
| 581 |
+
" prox_sgd_skip.solution.array,\n",
|
| 582 |
+
" prox_svrg_skip.solution.array,\n",
|
| 583 |
+
"]\n",
|
| 584 |
+
"\n",
|
| 585 |
+
"labels = [\n",
|
| 586 |
+
" \"Ground-Truth\",\n",
|
| 587 |
+
" f\"FBP\",\n",
|
| 588 |
+
" f\"ISTA\\nPSNR/SSIM = {ista.psnr[-1]:.2f}/{ista.ssim[-1]:.3f}\",\n",
|
| 589 |
+
" f\"FISTA\\nPSNR/SSIM = {fista.psnr[-1]:.2f}/{fista.ssim[-1]:.3f}\",\n",
|
| 590 |
+
" f\"ProxSkip\\nPSNR/SSIM = {proxskip.psnr[-1]:.2f}/{proxskip.ssim[-1]:.3f}\",\n",
|
| 591 |
+
" f\"ProxSVRG\\nPSNR/SSIM = {prox_svrg.psnr[-1]:.2f}/{prox_svrg.ssim[-1]:.3f}\",\n",
|
| 592 |
+
" f\"ProxSGDSkip\\nPSNR/SSIM = {prox_sgd_skip.psnr[-1]:.2f}/{prox_sgd_skip.ssim[-1]:.3f}\",\n",
|
| 593 |
+
" f\"ProxSVRGSkip\\nPSNR/SSIM = {prox_svrg_skip.psnr[-1]:.2f}/{prox_svrg_skip.ssim[-1]:.3f}\",\n",
|
| 594 |
+
"]\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"vmax = max(np.max(im) for im in imgs)\n",
|
| 597 |
+
"vmin = 0.01\n",
|
| 598 |
+
"vmax = 0.03\n",
|
| 599 |
+
"\n",
|
| 600 |
+
"h, w = imgs[0].shape[:2]\n",
|
| 601 |
+
"img_aspect = w / h\n",
|
| 602 |
+
"nrows, ncols = 2, 4\n",
|
| 603 |
+
"\n",
|
| 604 |
+
"fig_h = 15\n",
|
| 605 |
+
"fig_w = fig_h * (ncols / nrows) * img_aspect\n",
|
| 606 |
+
"\n",
|
| 607 |
+
"fig = plt.figure(figsize=(fig_w, fig_h))\n",
|
| 608 |
+
"gs = fig.add_gridspec(nrows, ncols, wspace=0, hspace=0)\n",
|
| 609 |
+
"axes = np.array([fig.add_subplot(gs[i, j]) for i in range(nrows) for j in range(ncols)])\n",
|
| 610 |
+
"fig.subplots_adjust(left=0, right=1, bottom=0, top=1)\n",
|
| 611 |
+
"\n",
|
| 612 |
+
"H, W = imgs[0].shape[:2]\n",
|
| 613 |
+
"roi_size = 80\n",
|
| 614 |
+
"cx, cy = 75, 120 \n",
|
| 615 |
+
"\n",
|
| 616 |
+
"x0 = cx - roi_size // 2\n",
|
| 617 |
+
"x1 = x0 + roi_size\n",
|
| 618 |
+
"y0 = cy - roi_size // 2\n",
|
| 619 |
+
"y1 = y0 + roi_size\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"x0 = max(0, min(x0, W - roi_size)); x1 = x0 + roi_size\n",
|
| 622 |
+
"y0 = max(0, min(y0, H - roi_size)); y1 = y0 + roi_size\n",
|
| 623 |
+
"\n",
|
| 624 |
+
"inset_size = \"32%\"\n",
|
| 625 |
+
"inset_borderpad = 0.6\n",
|
| 626 |
+
"\n",
|
| 627 |
+
"for ax, im, txt in zip(axes, imgs, labels):\n",
|
| 628 |
+
" ax.imshow(im, cmap=\"inferno\", vmin=vmin, vmax=vmax)\n",
|
| 629 |
+
" ax.axis(\"off\")\n",
|
| 630 |
+
"\n",
|
| 631 |
+
" ax.text(\n",
|
| 632 |
+
" 0.02, 0.98, txt,\n",
|
| 633 |
+
" transform=ax.transAxes,\n",
|
| 634 |
+
" ha=\"left\", va=\"top\",\n",
|
| 635 |
+
" color=\"white\", fontsize=32,\n",
|
| 636 |
+
" bbox=dict(boxstyle=\"round,pad=0.3\", facecolor=\"black\", alpha=0.25, edgecolor=\"none\")\n",
|
| 637 |
+
" )\n",
|
| 638 |
+
"\n",
|
| 639 |
+
" rect = Rectangle((x0, y0), roi_size, roi_size,\n",
|
| 640 |
+
" linewidth=2.5, edgecolor=\"yellow\", facecolor=\"none\")\n",
|
| 641 |
+
" ax.add_patch(rect)\n",
|
| 642 |
+
"\n",
|
| 643 |
+
" axins = inset_axes(ax, width=inset_size, height=inset_size,\n",
|
| 644 |
+
" loc=\"lower right\", borderpad=inset_borderpad)\n",
|
| 645 |
+
" axins.imshow(im, cmap=\"inferno\", vmin=vmin, vmax=vmax)\n",
|
| 646 |
+
" axins.set_xlim(x0, x1)\n",
|
| 647 |
+
" axins.set_ylim(y1, y0) \n",
|
| 648 |
+
" axins.set_xticks([])\n",
|
| 649 |
+
" axins.set_yticks([])\n",
|
| 650 |
+
" for spine in axins.spines.values():\n",
|
| 651 |
+
" spine.set_linewidth(2)\n",
|
| 652 |
+
" spine.set_edgecolor(\"yellow\")\n",
|
| 653 |
+
"\n",
|
| 654 |
+
"plt.show()\n"
|
| 655 |
+
]
|
| 656 |
+
},
|
| 657 |
+
{
|
| 658 |
+
"cell_type": "code",
|
| 659 |
+
"execution_count": null,
|
| 660 |
+
"id": "fa9861c4-aeaa-46db-92b7-3198ca2eaab3",
|
| 661 |
+
"metadata": {},
|
| 662 |
+
"outputs": [],
|
| 663 |
+
"source": []
|
| 664 |
+
}
|
| 665 |
+
],
|
| 666 |
+
"metadata": {
|
| 667 |
+
"kernelspec": {
|
| 668 |
+
"display_name": "Python [conda env:ssp]",
|
| 669 |
+
"language": "python",
|
| 670 |
+
"name": "conda-env-ssp-py"
|
| 671 |
+
},
|
| 672 |
+
"language_info": {
|
| 673 |
+
"codemirror_mode": {
|
| 674 |
+
"name": "ipython",
|
| 675 |
+
"version": 3
|
| 676 |
+
},
|
| 677 |
+
"file_extension": ".py",
|
| 678 |
+
"mimetype": "text/x-python",
|
| 679 |
+
"name": "python",
|
| 680 |
+
"nbconvert_exporter": "python",
|
| 681 |
+
"pygments_lexer": "ipython3",
|
| 682 |
+
"version": "3.12.12"
|
| 683 |
+
}
|
| 684 |
+
},
|
| 685 |
+
"nbformat": 4,
|
| 686 |
+
"nbformat_minor": 5
|
| 687 |
+
}
|
ProxSkip.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cil.optimisation.algorithms import Algorithm
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ProxSkip(Algorithm):
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
r"""Proximal Skip (ProxSkip) algorithm, see "ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!†"
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
|
| 14 |
+
initial : DataContainer
|
| 15 |
+
Initial point for the ProxSkip algorithm.
|
| 16 |
+
f : Function
|
| 17 |
+
A smooth function with Lipschitz continuous gradient.
|
| 18 |
+
g : Function
|
| 19 |
+
A convex function with a "simple" proximal.
|
| 20 |
+
prob : positive :obj:`float`
|
| 21 |
+
Probability to skip the proximal step. If :code:`prob=1`, proximal step is used in every iteration.
|
| 22 |
+
step_size : positive :obj:`float`
|
| 23 |
+
Step size for the ProxSkip algorithm and is equal to 1./L where L is the Lipschitz constant for the gradient of f.
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def __init__(self, initial, f, g, step_size, prob, seed=None, **kwargs):
|
| 29 |
+
""" Set up of the algorithm
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
super(ProxSkip, self).__init__(**kwargs)
|
| 33 |
+
|
| 34 |
+
self.f = f # smooth function
|
| 35 |
+
self.g = g # proximable
|
| 36 |
+
self.step_size = step_size
|
| 37 |
+
self.prob = prob
|
| 38 |
+
self.rng = np.random.default_rng(seed)
|
| 39 |
+
self.set_up(initial, f, g, step_size, prob, **kwargs)
|
| 40 |
+
self.thetas = []
|
| 41 |
+
self.prox_iterates = []
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def set_up(self, initial, f, g, step_size, prob, **kwargs):
|
| 45 |
+
|
| 46 |
+
logging.info("{} setting up".format(self.__class__.__name__, ))
|
| 47 |
+
|
| 48 |
+
self.initial = initial[0]
|
| 49 |
+
|
| 50 |
+
self.x = initial[0].copy()
|
| 51 |
+
self.xhat_new = initial[0].copy()
|
| 52 |
+
self.x_new = initial[0].copy()
|
| 53 |
+
self.ht = initial[1].copy() #self.f.gradient(initial)
|
| 54 |
+
# self.ht = self.f.gradient(initial)
|
| 55 |
+
|
| 56 |
+
self.configured = True
|
| 57 |
+
|
| 58 |
+
# count proximal and non proximal steps
|
| 59 |
+
self.use_prox = 0
|
| 60 |
+
# self.no_use_prox = 0
|
| 61 |
+
|
| 62 |
+
logging.info("{} configured".format(self.__class__.__name__, ))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def update(self):
|
| 66 |
+
r""" Performs a single iteration of the ProxSkip algorithm
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
self.f.gradient(self.x, out=self.xhat_new)
|
| 70 |
+
self.xhat_new -= self.ht
|
| 71 |
+
self.x.sapyb(1., self.xhat_new, -self.step_size, out=self.xhat_new)
|
| 72 |
+
|
| 73 |
+
theta = self.rng.random() < self.prob
|
| 74 |
+
if self.iteration==0:
|
| 75 |
+
theta = 1
|
| 76 |
+
self.thetas.append(theta)
|
| 77 |
+
|
| 78 |
+
if theta==1:
|
| 79 |
+
# print("here")
|
| 80 |
+
# Proximal step is used
|
| 81 |
+
self.g.proximal(self.xhat_new - (self.step_size/self.prob)*self.ht, self.step_size/self.prob, out=self.x_new)
|
| 82 |
+
self.ht.sapyb(1., (self.x_new - self.xhat_new), (self.prob/self.step_size), out=self.ht)
|
| 83 |
+
self.use_prox+=1
|
| 84 |
+
# self.prox_iterates.append(self.x.copy())
|
| 85 |
+
else:
|
| 86 |
+
# Proximal step is skipped
|
| 87 |
+
# print("here1")
|
| 88 |
+
self.x_new.fill(self.xhat_new)
|
| 89 |
+
|
| 90 |
+
def _update_previous_solution(self):
|
| 91 |
+
""" Swaps the references to current and previous solution based on the :func:`~Algorithm.update_previous_solution` of the base class :class:`Algorithm`.
|
| 92 |
+
"""
|
| 93 |
+
tmp = self.x_new
|
| 94 |
+
self.x = self.x_new
|
| 95 |
+
self.x = tmp
|
| 96 |
+
|
| 97 |
+
def get_output(self):
|
| 98 |
+
" Returns the current solution. "
|
| 99 |
+
return self.x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def update_objective(self):
|
| 103 |
+
|
| 104 |
+
""" Updates the objective
|
| 105 |
+
|
| 106 |
+
.. math:: f(x) + g(x)
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
fun_g = self.g(self.x)
|
| 111 |
+
fun_f = self.f(self.x)
|
| 112 |
+
p1 = fun_f + fun_g
|
| 113 |
+
self.loss.append( p1 )
|
| 114 |
+
|
| 115 |
+
def proximal_evaluations(self):
|
| 116 |
+
|
| 117 |
+
prox_evals = []
|
| 118 |
+
|
| 119 |
+
for i in self.iterations:
|
| 120 |
+
if self.rng.random() < self.prob:
|
| 121 |
+
prox_evals.append(1)
|
| 122 |
+
else:
|
| 123 |
+
prox_evals.append(0)
|
| 124 |
+
return prox_evals
|
README.md
CHANGED
|
@@ -1,10 +1,74 @@
|
|
| 1 |
---
|
| 2 |
-
title: Split
|
| 3 |
-
emoji: 🌍
|
| 4 |
-
colorFrom: pink
|
| 5 |
-
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Split-Skip-and-Play
|
|
|
|
|
|
|
|
|
|
| 3 |
sdk: docker
|
| 4 |
+
app_port: 7860
|
| 5 |
---
|
| 6 |
|
| 7 |
+
|
| 8 |
+
# Split-Skip-and-Play
|
| 9 |
+
|
| 10 |
+
Code to reproduce results for paper ["SPLIT, SKIP AND PLAY: VARIANCE-REDUCED PROXSKIP FOR TOMOGRAPHY
|
| 11 |
+
RECONSTRUCTION IS EXTREMELY FAST"](https://arxiv.org/abs/2602.09527) by Evangelos Papoutsellis, Zeljko Kereta, Kostas Papafitsoros. This is an extension of the paper ["Why do we regularise in every iteration for imaging inverse problems?"](https://arxiv.org/abs/2411.00688).
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
[](https://mybinder.org/v2/gh/epapoutsellis/Split-Skip-and-Play/HEAD)
|
| 16 |
+
|
| 17 |
+
### Abstract
|
| 18 |
+
Many modern iterative solvers for large-scale tomographic reconstruction incur two major computational costs per iteration: expensive forward/adjoint projections to update the data fidelity term and costly proximal computations for the regulariser, often done via inner iterations. This paper studies for the first time the application of methods that couple randomised skipping of the proximal with variance-reduced subset-based optimisation of data-fit term, to simultaneously reduce both costs in challenging tomographic reconstruction tasks. We provide a series of experiments using both synthetic and real data, demonstrating striking speed-ups of the order 5x--20x compared to the non-skipped counterparts which have been so far the standard approach for efficiently solving these problems. Our work lays the groundwork for broader adoption of these methods in inverse problems.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
### Installation
|
| 22 |
+
|
| 23 |
+
We use the [Core Imaging Library (CIL)](https://github.com/TomographicImaging/CIL) with some additional [new functionalities](#Appendix). Code and installation tested on macOS (Apple M2 Pro), Linux, and Windows 10.
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
conda create --name ssp -c conda-forge python=3.12 "numpy<2.0" cmake scipy six cython numba pillow jupyterlab scikit-learn dask "zarr<3" pywavelets astra-toolbox tqdm nb_conda_kernels
|
| 27 |
+
|
| 28 |
+
conda activate ssp
|
| 29 |
+
pip install bm3d xdesign scikit-image
|
| 30 |
+
pip install "setuptools<82" --upgrade
|
| 31 |
+
|
| 32 |
+
git clone https://github.com/epapoutsellis/StochasticCIL.git
|
| 33 |
+
cd StochasticCIL
|
| 34 |
+
git checkout svrg
|
| 35 |
+
git tag -a v1.0 -m "Version 1.0"
|
| 36 |
+
mkdir build
|
| 37 |
+
cd build
|
| 38 |
+
cmake ../ -DCONDA_BUILD=OFF -DCMAKE_BUILD_TYPE="Release" -DLIBRARY_LIB=$CONDA_PREFIX/lib -DLIBRARY_INC=$CONDA_PREFIX -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
|
| 39 |
+
make install
|
| 40 |
+
```
|
| 41 |
+
For windows: `cmake ../ -DCONDA_BUILD=OFF`, `cmake --build . --target install`
|
| 42 |
+
|
| 43 |
+
### Citation
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
@article{2602.09527,
|
| 47 |
+
Author = {Evangelos Papoutsellis and Zeljko Kereta and Kostas Papafitsoros},
|
| 48 |
+
Title = {Split, Skip and Play: Variance-Reduced ProxSkip for Tomography Reconstruction is Extremely Fast},
|
| 49 |
+
Year = {2026},
|
| 50 |
+
Eprint = {arXiv:2602.09527},
|
| 51 |
+
}
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### New CIL Functionalities
|
| 55 |
+
|
| 56 |
+
1) Splitting methods for Acquisition Data compatible with CIL and [SIRF](https://github.com/SyneRBI/SIRF).
|
| 57 |
+
2) Sampling methods for CIL Stochastic Functions (used also by SPDHG).
|
| 58 |
+
3) ApproximateGradientSumFunction (Base class for CIL Stochastic Functions)
|
| 59 |
+
4) SGFunction
|
| 60 |
+
5) SAGFunction
|
| 61 |
+
6) SAGAFunction
|
| 62 |
+
7) SVRGFunction
|
| 63 |
+
8) LSVRGFunction
|
| 64 |
+
9) PGA (Proximal Gradient Algorithm), base class for GD, ISTA, FISTA. Designed for fixed and/or adaptive Preconditioners and step sizes.
|
| 65 |
+
10) Preconditioner (base class). An instance of Preconditioner can be passed to GD, ISTA, FISTA.
|
| 66 |
+
11) StepSizeMethod (base class for step size search), including standard Armijo/Backtracking.
|
| 67 |
+
12) Callback utilities, including standard metrics (compared to a reference) and statistics of iterates. Any function (CIL), or callables from other libraries can be used.
|
| 68 |
+
13) PD3O (Primal Dual Three Operator Splitting Algorithm) which can be combined with a Stochastic CIL function.
|
| 69 |
+
14) SIRF Priors classes can be used in the CIL Optimisation Framework for free. In addition, SIRF ObjectiveFunction Classes can be used. This allows more flexibility for SIRF users to have control and monitor a CIL algorithm.
|
| 70 |
+
|
| 71 |
+
### Acknowledgements
|
| 72 |
+
|
| 73 |
+
E. Papoutsellis acknowledges funding through the Innovate UK Analysis for Innovators (A4i) program: "Denoising of chemical imaging and tomography data" under which the experiments were initially conducted, CCPi (EPSRC grant EP/T026677/1), CCP SyneRBI (EPSRC grant EP/T026693/1).
|
| 74 |
+
|
TV_Tomography_Reconstruction_VR_Skip.ipynb
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "e82fbdab-f56e-471c-a314-63c97d08d197",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from cil.framework import AcquisitionGeometry\n",
|
| 11 |
+
"from cil.optimisation.algorithms import ISTA, FISTA\n",
|
| 12 |
+
"from cil.optimisation.functions import SVRGFunction, LSVRGFunction, LeastSquares\n",
|
| 13 |
+
"from cil.optimisation.utilities import RandomSampling, MetricsDiagnostics\n",
|
| 14 |
+
"from cil.plugins.astra import FBP, ProjectionOperator\n",
|
| 15 |
+
"from TotalVariation import TotalVariationNew\n",
|
| 16 |
+
"from ProxSkip import ProxSkip\n",
|
| 17 |
+
"from utils import create_circular_mask, StoppingCriterion\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"import numpy as np\n",
|
| 20 |
+
"import zarr\n",
|
| 21 |
+
"import matplotlib.pyplot as plt\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"plt.rcParams['lines.linewidth'] = 5\n",
|
| 24 |
+
"plt.rcParams['lines.markersize'] = 20\n",
|
| 25 |
+
"plt.rcParams['font.size'] = 40\n",
|
| 26 |
+
"plt.rcParams[\"image.cmap\"] = \"inferno\""
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "markdown",
|
| 31 |
+
"id": "e5e68083-d6ff-4cc8-9f4e-6e7fbcd38d43",
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"source": [
|
| 34 |
+
"### Total Variation Reconstruction on a real dataset\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"In this notebook, we play the following game. We start from a high-accuracy reference solution of\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"$$ \n",
|
| 39 |
+
"\\begin{equation}\n",
|
| 40 |
+
"\\min_{x\\in\\mathbb{R}^n} \\; \\frac{1}{2}\\|Ax - b\\|_{2}^{2}\n",
|
| 41 |
+
"+ \\alpha \\mathrm{TV}(x)\n",
|
| 42 |
+
"+ \\mathbb{I}_{\\{x \\ge 0\\}}(x),\n",
|
| 43 |
+
"\\end{equation}\n",
|
| 44 |
+
"$$\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"and we fix a target accuracy threshold $\\varepsilon>0$. The goal is to compare algorithms by **time-to-accuracy**: namely, determine which algorithm reaches $\\frac{\\|x_k - x^*\\|_2}{\\|x^*\\|_2}$ below $\\varepsilon$ in the **shortest runtime**.\n"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "markdown",
|
| 51 |
+
"id": "1d21a834-36d9-4a3c-b036-c9fdf5c3dfbe",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"source": [
|
| 54 |
+
"### Load real dataset"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"id": "ee55e26e-95ac-48d0-91fa-8ffbe38825fc",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"sino = zarr.load(\"data/NiPd_spent_01_microct_rings_removed_2D.zarr\")\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"print(f\"sinogram shape is {sino.shape}\")\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"_, horizontal = sino.shape\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"angles_list = np.linspace(0, np.pi, 800)[::2]\n",
|
| 71 |
+
"ag2D = AcquisitionGeometry.create_Parallel2D().\\\n",
|
| 72 |
+
" set_panel(horizontal).\\\n",
|
| 73 |
+
" set_angles(angles_list, angle_unit=\"radian\").\\\n",
|
| 74 |
+
" set_labels(['angle','horizontal'])\n",
|
| 75 |
+
"ig2D = ag2D.get_ImageGeometry()\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"data2D = ag2D.allocate()\n",
|
| 78 |
+
"data2D.fill(sino)"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "markdown",
|
| 83 |
+
"id": "41b24ec4-f012-4fa1-a2de-95b162228de1",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"source": [
|
| 86 |
+
"### FBP reconstruction"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "code",
|
| 91 |
+
"execution_count": null,
|
| 92 |
+
"id": "f350f391-d445-432e-a772-24810e3a6266",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"fbp = FBP(ig2D, ag2D, device=\"cpu\")(data2D)\n",
|
| 97 |
+
"fbp.array[fbp.array<0] = 0\n",
|
| 98 |
+
"plt.imshow(fbp.array)"
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "markdown",
|
| 103 |
+
"id": "b48aee65-6bae-4520-ae24-35d14555adbf",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"source": [
|
| 106 |
+
"### Load high accuracy solution "
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"id": "79240538-ca62-4f0b-b1e6-df38b8d48f4a",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [],
|
| 115 |
+
"source": [
|
| 116 |
+
"alpha = 0.1"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "code",
|
| 121 |
+
"execution_count": null,
|
| 122 |
+
"id": "7f251eeb-3298-46ab-a8f3-bb8782d9ffd9",
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"outputs": [],
|
| 125 |
+
"source": [
|
| 126 |
+
"pdhg_optimal_info = zarr.open_group(\"data/pdhg_optimal_finden_tv_alpha_{}_explicit_precond_maxiterations_200000.zarr\".format(alpha))\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"pdhg_optimal_np = pdhg_optimal_info[\"solution\"][:]\n",
|
| 129 |
+
"pdhg_optimal_cil = ig2D.allocate()\n",
|
| 130 |
+
"pdhg_optimal_cil.fill(pdhg_optimal_np)"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "markdown",
|
| 135 |
+
"id": "44e0cb98-199d-46b1-9250-d876446a36f1",
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"source": [
|
| 138 |
+
"### Family of considered algorithms (ProxSkip template)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"**Parameters:** $\\gamma>0$, probability $p\\in(0,1]$, data subsets $N$\n",
|
| 141 |
+
"**Initialize:** $x_0, h_0 \\in \\mathbb{R}^n$\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"For $k=0,1,\\dots,K-1$:\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"1. Compute $G_k$ (an unbiased estimator of $\\nabla f(x_k)$).\n",
|
| 146 |
+
"2. $$\\hat x_{k+1} = x_k - \\gamma\\big(G_k(x_k) - h_k\\big).$$\n",
|
| 147 |
+
"3. Sample $\\theta_k\\sim\\mathrm{Bernoulli}(p)$, $\\theta_k\\in{0,1}$.\n",
|
| 148 |
+
"4. If $\\theta_k=1$:\n",
|
| 149 |
+
" $$x_{k+1}=\\mathrm{prox}_{\\frac{\\gamma}{p}g}\\left(\\hat x_{k+1}-\\frac{\\gamma}{p}h_k\\right),$$\n",
|
| 150 |
+
" else:\n",
|
| 151 |
+
" $$x_{k+1}=\\hat x_{k+1}.$$\n",
|
| 152 |
+
"5. Update:\n",
|
| 153 |
+
" $$h_{k+1}=h_k+\\frac{p}{\\gamma}\\big(x_{k+1}-\\hat x_{k+1}\\big).$$\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"---\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"| $p=1$ | $0<p<1$ | $G_k$ |\n",
|
| 158 |
+
"| --------- | ------------- | -------------------------------------------------------------------- |\n",
|
| 159 |
+
"| ISTA | ProxSkip | $\\nabla f(x_k)$ |\n",
|
| 160 |
+
"| ProxSGD | ProxSGDSkip | $N\\nabla f_{i_k}(x_k)$ |\n",
|
| 161 |
+
"| ProxSAGA | ProxSAGASkip | $N(\\nabla f_{i_k}(x_k)-v_k^{,i_k})+\\bar v_k$ |\n",
|
| 162 |
+
"| ProxSVRG | ProxSVRGSkip | $N(\\nabla f_{i_k}(x_k)-\\nabla f_{i_k}(\\tilde x))+\\nabla f(\\tilde x)$ |\n",
|
| 163 |
+
"| ProxLSVRG | ProxLSVRGSkip | as above, updated with $p=1/N$ |\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"*Note:* FISTA is ISTA with an acceleration step (Beck & Teboulle).\n"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "markdown",
|
| 170 |
+
"id": "6079afd7-9f2b-43bf-9a9a-59a589c52b8e",
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"source": [
|
| 173 |
+
"### Define Stopping Criteria and Metrics (PSNR/SSIM)"
|
| 174 |
+
]
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"cell_type": "code",
|
| 178 |
+
"execution_count": null,
|
| 179 |
+
"id": "baa0f235-7dab-430a-87c6-4b6d5a72f31a",
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"outputs": [],
|
| 182 |
+
"source": [
|
| 183 |
+
"h, w = pdhg_optimal_cil.shape\n",
|
| 184 |
+
"mask = create_circular_mask(h, w, center=(170,165), radius=150)\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"def NRSE(x, y, **kwargs):\n",
|
| 187 |
+
" return np.sqrt(np.sum(np.abs(x - y*mask.ravel())**2))/np.sqrt(np.sum(x**2))\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"cb_metrics = MetricsDiagnostics(reference_image=pdhg_optimal_cil*mask, \n",
|
| 190 |
+
" metrics_dict={\"rse\":NRSE}) \n",
|
| 191 |
+
"\n",
|
| 192 |
+
"epsilon = 0.99e-5\n",
|
| 193 |
+
"inner_its = 10\n",
|
| 194 |
+
"epochs = 100\n",
|
| 195 |
+
"max_iteration = 5000\n",
|
| 196 |
+
"seed_skip = 42\n",
|
| 197 |
+
"seed_sampling = 42"
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"cell_type": "code",
|
| 202 |
+
"execution_count": null,
|
| 203 |
+
"id": "f4146831-c967-407d-a1f0-aa0ff7dca979",
|
| 204 |
+
"metadata": {},
|
| 205 |
+
"outputs": [],
|
| 206 |
+
"source": [
|
| 207 |
+
"### Avoid computing objectives, not needed for this demo\n",
|
| 208 |
+
"def update_objective(self):\n",
|
| 209 |
+
" return 0.\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"ISTA.update_objective = update_objective\n",
|
| 212 |
+
"ProxSkip.update_objective = update_objective\n",
|
| 213 |
+
"FISTA.update_objective = update_objective"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "markdown",
|
| 218 |
+
"id": "743c77b3-49a0-46c7-bf43-b322f78b1753",
|
| 219 |
+
"metadata": {},
|
| 220 |
+
"source": [
|
| 221 |
+
"### Run deterministic algorithms: No data splitting, no skipping"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"cell_type": "code",
|
| 226 |
+
"execution_count": null,
|
| 227 |
+
"id": "86570298-273c-4ab5-a04b-96e70df402ac",
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"outputs": [],
|
| 230 |
+
"source": [
|
| 231 |
+
"K = ProjectionOperator(ig2D, ag2D, device=\"cpu\")\n",
|
| 232 |
+
"F = LeastSquares(A=K, b=data2D, c=0.5)\n",
|
| 233 |
+
"G = alpha * TotalVariationNew(max_iteration = inner_its, tolerance=None, correlation='Space',\n",
|
| 234 |
+
" backend='c', lower=0, upper=np.infty, isotropic=True, \n",
|
| 235 |
+
" split=False, info=False, strong_convexity_constant=0,\n",
|
| 236 |
+
" warm_start=True)\n",
|
| 237 |
+
"initial = ig2D.allocate()"
|
| 238 |
+
]
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"cell_type": "markdown",
|
| 242 |
+
"id": "71eca541-0265-4af8-bbbf-c7b791bbe362",
|
| 243 |
+
"metadata": {},
|
| 244 |
+
"source": [
|
| 245 |
+
"### FISTA"
|
| 246 |
+
]
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"cell_type": "code",
|
| 250 |
+
"execution_count": null,
|
| 251 |
+
"id": "f5a3e483-7bdc-45da-b35b-c8147f1c6ce2",
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"outputs": [],
|
| 254 |
+
"source": [
|
| 255 |
+
"step_size = 1./F.L\n",
|
| 256 |
+
"cb_error = StoppingCriterion(epsilon=epsilon) \n",
|
| 257 |
+
"fista = FISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
|
| 258 |
+
" update_objective_interval = 1,\n",
|
| 259 |
+
" max_iteration = max_iteration) \n",
|
| 260 |
+
"fista.run(verbose=0, callback=[cb_metrics, cb_error])"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"cell_type": "markdown",
|
| 265 |
+
"id": "47e90863-2b50-4d61-a9b3-e0918f3bf316",
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"source": [
|
| 268 |
+
"### ISTA"
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": null,
|
| 274 |
+
"id": "f11fc38d-cca0-4fad-bd5c-86f3e3402e41",
|
| 275 |
+
"metadata": {},
|
| 276 |
+
"outputs": [],
|
| 277 |
+
"source": [
|
| 278 |
+
"step_size = 1.99/F.L\n",
|
| 279 |
+
"cb_error = StoppingCriterion(epsilon=epsilon) \n",
|
| 280 |
+
"ista = ISTA(initial = initial, f = F, step_size = step_size, g=G, \n",
|
| 281 |
+
" update_objective_interval = 1,\n",
|
| 282 |
+
" max_iteration = max_iteration) \n",
|
| 283 |
+
"ista.run(verbose=0, callback=[cb_metrics, cb_error])"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "markdown",
|
| 288 |
+
"id": "9dd6d3f8-7418-44b4-9a6a-f831f116cced",
|
| 289 |
+
"metadata": {},
|
| 290 |
+
"source": [
|
| 291 |
+
"### Run ProxSkip: Skip the regulariser"
|
| 292 |
+
]
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"cell_type": "code",
|
| 296 |
+
"execution_count": null,
|
| 297 |
+
"id": "72f458c9-6f9c-4a2d-8877-19a9dc5a638f",
|
| 298 |
+
"metadata": {},
|
| 299 |
+
"outputs": [],
|
| 300 |
+
"source": [
|
| 301 |
+
"prob = 0.1"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"cell_type": "code",
|
| 306 |
+
"execution_count": null,
|
| 307 |
+
"id": "805fd84c-9f20-45ba-8188-957f044cee89",
|
| 308 |
+
"metadata": {},
|
| 309 |
+
"outputs": [],
|
| 310 |
+
"source": [
|
| 311 |
+
"step_size = 1.99/F.L\n",
|
| 312 |
+
"cb_error = StoppingCriterion(epsilon=epsilon) \n",
|
| 313 |
+
"proxskip = ProxSkip(initial = [initial, initial], f = F, step_size = step_size, g=G, \n",
|
| 314 |
+
" update_objective_interval = 1, prob=prob, seed = seed_skip,\n",
|
| 315 |
+
" max_iteration = max_iteration) \n",
|
| 316 |
+
"proxskip.run(verbose=0, callback=[cb_metrics, cb_error])\n"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"cell_type": "markdown",
|
| 321 |
+
"id": "f6d2c3a3-ec14-4b9c-ad87-76e62f1161fe",
|
| 322 |
+
"metadata": {},
|
| 323 |
+
"source": [
|
| 324 |
+
"### Run stochastic algorithms: Data splitting, no skipping"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"cell_type": "code",
|
| 329 |
+
"execution_count": null,
|
| 330 |
+
"id": "8fbc2e91-2a5d-4f50-9f60-c2ced3b41038",
|
| 331 |
+
"metadata": {},
|
| 332 |
+
"outputs": [],
|
| 333 |
+
"source": [
|
| 334 |
+
"def list_of_functions(data):\n",
|
| 335 |
+
" \n",
|
| 336 |
+
" list_funcs = []\n",
|
| 337 |
+
" ig = data[0].geometry.get_ImageGeometry()\n",
|
| 338 |
+
" \n",
|
| 339 |
+
" for d in data:\n",
|
| 340 |
+
" ageom_subset = d.geometry \n",
|
| 341 |
+
" Ai = ProjectionOperator(ig, ageom_subset, device = 'cpu') \n",
|
| 342 |
+
" fi = LeastSquares(Ai, b = d, c = 0.5)\n",
|
| 343 |
+
" list_funcs.append(fi) \n",
|
| 344 |
+
" \n",
|
| 345 |
+
" return list_funcs"
|
| 346 |
+
]
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"cell_type": "code",
|
| 350 |
+
"execution_count": null,
|
| 351 |
+
"id": "2364fecd-aab6-4653-8ba2-cf855a9ebda9",
|
| 352 |
+
"metadata": {},
|
| 353 |
+
"outputs": [],
|
| 354 |
+
"source": [
|
| 355 |
+
"nsub = 200\n",
|
| 356 |
+
"data_split, method = data2D.split_to_subsets(nsub, method= \"ordered\", info=True)\n",
|
| 357 |
+
"list_func = list_of_functions(data_split) "
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"cell_type": "code",
|
| 362 |
+
"execution_count": null,
|
| 363 |
+
"id": "a25d5716-cba8-4c78-9327-043d351ed2b1",
|
| 364 |
+
"metadata": {},
|
| 365 |
+
"outputs": [],
|
| 366 |
+
"source": [
|
| 367 |
+
"selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
|
| 368 |
+
"Fsvrg = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
|
| 369 |
+
"Fsvrg.initial = initial\n",
|
| 370 |
+
"step_size = 1./(Fsvrg.L)\n",
|
| 371 |
+
"cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
|
| 372 |
+
"prox_svrg = ISTA(initial = initial, f = Fsvrg, step_size = step_size, g=G, \n",
|
| 373 |
+
" update_objective_interval = 1, \n",
|
| 374 |
+
" max_iteration = max_iteration) \n",
|
| 375 |
+
"prox_svrg.run(verbose=0, callback=[cb_metrics, cb_error])"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"cell_type": "code",
|
| 380 |
+
"execution_count": null,
|
| 381 |
+
"id": "1dce5d0d-4873-40f9-8339-d35fd5eac23f",
|
| 382 |
+
"metadata": {},
|
| 383 |
+
"outputs": [],
|
| 384 |
+
"source": [
|
| 385 |
+
"selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
|
| 386 |
+
"Flsvrg = LSVRGFunction(list_func, selection = selection, update_prob=1./len(list_func))\n",
|
| 387 |
+
"Flsvrg.initial = initial\n",
|
| 388 |
+
"step_size = 1./(Fsvrg.L)\n",
|
| 389 |
+
"cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
|
| 390 |
+
"prox_lsvrg = ISTA(initial = initial, f = Flsvrg, step_size = step_size, g=G, \n",
|
| 391 |
+
" update_objective_interval = 1, \n",
|
| 392 |
+
" max_iteration = max_iteration) \n",
|
| 393 |
+
"prox_lsvrg.run(verbose=0, callback=[cb_metrics, cb_error])"
|
| 394 |
+
]
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"cell_type": "markdown",
|
| 398 |
+
"id": "59988152-5f09-4950-a770-a8b4d435ee44",
|
| 399 |
+
"metadata": {},
|
| 400 |
+
"source": [
|
| 401 |
+
"### Run stochastic algorithms: Data splitting and skipping"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"cell_type": "code",
|
| 406 |
+
"execution_count": null,
|
| 407 |
+
"id": "4d74a1ae-4f52-4e04-a4a4-03b4488f5fef",
|
| 408 |
+
"metadata": {},
|
| 409 |
+
"outputs": [],
|
| 410 |
+
"source": [
|
| 411 |
+
"selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
|
| 412 |
+
"Fsvrg_skip = SVRGFunction(list_func, selection = selection, update_frequency=len(list_func))\n",
|
| 413 |
+
"Fsvrg_skip.initial = initial\n",
|
| 414 |
+
"step_size = 1./(Fsvrg_skip.L)\n",
|
| 415 |
+
"cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
|
| 416 |
+
"prox_svrg_skip = ProxSkip(initial = [initial, initial], f = Fsvrg_skip, step_size = step_size, g=G, \n",
|
| 417 |
+
" update_objective_interval = 1, prob = prob,\n",
|
| 418 |
+
" max_iteration = max_iteration) \n",
|
| 419 |
+
"prox_svrg_skip.run(verbose=0, callback=[cb_metrics, cb_error])"
|
| 420 |
+
]
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
"cell_type": "code",
|
| 424 |
+
"execution_count": null,
|
| 425 |
+
"id": "2244d1db-1338-46ca-8355-bfcc03ce7e56",
|
| 426 |
+
"metadata": {},
|
| 427 |
+
"outputs": [],
|
| 428 |
+
"source": [
|
| 429 |
+
"selection = RandomSampling(len(list_func), nsub, seed=seed_sampling)\n",
|
| 430 |
+
"Flsvrg_skip = LSVRGFunction(list_func, selection = selection, update_prob=1./len(list_func))\n",
|
| 431 |
+
"Flsvrg_skip.initial = initial\n",
|
| 432 |
+
"step_size = 1./(Flsvrg_skip.L)\n",
|
| 433 |
+
"cb_error = StoppingCriterion(epsilon=epsilon, epochs = 100) \n",
|
| 434 |
+
"prox_lsvrg_skip = ProxSkip(initial = [initial, initial], f = Flsvrg_skip, step_size = step_size, g=G, \n",
|
| 435 |
+
" update_objective_interval = 1, prob = prob,\n",
|
| 436 |
+
" max_iteration = max_iteration) \n",
|
| 437 |
+
"prox_lsvrg_skip.run(verbose=0, callback=[cb_metrics, cb_error])"
|
| 438 |
+
]
|
| 439 |
+
},
|
| 440 |
+
{
|
| 441 |
+
"cell_type": "markdown",
|
| 442 |
+
"id": "f0d1a611-d16c-432e-a6e5-56cae3b7e640",
|
| 443 |
+
"metadata": {},
|
| 444 |
+
"source": [
|
| 445 |
+
"### Plot PSNR/SSIM progress\n",
|
| 446 |
+
"- with respect to iteration\n",
|
| 447 |
+
"- with respect to time"
|
| 448 |
+
]
|
| 449 |
+
},
|
| 450 |
+
{
|
| 451 |
+
"cell_type": "code",
|
| 452 |
+
"execution_count": null,
|
| 453 |
+
"id": "ee4475aa-9018-427e-82c4-63d0f487e899",
|
| 454 |
+
"metadata": {},
|
| 455 |
+
"outputs": [],
|
| 456 |
+
"source": [
|
| 457 |
+
"def find_iter_to_error(algo, error):\n",
|
| 458 |
+
" rse = np.asarray(algo.rse)\n",
|
| 459 |
+
" idx = np.where(rse < error)[0]\n",
|
| 460 |
+
" return np.nan if idx.size == 0 else int(idx[0])\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"def time_to_iter(algo, it):\n",
|
| 463 |
+
" \"\"\"Time up to and including iteration it.\"\"\"\n",
|
| 464 |
+
" if it is None or (isinstance(it, float) and np.isnan(it)):\n",
|
| 465 |
+
" return np.nan\n",
|
| 466 |
+
" it = int(it)\n",
|
| 467 |
+
" return float(np.sum(np.asarray(algo.timing)[:it+1]))\n",
|
| 468 |
+
"\n",
|
| 469 |
+
"def time_to_error(algo, error):\n",
|
| 470 |
+
" it = find_iter_to_error(algo, error)\n",
|
| 471 |
+
" return time_to_iter(algo, it)"
|
| 472 |
+
]
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"cell_type": "code",
|
| 476 |
+
"execution_count": null,
|
| 477 |
+
"id": "a74d852a-744f-46ee-927b-2152c130bd6d",
|
| 478 |
+
"metadata": {},
|
| 479 |
+
"outputs": [],
|
| 480 |
+
"source": [
|
| 481 |
+
"t_ista = time_to_error(ista, epsilon)\n",
|
| 482 |
+
"t_proxskip = time_to_error(proxskip, epsilon)\n",
|
| 483 |
+
"t_prox_svrg =time_to_error(prox_svrg, epsilon)\n",
|
| 484 |
+
"t_prox_lsvrg = time_to_error(prox_lsvrg, epsilon)\n",
|
| 485 |
+
"t_prox_svrg_skip = time_to_error(prox_svrg_skip, epsilon)\n",
|
| 486 |
+
"t_prox_lsvrg_skip = time_to_error(prox_lsvrg_skip, epsilon)"
|
| 487 |
+
]
|
| 488 |
+
},
|
| 489 |
+
{
|
| 490 |
+
"cell_type": "code",
|
| 491 |
+
"execution_count": null,
|
| 492 |
+
"id": "f28efd88-80de-494b-a991-207df0444ad8",
|
| 493 |
+
"metadata": {},
|
| 494 |
+
"outputs": [],
|
| 495 |
+
"source": [
|
| 496 |
+
"fig, ax = plt.subplots(2,1,figsize=(30, 25))\n",
|
| 497 |
+
"\n",
|
| 498 |
+
"fig.subplots_adjust(top=0.80)\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"ax[0].semilogy(prox_svrg.rse[:-1],label=f\"ProxSVRG (N=200)\")\n",
|
| 501 |
+
"ax[0].semilogy(prox_lsvrg.rse[:-1],label=f\"ProxLSVRG (N=200)\")\n",
|
| 502 |
+
"ax[0].semilogy(proxskip.rse[:-1],label=f\"ProxSkip (p=0.1)\")\n",
|
| 503 |
+
"ax[0].semilogy(prox_svrg_skip.rse[:-1],label=f\"ProxSVRGSkip (N=200, p=0.05)\")\n",
|
| 504 |
+
"ax[0].semilogy(prox_lsvrg_skip.rse[:-1],label=f\"ProxLSVRGSkip (N=200, p=0.05)\")\n",
|
| 505 |
+
"ax[0].semilogy(ista.rse[:-1],label=f\"ISTA\")\n",
|
| 506 |
+
"ax[0].semilogy(fista.rse[:-1],label=f\"FISTA\")\n",
|
| 507 |
+
"ax[0].set_xlabel(\"Iteration\")\n",
|
| 508 |
+
"ax[0].set_ylabel(r\"$\\frac{\\|x_{k} - x^{*}\\|_{2}}{\\|x^{*}\\|_{2}}$\")\n",
|
| 509 |
+
"ax[0].grid(True, which=\"major\")\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"ax[0].text(\n",
|
| 512 |
+
" 0.25, 0.9, \"Proximal-TV with 10 iterations\",\n",
|
| 513 |
+
" transform=ax[0].transAxes,\n",
|
| 514 |
+
" ha=\"left\", va=\"top\",\n",
|
| 515 |
+
" fontsize=45,\n",
|
| 516 |
+
" bbox=dict(boxstyle=\"round,pad=0.35\", facecolor=\"white\", edgecolor=\"none\")\n",
|
| 517 |
+
")\n",
|
| 518 |
+
"ax[0].legend(ncols=2, loc=\"lower center\", bbox_to_anchor=(0.5, 1.00), frameon=True)\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"ax[1].semilogy(np.cumsum(prox_svrg.timing), prox_svrg.rse[:-1],label=f\"ProxSVRG (N=200)\")\n",
|
| 521 |
+
"ax[1].semilogy(np.cumsum(prox_lsvrg.timing), prox_lsvrg.rse[:-1],label=f\"ProxLSVRG (N=200)\")\n",
|
| 522 |
+
"ax[1].semilogy(np.cumsum(proxskip.timing), proxskip.rse[:-1],label=f\"ProxSkip (p=0.1)\")\n",
|
| 523 |
+
"ax[1].semilogy(np.cumsum(prox_svrg_skip.timing), prox_svrg_skip.rse[:-1],label=f\"ProxSVRGSkip (N=200, p=0.05)\")\n",
|
| 524 |
+
"ax[1].semilogy(np.cumsum(prox_lsvrg_skip.timing), prox_lsvrg_skip.rse[:-1],label=f\"ProxLSVRGSkip (N=200, p=0.05)\")\n",
|
| 525 |
+
"ax[1].semilogy(np.cumsum(ista.timing), ista.rse[:-1],label=f\"ISTA\")\n",
|
| 526 |
+
"ax[1].semilogy(np.cumsum(fista.timing), fista.rse[:-1],label=f\"FISTA\")\n",
|
| 527 |
+
"ax[1].set_xlabel(\"Time (sec)\")\n",
|
| 528 |
+
"ax[1].set_ylabel(r\"$\\frac{\\|x_{k} - x^{*}\\|_{2}}{\\|x^{*}\\|_{2}}$\")\n",
|
| 529 |
+
"ax[1].grid(True, which=\"major\")\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"plt.show()"
|
| 532 |
+
]
|
| 533 |
+
}
|
| 534 |
+
],
|
| 535 |
+
"metadata": {
|
| 536 |
+
"kernelspec": {
|
| 537 |
+
"display_name": "Python [conda env:ssp]",
|
| 538 |
+
"language": "python",
|
| 539 |
+
"name": "conda-env-ssp-py"
|
| 540 |
+
},
|
| 541 |
+
"language_info": {
|
| 542 |
+
"codemirror_mode": {
|
| 543 |
+
"name": "ipython",
|
| 544 |
+
"version": 3
|
| 545 |
+
},
|
| 546 |
+
"file_extension": ".py",
|
| 547 |
+
"mimetype": "text/x-python",
|
| 548 |
+
"name": "python",
|
| 549 |
+
"nbconvert_exporter": "python",
|
| 550 |
+
"pygments_lexer": "ipython3",
|
| 551 |
+
"version": "3.12.12"
|
| 552 |
+
}
|
| 553 |
+
},
|
| 554 |
+
"nbformat": 4,
|
| 555 |
+
"nbformat_minor": 5
|
| 556 |
+
}
|
TotalVariation.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2020 United Kingdom Research and Innovation
|
| 3 |
+
# Copyright 2020 The University of Manchester
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
# Authors:
|
| 18 |
+
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
|
| 19 |
+
# Claire Delplancke (University of Bath)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from cil.optimisation.functions import Function, IndicatorBox, MixedL21Norm, MixedL11Norm
|
| 23 |
+
from cil.optimisation.operators import GradientOperator
|
| 24 |
+
import numpy as np
|
| 25 |
+
from numbers import Number
|
| 26 |
+
import warnings
|
| 27 |
+
import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TotalVariationNew(Function):
|
| 31 |
+
|
| 32 |
+
r""" Total variation Function
|
| 33 |
+
|
| 34 |
+
.. math:: \mathrm{TV}(u) := \|\nabla u\|_{2,1} = \sum \|\nabla u\|_{2},\, (\mbox{isotropic})
|
| 35 |
+
|
| 36 |
+
.. math:: \mathrm{TV}(u) := \|\nabla u\|_{1,1} = \sum \|\nabla u\|_{1}\, (\mbox{anisotropic})
|
| 37 |
+
|
| 38 |
+
Notes
|
| 39 |
+
-----
|
| 40 |
+
|
| 41 |
+
The :code:`TotalVariation` (TV) :code:`Function` acts as a composite function, i.e.,
|
| 42 |
+
the composition of the :class:`.MixedL21Norm` function and the :class:`.GradientOperator` operator,
|
| 43 |
+
|
| 44 |
+
.. math:: f(u) = \|u\|_{2,1}, \Rightarrow (f\circ\nabla)(u) = f(\nabla x) = \mathrm{TV}(u)
|
| 45 |
+
|
| 46 |
+
In that case, the proximal operator of TV does not have an exact solution and we use an iterative
|
| 47 |
+
algorithm to solve:
|
| 48 |
+
|
| 49 |
+
.. math:: \mathrm{prox}_{\tau \mathrm{TV}}(b) := \underset{u}{\mathrm{argmin}} \frac{1}{2\tau}\|u - b\|^{2} + \mathrm{TV}(u)
|
| 50 |
+
|
| 51 |
+
The algorithm used for the proximal operator of TV is the Fast Gradient Projection algorithm (or FISTA)
|
| 52 |
+
applied to the _dual problem_ of the above problem, see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`, :cite:`Zhu2010`.
|
| 53 |
+
|
| 54 |
+
See also "Multicontrast MRI Reconstruction with Structure-Guided Total Variation", Ehrhardt, Betcke, 2016.
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
Parameters
|
| 58 |
+
----------
|
| 59 |
+
|
| 60 |
+
max_iteration : :obj:`int`, default = 5
|
| 61 |
+
Maximum number of iterations for the FGP algorithm to solve to solve the dual problem
|
| 62 |
+
of the Total Variation Denoising problem (ROF). If warm_start=False, this should be around 100,
|
| 63 |
+
or larger, with a set tolerance.
|
| 64 |
+
tolerance : :obj:`float`, default = None
|
| 65 |
+
Stopping criterion for the FGP algorithm used to to solve the dual problem
|
| 66 |
+
of the Total Variation Denoising problem (ROF). If the difference between iterates in the FGP algorithm is less than the tolerance
|
| 67 |
+
the iterations end before the max_iteration number.
|
| 68 |
+
|
| 69 |
+
.. math:: \|x^{k+1} - x^{k}\|_{2} < \mathrm{tolerance}
|
| 70 |
+
|
| 71 |
+
correlation : :obj:`str`, default = `Space`
|
| 72 |
+
Correlation between `Space` and/or `SpaceChannels` for the :class:`.GradientOperator`.
|
| 73 |
+
backend : :obj:`str`, default = `c`
|
| 74 |
+
Backend to compute the :class:`.GradientOperator`
|
| 75 |
+
lower : :obj:`'float`, default = None
|
| 76 |
+
A constraint is enforced using the :class:`.IndicatorBox` function, e.g., :code:`IndicatorBox(lower, upper)`.
|
| 77 |
+
upper : :obj:`'float`, default = None
|
| 78 |
+
A constraint is enforced using the :class:`.IndicatorBox` function, e.g., :code:`IndicatorBox(lower, upper)`.
|
| 79 |
+
isotropic : :obj:`boolean`, default = True
|
| 80 |
+
Use either isotropic or anisotropic definition of TV.
|
| 81 |
+
|
| 82 |
+
.. math:: |x|_{2} = \sqrt{x_{1}^{2} + x_{2}^{2}},\, (\mbox{isotropic})
|
| 83 |
+
|
| 84 |
+
.. math:: |x|_{1} = |x_{1}| + |x_{2}|\, (\mbox{anisotropic})
|
| 85 |
+
|
| 86 |
+
split : :obj:`boolean`, default = False
|
| 87 |
+
Splits the Gradient into spatial gradient and spectral or temporal gradient for multichannel data.
|
| 88 |
+
|
| 89 |
+
info : :obj:`boolean`, default = False
|
| 90 |
+
Information is printed for the stopping criterion of the FGP algorithm used to solve the dual problem
|
| 91 |
+
of the Total Variation Denoising problem (ROF).
|
| 92 |
+
|
| 93 |
+
strong_convexity_constant : :obj:`float`, default = 0
|
| 94 |
+
A strongly convex term weighted by the :code:`strong_convexity_constant` (:math:`\gamma`) parameter is added to the Total variation.
|
| 95 |
+
Now the :code:`TotalVariation` function is :math:`\gamma` - strongly convex and the proximal operator is
|
| 96 |
+
|
| 97 |
+
.. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\tau}\|u - b\|^{2} + \mathrm{TV}(u) + \frac{\gamma}{2}\|u\|^{2} \Leftrightarrow
|
| 98 |
+
|
| 99 |
+
.. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\frac{\tau}{1+\gamma\tau}}\|u - \frac{b}{1+\gamma\tau}\|^{2} + \mathrm{TV}(u)
|
| 100 |
+
|
| 101 |
+
warm_start : :obj:`boolean`, default = True
|
| 102 |
+
If set to true, the FGP algorithm used to solve the dual problem of the Total Variation Denoising problem (ROF) is initiated by the final value from the previous iteration and not at zero.
|
| 103 |
+
This allows the max_iteration value to be reduced to 5-10 iterations.
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
Note
|
| 107 |
+
----
|
| 108 |
+
|
| 109 |
+
With warm_start set to the default, True, the TV function will keep in memory the range of the gradient of the image to be denoised, i.e. N times the dimensionality of the image. This increases the memory requirements.
|
| 110 |
+
However, during the evaluation of `proximal` the memory requirements will be unchanged as the same amount of memory will need to be allocated and deallocated.
|
| 111 |
+
|
| 112 |
+
Note
|
| 113 |
+
----
|
| 114 |
+
|
| 115 |
+
In the case where the Total variation becomes a :math:`\gamma` - strongly convex function, i.e.,
|
| 116 |
+
|
| 117 |
+
.. math:: \mathrm{TV}(u) + \frac{\gamma}{2}\|u\|^{2}
|
| 118 |
+
|
| 119 |
+
:math:`\gamma` should be relatively small, so as the second term above will not act as an additional regulariser.
|
| 120 |
+
For more information, see :cite:`Rasch2020`, :cite:`CP2011`.
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
Examples
|
| 126 |
+
--------
|
| 127 |
+
|
| 128 |
+
.. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + \alpha\|\nabla u\|_{2,1}
|
| 129 |
+
|
| 130 |
+
>>> alpha = 2.0
|
| 131 |
+
>>> TV = TotalVariation()
|
| 132 |
+
>>> sol = TV.proximal(b, tau = alpha)
|
| 133 |
+
|
| 134 |
+
Examples
|
| 135 |
+
--------
|
| 136 |
+
|
| 137 |
+
.. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + \alpha\|\nabla u\|_{1,1} + \mathbb{I}_{C}(u)
|
| 138 |
+
|
| 139 |
+
where :math:`C = \{1.0\leq u\leq 2.0\}`.
|
| 140 |
+
|
| 141 |
+
>>> alpha = 2.0
|
| 142 |
+
>>> TV = TotalVariation(isotropic=False, lower=1.0, upper=2.0)
|
| 143 |
+
>>> sol = TV.proximal(b, tau = alpha)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
Examples
|
| 147 |
+
--------
|
| 148 |
+
|
| 149 |
+
.. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + (\alpha\|\nabla u\|_{2,1} + \frac{\gamma}{2}\|u\|^{2})
|
| 150 |
+
|
| 151 |
+
>>> alpha = 2.0
|
| 152 |
+
>>> gamma = 1e-3
|
| 153 |
+
>>> TV = alpha * TotalVariation(isotropic=False, strong_convexity_constant=gamma)
|
| 154 |
+
>>> sol = TV.proximal(b, tau = 1.0)
|
| 155 |
+
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self,
|
| 159 |
+
max_iteration=10,
|
| 160 |
+
tolerance=None,
|
| 161 |
+
correlation="Space",
|
| 162 |
+
backend="c",
|
| 163 |
+
lower=None,
|
| 164 |
+
upper=None,
|
| 165 |
+
isotropic=True,
|
| 166 |
+
split=False,
|
| 167 |
+
info=False,
|
| 168 |
+
strong_convexity_constant=0,
|
| 169 |
+
warm_start=True):
|
| 170 |
+
|
| 171 |
+
super(TotalVariationNew, self).__init__(L=None)
|
| 172 |
+
|
| 173 |
+
# Regularising parameter = alpha
|
| 174 |
+
self.regularisation_parameter = 1.
|
| 175 |
+
|
| 176 |
+
self.iterations = max_iteration
|
| 177 |
+
|
| 178 |
+
self.tolerance = tolerance
|
| 179 |
+
|
| 180 |
+
# Total variation correlation (isotropic=Default)
|
| 181 |
+
self.isotropic = isotropic
|
| 182 |
+
|
| 183 |
+
# correlation space or spacechannels
|
| 184 |
+
self.correlation = correlation
|
| 185 |
+
self.backend = backend
|
| 186 |
+
|
| 187 |
+
# Define orthogonal projection onto the convex set C
|
| 188 |
+
if lower is None:
|
| 189 |
+
lower = -np.inf
|
| 190 |
+
if upper is None:
|
| 191 |
+
upper = np.inf
|
| 192 |
+
self.lower = lower
|
| 193 |
+
self.upper = upper
|
| 194 |
+
self.projection_C = IndicatorBox(lower, upper).proximal
|
| 195 |
+
|
| 196 |
+
# Setup GradientOperator as None. This is to avoid domain argument in the __init__
|
| 197 |
+
self._gradient = None
|
| 198 |
+
self._domain = None
|
| 199 |
+
|
| 200 |
+
self.info = info
|
| 201 |
+
if self.info:
|
| 202 |
+
warnings.warn(" `info` is deprecate. Please use logging instead.")
|
| 203 |
+
|
| 204 |
+
# splitting Gradient
|
| 205 |
+
self.split = split
|
| 206 |
+
|
| 207 |
+
# For the warm_start functionality
|
| 208 |
+
self.warm_start = warm_start
|
| 209 |
+
self._p2 = None
|
| 210 |
+
|
| 211 |
+
# Strong convexity for TV
|
| 212 |
+
self.strong_convexity_constant = strong_convexity_constant
|
| 213 |
+
|
| 214 |
+
# Define Total variation norm
|
| 215 |
+
if self.isotropic:
|
| 216 |
+
self.func = MixedL21Norm()
|
| 217 |
+
else:
|
| 218 |
+
self.func = MixedL11Norm()
|
| 219 |
+
|
| 220 |
+
def _get_p2(self):
|
| 221 |
+
r"""The initial value for the dual in the proximal calculation - allocated to zero in the case of warm_start=False
|
| 222 |
+
or initialised as the last iterate seen in the proximal calculation in the case warm_start=True ."""
|
| 223 |
+
|
| 224 |
+
if self._p2 is None:
|
| 225 |
+
return self.gradient.range_geometry().allocate(0)
|
| 226 |
+
else:
|
| 227 |
+
return self._p2
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def regularisation_parameter(self):
|
| 231 |
+
return self._regularisation_parameter
|
| 232 |
+
|
| 233 |
+
@regularisation_parameter.setter
|
| 234 |
+
def regularisation_parameter(self, value):
|
| 235 |
+
if not isinstance(value, Number):
|
| 236 |
+
raise TypeError(
|
| 237 |
+
"regularisation_parameter: expected a number, got {}".format(type(value)))
|
| 238 |
+
self._regularisation_parameter = value
|
| 239 |
+
|
| 240 |
+
def __call__(self, x):
|
| 241 |
+
r""" Returns the value of the TotalVariation function at :code:`x` ."""
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
self._domain = x.geometry
|
| 245 |
+
except:
|
| 246 |
+
self._domain = x
|
| 247 |
+
|
| 248 |
+
# Compute Lipschitz constant provided that domain is not None.
|
| 249 |
+
# Lipschitz constant dependes on the GradientOperator, which is configured only if domain is not None
|
| 250 |
+
if self._L is None:
|
| 251 |
+
self.calculate_Lipschitz()
|
| 252 |
+
|
| 253 |
+
if self.strong_convexity_constant > 0:
|
| 254 |
+
strongly_convex_term = (
|
| 255 |
+
self.strong_convexity_constant/2)*x.squared_norm()
|
| 256 |
+
else:
|
| 257 |
+
strongly_convex_term = 0
|
| 258 |
+
|
| 259 |
+
return self.regularisation_parameter * self.func(self.gradient.direct(x)) + strongly_convex_term
|
| 260 |
+
|
| 261 |
+
def proximal(self, x, tau, out=None):
|
| 262 |
+
r""" Returns the proximal operator of the TotalVariation function at :code:`x` ."""
|
| 263 |
+
|
| 264 |
+
if self.strong_convexity_constant > 0:
|
| 265 |
+
|
| 266 |
+
strongly_convex_factor = (1 + tau * self.strong_convexity_constant)
|
| 267 |
+
x /= strongly_convex_factor
|
| 268 |
+
tau /= strongly_convex_factor
|
| 269 |
+
|
| 270 |
+
if out is None:
|
| 271 |
+
solution = self._fista_on_dual_rof(x, tau)
|
| 272 |
+
else:
|
| 273 |
+
self._fista_on_dual_rof(x, tau, out=out)
|
| 274 |
+
|
| 275 |
+
if self.strong_convexity_constant > 0:
|
| 276 |
+
x *= strongly_convex_factor
|
| 277 |
+
tau *= strongly_convex_factor
|
| 278 |
+
|
| 279 |
+
if out is None:
|
| 280 |
+
return solution
|
| 281 |
+
|
| 282 |
+
def _fista_on_dual_rof(self, x, tau, out=None):
|
| 283 |
+
r""" Runs the Fast Gradient Projection (FGP) algorithm to solve the dual problem
|
| 284 |
+
of the Total Variation Denoising problem (ROF).
|
| 285 |
+
|
| 286 |
+
.. math: \max_{\|y\|_{\infty}<=1.} \frac{1}{2}\|\nabla^{*} y + x \|^{2} - \frac{1}{2}\|x\|^{2}
|
| 287 |
+
|
| 288 |
+
"""
|
| 289 |
+
try:
|
| 290 |
+
self._domain = x.geometry
|
| 291 |
+
except:
|
| 292 |
+
self._domain = x
|
| 293 |
+
|
| 294 |
+
# Compute Lipschitz constant provided that domain is not None.
|
| 295 |
+
# Lipschitz constant depends on the GradientOperator, which is configured only if domain is not None
|
| 296 |
+
if self._L is None:
|
| 297 |
+
self.calculate_Lipschitz()
|
| 298 |
+
|
| 299 |
+
# initialise
|
| 300 |
+
t = 1
|
| 301 |
+
|
| 302 |
+
# dual variable - its content is overwritten during iterations
|
| 303 |
+
p1 = self.gradient.range_geometry().allocate(None)
|
| 304 |
+
p2 = self._get_p2()
|
| 305 |
+
tmp_q = p2.copy()
|
| 306 |
+
|
| 307 |
+
# multiply tau by -1 * regularisation_parameter here so it's not recomputed every iteration
|
| 308 |
+
# when tau is an array this is done inplace so reverted at the end
|
| 309 |
+
if isinstance(tau, Number):
|
| 310 |
+
tau_reg_neg = -self.regularisation_parameter * tau
|
| 311 |
+
else:
|
| 312 |
+
tau_reg_neg = tau
|
| 313 |
+
tau.multiply(-self.regularisation_parameter, out=tau_reg_neg)
|
| 314 |
+
|
| 315 |
+
should_return = False
|
| 316 |
+
if out is None:
|
| 317 |
+
should_return = True
|
| 318 |
+
out = self.gradient.domain_geometry().allocate(0)
|
| 319 |
+
|
| 320 |
+
for k in range(self.iterations):
|
| 321 |
+
|
| 322 |
+
t0 = t
|
| 323 |
+
self.gradient.adjoint(tmp_q, out=out)
|
| 324 |
+
out.sapyb(tau_reg_neg, x, 1.0, out=out)
|
| 325 |
+
self.projection_C(out, tau=None, out=out)
|
| 326 |
+
|
| 327 |
+
self.gradient.direct(out, out=p1)
|
| 328 |
+
|
| 329 |
+
multip = (-self.L)/tau_reg_neg
|
| 330 |
+
|
| 331 |
+
tmp_q.sapyb(1., p1, multip, out=tmp_q)
|
| 332 |
+
|
| 333 |
+
if self.tolerance is not None and k % 5 == 0:
|
| 334 |
+
p1 *= multip
|
| 335 |
+
error = p1.norm()
|
| 336 |
+
error /= tmp_q.norm()
|
| 337 |
+
if error < self.tolerance:
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
self.func.proximal_conjugate(tmp_q, 1.0, out=p1)
|
| 341 |
+
|
| 342 |
+
t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2
|
| 343 |
+
p1.subtract(p2, out=tmp_q)
|
| 344 |
+
tmp_q *= (t0-1)/t
|
| 345 |
+
tmp_q += p1
|
| 346 |
+
|
| 347 |
+
# switch p1 and p2 references
|
| 348 |
+
tmp = p1
|
| 349 |
+
p1 = p2
|
| 350 |
+
p2 = tmp
|
| 351 |
+
if self.warm_start:
|
| 352 |
+
self._p2 = p2
|
| 353 |
+
|
| 354 |
+
if self.info:
|
| 355 |
+
if self.tolerance is not None:
|
| 356 |
+
logging.info(
|
| 357 |
+
"Stop at {} iterations with tolerance {} .".format(k, error))
|
| 358 |
+
else:
|
| 359 |
+
logging.info("Stop at {} iterations.".format(k))
|
| 360 |
+
|
| 361 |
+
# return tau to its original state if it was modified
|
| 362 |
+
if id(tau_reg_neg) == id(tau):
|
| 363 |
+
tau_reg_neg.divide(-self.regularisation_parameter, out=tau)
|
| 364 |
+
|
| 365 |
+
if should_return:
|
| 366 |
+
return out
|
| 367 |
+
|
| 368 |
+
def convex_conjugate(self, x):
|
| 369 |
+
r""" Returns the value of convex conjugate of the TotalVariation function at :code:`x` ."""
|
| 370 |
+
return 0.0
|
| 371 |
+
|
| 372 |
+
def calculate_Lipschitz(self):
|
| 373 |
+
r""" Default value for the Lipschitz constant."""
|
| 374 |
+
|
| 375 |
+
# Compute the Lipschitz parameter from the operator if possible
|
| 376 |
+
# Leave it initialised to None otherwise
|
| 377 |
+
self._L = (1./self.gradient.norm())**2
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def gradient(self):
|
| 381 |
+
r""" GradientOperator is created if it is not instantiated yet. The domain of the `_gradient`,
|
| 382 |
+
is created in the `__call__` and `proximal` methods.
|
| 383 |
+
|
| 384 |
+
"""
|
| 385 |
+
if self._domain is not None:
|
| 386 |
+
self._gradient = GradientOperator(
|
| 387 |
+
self._domain, correlation=self.correlation, backend=self.backend)
|
| 388 |
+
else:
|
| 389 |
+
raise ValueError(
|
| 390 |
+
" The domain of the TotalVariation is {}. Please use the __call__ or proximal methods first before calling gradient.".format(self._domain))
|
| 391 |
+
|
| 392 |
+
return self._gradient
|
| 393 |
+
|
| 394 |
+
def __rmul__(self, scalar):
|
| 395 |
+
if not isinstance(scalar, Number):
|
| 396 |
+
raise TypeError(
|
| 397 |
+
"scalar: Expected a number, got {}".format(type(scalar)))
|
| 398 |
+
self.regularisation_parameter *= scalar
|
| 399 |
+
return self
|
binder/.ipynb_checkpoints/environment-checkpoint.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ssp
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- python=3.12
|
| 6 |
+
- "numpy<2.0"
|
| 7 |
+
- "zarr<3"
|
| 8 |
+
|
| 9 |
+
- cmake
|
| 10 |
+
- scipy
|
| 11 |
+
- six
|
| 12 |
+
- cython
|
| 13 |
+
- numba
|
| 14 |
+
- pillow
|
| 15 |
+
- pywavelets
|
| 16 |
+
- astra-toolbox
|
| 17 |
+
- tqdm
|
| 18 |
+
- "setuptools<82"
|
| 19 |
+
- git
|
| 20 |
+
|
| 21 |
+
# pip deps
|
| 22 |
+
- pip
|
| 23 |
+
- wheel
|
| 24 |
+
- pip:
|
| 25 |
+
- bm3d
|
| 26 |
+
- xdesign
|
| 27 |
+
- scikit-image
|
binder/.ipynb_checkpoints/postBuild-checkpoint
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euxo pipefail
|
| 3 |
+
|
| 4 |
+
python -m ipykernel install --user --name ssp --display-name "Python (ssp)"
|
| 5 |
+
|
| 6 |
+
# Pick the active environment prefix in Binder
|
| 7 |
+
PREFIX="${CONDA_PREFIX:-${NB_PYTHON_PREFIX:-}}"
|
| 8 |
+
if [ -z "${PREFIX}" ]; then
|
| 9 |
+
PREFIX="$(python -c 'import sys, os; print(os.path.dirname(os.path.dirname(sys.executable)))')"
|
| 10 |
+
fi
|
| 11 |
+
echo "Using PREFIX=${PREFIX}"
|
| 12 |
+
|
| 13 |
+
if [ ! -d "StochasticCIL" ]; then
|
| 14 |
+
git clone https://github.com/epapoutsellis/StochasticCIL.git
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
cd StochasticCIL
|
| 18 |
+
git fetch --all --tags
|
| 19 |
+
git checkout svrg
|
| 20 |
+
|
| 21 |
+
# Ensure annotated tag for `git describe`
|
| 22 |
+
git config user.email "binder@local"
|
| 23 |
+
git config user.name "Binder Build"
|
| 24 |
+
|
| 25 |
+
if git rev-parse -q --verify refs/tags/v1.0 >/dev/null; then
|
| 26 |
+
if [ "$(git cat-file -t v1.0)" != "tag" ]; then
|
| 27 |
+
git tag -d v1.0
|
| 28 |
+
git tag -a v1.0 -m "Version 1.0"
|
| 29 |
+
fi
|
| 30 |
+
else
|
| 31 |
+
git tag -a v1.0 -m "Version 1.0"
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
mkdir -p build
|
| 35 |
+
cd build
|
| 36 |
+
|
| 37 |
+
cmake ../ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
|
| 38 |
+
-DCONDA_BUILD=OFF \
|
| 39 |
+
-DCMAKE_BUILD_TYPE=Release \
|
| 40 |
+
-DLIBRARY_LIB="${PREFIX}/lib" \
|
| 41 |
+
-DLIBRARY_INC="${PREFIX}" \
|
| 42 |
+
-DCMAKE_INSTALL_PREFIX="${PREFIX}" \
|
| 43 |
+
-DPython_EXECUTABLE="${PREFIX}/bin/python"
|
| 44 |
+
|
| 45 |
+
make -j"$(nproc)"
|
| 46 |
+
make install
|
binder/environment.yml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ssp
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- python=3.12
|
| 6 |
+
- "numpy<2.0"
|
| 7 |
+
- "zarr<3"
|
| 8 |
+
- jupyterlab
|
| 9 |
+
- ipykernel
|
| 10 |
+
- cmake
|
| 11 |
+
- scipy
|
| 12 |
+
- six
|
| 13 |
+
- cython
|
| 14 |
+
- numba
|
| 15 |
+
- pillow
|
| 16 |
+
- pywavelets
|
| 17 |
+
- astra-toolbox
|
| 18 |
+
- tqdm
|
| 19 |
+
- "setuptools<82"
|
| 20 |
+
- git
|
| 21 |
+
|
| 22 |
+
# pip deps
|
| 23 |
+
- pip
|
| 24 |
+
- wheel
|
| 25 |
+
- pip:
|
| 26 |
+
- bm3d
|
| 27 |
+
- xdesign
|
| 28 |
+
- scikit-image
|
binder/postBuild
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euxo pipefail
|
| 3 |
+
|
| 4 |
+
python -m ipykernel install --user --name ssp --display-name "Python (ssp)"
|
| 5 |
+
|
| 6 |
+
# Pick the active environment prefix in Binder
|
| 7 |
+
PREFIX="${CONDA_PREFIX:-${NB_PYTHON_PREFIX:-}}"
|
| 8 |
+
if [ -z "${PREFIX}" ]; then
|
| 9 |
+
PREFIX="$(python -c 'import sys, os; print(os.path.dirname(os.path.dirname(sys.executable)))')"
|
| 10 |
+
fi
|
| 11 |
+
echo "Using PREFIX=${PREFIX}"
|
| 12 |
+
|
| 13 |
+
if [ ! -d "StochasticCIL" ]; then
|
| 14 |
+
git clone https://github.com/epapoutsellis/StochasticCIL.git
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
cd StochasticCIL
|
| 18 |
+
git fetch --all --tags
|
| 19 |
+
git checkout svrg
|
| 20 |
+
|
| 21 |
+
# Ensure annotated tag for `git describe`
|
| 22 |
+
git config user.email "binder@local"
|
| 23 |
+
git config user.name "Binder Build"
|
| 24 |
+
|
| 25 |
+
if git rev-parse -q --verify refs/tags/v1.0 >/dev/null; then
|
| 26 |
+
if [ "$(git cat-file -t v1.0)" != "tag" ]; then
|
| 27 |
+
git tag -d v1.0
|
| 28 |
+
git tag -a v1.0 -m "Version 1.0"
|
| 29 |
+
fi
|
| 30 |
+
else
|
| 31 |
+
git tag -a v1.0 -m "Version 1.0"
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
mkdir -p build
|
| 35 |
+
cd build
|
| 36 |
+
|
| 37 |
+
cmake ../ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
|
| 38 |
+
-DCONDA_BUILD=OFF \
|
| 39 |
+
-DCMAKE_BUILD_TYPE=Release \
|
| 40 |
+
-DLIBRARY_LIB="${PREFIX}/lib" \
|
| 41 |
+
-DLIBRARY_INC="${PREFIX}" \
|
| 42 |
+
-DCMAKE_INSTALL_PREFIX="${PREFIX}" \
|
| 43 |
+
-DPython_EXECUTABLE="${PREFIX}/bin/python"
|
| 44 |
+
|
| 45 |
+
make -j"$(nproc)"
|
| 46 |
+
make install
|
environment.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ssp
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- python=3.12
|
| 6 |
+
- "numpy<2.0"
|
| 7 |
+
- "zarr<3"
|
| 8 |
+
|
| 9 |
+
- cmake
|
| 10 |
+
- scipy
|
| 11 |
+
- six
|
| 12 |
+
- cython
|
| 13 |
+
- numba
|
| 14 |
+
- pillow
|
| 15 |
+
- pywavelets
|
| 16 |
+
- astra-toolbox
|
| 17 |
+
- tqdm
|
| 18 |
+
- "setuptools<82"
|
| 19 |
+
- git
|
| 20 |
+
|
| 21 |
+
# pip deps
|
| 22 |
+
- pip
|
| 23 |
+
- wheel
|
| 24 |
+
- pip:
|
| 25 |
+
- bm3d
|
| 26 |
+
- xdesign
|
| 27 |
+
- scikit-image
|
start.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
TOKEN="${JUPYTER_TOKEN:-}"
|
| 5 |
+
|
| 6 |
+
exec micromamba run -n ssp jupyter lab \
|
| 7 |
+
--allow-root \
|
| 8 |
+
--ip=0.0.0.0 --port=7860 \
|
| 9 |
+
--no-browser \
|
| 10 |
+
--ServerApp.allow_remote_access=True \
|
| 11 |
+
--ServerApp.root_dir=/work \
|
| 12 |
+
--IdentityProvider.token="${TOKEN}"
|
utils.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cil.optimisation.utilities import AlgorithmDiagnostics
|
| 2 |
+
import numpy as np
|
| 3 |
+
from bm3d import bm3d, BM3DStages
|
| 4 |
+
from cil.optimisation.functions import Function
|
| 5 |
+
|
| 6 |
+
class StoppingCriterionTime(AlgorithmDiagnostics):
|
| 7 |
+
|
| 8 |
+
def __init__(self, time):
|
| 9 |
+
|
| 10 |
+
self.time = time
|
| 11 |
+
super(StoppingCriterionTime, self).__init__(verbose=0)
|
| 12 |
+
|
| 13 |
+
self.should_stop = False
|
| 14 |
+
|
| 15 |
+
def _should_stop(self):
|
| 16 |
+
|
| 17 |
+
return self.should_stop
|
| 18 |
+
|
| 19 |
+
def __call__(self, algo):
|
| 20 |
+
|
| 21 |
+
if algo.iteration==0:
|
| 22 |
+
algo.should_stop = self._should_stop
|
| 23 |
+
|
| 24 |
+
stop_crit = np.sum(algo.timing)>self.time
|
| 25 |
+
if stop_crit:
|
| 26 |
+
self.should_stop = True
|
| 27 |
+
print("Stop at {} time {}".format(algo.iteration, np.sum(algo.timing)))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BM3DFunction(Function):
|
| 31 |
+
"""
|
| 32 |
+
PnP 'regulariser' whose proximal applies BM3D denoising.
|
| 33 |
+
|
| 34 |
+
In PnP-ISTA/FISTA we typically use a FIXED BM3D sigma (regularization strength),
|
| 35 |
+
independent of the gradient step-size tau.
|
| 36 |
+
Optionally apply damping: (1-gamma) z + gamma * BM3D(z).
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, sigma, gamma=1.0, profile="np",
|
| 40 |
+
stage_arg=BM3DStages.ALL_STAGES, positivity=True):
|
| 41 |
+
self.sigma = float(sigma) # BM3D noise parameter
|
| 42 |
+
self.gamma = float(gamma) # damping in (0,1]
|
| 43 |
+
if not (0.0 < self.gamma <= 1.0):
|
| 44 |
+
raise ValueError("gamma must be in (0,1].")
|
| 45 |
+
self.profile = profile
|
| 46 |
+
self.stage_arg = stage_arg
|
| 47 |
+
self.positivity = positivity
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
def __call__(self, x):
|
| 51 |
+
return 0.0
|
| 52 |
+
|
| 53 |
+
def convex_conjugate(self, x):
|
| 54 |
+
return 0.0
|
| 55 |
+
|
| 56 |
+
def _denoise(self, znp: np.ndarray) -> np.ndarray:
|
| 57 |
+
z = np.asarray(znp, dtype=np.float32)
|
| 58 |
+
# BM3D expects sigma as noise std (same units as the image)
|
| 59 |
+
return bm3d(z, sigma_psd=self.sigma,
|
| 60 |
+
profile=self.profile,
|
| 61 |
+
stage_arg=self.stage_arg).astype(np.float32)
|
| 62 |
+
|
| 63 |
+
def proximal(self, x, tau, out=None):
|
| 64 |
+
z = x.array.astype(np.float32, copy=False)
|
| 65 |
+
d = self._denoise(z)
|
| 66 |
+
|
| 67 |
+
# damping/relaxation (recommended if you see oscillations)
|
| 68 |
+
u = (1.0 - self.gamma) * z + self.gamma * d
|
| 69 |
+
|
| 70 |
+
if self.positivity:
|
| 71 |
+
u = np.maximum(u, 0.0)
|
| 72 |
+
|
| 73 |
+
if out is None:
|
| 74 |
+
out = x * 0.0
|
| 75 |
+
out.fill(u)
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def create_circular_mask(h, w, center=None, radius=None):
|
| 80 |
+
|
| 81 |
+
if center is None:
|
| 82 |
+
center = (int(w/2), int(h/2))
|
| 83 |
+
if radius is None:
|
| 84 |
+
radius = min(center[0], center[1], w-center[0], h-center[1])
|
| 85 |
+
|
| 86 |
+
Y, X = np.ogrid[:h, :w]
|
| 87 |
+
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
|
| 88 |
+
|
| 89 |
+
mask = dist_from_center <= radius
|
| 90 |
+
return mask
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class StoppingCriterion(AlgorithmDiagnostics):
|
| 94 |
+
def __init__(self, epsilon, epochs=None):
|
| 95 |
+
self.epsilon = epsilon
|
| 96 |
+
self.epochs = epochs
|
| 97 |
+
super().__init__(verbose=0)
|
| 98 |
+
self.should_stop = False
|
| 99 |
+
self.rse_reached = False
|
| 100 |
+
|
| 101 |
+
def _should_stop(self):
|
| 102 |
+
return self.should_stop
|
| 103 |
+
|
| 104 |
+
def __call__(self, algo):
|
| 105 |
+
|
| 106 |
+
if algo.iteration == 0:
|
| 107 |
+
algo.should_stop = self._should_stop
|
| 108 |
+
|
| 109 |
+
stop_rse = (algo.rse[-1] <= self.epsilon)
|
| 110 |
+
|
| 111 |
+
stop_epochs = False
|
| 112 |
+
if self.epochs is not None:
|
| 113 |
+
try:
|
| 114 |
+
dp = algo.f.data_passes
|
| 115 |
+
dp_last = dp[-1] if hasattr(dp, "__len__") else dp
|
| 116 |
+
stop_epochs = (dp_last > self.epochs)
|
| 117 |
+
except AttributeError:
|
| 118 |
+
stop_epochs = False
|
| 119 |
+
|
| 120 |
+
stop = stop_rse or stop_epochs
|
| 121 |
+
|
| 122 |
+
if algo.iteration < algo.max_iteration:
|
| 123 |
+
if stop:
|
| 124 |
+
self.rse_reached = stop_rse
|
| 125 |
+
self.should_stop = True
|
| 126 |
+
print(f"Accuracy reached at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")
|
| 127 |
+
else:
|
| 128 |
+
print(f"Failed to reach accuracy. Stop at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")
|