File size: 4,253 Bytes
7fc4eb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import torch
import json
import os
from transformers import AutoConfig, Qwen3ForCausalLM, AutoTokenizer
from rkllm.api import RKLLM
import argparse
import shutil
from pathlib import Path
from typing import Dict
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForCausalLM
TOKENIZER_FILES = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"added_tokens.json",
"vocab.json",
"merges.txt",
"chat_template.jinja",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--source",
type=Path,
default=".",
help="Path to the InternVL (HF-format) checkpoint directory, e.g. /path/to/InternVL3_5-2B-HF",
)
parser.add_argument(
"--output",
type=Path,
default="llm/",
help="Directory where the extracted Qwen3 checkpoint will be written",
)
parser.add_argument(
"--safe-serialization",
action="store_true",
default=True,
help="Save the exported model using safetensors instead of PyTorch binaries.",
)
return parser.parse_args()
def extract_text_state_dict(full_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
prefix = "language_model.model."
lm_head_prefix = "language_model.lm_head."
text_state: Dict[str, torch.Tensor] = {}
for key, tensor in full_state.items():
if key.startswith(prefix):
text_key = "model." + key[len(prefix) :]
elif key.startswith(lm_head_prefix):
text_key = "lm_head." + key[len(lm_head_prefix) :]
else:
continue
text_state[text_key] = tensor
if not text_state:
raise ValueError("Did not find any language_model weights in checkpoint; is this an InternVL model?")
return text_state
def copy_tokenizer_files(source_dir: Path, output_dir: Path) -> None:
for filename in TOKENIZER_FILES:
src = source_dir / filename
if src.exists():
dst = output_dir / filename
shutil.copyfile(src, dst)
def main() -> None:
args = parse_args()
source_dir = args.source.expanduser().resolve()
output_dir = args.output.expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
config = AutoConfig.from_pretrained(source_dir, trust_remote_code=True)
text_config = config.text_config
weights_path = source_dir / "model.safetensors"
if not weights_path.exists():
raise FileNotFoundError(f"Could not find {weights_path}; expected a safetensors checkpoint")
all_weights = load_file(weights_path)
text_state = extract_text_state_dict(all_weights)
sample_tensor = next(iter(text_state.values()))
target_dtype = sample_tensor.dtype
text_model = AutoModelForCausalLM.from_config(text_config)
text_model = text_model.to(dtype=target_dtype, device=torch.device("cpu"))
missing, unexpected = text_model.load_state_dict(text_state, strict=False)
if missing or unexpected:
raise RuntimeError(
"State dict mismatch when loading text weights: "
f"missing={missing}, unexpected={unexpected}"
)
text_config.save_pretrained(output_dir)
text_model.generation_config.save_pretrained(output_dir)
text_model.save_pretrained(output_dir, safe_serialization=args.safe_serialization)
copy_tokenizer_files(source_dir, output_dir)
print(f"Exported Qwen3 model saved to {output_dir}")
modelpath = output_dir
llm = RKLLM()
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
if ret != 0:
print('Load model failed!')
exit(ret)
qparams = None
ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
if ret != 0:
print('Build model failed!')
exit(ret)
# Export rkllm model
ret = llm.export_rkllm("./language_model_w8a8.rkllm")
if ret != 0:
print('Export model failed!')
exit(ret)
if __name__ == "__main__":
main()
|