File size: 2,807 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn

################################################################################
# Mixin class for handling selective/orthogonal PGD variants
################################################################################


class SelectiveOrthogonalPGDMixin(object):

    def __init__(self, **kwargs):
        pass

    @staticmethod
    def _dot(x1: torch.Tensor, x2: torch.Tensor):
        """

        Compute batch dot product along final dimension

        """
        return (x1*x2).sum(-1, keepdim=True)

    def _project_orthogonal(self, x1: torch.Tensor, x2: torch.Tensor):
        """

        Compute projection component of x1 along x2. For projection

        onto zero vector, return zero vector

        """
        return x2 * (self._dot(x1, x2) / self._dot(x2, x2).clamp_min(1e-12))

    def _component_orthogonal(self,

                              x1: torch.Tensor,

                              x2: torch.Tensor,

                              x3: torch.Tensor):
        """

        Compute component of x1 approximately orthogonal to x2 and x3

        """
        return x1 - self._project_orthogonal(
            x1, x2 - self._project_orthogonal(x2, x3)
        ) - self._project_orthogonal(x1, x3)

    @staticmethod
    def _retrieve_parameter_gradients(m: nn.Module):
        """

        Retrieve all trainable parameters of a nn.Module object

        :return: tensor of shape (n_parameters,)

        """

        flattened_grad = []

        for param in m.parameters():
            if param.requires_grad:
                if param.grad is None:
                    flattened_grad.append(
                        torch.zeros_like(param).detach().flatten()
                    )
                else:
                    flattened_grad.append(param.grad.detach().flatten())

        return torch.cat(flattened_grad, dim=-1)

    @staticmethod
    def _set_parameter_gradients(flattened_grad: torch.Tensor, m: nn.Module):
        """

        Set gradient attributes of trainable parameters of a nn.Module object

        :param params: tensor of shape (n_parameters,)

        """

        # check that flattened gradients have valid shape
        prod = sum(
            [p.shape.numel() for p in m.parameters() if p.requires_grad]
        )

        assert flattened_grad.ndim <= 1
        assert flattened_grad.numel() == prod

        idx = 0
        for param in m.parameters():
            if param.requires_grad:
                param_length = param.shape.numel()
                grad = flattened_grad[idx:idx + param_length].reshape(
                    param.shape
                ).detach()
                param.grad = grad
                idx += param_length