File size: 5,342 Bytes
1be5b40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import math
import re
import flax.linen as nn
import flax.struct as struct
import jax.numpy as jnp
import openpi.shared.array_typing as at
@struct.dataclass
class LoRAConfig:
"""Configuration for LoRA."""
# LoRA rank.
rank: int
# LoRA scaling factor.
alpha: float = 1.0
# Initialization function for LoRA parameters.
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
rslora: bool = False
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
axes: tuple[int, int] = (-2, -1)
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
label: str = "L"
@property
def scaling_value(self) -> float:
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
class Einsum(nn.Module):
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
# Shape of the weight.
shape: tuple[int, ...]
# Initialization function for the weight.
init_fn: nn.initializers.Initializer = nn.initializers.zeros
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w = self.param("w", self.init_fn, self.shape)
if config := self.lora_config:
# Setup LoRA parameters.
shape_a, shape_b = list(self.shape), list(self.shape)
shape_a[config.axes[1]] = config.rank
shape_b[config.axes[0]] = config.rank
self.w_a = self.param("lora_a", config.init_fn, shape_a)
self.w_b = self.param("lora_b", config.init_fn, shape_b)
@nn.compact
def __call__(self, eqn: str, x):
dtype = x.dtype # original dtype, could be half-precision
result = jnp.einsum(eqn, x, self.w.astype(dtype))
if config := self.lora_config:
eqn_a, eqn_b = self._make_lora_eqns(eqn)
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
result = result + lora * config.scaling_value
return result
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
if "L" in eqn:
raise ValueError(f"L already in eqn: {eqn}")
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
raise ValueError(f"Unsupported einsum eqn: {eqn}")
lhs, rhs, out = m.groups()
assert self.lora_config is not None
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
label = self.lora_config.label
a_rhs = rhs.replace(b_label, label)
a_out = out.replace(b_label, label)
eqn_a = f"{lhs},{a_rhs}->{a_out}"
b_rhs = rhs.replace(a_label, label)
eqn_b = f"{a_out},{b_rhs}->{out}"
return eqn_a, eqn_b
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w_gating = self.param(
"gating_einsum",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
(2, self.features, self.hidden_dim),
)
self.w_linear = self.param(
"linear",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
(self.hidden_dim, self.features),
)
self.w_gating_lora = None
self.w_linear_lora = None
if self.lora_config:
# Setup LoRA parameters.
# TODO: follow up with a simplified init_fn api.
self.w_gating_lora = (
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
self.param(
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
),
)
self.w_linear_lora = (
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
)
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
ff_gate = self._dot(
x,
self.w_gating[0],
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
)
gate_value = nn.gelu(ff_gate)
ff1 = self._dot(
x,
self.w_gating[1],
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
)
activations = gate_value * ff1
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
assert outputs.dtype == dtype
return outputs
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
base = jnp.dot(x, w.astype(x.dtype))
if lora_weights is None:
return base
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
|