from __future__ import annotations from collections.abc import Callable, Mapping, Sequence import torch import torch.nn as nn from src.model.qwen_motif_attention import QwenMotifAttentionAdapter from src.model.qwen_motif_config import ( QwenMotifAttentionPatchConfig, QwenMotifFullConfig, QwenMotifPatchConfig, build_layer_motif_indices, ) from src.model.qwen_motif_ffn import QwenMotifSplitMLP from src.model.qwen_motif_lora import QwenMotifSplitLoRAMLP from src.model.qwen_motif_router import BaseMotifRouter, build_motif_router QWEN_MOTIF_MLP_TYPES = (QwenMotifSplitMLP, QwenMotifSplitLoRAMLP) def get_qwen_decoder_layers(module: nn.Module) -> nn.ModuleList: if hasattr(module, "model") and hasattr(module.model, "layers") and isinstance(module.model.layers, nn.ModuleList): return module.model.layers base_model = getattr(module, "base_model", None) if isinstance(base_model, nn.Module): return get_qwen_decoder_layers(base_model) raise ValueError("could not resolve Qwen decoder layers from module") def patch_qwen_ffn_layers( module: nn.Module, layer_ids: Sequence[int], motif_index_by_layer: Mapping[int, torch.Tensor] | torch.Tensor, router_factory: Callable[[int], BaseMotifRouter], freeze_base: bool = True, ) -> dict[int, QwenMotifSplitMLP]: layers = get_qwen_decoder_layers(module) patched_layers: dict[int, QwenMotifSplitMLP] = {} for layer_id in layer_ids: index = int(layer_id) if index < 0 or index >= len(layers): raise IndexError(f"layer_id {index} is out of range") layer = layers[index] base_mlp = layer.mlp if isinstance(base_mlp, QWEN_MOTIF_MLP_TYPES): raise ValueError(f"layer {index} is already patched with a Qwen motif MLP") motif_index = motif_index_by_layer if isinstance(motif_index_by_layer, torch.Tensor) else motif_index_by_layer[index] patched = QwenMotifSplitMLP( base_mlp=base_mlp, motif_index=motif_index.detach().clone(), router=router_factory(index), freeze_base=freeze_base, ) patched.to(device=base_mlp.gate_proj.weight.device, dtype=base_mlp.gate_proj.weight.dtype) layer.mlp = patched patched_layers[index] = patched return patched_layers def build_and_patch_qwen_ffn_lora_layers( module: nn.Module, config: QwenMotifPatchConfig, motif_index_by_layer: Mapping[int, torch.Tensor] | None = None, hidden_size: int | None = None, ) -> dict[int, QwenMotifSplitLoRAMLP]: layers = get_qwen_decoder_layers(module) if not config.layer_ids: return {} first_layer = layers[int(config.layer_ids[0])] model_hidden_size = int(first_layer.mlp.gate_proj.in_features) if hidden_size is None else int(hidden_size) intermediate_size = int(first_layer.mlp.gate_proj.out_features) layer_motif_indices = motif_index_by_layer or build_layer_motif_indices( layer_ids=config.layer_ids, intermediate_size=intermediate_size, num_motifs=len(config.motif_names), assignment=config.assignment, motif_proportions=config.motif_proportions, seed=config.random_seed, ) patched_layers: dict[int, QwenMotifSplitLoRAMLP] = {} for layer_id in config.layer_ids: index = int(layer_id) layer = layers[index] base_mlp = layer.mlp if isinstance(base_mlp, QWEN_MOTIF_MLP_TYPES): raise ValueError(f"layer {index} is already patched with a Qwen motif MLP") router = build_motif_router( config=config.router, model_hidden_size=model_hidden_size, num_motifs=len(config.motif_names), ).module patched = QwenMotifSplitLoRAMLP( base_mlp=base_mlp, motif_index=layer_motif_indices[index].detach().clone(), router=router, expert_configs=config.expert_lora or {}, motif_names=config.motif_names, freeze_base=config.freeze_base, ) patched.to(device=base_mlp.gate_proj.weight.device, dtype=base_mlp.gate_proj.weight.dtype) layer.mlp = patched patched_layers[index] = patched if config.freeze_model: freeze_model_except_qwen_motif_trainables(module) return patched_layers def build_and_patch_qwen_ffn_layers( module: nn.Module, config: QwenMotifPatchConfig, motif_index_by_layer: Mapping[int, torch.Tensor] | None = None, ) -> dict[int, nn.Module]: layers = get_qwen_decoder_layers(module) if not config.layer_ids: return {} first_layer = layers[int(config.layer_ids[0])] hidden_size = int(first_layer.mlp.gate_proj.in_features) intermediate_size = int(first_layer.mlp.gate_proj.out_features) resolved_motif_index_by_layer = motif_index_by_layer or build_layer_motif_indices( layer_ids=config.layer_ids, intermediate_size=intermediate_size, num_motifs=len(config.motif_names), assignment=config.assignment, motif_proportions=config.motif_proportions, seed=config.random_seed, ) if config.expert_lora: return build_and_patch_qwen_ffn_lora_layers( module=module, config=config, motif_index_by_layer=resolved_motif_index_by_layer, hidden_size=hidden_size, ) def router_factory(_layer_id: int) -> BaseMotifRouter: return build_motif_router( config=config.router, model_hidden_size=hidden_size, num_motifs=len(config.motif_names), ).module patched_layers = patch_qwen_ffn_layers( module=module, layer_ids=config.layer_ids, motif_index_by_layer=resolved_motif_index_by_layer, router_factory=router_factory, freeze_base=config.freeze_base, ) if config.freeze_model: freeze_model_except_qwen_motif_trainables(module) return patched_layers def collect_qwen_motif_mlps(module: nn.Module) -> dict[int, nn.Module]: layers = get_qwen_decoder_layers(module) collected: dict[int, nn.Module] = {} for layer_id, layer in enumerate(layers): if isinstance(layer.mlp, QWEN_MOTIF_MLP_TYPES): collected[int(layer_id)] = layer.mlp return collected def patch_qwen_attention_layers( module: nn.Module, layer_ids: Sequence[int], router_factory: Callable[[int], BaseMotifRouter], config: QwenMotifAttentionPatchConfig, ) -> dict[int, QwenMotifAttentionAdapter]: layers = get_qwen_decoder_layers(module) patched_layers: dict[int, QwenMotifAttentionAdapter] = {} for layer_id in layer_ids: index = int(layer_id) if index < 0 or index >= len(layers): raise IndexError(f"layer_id {index} is out of range") layer = layers[index] base_attention = layer.self_attn if isinstance(base_attention, QwenMotifAttentionAdapter): raise ValueError(f"layer {index} is already patched with QwenMotifAttentionAdapter") attention_device = base_attention.q_proj.weight.device attention_dtype = base_attention.q_proj.weight.dtype patched = QwenMotifAttentionAdapter( base_attention=base_attention, router=router_factory(index), config=config, ) patched.to(device=attention_device, dtype=attention_dtype) layer.self_attn = patched patched_layers[index] = patched return patched_layers def build_and_patch_qwen_attention_layers( module: nn.Module, config: QwenMotifAttentionPatchConfig, ) -> dict[int, QwenMotifAttentionAdapter]: layers = get_qwen_decoder_layers(module) if not config.layer_ids: return {} first_layer = layers[int(config.layer_ids[0])] hidden_size = int(first_layer.self_attn.q_proj.in_features) def router_factory(_layer_id: int) -> BaseMotifRouter: return build_motif_router( config=config.router, model_hidden_size=hidden_size, num_motifs=len(config.motif_names), ).module patched_layers = patch_qwen_attention_layers( module=module, layer_ids=config.layer_ids, router_factory=router_factory, config=config, ) if config.freeze_model: freeze_model_except_qwen_motif_trainables(module) return patched_layers def collect_qwen_motif_attention_adapters(module: nn.Module) -> dict[int, QwenMotifAttentionAdapter]: layers = get_qwen_decoder_layers(module) collected: dict[int, QwenMotifAttentionAdapter] = {} for layer_id, layer in enumerate(layers): if isinstance(layer.self_attn, QwenMotifAttentionAdapter): collected[int(layer_id)] = layer.self_attn return collected def freeze_model_except_motif_routers(module: nn.Module) -> None: for parameter in module.parameters(): parameter.requires_grad = False for motif_mlp in collect_qwen_motif_mlps(module).values(): for parameter in motif_mlp.router.parameters(): parameter.requires_grad = True for attn_adapter in collect_qwen_motif_attention_adapters(module).values(): for parameter in attn_adapter.router.parameters(): parameter.requires_grad = True def freeze_model_except_qwen_motif_trainables(module: nn.Module) -> None: for parameter in module.parameters(): parameter.requires_grad = False for motif_mlp in collect_qwen_motif_mlps(module).values(): for parameter in motif_mlp.router.parameters(): parameter.requires_grad = True for expert in getattr(motif_mlp, "experts", {}).values(): for parameter in expert.parameters(): parameter.requires_grad = True for attn_adapter in collect_qwen_motif_attention_adapters(module).values(): for parameter in attn_adapter.router.parameters(): parameter.requires_grad = True for projection in (attn_adapter.q_proj, attn_adapter.k_proj, attn_adapter.v_proj, attn_adapter.o_proj): if getattr(projection, "adapter", None) is not None: for parameter in projection.adapter.parameters(): parameter.requires_grad = True def collect_qwen_motif_trainable_names(module: nn.Module) -> list[str]: return [name for name, parameter in module.named_parameters() if parameter.requires_grad] def partial_reinit_qwen_motif_modules(module: nn.Module, fraction: float = 1.0) -> None: for motif_mlp in collect_qwen_motif_mlps(module).values(): if hasattr(motif_mlp, "partial_reinit_"): motif_mlp.partial_reinit_(fraction=fraction) for attn_adapter in collect_qwen_motif_attention_adapters(module).values(): attn_adapter.partial_reinit_(fraction=fraction) def apply_qwen_motif_pipeline(module: nn.Module, config: QwenMotifFullConfig) -> dict[str, dict[int, nn.Module]]: results: dict[str, dict[int, nn.Module]] = {} if config.ffn is not None: results["ffn"] = build_and_patch_qwen_ffn_layers(module, config.ffn) if config.attention is not None: results["attention"] = build_and_patch_qwen_attention_layers(module, config.attention) return results