abpt / src /model /qwen_motif_patch.py
Search
auto: sync run_qwen_motif_protocol.py
c4cddd0
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
import torch
import torch.nn as nn
from src.model.qwen_motif_attention import QwenMotifAttentionAdapter
from src.model.qwen_motif_config import (
QwenMotifAttentionPatchConfig,
QwenMotifFullConfig,
QwenMotifPatchConfig,
build_layer_motif_indices,
)
from src.model.qwen_motif_ffn import QwenMotifSplitMLP
from src.model.qwen_motif_lora import QwenMotifSplitLoRAMLP
from src.model.qwen_motif_router import BaseMotifRouter, build_motif_router
QWEN_MOTIF_MLP_TYPES = (QwenMotifSplitMLP, QwenMotifSplitLoRAMLP)
def get_qwen_decoder_layers(module: nn.Module) -> nn.ModuleList:
if hasattr(module, "model") and hasattr(module.model, "layers") and isinstance(module.model.layers, nn.ModuleList):
return module.model.layers
base_model = getattr(module, "base_model", None)
if isinstance(base_model, nn.Module):
return get_qwen_decoder_layers(base_model)
raise ValueError("could not resolve Qwen decoder layers from module")
def patch_qwen_ffn_layers(
module: nn.Module,
layer_ids: Sequence[int],
motif_index_by_layer: Mapping[int, torch.Tensor] | torch.Tensor,
router_factory: Callable[[int], BaseMotifRouter],
freeze_base: bool = True,
) -> dict[int, QwenMotifSplitMLP]:
layers = get_qwen_decoder_layers(module)
patched_layers: dict[int, QwenMotifSplitMLP] = {}
for layer_id in layer_ids:
index = int(layer_id)
if index < 0 or index >= len(layers):
raise IndexError(f"layer_id {index} is out of range")
layer = layers[index]
base_mlp = layer.mlp
if isinstance(base_mlp, QWEN_MOTIF_MLP_TYPES):
raise ValueError(f"layer {index} is already patched with a Qwen motif MLP")
motif_index = motif_index_by_layer if isinstance(motif_index_by_layer, torch.Tensor) else motif_index_by_layer[index]
patched = QwenMotifSplitMLP(
base_mlp=base_mlp,
motif_index=motif_index.detach().clone(),
router=router_factory(index),
freeze_base=freeze_base,
)
patched.to(device=base_mlp.gate_proj.weight.device, dtype=base_mlp.gate_proj.weight.dtype)
layer.mlp = patched
patched_layers[index] = patched
return patched_layers
def build_and_patch_qwen_ffn_lora_layers(
module: nn.Module,
config: QwenMotifPatchConfig,
motif_index_by_layer: Mapping[int, torch.Tensor] | None = None,
hidden_size: int | None = None,
) -> dict[int, QwenMotifSplitLoRAMLP]:
layers = get_qwen_decoder_layers(module)
if not config.layer_ids:
return {}
first_layer = layers[int(config.layer_ids[0])]
model_hidden_size = int(first_layer.mlp.gate_proj.in_features) if hidden_size is None else int(hidden_size)
intermediate_size = int(first_layer.mlp.gate_proj.out_features)
layer_motif_indices = motif_index_by_layer or build_layer_motif_indices(
layer_ids=config.layer_ids,
intermediate_size=intermediate_size,
num_motifs=len(config.motif_names),
assignment=config.assignment,
motif_proportions=config.motif_proportions,
seed=config.random_seed,
)
patched_layers: dict[int, QwenMotifSplitLoRAMLP] = {}
for layer_id in config.layer_ids:
index = int(layer_id)
layer = layers[index]
base_mlp = layer.mlp
if isinstance(base_mlp, QWEN_MOTIF_MLP_TYPES):
raise ValueError(f"layer {index} is already patched with a Qwen motif MLP")
router = build_motif_router(
config=config.router,
model_hidden_size=model_hidden_size,
num_motifs=len(config.motif_names),
).module
patched = QwenMotifSplitLoRAMLP(
base_mlp=base_mlp,
motif_index=layer_motif_indices[index].detach().clone(),
router=router,
expert_configs=config.expert_lora or {},
motif_names=config.motif_names,
freeze_base=config.freeze_base,
)
patched.to(device=base_mlp.gate_proj.weight.device, dtype=base_mlp.gate_proj.weight.dtype)
layer.mlp = patched
patched_layers[index] = patched
if config.freeze_model:
freeze_model_except_qwen_motif_trainables(module)
return patched_layers
def build_and_patch_qwen_ffn_layers(
module: nn.Module,
config: QwenMotifPatchConfig,
motif_index_by_layer: Mapping[int, torch.Tensor] | None = None,
) -> dict[int, nn.Module]:
layers = get_qwen_decoder_layers(module)
if not config.layer_ids:
return {}
first_layer = layers[int(config.layer_ids[0])]
hidden_size = int(first_layer.mlp.gate_proj.in_features)
intermediate_size = int(first_layer.mlp.gate_proj.out_features)
resolved_motif_index_by_layer = motif_index_by_layer or build_layer_motif_indices(
layer_ids=config.layer_ids,
intermediate_size=intermediate_size,
num_motifs=len(config.motif_names),
assignment=config.assignment,
motif_proportions=config.motif_proportions,
seed=config.random_seed,
)
if config.expert_lora:
return build_and_patch_qwen_ffn_lora_layers(
module=module,
config=config,
motif_index_by_layer=resolved_motif_index_by_layer,
hidden_size=hidden_size,
)
def router_factory(_layer_id: int) -> BaseMotifRouter:
return build_motif_router(
config=config.router,
model_hidden_size=hidden_size,
num_motifs=len(config.motif_names),
).module
patched_layers = patch_qwen_ffn_layers(
module=module,
layer_ids=config.layer_ids,
motif_index_by_layer=resolved_motif_index_by_layer,
router_factory=router_factory,
freeze_base=config.freeze_base,
)
if config.freeze_model:
freeze_model_except_qwen_motif_trainables(module)
return patched_layers
def collect_qwen_motif_mlps(module: nn.Module) -> dict[int, nn.Module]:
layers = get_qwen_decoder_layers(module)
collected: dict[int, nn.Module] = {}
for layer_id, layer in enumerate(layers):
if isinstance(layer.mlp, QWEN_MOTIF_MLP_TYPES):
collected[int(layer_id)] = layer.mlp
return collected
def patch_qwen_attention_layers(
module: nn.Module,
layer_ids: Sequence[int],
router_factory: Callable[[int], BaseMotifRouter],
config: QwenMotifAttentionPatchConfig,
) -> dict[int, QwenMotifAttentionAdapter]:
layers = get_qwen_decoder_layers(module)
patched_layers: dict[int, QwenMotifAttentionAdapter] = {}
for layer_id in layer_ids:
index = int(layer_id)
if index < 0 or index >= len(layers):
raise IndexError(f"layer_id {index} is out of range")
layer = layers[index]
base_attention = layer.self_attn
if isinstance(base_attention, QwenMotifAttentionAdapter):
raise ValueError(f"layer {index} is already patched with QwenMotifAttentionAdapter")
attention_device = base_attention.q_proj.weight.device
attention_dtype = base_attention.q_proj.weight.dtype
patched = QwenMotifAttentionAdapter(
base_attention=base_attention,
router=router_factory(index),
config=config,
)
patched.to(device=attention_device, dtype=attention_dtype)
layer.self_attn = patched
patched_layers[index] = patched
return patched_layers
def build_and_patch_qwen_attention_layers(
module: nn.Module,
config: QwenMotifAttentionPatchConfig,
) -> dict[int, QwenMotifAttentionAdapter]:
layers = get_qwen_decoder_layers(module)
if not config.layer_ids:
return {}
first_layer = layers[int(config.layer_ids[0])]
hidden_size = int(first_layer.self_attn.q_proj.in_features)
def router_factory(_layer_id: int) -> BaseMotifRouter:
return build_motif_router(
config=config.router,
model_hidden_size=hidden_size,
num_motifs=len(config.motif_names),
).module
patched_layers = patch_qwen_attention_layers(
module=module,
layer_ids=config.layer_ids,
router_factory=router_factory,
config=config,
)
if config.freeze_model:
freeze_model_except_qwen_motif_trainables(module)
return patched_layers
def collect_qwen_motif_attention_adapters(module: nn.Module) -> dict[int, QwenMotifAttentionAdapter]:
layers = get_qwen_decoder_layers(module)
collected: dict[int, QwenMotifAttentionAdapter] = {}
for layer_id, layer in enumerate(layers):
if isinstance(layer.self_attn, QwenMotifAttentionAdapter):
collected[int(layer_id)] = layer.self_attn
return collected
def freeze_model_except_motif_routers(module: nn.Module) -> None:
for parameter in module.parameters():
parameter.requires_grad = False
for motif_mlp in collect_qwen_motif_mlps(module).values():
for parameter in motif_mlp.router.parameters():
parameter.requires_grad = True
for attn_adapter in collect_qwen_motif_attention_adapters(module).values():
for parameter in attn_adapter.router.parameters():
parameter.requires_grad = True
def freeze_model_except_qwen_motif_trainables(module: nn.Module) -> None:
for parameter in module.parameters():
parameter.requires_grad = False
for motif_mlp in collect_qwen_motif_mlps(module).values():
for parameter in motif_mlp.router.parameters():
parameter.requires_grad = True
for expert in getattr(motif_mlp, "experts", {}).values():
for parameter in expert.parameters():
parameter.requires_grad = True
for attn_adapter in collect_qwen_motif_attention_adapters(module).values():
for parameter in attn_adapter.router.parameters():
parameter.requires_grad = True
for projection in (attn_adapter.q_proj, attn_adapter.k_proj, attn_adapter.v_proj, attn_adapter.o_proj):
if getattr(projection, "adapter", None) is not None:
for parameter in projection.adapter.parameters():
parameter.requires_grad = True
def collect_qwen_motif_trainable_names(module: nn.Module) -> list[str]:
return [name for name, parameter in module.named_parameters() if parameter.requires_grad]
def partial_reinit_qwen_motif_modules(module: nn.Module, fraction: float = 1.0) -> None:
for motif_mlp in collect_qwen_motif_mlps(module).values():
if hasattr(motif_mlp, "partial_reinit_"):
motif_mlp.partial_reinit_(fraction=fraction)
for attn_adapter in collect_qwen_motif_attention_adapters(module).values():
attn_adapter.partial_reinit_(fraction=fraction)
def apply_qwen_motif_pipeline(module: nn.Module, config: QwenMotifFullConfig) -> dict[str, dict[int, nn.Module]]:
results: dict[str, dict[int, nn.Module]] = {}
if config.ffn is not None:
results["ffn"] = build_and_patch_qwen_ffn_layers(module, config.ffn)
if config.attention is not None:
results["attention"] = build_and_patch_qwen_attention_layers(module, config.attention)
return results