|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
Script that prunes TRT-LLM checkpoints. |
|
|
''' |
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
|
|
|
import safetensors |
|
|
import torch |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
from tensorrt_llm.logger import logger |
|
|
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig |
|
|
|
|
|
SUPPORTED_MODELS = list(MODEL_MAP.keys()) |
|
|
PRUNABLE_WEIGHTS = [ |
|
|
'attention.qkv.weight', |
|
|
'attention.proj.weight', |
|
|
'mlp.fc.weight', |
|
|
'mlp.proj.weight', |
|
|
'mlp.gate.weight', |
|
|
] |
|
|
|
|
|
|
|
|
def can_prune(key: str) -> bool: |
|
|
for w in PRUNABLE_WEIGHTS: |
|
|
if w in key: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def load_config(config_path: Path) -> Dict[str, any]: |
|
|
if not config_path.exists(): |
|
|
return {} |
|
|
|
|
|
with open(str(config_path), 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def prune_and_save(ckpt_dir: str, out_dir: str, prune_all: bool): |
|
|
logger.info(f'Checkpoint Dir: {ckpt_dir}, Out Dir: {out_dir}') |
|
|
model_config = PretrainedConfig.from_json_file( |
|
|
os.path.join(ckpt_dir, 'config.json')) |
|
|
|
|
|
architecture = model_config.architecture |
|
|
if architecture not in MODEL_MAP: |
|
|
raise RuntimeError(f'Unsupported model architecture: {architecture}') |
|
|
|
|
|
if not os.path.exists(out_dir): |
|
|
os.makedirs(out_dir) |
|
|
|
|
|
for rank in range(model_config.mapping.world_size): |
|
|
pruned_weights = {} |
|
|
with safetensors.safe_open(os.path.join(ckpt_dir, |
|
|
f'rank{rank}.safetensors'), |
|
|
framework='pt', |
|
|
device='cpu') as f: |
|
|
for key in f.keys(): |
|
|
tensor = f.get_tensor(key) |
|
|
if prune_all or can_prune(key): |
|
|
pruned_weights[key] = torch.tensor([], dtype=tensor.dtype) |
|
|
else: |
|
|
pruned_weights[key] = tensor |
|
|
|
|
|
save_file(pruned_weights, |
|
|
os.path.join(out_dir, f'rank{rank}.safetensors')) |
|
|
|
|
|
config_path = Path(ckpt_dir, 'config.json') |
|
|
with open(str(Path(out_dir, 'config.json')), 'w') as f: |
|
|
config = load_config(config_path) |
|
|
config['is_pruned'] = True |
|
|
json.dump(config, f) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--checkpoint_dir', type=str, default=None) |
|
|
parser.add_argument('--prune_all', |
|
|
default=False, |
|
|
action='store_true', |
|
|
help='Remove all weights in the checkpoint') |
|
|
parser.add_argument( |
|
|
'--out_dir', |
|
|
type=str, |
|
|
default=None, |
|
|
help= |
|
|
'Path to write pruned checkpoint. Defaults to the same directory append with `.pruned`' |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.checkpoint_dir is None: |
|
|
raise RuntimeError( |
|
|
"No `--checkpoint_dir` supplied to checkpoint pruner.") |
|
|
|
|
|
if args.out_dir is None: |
|
|
ckpt_path = Path(args.checkpoint_dir) |
|
|
ckpt_name = ckpt_path.name |
|
|
args.out_dir = str( |
|
|
Path(args.checkpoint_dir).with_name(ckpt_name + '.pruned')) |
|
|
|
|
|
prune_and_save(os.path.abspath(args.checkpoint_dir), |
|
|
os.path.abspath(args.out_dir), args.prune_all) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|