File size: 3,972 Bytes
de93bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from cil.optimisation.algorithms import Algorithm
import numpy as np
import logging


class ProxSkip(Algorithm):
    

    r"""Proximal Skip  (ProxSkip) algorithm, see "ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!†"
    
        Parameters
        ----------

        initial : DataContainer
                  Initial point for the ProxSkip algorithm. 
        f : Function
            A smooth function with Lipschitz continuous gradient.
        g : Function
            A convex function with a "simple" proximal.
        prob : positive :obj:`float`
             Probability to skip the proximal step. If :code:`prob=1`, proximal step is used in every iteration.             
        step_size : positive :obj:`float`
            Step size for the ProxSkip algorithm and is equal to 1./L where L is the  Lipschitz constant for the gradient of f.          
            
     """


    def __init__(self, initial, f, g, step_size, prob, seed=None, **kwargs):
        """ Set up of the algorithm
        """        

        super(ProxSkip, self).__init__(**kwargs)

        self.f = f # smooth function
        self.g = g # proximable
        self.step_size = step_size
        self.prob = prob
        self.rng = np.random.default_rng(seed)
        self.set_up(initial, f, g, step_size, prob, **kwargs)
        self.thetas = []
        self.prox_iterates = []
 
                  
    def set_up(self, initial, f, g, step_size, prob, **kwargs):
        
        logging.info("{} setting up".format(self.__class__.__name__, ))        
        
        self.initial = initial[0]

        self.x = initial[0].copy()   
        self.xhat_new = initial[0].copy()
        self.x_new = initial[0].copy()
        self.ht = initial[1].copy() #self.f.gradient(initial)      
        # self.ht = self.f.gradient(initial) 

        self.configured = True
        
        # count proximal and non proximal steps
        self.use_prox = 0
        # self.no_use_prox = 0

        logging.info("{} configured".format(self.__class__.__name__, ))
                                  

    def update(self):
        r""" Performs a single iteration of the ProxSkip algorithm        
        """
        
        self.f.gradient(self.x, out=self.xhat_new)
        self.xhat_new -= self.ht
        self.x.sapyb(1., self.xhat_new, -self.step_size, out=self.xhat_new)

        theta = self.rng.random() < self.prob 
        if self.iteration==0:
            theta = 1
        self.thetas.append(theta)

        if theta==1:
            # print("here")
            # Proximal step is used
            self.g.proximal(self.xhat_new - (self.step_size/self.prob)*self.ht, self.step_size/self.prob, out=self.x_new)
            self.ht.sapyb(1., (self.x_new - self.xhat_new), (self.prob/self.step_size), out=self.ht)            
            self.use_prox+=1
            # self.prox_iterates.append(self.x.copy())
        else:
            # Proximal step is skipped
            # print("here1")
            self.x_new.fill(self.xhat_new)

    def _update_previous_solution(self):
        """ Swaps the references to current and previous solution based on the :func:`~Algorithm.update_previous_solution` of the base class :class:`Algorithm`.
        """
        tmp = self.x_new
        self.x = self.x_new
        self.x = tmp

    def get_output(self):
        " Returns the current solution. "
        return self.x        
      
                  
    def update_objective(self):

        """ Updates the objective

        .. math:: f(x) + g(x)

        """        

        fun_g = self.g(self.x)
        fun_f = self.f(self.x)
        p1 = fun_f + fun_g
        self.loss.append( p1 )

    def proximal_evaluations(self):

        prox_evals = []

        for i in self.iterations:
            if self.rng.random() < self.prob:
                prox_evals.append(1)
            else:
                prox_evals.append(0)
        return prox_evals