MiniCPM-V-4.6 / python /infer_torch.py
yongqiang
Add AX650 MiniCPM-V 4.6 release package
aa969ca
import argparse
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
DEFAULT_MODEL_PATH = (
"/data/tmp/yongqiang/nfs/auto_model_deployment/"
"Minicpm-V-4.6-hf-original/MiniCPM-V-4.6"
)
def resolve_dtype(name: str):
if name == "auto":
return "auto"
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[name]
def resolve_device(name: str):
if name != "auto":
return name
return "cuda:0" if torch.cuda.is_available() else "cpu"
def main():
parser = argparse.ArgumentParser(description="MiniCPM-V-4.6 official torch text-only inference")
parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH)
parser.add_argument("--prompt", default="你好,请做一个简短自我介绍。")
parser.add_argument("--device", default="auto")
parser.add_argument("--dtype", default="auto", choices=["auto", "float32", "float16", "bfloat16"])
parser.add_argument("--attn-implementation", default="eager")
parser.add_argument("--max-new-tokens", type=int, default=64)
parser.add_argument("--do-sample", action="store_true")
args = parser.parse_args()
device = resolve_device(args.device)
dtype = resolve_dtype(args.dtype)
processor = AutoProcessor.from_pretrained(args.model_path)
load_kwargs = {
"attn_implementation": args.attn_implementation,
"device_map": device,
}
if dtype != "auto":
load_kwargs["dtype"] = dtype
model = AutoModelForImageTextToText.from_pretrained(args.model_path, **load_kwargs).eval()
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": args.prompt},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
prompt_text = processor.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=False)
print("model_path:", args.model_path)
print("device:", model.device)
print("input_ids.shape:", tuple(inputs.input_ids.shape))
print("input_ids:", inputs.input_ids[0].tolist())
print("prompt_template_repr:", prompt_text.encode("unicode_escape").decode())
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
)
trimmed_ids = generated_ids[0, inputs.input_ids.shape[-1] :]
output_text = processor.decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print("generated_ids:", trimmed_ids.tolist())
print("output_text:", output_text)
if __name__ == "__main__":
main()