|
|
import math |
|
|
import re |
|
|
|
|
|
import flax.linen as nn |
|
|
import flax.struct as struct |
|
|
import jax.numpy as jnp |
|
|
|
|
|
import openpi.shared.array_typing as at |
|
|
|
|
|
|
|
|
@struct.dataclass |
|
|
class LoRAConfig: |
|
|
"""Configuration for LoRA.""" |
|
|
|
|
|
|
|
|
rank: int |
|
|
|
|
|
alpha: float = 1.0 |
|
|
|
|
|
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01) |
|
|
|
|
|
rslora: bool = False |
|
|
|
|
|
axes: tuple[int, int] = (-2, -1) |
|
|
|
|
|
label: str = "L" |
|
|
|
|
|
@property |
|
|
def scaling_value(self) -> float: |
|
|
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank |
|
|
|
|
|
|
|
|
class Einsum(nn.Module): |
|
|
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum.""" |
|
|
|
|
|
|
|
|
shape: tuple[int, ...] |
|
|
|
|
|
init_fn: nn.initializers.Initializer = nn.initializers.zeros |
|
|
|
|
|
lora_config: LoRAConfig | None = None |
|
|
|
|
|
def setup(self): |
|
|
self.w = self.param("w", self.init_fn, self.shape) |
|
|
|
|
|
if config := self.lora_config: |
|
|
|
|
|
shape_a, shape_b = list(self.shape), list(self.shape) |
|
|
shape_a[config.axes[1]] = config.rank |
|
|
shape_b[config.axes[0]] = config.rank |
|
|
self.w_a = self.param("lora_a", config.init_fn, shape_a) |
|
|
self.w_b = self.param("lora_b", config.init_fn, shape_b) |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, eqn: str, x): |
|
|
dtype = x.dtype |
|
|
result = jnp.einsum(eqn, x, self.w.astype(dtype)) |
|
|
|
|
|
if config := self.lora_config: |
|
|
eqn_a, eqn_b = self._make_lora_eqns(eqn) |
|
|
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype)) |
|
|
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype)) |
|
|
result = result + lora * config.scaling_value |
|
|
|
|
|
return result |
|
|
|
|
|
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]: |
|
|
if "L" in eqn: |
|
|
raise ValueError(f"L already in eqn: {eqn}") |
|
|
if not (m := re.match("(.*),(.*)->(.*)", eqn)): |
|
|
raise ValueError(f"Unsupported einsum eqn: {eqn}") |
|
|
lhs, rhs, out = m.groups() |
|
|
|
|
|
assert self.lora_config is not None |
|
|
a_label, b_label = (rhs[x] for x in self.lora_config.axes) |
|
|
label = self.lora_config.label |
|
|
|
|
|
a_rhs = rhs.replace(b_label, label) |
|
|
a_out = out.replace(b_label, label) |
|
|
eqn_a = f"{lhs},{a_rhs}->{a_out}" |
|
|
|
|
|
b_rhs = rhs.replace(a_label, label) |
|
|
eqn_b = f"{a_out},{b_rhs}->{out}" |
|
|
|
|
|
return eqn_a, eqn_b |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
"""Feed forward module.""" |
|
|
|
|
|
features: int |
|
|
hidden_dim: int |
|
|
|
|
|
lora_config: LoRAConfig | None = None |
|
|
|
|
|
def setup(self): |
|
|
self.w_gating = self.param( |
|
|
"gating_einsum", |
|
|
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), |
|
|
(2, self.features, self.hidden_dim), |
|
|
) |
|
|
self.w_linear = self.param( |
|
|
"linear", |
|
|
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), |
|
|
(self.hidden_dim, self.features), |
|
|
) |
|
|
self.w_gating_lora = None |
|
|
self.w_linear_lora = None |
|
|
if self.lora_config: |
|
|
|
|
|
|
|
|
self.w_gating_lora = ( |
|
|
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)), |
|
|
self.param( |
|
|
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim) |
|
|
), |
|
|
) |
|
|
self.w_linear_lora = ( |
|
|
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)), |
|
|
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)), |
|
|
) |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
dtype = x.dtype |
|
|
ff_gate = self._dot( |
|
|
x, |
|
|
self.w_gating[0], |
|
|
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]), |
|
|
) |
|
|
gate_value = nn.gelu(ff_gate) |
|
|
|
|
|
ff1 = self._dot( |
|
|
x, |
|
|
self.w_gating[1], |
|
|
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]), |
|
|
) |
|
|
activations = gate_value * ff1 |
|
|
|
|
|
outputs = self._dot(activations, self.w_linear, self.w_linear_lora) |
|
|
assert outputs.dtype == dtype |
|
|
return outputs |
|
|
|
|
|
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array: |
|
|
base = jnp.dot(x, w.astype(x.dtype)) |
|
|
if lora_weights is None: |
|
|
return base |
|
|
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype)) |
|
|
|