| from .base import BaseAWQForCausalLM |
| from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM |
|
|
|
|
| class MptAWQForCausalLM(BaseAWQForCausalLM): |
| layer_type = "MPTBlock" |
| max_seq_len_key = "max_seq_len" |
|
|
| @staticmethod |
| def fuse_layers(model: MptForCausalLM): |
| fuser = MptFuser(model) |
| fuser.fuse_transformer() |
|
|
| @staticmethod |
| def get_model_layers(model: MptForCausalLM): |
| return model.transformer.blocks |
|
|
| @staticmethod |
| def get_act_for_scaling(module: OldMptBlock): |
| return dict( |
| is_scalable=True, |
| scale_name="ffn.act", |
| scale_layer=module.ffn.act, |
| scale_shape=module.ffn.up_proj.out_features, |
| ) |
|
|
| @staticmethod |
| def move_embed(model: MptForCausalLM, device: str): |
| model.transformer.wte = model.transformer.wte.to(device) |
| model.transformer.emb_drop = model.transformer.emb_drop.to(device) |
|
|
| @staticmethod |
| def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs): |
| layers = [] |
|
|
| if module_kwargs.get("output_attentions") is not None: |
| module_kwargs.pop("output_attentions") |
|
|
| |
| layers.append( |
| dict( |
| prev_op=module.norm_1, |
| layers=[module.attn.Wqkv], |
| inp=input_feat["attn.Wqkv"], |
| module2inspect=module.attn, |
| kwargs=module_kwargs, |
| ) |
| ) |
|
|
| |
| layers.append( |
| dict( |
| prev_op=module.attn.Wqkv, |
| layers=[module.attn.out_proj], |
| inp=input_feat["attn.out_proj"], |
| ) |
| ) |
|
|
| |
| layers.append( |
| dict( |
| prev_op=module.norm_2, |
| layers=[module.ffn.up_proj], |
| inp=input_feat["ffn.up_proj"], |
| module2inspect=module.ffn, |
| ) |
| ) |
|
|
| |
| layers.append( |
| dict( |
| prev_op=module.ffn.act, |
| layers=[module.ffn.down_proj], |
| inp=input_feat["ffn.down_proj"], |
| ) |
| ) |
|
|
| return layers |
|
|
|
|
| from typing import List, Tuple |
| from awq.utils.utils import set_module_name |
| from awq.modules.fused.block import MPTBlock |
| from awq.modules.fused.model import MPTModel |
|
|
|
|
| class MptFuser: |
| def __init__(self, model: MptForCausalLM): |
| self.model = model |
|
|
| self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [ |
| (name, module) |
| for name, module in self.model.named_modules() |
| if "mptblock" in module.__class__.__name__.lower() |
| ] |
|
|
| def fuse_transformer(self): |
| blocks = [] |
|
|
| module: OldMptBlock |
| for module in self.model.transformer.blocks: |
| blocks.append( |
| MPTBlock( |
| self.model.config.d_model, |
| self.model.config.n_heads, |
| module.attn.Wqkv, |
| module.attn.out_proj, |
| module.ffn, |
| module.norm_1, |
| module.norm_2, |
| next(iter(module.state_dict().values())).device, |
| self.model.config.max_seq_len, |
| ) |
| ) |
|
|
| self.model.transformer = MPTModel( |
| self.model.config.vocab_size, |
| blocks, |
| self.model.transformer.wte, |
| self.model.transformer.norm_f, |
| ) |
|
|
| setattr(self.model.transformer, "blocks", self.model.transformer.blocks) |
|
|