File size: 1,749 Bytes
ab18129 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #!/usr/bin/env python3
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
def main() -> None:
parser = argparse.ArgumentParser(description="Run a local Qwen2.5-VL/Jarvis VLP world-knowledge checkpoint on one image.")
parser.add_argument("--model", default="/data/zianguan/shared_vlp_world_knowledge_clean_best_20260520/best_vqa_world_knowledge_ckpt")
parser.add_argument("--image", required=True)
parser.add_argument("--prompt", required=True)
parser.add_argument("--max-new-tokens", type=int, default=128)
args = parser.parse_args()
image = Image.open(Path(args.image)).convert("RGB")
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.model,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": args.prompt},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], return_tensors="pt").to(model.device)
with torch.inference_mode():
output_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens)
generated = output_ids[:, inputs.input_ids.shape[1] :]
print(processor.batch_decode(generated, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
if __name__ == "__main__":
main()
|