import argparse import torch from PIL import Image from transformers import AutoTokenizer from earthdial.model.internvl_chat import InternVLChatModel from earthdial.train.dataset import build_transform def run_single_inference(args): print(f"Loading model and tokenizer from Hugging Face: {args.checkpoint}") tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) model = InternVLChatModel.from_pretrained( args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map="auto" if args.auto else None, load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit ).eval() if not args.load_in_8bit and not args.load_in_4bit and not args.auto: model = model.cuda() image = Image.open(args.image_path).convert("RGB") image_size = model.config.force_image_size or model.config.vision_config.image_size transform = build_transform(is_train=False, input_size=image_size, normalize_type='imagenet') pixel_values = transform(image).unsqueeze(0).cuda().to(torch.bfloat16) generation_config = { "num_beams": args.num_beams, "max_new_tokens": 100, "min_new_tokens": 1, "do_sample": args.temperature > 0, "temperature": args.temperature, } answer = model.chat( tokenizer=tokenizer, pixel_values=pixel_values, question=args.question, generation_config=generation_config, verbose=True ) print("\n=== Inference Result ===") print(f"Question: {args.question}") print(f"Answer: {answer}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', type=str, required=True, help='Model repo ID on Hugging Face Hub') parser.add_argument('--image-path', type=str, required=True, help='Path to input image') parser.add_argument('--question', type=str, required=True, help='Visual question to ask') parser.add_argument('--num-beams', type=int, default=5) parser.add_argument('--temperature', type=float, default=0.0) parser.add_argument('--load-in-8bit', action='store_true') parser.add_argument('--load-in-4bit', action='store_true') parser.add_argument('--auto', action='store_true') args = parser.parse_args() run_single_inference(args)