|
|
import torch |
|
|
import json |
|
|
import os |
|
|
from transformers import AutoConfig, Qwen3ForCausalLM, AutoTokenizer |
|
|
|
|
|
from rkllm.api import RKLLM |
|
|
|
|
|
import argparse |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
|
|
TOKENIZER_FILES = [ |
|
|
"tokenizer.json", |
|
|
"tokenizer_config.json", |
|
|
"special_tokens_map.json", |
|
|
"added_tokens.json", |
|
|
"vocab.json", |
|
|
"merges.txt", |
|
|
"chat_template.jinja", |
|
|
] |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description=__doc__) |
|
|
parser.add_argument( |
|
|
"--source", |
|
|
type=Path, |
|
|
default=".", |
|
|
help="Path to the InternVL (HF-format) checkpoint directory, e.g. /path/to/InternVL3_5-2B-HF", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=Path, |
|
|
default="llm/", |
|
|
help="Directory where the extracted Qwen3 checkpoint will be written", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--safe-serialization", |
|
|
action="store_true", |
|
|
default=True, |
|
|
help="Save the exported model using safetensors instead of PyTorch binaries.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def extract_text_state_dict(full_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
prefix = "language_model.model." |
|
|
lm_head_prefix = "language_model.lm_head." |
|
|
text_state: Dict[str, torch.Tensor] = {} |
|
|
|
|
|
for key, tensor in full_state.items(): |
|
|
if key.startswith(prefix): |
|
|
text_key = "model." + key[len(prefix) :] |
|
|
elif key.startswith(lm_head_prefix): |
|
|
text_key = "lm_head." + key[len(lm_head_prefix) :] |
|
|
else: |
|
|
continue |
|
|
text_state[text_key] = tensor |
|
|
|
|
|
if not text_state: |
|
|
raise ValueError("Did not find any language_model weights in checkpoint; is this an InternVL model?") |
|
|
|
|
|
return text_state |
|
|
|
|
|
|
|
|
def copy_tokenizer_files(source_dir: Path, output_dir: Path) -> None: |
|
|
for filename in TOKENIZER_FILES: |
|
|
src = source_dir / filename |
|
|
if src.exists(): |
|
|
dst = output_dir / filename |
|
|
shutil.copyfile(src, dst) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
source_dir = args.source.expanduser().resolve() |
|
|
output_dir = args.output.expanduser().resolve() |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
config = AutoConfig.from_pretrained(source_dir, trust_remote_code=True) |
|
|
text_config = config.text_config |
|
|
|
|
|
weights_path = source_dir / "model.safetensors" |
|
|
if not weights_path.exists(): |
|
|
raise FileNotFoundError(f"Could not find {weights_path}; expected a safetensors checkpoint") |
|
|
|
|
|
all_weights = load_file(weights_path) |
|
|
text_state = extract_text_state_dict(all_weights) |
|
|
|
|
|
sample_tensor = next(iter(text_state.values())) |
|
|
target_dtype = sample_tensor.dtype |
|
|
|
|
|
text_model = AutoModelForCausalLM.from_config(text_config) |
|
|
text_model = text_model.to(dtype=target_dtype, device=torch.device("cpu")) |
|
|
missing, unexpected = text_model.load_state_dict(text_state, strict=False) |
|
|
if missing or unexpected: |
|
|
raise RuntimeError( |
|
|
"State dict mismatch when loading text weights: " |
|
|
f"missing={missing}, unexpected={unexpected}" |
|
|
) |
|
|
|
|
|
text_config.save_pretrained(output_dir) |
|
|
text_model.generation_config.save_pretrained(output_dir) |
|
|
text_model.save_pretrained(output_dir, safe_serialization=args.safe_serialization) |
|
|
|
|
|
copy_tokenizer_files(source_dir, output_dir) |
|
|
print(f"Exported Qwen3 model saved to {output_dir}") |
|
|
|
|
|
|
|
|
modelpath = output_dir |
|
|
llm = RKLLM() |
|
|
|
|
|
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu') |
|
|
if ret != 0: |
|
|
print('Load model failed!') |
|
|
exit(ret) |
|
|
|
|
|
qparams = None |
|
|
ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8', |
|
|
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams) |
|
|
|
|
|
if ret != 0: |
|
|
print('Build model failed!') |
|
|
exit(ret) |
|
|
|
|
|
|
|
|
ret = llm.export_rkllm("./language_model_w8a8.rkllm") |
|
|
if ret != 0: |
|
|
print('Export model failed!') |
|
|
exit(ret) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|