pi0 / src /openpi /models /lora.py
s3y's picture
Upload folder using huggingface_hub
1be5b40 verified
import math
import re
import flax.linen as nn
import flax.struct as struct
import jax.numpy as jnp
import openpi.shared.array_typing as at
@struct.dataclass
class LoRAConfig:
"""Configuration for LoRA."""
# LoRA rank.
rank: int
# LoRA scaling factor.
alpha: float = 1.0
# Initialization function for LoRA parameters.
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
rslora: bool = False
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
axes: tuple[int, int] = (-2, -1)
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
label: str = "L"
@property
def scaling_value(self) -> float:
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
class Einsum(nn.Module):
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
# Shape of the weight.
shape: tuple[int, ...]
# Initialization function for the weight.
init_fn: nn.initializers.Initializer = nn.initializers.zeros
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w = self.param("w", self.init_fn, self.shape)
if config := self.lora_config:
# Setup LoRA parameters.
shape_a, shape_b = list(self.shape), list(self.shape)
shape_a[config.axes[1]] = config.rank
shape_b[config.axes[0]] = config.rank
self.w_a = self.param("lora_a", config.init_fn, shape_a)
self.w_b = self.param("lora_b", config.init_fn, shape_b)
@nn.compact
def __call__(self, eqn: str, x):
dtype = x.dtype # original dtype, could be half-precision
result = jnp.einsum(eqn, x, self.w.astype(dtype))
if config := self.lora_config:
eqn_a, eqn_b = self._make_lora_eqns(eqn)
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
result = result + lora * config.scaling_value
return result
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
if "L" in eqn:
raise ValueError(f"L already in eqn: {eqn}")
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
raise ValueError(f"Unsupported einsum eqn: {eqn}")
lhs, rhs, out = m.groups()
assert self.lora_config is not None
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
label = self.lora_config.label
a_rhs = rhs.replace(b_label, label)
a_out = out.replace(b_label, label)
eqn_a = f"{lhs},{a_rhs}->{a_out}"
b_rhs = rhs.replace(a_label, label)
eqn_b = f"{a_out},{b_rhs}->{out}"
return eqn_a, eqn_b
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w_gating = self.param(
"gating_einsum",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
(2, self.features, self.hidden_dim),
)
self.w_linear = self.param(
"linear",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
(self.hidden_dim, self.features),
)
self.w_gating_lora = None
self.w_linear_lora = None
if self.lora_config:
# Setup LoRA parameters.
# TODO: follow up with a simplified init_fn api.
self.w_gating_lora = (
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
self.param(
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
),
)
self.w_linear_lora = (
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
)
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
ff_gate = self._dot(
x,
self.w_gating[0],
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
)
gate_value = nn.gelu(ff_gate)
ff1 = self._dot(
x,
self.w_gating[1],
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
)
activations = gate_value * ff1
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
assert outputs.dtype == dtype
return outputs
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
base = jnp.dot(x, w.astype(x.dtype))
if lora_weights is None:
return base
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))