Spaces:
Running
Running
| from cil.optimisation.utilities import AlgorithmDiagnostics | |
| import numpy as np | |
| from bm3d import bm3d, BM3DStages | |
| from cil.optimisation.functions import Function | |
| class StoppingCriterionTime(AlgorithmDiagnostics): | |
| def __init__(self, time): | |
| self.time = time | |
| super(StoppingCriterionTime, self).__init__(verbose=0) | |
| self.should_stop = False | |
| def _should_stop(self): | |
| return self.should_stop | |
| def __call__(self, algo): | |
| if algo.iteration==0: | |
| algo.should_stop = self._should_stop | |
| stop_crit = np.sum(algo.timing)>self.time | |
| if stop_crit: | |
| self.should_stop = True | |
| print("Stop at {} time {}".format(algo.iteration, np.sum(algo.timing))) | |
| class BM3DFunction(Function): | |
| """ | |
| PnP 'regulariser' whose proximal applies BM3D denoising. | |
| In PnP-ISTA/FISTA we typically use a FIXED BM3D sigma (regularization strength), | |
| independent of the gradient step-size tau. | |
| Optionally apply damping: (1-gamma) z + gamma * BM3D(z). | |
| """ | |
| def __init__(self, sigma, gamma=1.0, profile="np", | |
| stage_arg=BM3DStages.ALL_STAGES, positivity=True): | |
| self.sigma = float(sigma) # BM3D noise parameter | |
| self.gamma = float(gamma) # damping in (0,1] | |
| if not (0.0 < self.gamma <= 1.0): | |
| raise ValueError("gamma must be in (0,1].") | |
| self.profile = profile | |
| self.stage_arg = stage_arg | |
| self.positivity = positivity | |
| super().__init__() | |
| def __call__(self, x): | |
| return 0.0 | |
| def convex_conjugate(self, x): | |
| return 0.0 | |
| def _denoise(self, znp: np.ndarray) -> np.ndarray: | |
| z = np.asarray(znp, dtype=np.float32) | |
| # BM3D expects sigma as noise std (same units as the image) | |
| return bm3d(z, sigma_psd=self.sigma, | |
| profile=self.profile, | |
| stage_arg=self.stage_arg).astype(np.float32) | |
| def proximal(self, x, tau, out=None): | |
| z = x.array.astype(np.float32, copy=False) | |
| d = self._denoise(z) | |
| # damping/relaxation (recommended if you see oscillations) | |
| u = (1.0 - self.gamma) * z + self.gamma * d | |
| if self.positivity: | |
| u = np.maximum(u, 0.0) | |
| if out is None: | |
| out = x * 0.0 | |
| out.fill(u) | |
| return out | |
| def create_circular_mask(h, w, center=None, radius=None): | |
| if center is None: | |
| center = (int(w/2), int(h/2)) | |
| if radius is None: | |
| radius = min(center[0], center[1], w-center[0], h-center[1]) | |
| Y, X = np.ogrid[:h, :w] | |
| dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2) | |
| mask = dist_from_center <= radius | |
| return mask | |
| class StoppingCriterion(AlgorithmDiagnostics): | |
| def __init__(self, epsilon, epochs=None): | |
| self.epsilon = epsilon | |
| self.epochs = epochs | |
| super().__init__(verbose=0) | |
| self.should_stop = False | |
| self.rse_reached = False | |
| def _should_stop(self): | |
| return self.should_stop | |
| def __call__(self, algo): | |
| if algo.iteration == 0: | |
| algo.should_stop = self._should_stop | |
| stop_rse = (algo.rse[-1] <= self.epsilon) | |
| stop_epochs = False | |
| if self.epochs is not None: | |
| try: | |
| dp = algo.f.data_passes | |
| dp_last = dp[-1] if hasattr(dp, "__len__") else dp | |
| stop_epochs = (dp_last > self.epochs) | |
| except AttributeError: | |
| stop_epochs = False | |
| stop = stop_rse or stop_epochs | |
| if algo.iteration < algo.max_iteration: | |
| if stop: | |
| self.rse_reached = stop_rse | |
| self.should_stop = True | |
| print(f"Accuracy reached at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}") | |
| else: | |
| print(f"Failed to reach accuracy. Stop at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}") | |