InternVL3_5-2B-RKLLM / rkllm-convert.py
happyme531's picture
Upload 11 files
7fc4eb4 verified
import torch
import json
import os
from transformers import AutoConfig, Qwen3ForCausalLM, AutoTokenizer
from rkllm.api import RKLLM
import argparse
import shutil
from pathlib import Path
from typing import Dict
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForCausalLM
TOKENIZER_FILES = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"added_tokens.json",
"vocab.json",
"merges.txt",
"chat_template.jinja",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--source",
type=Path,
default=".",
help="Path to the InternVL (HF-format) checkpoint directory, e.g. /path/to/InternVL3_5-2B-HF",
)
parser.add_argument(
"--output",
type=Path,
default="llm/",
help="Directory where the extracted Qwen3 checkpoint will be written",
)
parser.add_argument(
"--safe-serialization",
action="store_true",
default=True,
help="Save the exported model using safetensors instead of PyTorch binaries.",
)
return parser.parse_args()
def extract_text_state_dict(full_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
prefix = "language_model.model."
lm_head_prefix = "language_model.lm_head."
text_state: Dict[str, torch.Tensor] = {}
for key, tensor in full_state.items():
if key.startswith(prefix):
text_key = "model." + key[len(prefix) :]
elif key.startswith(lm_head_prefix):
text_key = "lm_head." + key[len(lm_head_prefix) :]
else:
continue
text_state[text_key] = tensor
if not text_state:
raise ValueError("Did not find any language_model weights in checkpoint; is this an InternVL model?")
return text_state
def copy_tokenizer_files(source_dir: Path, output_dir: Path) -> None:
for filename in TOKENIZER_FILES:
src = source_dir / filename
if src.exists():
dst = output_dir / filename
shutil.copyfile(src, dst)
def main() -> None:
args = parse_args()
source_dir = args.source.expanduser().resolve()
output_dir = args.output.expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
config = AutoConfig.from_pretrained(source_dir, trust_remote_code=True)
text_config = config.text_config
weights_path = source_dir / "model.safetensors"
if not weights_path.exists():
raise FileNotFoundError(f"Could not find {weights_path}; expected a safetensors checkpoint")
all_weights = load_file(weights_path)
text_state = extract_text_state_dict(all_weights)
sample_tensor = next(iter(text_state.values()))
target_dtype = sample_tensor.dtype
text_model = AutoModelForCausalLM.from_config(text_config)
text_model = text_model.to(dtype=target_dtype, device=torch.device("cpu"))
missing, unexpected = text_model.load_state_dict(text_state, strict=False)
if missing or unexpected:
raise RuntimeError(
"State dict mismatch when loading text weights: "
f"missing={missing}, unexpected={unexpected}"
)
text_config.save_pretrained(output_dir)
text_model.generation_config.save_pretrained(output_dir)
text_model.save_pretrained(output_dir, safe_serialization=args.safe_serialization)
copy_tokenizer_files(source_dir, output_dir)
print(f"Exported Qwen3 model saved to {output_dir}")
modelpath = output_dir
llm = RKLLM()
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
if ret != 0:
print('Load model failed!')
exit(ret)
qparams = None
ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
if ret != 0:
print('Build model failed!')
exit(ret)
# Export rkllm model
ret = llm.export_rkllm("./language_model_w8a8.rkllm")
if ret != 0:
print('Export model failed!')
exit(ret)
if __name__ == "__main__":
main()