split_skip_and_play / ProxSkip.py
epaps's picture
test
de93bc1
from cil.optimisation.algorithms import Algorithm
import numpy as np
import logging
class ProxSkip(Algorithm):
r"""Proximal Skip (ProxSkip) algorithm, see "ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!†"
Parameters
----------
initial : DataContainer
Initial point for the ProxSkip algorithm.
f : Function
A smooth function with Lipschitz continuous gradient.
g : Function
A convex function with a "simple" proximal.
prob : positive :obj:`float`
Probability to skip the proximal step. If :code:`prob=1`, proximal step is used in every iteration.
step_size : positive :obj:`float`
Step size for the ProxSkip algorithm and is equal to 1./L where L is the Lipschitz constant for the gradient of f.
"""
def __init__(self, initial, f, g, step_size, prob, seed=None, **kwargs):
""" Set up of the algorithm
"""
super(ProxSkip, self).__init__(**kwargs)
self.f = f # smooth function
self.g = g # proximable
self.step_size = step_size
self.prob = prob
self.rng = np.random.default_rng(seed)
self.set_up(initial, f, g, step_size, prob, **kwargs)
self.thetas = []
self.prox_iterates = []
def set_up(self, initial, f, g, step_size, prob, **kwargs):
logging.info("{} setting up".format(self.__class__.__name__, ))
self.initial = initial[0]
self.x = initial[0].copy()
self.xhat_new = initial[0].copy()
self.x_new = initial[0].copy()
self.ht = initial[1].copy() #self.f.gradient(initial)
# self.ht = self.f.gradient(initial)
self.configured = True
# count proximal and non proximal steps
self.use_prox = 0
# self.no_use_prox = 0
logging.info("{} configured".format(self.__class__.__name__, ))
def update(self):
r""" Performs a single iteration of the ProxSkip algorithm
"""
self.f.gradient(self.x, out=self.xhat_new)
self.xhat_new -= self.ht
self.x.sapyb(1., self.xhat_new, -self.step_size, out=self.xhat_new)
theta = self.rng.random() < self.prob
if self.iteration==0:
theta = 1
self.thetas.append(theta)
if theta==1:
# print("here")
# Proximal step is used
self.g.proximal(self.xhat_new - (self.step_size/self.prob)*self.ht, self.step_size/self.prob, out=self.x_new)
self.ht.sapyb(1., (self.x_new - self.xhat_new), (self.prob/self.step_size), out=self.ht)
self.use_prox+=1
# self.prox_iterates.append(self.x.copy())
else:
# Proximal step is skipped
# print("here1")
self.x_new.fill(self.xhat_new)
def _update_previous_solution(self):
""" Swaps the references to current and previous solution based on the :func:`~Algorithm.update_previous_solution` of the base class :class:`Algorithm`.
"""
tmp = self.x_new
self.x = self.x_new
self.x = tmp
def get_output(self):
" Returns the current solution. "
return self.x
def update_objective(self):
""" Updates the objective
.. math:: f(x) + g(x)
"""
fun_g = self.g(self.x)
fun_f = self.f(self.x)
p1 = fun_f + fun_g
self.loss.append( p1 )
def proximal_evaluations(self):
prox_evals = []
for i in self.iterations:
if self.rng.random() < self.prob:
prox_evals.append(1)
else:
prox_evals.append(0)
return prox_evals