|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytest |
|
|
from torch import nn |
|
|
|
|
|
from nemo.lightning.io.state import StateDictTransform, TransformCTX, state_transform |
|
|
|
|
|
|
|
|
class TestStateDictTransform: |
|
|
""" |
|
|
Tests for the StateDictTransform functionality. |
|
|
""" |
|
|
|
|
|
@pytest.fixture |
|
|
def mock_ctx(self): |
|
|
""" |
|
|
Provides a mock transformation context with predefined source and target states. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
TransformCTX: A context object with source and target states. |
|
|
""" |
|
|
source_state = { |
|
|
"model.layers.0.self_attn.q_proj.weight": 1, |
|
|
"model.layers.0.self_attn.k_proj.weight": 2, |
|
|
"model.layers.0.self_attn.v_proj.weight": 3, |
|
|
"model.layers.0.mlp.experts.0.gate_proj.weight": 4, |
|
|
"model.layers.0.mlp.experts.0.up_proj.weight": 5, |
|
|
"model.layers.0.mlp.experts.0.down_proj.weight": 8, |
|
|
"model.layers.0.mlp.experts.1.gate_proj.weight": 6, |
|
|
"model.layers.0.mlp.experts.1.up_proj.weight": 7, |
|
|
"model.layers.0.mlp.experts.1.down_proj.weight": 8, |
|
|
"model.layers.1.self_attn.q_proj.weight": 2, |
|
|
"model.layers.1.self_attn.k_proj.weight": 3, |
|
|
"model.layers.1.self_attn.v_proj.weight": 4, |
|
|
"model.layers.1.mlp.experts.0.gate_proj.weight": 5, |
|
|
"model.layers.1.mlp.experts.0.up_proj.weight": 6, |
|
|
"model.layers.1.mlp.experts.0.down_proj.weight": 9, |
|
|
"model.layers.1.mlp.experts.1.gate_proj.weight": 7, |
|
|
"model.layers.1.mlp.experts.1.up_proj.weight": 8, |
|
|
"model.layers.1.mlp.experts.1.down_proj.weight": 9, |
|
|
} |
|
|
target_state = { |
|
|
"decoder.layers.0.self_attention.linear_qkv.weight": -1, |
|
|
"decoder.layers.0.self_attention.linear_proj.weight": -1, |
|
|
"decoder.layers.0.mlp.experts.linear_fc1.weight0": -1, |
|
|
"decoder.layers.0.mlp.experts.linear_fc1.weight1": -1, |
|
|
"decoder.layers.0.mlp.experts.linear_fc2.weight": -1, |
|
|
"decoder.layers.1.self_attention.linear_qkv.weight": -1, |
|
|
"decoder.layers.1.self_attention.linear_proj.weight": -1, |
|
|
"decoder.layers.1.mlp.experts.linear_fc1.weight0": -1, |
|
|
"decoder.layers.1.mlp.experts.linear_fc1.weight1": -1, |
|
|
"decoder.layers.1.mlp.experts.linear_fc2.weight": -1, |
|
|
} |
|
|
ctx = TransformCTX( |
|
|
source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state |
|
|
) |
|
|
return ctx |
|
|
|
|
|
@pytest.fixture |
|
|
def mock_multi_target_ctx(self): |
|
|
""" |
|
|
Provides a mock transformation context with a source state that matches the expected source_key |
|
|
and a target state prepared with initial values for the expected target_keys. |
|
|
""" |
|
|
source_state = { |
|
|
"decoder.layers.0.mlp.linear_fc1.weight": 1, |
|
|
"decoder.layers.1.mlp.linear_fc1.weight": 2, |
|
|
"decoder.layers.2.mlp.experts.linear_fc1.weight0": 23, |
|
|
"decoder.layers.2.mlp.experts.linear_fc1.weight1": 45, |
|
|
"decoder.layers.3.mlp.experts.linear_fc1.weight0": 34, |
|
|
"decoder.layers.3.mlp.experts.linear_fc1.weight1": 56, |
|
|
} |
|
|
|
|
|
target_state = { |
|
|
"model.layers.0.mlp.gate_proj.weight": -1, |
|
|
"model.layers.0.mlp.up_proj.weight": -1, |
|
|
"model.layers.1.mlp.gate_proj.weight": -1, |
|
|
"model.layers.1.mlp.up_proj.weight": -1, |
|
|
"model.layers.2.mlp.experts.0.gate_proj.weight": -1, |
|
|
"model.layers.2.mlp.experts.0.up_proj.weight": -1, |
|
|
"model.layers.2.mlp.experts.1.gate_proj.weight": -1, |
|
|
"model.layers.2.mlp.experts.1.up_proj.weight": -1, |
|
|
"model.layers.3.mlp.experts.0.gate_proj.weight": -1, |
|
|
"model.layers.3.mlp.experts.0.up_proj.weight": -1, |
|
|
"model.layers.3.mlp.experts.1.gate_proj.weight": -1, |
|
|
"model.layers.3.mlp.experts.1.up_proj.weight": -1, |
|
|
} |
|
|
ctx = TransformCTX( |
|
|
source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state |
|
|
) |
|
|
return ctx |
|
|
|
|
|
def test_transform_with_single_source_single_target(self, mock_ctx): |
|
|
""" |
|
|
Test transformation when a single source and target key is specified. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key="model.layers.*.mlp.experts.0.down_proj.weight", |
|
|
target_key="decoder.layers.*.mlp.experts.linear_fc2.weight", |
|
|
transform=lambda ctx, x: x * 100, |
|
|
) |
|
|
transform(mock_ctx) |
|
|
assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc2.weight"] == 800 |
|
|
assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc2.weight"] == 900 |
|
|
|
|
|
def test_transform_with_multiple_sources(self, mock_ctx): |
|
|
""" |
|
|
Test transformation when multiple source keys are specified. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key=( |
|
|
"model.layers.*.self_attn.q_proj.weight", |
|
|
"model.layers.*.self_attn.k_proj.weight", |
|
|
"model.layers.*.self_attn.v_proj.weight", |
|
|
), |
|
|
target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
|
|
transform=lambda ctx, q, k, v: q * 100 + k * 10 + v, |
|
|
) |
|
|
transform(mock_ctx) |
|
|
assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 123 |
|
|
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 234 |
|
|
|
|
|
def test_transform_with_multiple_mapped_sources(self, mock_ctx): |
|
|
""" |
|
|
Test transformation with a dictionary mapping for source keys. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key={ |
|
|
"q": "model.layers.*.self_attn.q_proj.weight", |
|
|
"k": "model.layers.*.self_attn.k_proj.weight", |
|
|
"v": "model.layers.*.self_attn.v_proj.weight", |
|
|
}, |
|
|
target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
|
|
transform=lambda ctx, q, k, v: q * 100 + k * 10 + v, |
|
|
) |
|
|
transform(mock_ctx) |
|
|
assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 123 |
|
|
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 234 |
|
|
|
|
|
def test_transform_with_variable_arguments(self, mock_ctx): |
|
|
""" |
|
|
Test transformation with a wildcard pattern and variable arguments. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key="model.layers.*.self_attn.*_proj.weight", |
|
|
target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
|
|
transform=lambda ctx, *args: sum(args), |
|
|
) |
|
|
transform(mock_ctx) |
|
|
assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 6 |
|
|
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 9 |
|
|
|
|
|
def test_transform_with_no_matching_source(self, mock_ctx): |
|
|
""" |
|
|
Test transformation when no source keys match the pattern. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key="non.existent.pattern", |
|
|
target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
|
|
transform=lambda ctx, *args: sum(args), |
|
|
) |
|
|
with pytest.raises(ValueError): |
|
|
transform(mock_ctx) |
|
|
|
|
|
def test_transform_with_multiple_targets(self, mock_multi_target_ctx): |
|
|
""" |
|
|
Test transformation where the target_key is a tuple and the transform function |
|
|
returns multiple values that are then unrolled to these target keys. |
|
|
""" |
|
|
|
|
|
|
|
|
def split_transform(ctx, x): |
|
|
return x - 1, x + 1 |
|
|
|
|
|
|
|
|
transform = StateDictTransform( |
|
|
source_key="decoder.layers.*.mlp.linear_fc1.weight", |
|
|
target_key=( |
|
|
"model.layers.*.mlp.gate_proj.weight", |
|
|
"model.layers.*.mlp.up_proj.weight", |
|
|
), |
|
|
transform=split_transform, |
|
|
) |
|
|
transform(mock_multi_target_ctx) |
|
|
|
|
|
|
|
|
assert mock_multi_target_ctx.target_state["model.layers.0.mlp.gate_proj.weight"] == 0 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.0.mlp.up_proj.weight"] == 2 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.1.mlp.gate_proj.weight"] == 1 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.1.mlp.up_proj.weight"] == 3 |
|
|
|
|
|
def test_transform_with_multiple_sources_multiple_wildcards(self, mock_ctx): |
|
|
""" |
|
|
Test transformation when multiple source keys are specified, each with more than 1 wildcard. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key=( |
|
|
"model.layers.*.mlp.experts.*.gate_proj.weight", |
|
|
"model.layers.*.mlp.experts.*.up_proj.weight", |
|
|
), |
|
|
target_key="decoder.layers.*.mlp.experts.linear_fc1.weight*", |
|
|
transform=lambda ctx, gate, up: gate * 10 + up, |
|
|
) |
|
|
transform(mock_ctx) |
|
|
assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc1.weight0"] == 45 |
|
|
assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc1.weight1"] == 67 |
|
|
assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc1.weight0"] == 56 |
|
|
assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc1.weight1"] == 78 |
|
|
|
|
|
def test_transform_with_multiple_targets_multiple_wildcards(self, mock_multi_target_ctx): |
|
|
""" |
|
|
Test transformation when multiple target keys are specified, each with more than 1 wildcard. |
|
|
""" |
|
|
|
|
|
def split_transform(ctx, x): |
|
|
return x // 10, x % 10 |
|
|
|
|
|
transform = StateDictTransform( |
|
|
source_key="decoder.layers.*.mlp.experts.linear_fc1.weight*", |
|
|
target_key=( |
|
|
"model.layers.*.mlp.experts.*.gate_proj.weight", |
|
|
"model.layers.*.mlp.experts.*.up_proj.weight", |
|
|
), |
|
|
transform=split_transform, |
|
|
) |
|
|
|
|
|
transform(mock_multi_target_ctx) |
|
|
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.0.gate_proj.weight"] == 2 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.0.up_proj.weight"] == 3 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.1.gate_proj.weight"] == 4 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.1.up_proj.weight"] == 5 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.0.gate_proj.weight"] == 3 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.0.up_proj.weight"] == 4 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.1.gate_proj.weight"] == 5 |
|
|
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.1.up_proj.weight"] == 6 |
|
|
|
|
|
def test_transform_with_no_matching_target(self, mock_ctx): |
|
|
""" |
|
|
Test transformation when no source keys match the pattern. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key="model.layers.*.mlp.experts.0.down_proj.weight", |
|
|
target_key="non.existent.pattern", |
|
|
transform=lambda ctx, *args: sum(args), |
|
|
) |
|
|
with pytest.raises(ValueError): |
|
|
transform(mock_ctx) |
|
|
|
|
|
def test_transform_with_invalid_transform_function(self, mock_ctx): |
|
|
""" |
|
|
Test transformation with a transform function that does not match expected signature. |
|
|
""" |
|
|
transform = StateDictTransform( |
|
|
source_key="model.layers.*.self_attn.q_proj.weight", |
|
|
target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
|
|
transform=lambda ctx: 0, |
|
|
) |
|
|
with pytest.raises(ValueError): |
|
|
transform(mock_ctx) |
|
|
|
|
|
|
|
|
class TestStateTransformDecorator: |
|
|
""" |
|
|
Tests for the @state_transform decorator functionality. |
|
|
""" |
|
|
|
|
|
@pytest.fixture |
|
|
def mock_ctx(self): |
|
|
""" |
|
|
Provides a mock transformation context with predefined source and target states. |
|
|
""" |
|
|
source_state = { |
|
|
'model.layers.1.self_attn.q_proj.weight': 1, |
|
|
'model.layers.1.self_attn.k_proj.weight': 2, |
|
|
'model.layers.1.self_attn.v_proj.weight': 3, |
|
|
} |
|
|
|
|
|
target_state = { |
|
|
"decoder.layers.1.self_attention.linear_q.weight": 0, |
|
|
"decoder.layers.1.self_attention.linear_k.weight": 0, |
|
|
"decoder.layers.1.self_attention.linear_v.weight": 0, |
|
|
} |
|
|
ctx = TransformCTX( |
|
|
source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state |
|
|
) |
|
|
return ctx |
|
|
|
|
|
def test_single_transform(self, mock_ctx): |
|
|
""" |
|
|
Test the @state_transform decorator with a single source and target key. |
|
|
""" |
|
|
|
|
|
single_transform(mock_ctx) |
|
|
|
|
|
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 11 |
|
|
|
|
|
def test_multiple_outputs_transform(self, mock_ctx): |
|
|
""" |
|
|
Test the @state_transform decorator with a single source key and multiple target keys. |
|
|
""" |
|
|
|
|
|
multiple_outputs_transform(mock_ctx) |
|
|
|
|
|
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 1 |
|
|
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_k.weight"] == 2 |
|
|
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_v.weight"] == 3 |
|
|
|
|
|
|
|
|
@state_transform( |
|
|
source_key="model.layers.*.self_attn.q_proj.weight", target_key="decoder.layers.1.self_attention.linear_q.weight" |
|
|
) |
|
|
def single_transform(ctx, x): |
|
|
""" |
|
|
A single transformation function that adds 10 to the input value. |
|
|
""" |
|
|
return x + 10 |
|
|
|
|
|
|
|
|
@state_transform( |
|
|
source_key="model.layers.1.self_attn.*_proj.weight", |
|
|
target_key=("decoder.layers.1.self_attention.linear_*.weight",), |
|
|
) |
|
|
def multiple_outputs_transform(ctx, *args): |
|
|
""" |
|
|
A transformation function that returns multiple values for multiple target keys. |
|
|
""" |
|
|
return args |
|
|
|