| import argparse | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoTokenizer | |
| from earthdial.model.internvl_chat import InternVLChatModel | |
| from earthdial.train.dataset import build_transform | |
| def run_single_inference(args): | |
| print(f"Loading model and tokenizer from Hugging Face: {args.checkpoint}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) | |
| model = InternVLChatModel.from_pretrained( | |
| args.checkpoint, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" if args.auto else None, | |
| load_in_8bit=args.load_in_8bit, | |
| load_in_4bit=args.load_in_4bit | |
| ).eval() | |
| if not args.load_in_8bit and not args.load_in_4bit and not args.auto: | |
| model = model.cuda() | |
| image = Image.open(args.image_path).convert("RGB") | |
| image_size = model.config.force_image_size or model.config.vision_config.image_size | |
| transform = build_transform(is_train=False, input_size=image_size, normalize_type='imagenet') | |
| pixel_values = transform(image).unsqueeze(0).cuda().to(torch.bfloat16) | |
| generation_config = { | |
| "num_beams": args.num_beams, | |
| "max_new_tokens": 100, | |
| "min_new_tokens": 1, | |
| "do_sample": args.temperature > 0, | |
| "temperature": args.temperature, | |
| } | |
| answer = model.chat( | |
| tokenizer=tokenizer, | |
| pixel_values=pixel_values, | |
| question=args.question, | |
| generation_config=generation_config, | |
| verbose=True | |
| ) | |
| print("\n=== Inference Result ===") | |
| print(f"Question: {args.question}") | |
| print(f"Answer: {answer}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--checkpoint', type=str, required=True, help='Model repo ID on Hugging Face Hub') | |
| parser.add_argument('--image-path', type=str, required=True, help='Path to input image') | |
| parser.add_argument('--question', type=str, required=True, help='Visual question to ask') | |
| parser.add_argument('--num-beams', type=int, default=5) | |
| parser.add_argument('--temperature', type=float, default=0.0) | |
| parser.add_argument('--load-in-8bit', action='store_true') | |
| parser.add_argument('--load-in-4bit', action='store_true') | |
| parser.add_argument('--auto', action='store_true') | |
| args = parser.parse_args() | |
| run_single_inference(args) | |