File size: 4,025 Bytes
de93bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from cil.optimisation.utilities import AlgorithmDiagnostics
import numpy as np
from bm3d import bm3d, BM3DStages
from cil.optimisation.functions import Function

class StoppingCriterionTime(AlgorithmDiagnostics):

    def __init__(self, time):

        self.time = time
        super(StoppingCriterionTime, self).__init__(verbose=0)

        self.should_stop = False
        
    def _should_stop(self):

        return self.should_stop       
            
    def __call__(self, algo):

        if algo.iteration==0:
            algo.should_stop = self._should_stop  

        stop_crit = np.sum(algo.timing)>self.time
        if stop_crit:
            self.should_stop = True                 
            print("Stop at {} time {}".format(algo.iteration, np.sum(algo.timing)))


class BM3DFunction(Function):
    """
    PnP 'regulariser' whose proximal applies BM3D denoising.

    In PnP-ISTA/FISTA we typically use a FIXED BM3D sigma (regularization strength),
    independent of the gradient step-size tau.
    Optionally apply damping: (1-gamma) z + gamma * BM3D(z).
    """

    def __init__(self, sigma, gamma=1.0, profile="np",
                 stage_arg=BM3DStages.ALL_STAGES, positivity=True):
        self.sigma = float(sigma)      # BM3D noise parameter
        self.gamma = float(gamma)      # damping in (0,1]
        if not (0.0 < self.gamma <= 1.0):
            raise ValueError("gamma must be in (0,1].")
        self.profile = profile
        self.stage_arg = stage_arg
        self.positivity = positivity
        super().__init__()

    def __call__(self, x):
        return 0.0

    def convex_conjugate(self, x):
        return 0.0

    def _denoise(self, znp: np.ndarray) -> np.ndarray:
        z = np.asarray(znp, dtype=np.float32)
        # BM3D expects sigma as noise std (same units as the image)
        return bm3d(z, sigma_psd=self.sigma,
                    profile=self.profile,
                    stage_arg=self.stage_arg).astype(np.float32)

    def proximal(self, x, tau, out=None):
        z = x.array.astype(np.float32, copy=False)
        d = self._denoise(z)

        # damping/relaxation (recommended if you see oscillations)
        u = (1.0 - self.gamma) * z + self.gamma * d

        if self.positivity:
            u = np.maximum(u, 0.0)

        if out is None:
            out = x * 0.0
        out.fill(u)
        return out


def create_circular_mask(h, w, center=None, radius=None):

    if center is None: 
        center = (int(w/2), int(h/2))
    if radius is None: 
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask


class StoppingCriterion(AlgorithmDiagnostics):
    def __init__(self, epsilon, epochs=None):
        self.epsilon = epsilon
        self.epochs = epochs
        super().__init__(verbose=0)
        self.should_stop = False
        self.rse_reached = False

    def _should_stop(self):
        return self.should_stop

    def __call__(self, algo):

        if algo.iteration == 0:
            algo.should_stop = self._should_stop

        stop_rse = (algo.rse[-1] <= self.epsilon)

        stop_epochs = False
        if self.epochs is not None:
            try:
                dp = algo.f.data_passes
                dp_last = dp[-1] if hasattr(dp, "__len__") else dp
                stop_epochs = (dp_last > self.epochs)
            except AttributeError:
                stop_epochs = False  
                
        stop = stop_rse or stop_epochs

        if algo.iteration < algo.max_iteration:
            if stop:
                self.rse_reached = stop_rse
                self.should_stop = True
                print(f"Accuracy reached at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")
        else:
            print(f"Failed to reach accuracy. Stop at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")