from cil.optimisation.algorithms import Algorithm import numpy as np import logging class ProxSkip(Algorithm): r"""Proximal Skip (ProxSkip) algorithm, see "ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!†" Parameters ---------- initial : DataContainer Initial point for the ProxSkip algorithm. f : Function A smooth function with Lipschitz continuous gradient. g : Function A convex function with a "simple" proximal. prob : positive :obj:`float` Probability to skip the proximal step. If :code:`prob=1`, proximal step is used in every iteration. step_size : positive :obj:`float` Step size for the ProxSkip algorithm and is equal to 1./L where L is the Lipschitz constant for the gradient of f. """ def __init__(self, initial, f, g, step_size, prob, seed=None, **kwargs): """ Set up of the algorithm """ super(ProxSkip, self).__init__(**kwargs) self.f = f # smooth function self.g = g # proximable self.step_size = step_size self.prob = prob self.rng = np.random.default_rng(seed) self.set_up(initial, f, g, step_size, prob, **kwargs) self.thetas = [] self.prox_iterates = [] def set_up(self, initial, f, g, step_size, prob, **kwargs): logging.info("{} setting up".format(self.__class__.__name__, )) self.initial = initial[0] self.x = initial[0].copy() self.xhat_new = initial[0].copy() self.x_new = initial[0].copy() self.ht = initial[1].copy() #self.f.gradient(initial) # self.ht = self.f.gradient(initial) self.configured = True # count proximal and non proximal steps self.use_prox = 0 # self.no_use_prox = 0 logging.info("{} configured".format(self.__class__.__name__, )) def update(self): r""" Performs a single iteration of the ProxSkip algorithm """ self.f.gradient(self.x, out=self.xhat_new) self.xhat_new -= self.ht self.x.sapyb(1., self.xhat_new, -self.step_size, out=self.xhat_new) theta = self.rng.random() < self.prob if self.iteration==0: theta = 1 self.thetas.append(theta) if theta==1: # print("here") # Proximal step is used self.g.proximal(self.xhat_new - (self.step_size/self.prob)*self.ht, self.step_size/self.prob, out=self.x_new) self.ht.sapyb(1., (self.x_new - self.xhat_new), (self.prob/self.step_size), out=self.ht) self.use_prox+=1 # self.prox_iterates.append(self.x.copy()) else: # Proximal step is skipped # print("here1") self.x_new.fill(self.xhat_new) def _update_previous_solution(self): """ Swaps the references to current and previous solution based on the :func:`~Algorithm.update_previous_solution` of the base class :class:`Algorithm`. """ tmp = self.x_new self.x = self.x_new self.x = tmp def get_output(self): " Returns the current solution. " return self.x def update_objective(self): """ Updates the objective .. math:: f(x) + g(x) """ fun_g = self.g(self.x) fun_f = self.f(self.x) p1 = fun_f + fun_g self.loss.append( p1 ) def proximal_evaluations(self): prox_evals = [] for i in self.iterations: if self.rng.random() < self.prob: prox_evals.append(1) else: prox_evals.append(0) return prox_evals