| | import logging |
| | import os |
| |
|
| | from awq.models import * |
| | from awq.models.base import BaseAWQForCausalLM |
| |
|
| | from transformers import AutoConfig |
| |
|
| | AWQ_CAUSAL_LM_MODEL_MAP = { |
| | "mpt": MptAWQForCausalLM, |
| | "llama": LlamaAWQForCausalLM, |
| | "opt": OptAWQForCausalLM, |
| | "RefinedWeb": FalconAWQForCausalLM, |
| | "RefinedWebModel": FalconAWQForCausalLM, |
| | "falcon": FalconAWQForCausalLM, |
| | "bloom": BloomAWQForCausalLM, |
| | "gptj": GPTJAWQForCausalLM, |
| | "gpt_bigcode": GptBigCodeAWQForCausalLM, |
| | "mistral": MistralAWQForCausalLM, |
| | "mixtral": MixtralAWQForCausalLM, |
| | "gpt_neox": GPTNeoXAWQForCausalLM, |
| | "aquila": AquilaAWQForCausalLM, |
| | "Yi": YiAWQForCausalLM, |
| | "qwen": QwenAWQForCausalLM, |
| | "baichuan": BaichuanAWQForCausalLM, |
| | "llava": LlavaAWQForCausalLM, |
| | "qwen2": Qwen2AWQForCausalLM, |
| | "gemma": GemmaAWQForCausalLM, |
| | "stablelm": StableLmAWQForCausalLM, |
| | "deepseek": DeepseekAWQForCausalLM, |
| | } |
| |
|
| |
|
| | def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs): |
| | config = AutoConfig.from_pretrained( |
| | model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs |
| | ) |
| | if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys(): |
| | raise TypeError(f"{config.model_type} isn't supported yet.") |
| | model_type = config.model_type |
| | return model_type |
| |
|
| |
|
| | class AutoAWQForCausalLM: |
| | def __init__(self): |
| | raise EnvironmentError( |
| | "You must instantiate AutoAWQForCausalLM with\n" |
| | "AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained" |
| | ) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | self, |
| | model_path, |
| | trust_remote_code=True, |
| | safetensors=True, |
| | device_map=None, |
| | download_kwargs=None, |
| | **model_init_kwargs, |
| | ) -> BaseAWQForCausalLM: |
| | model_type = check_and_get_model_type( |
| | model_path, trust_remote_code, **model_init_kwargs |
| | ) |
| |
|
| | return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( |
| | model_path, |
| | model_type, |
| | trust_remote_code=trust_remote_code, |
| | safetensors=safetensors, |
| | device_map=device_map, |
| | download_kwargs=download_kwargs, |
| | **model_init_kwargs, |
| | ) |
| |
|
| | @classmethod |
| | def from_quantized( |
| | self, |
| | quant_path, |
| | quant_filename="", |
| | max_seq_len=2048, |
| | trust_remote_code=True, |
| | fuse_layers=True, |
| | use_exllama=False, |
| | use_exllama_v2=False, |
| | batch_size=1, |
| | safetensors=True, |
| | device_map="balanced", |
| | offload_folder=None, |
| | download_kwargs=None, |
| | **config_kwargs, |
| | ) -> BaseAWQForCausalLM: |
| | os.environ["AWQ_BATCH_SIZE"] = str(batch_size) |
| | model_type = check_and_get_model_type(quant_path, trust_remote_code) |
| |
|
| | if config_kwargs.get("max_new_tokens") is not None: |
| | max_seq_len = config_kwargs["max_new_tokens"] |
| | logging.warning( |
| | "max_new_tokens argument is deprecated... gracefully " |
| | "setting max_seq_len=max_new_tokens." |
| | ) |
| |
|
| | return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( |
| | quant_path, |
| | model_type, |
| | quant_filename, |
| | max_seq_len, |
| | trust_remote_code=trust_remote_code, |
| | fuse_layers=fuse_layers, |
| | use_exllama=use_exllama, |
| | use_exllama_v2=use_exllama_v2, |
| | safetensors=safetensors, |
| | device_map=device_map, |
| | offload_folder=offload_folder, |
| | download_kwargs=download_kwargs, |
| | **config_kwargs, |
| | ) |
| |
|