"""End-to-end inference demo for LlavaOnevision2 (image + video). This script shows the two canonical inference paths supported by the model: * Image captioning (``--mode image``, default) * Video captioning (``--mode video``) Both modes share the same loading pattern: from transformers import AutoProcessor, AutoModelForImageTextToText processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( model_dir, trust_remote_code=True, dtype=torch.bfloat16, device_map="cuda", ) Examples -------- # Image (default sample image from the web) python demo_inference.py # Image with a local file and a custom prompt python demo_inference.py --mode image --media /path/to/cat.jpg \ --prompt "What is the cat doing?" # Video # - ``--num-frames`` selects exactly N frames (uniform sampling). # - ``--max-pixels`` caps each frame's pixel budget. Lower it to fit smaller # GPUs; 200704 (=448*448) is a safe default for a single ~80GB card. python demo_inference.py --mode video --media /path/to/clip.mp4 \ --num-frames 16 --max-pixels 200704 \ --prompt "Describe what happens in this video." Tested with: transformers == 5.7.0 torch >= 2.4 decord, Pillow, requests """ from __future__ import annotations import argparse import io import os import sys import torch # Placeholder constants so the user can swap their own media in easily. # (Public sample image from the transformers project; no auth required.) DEFAULT_IMAGE_URL = "https://www.ilankelman.org/stopsigns/australia.jpg" DEFAULT_VIDEO_PATH = "/path/to/your/video.mp4" # <-- replace me DEFAULT_IMAGE_PROMPT = "Describe this image in detail." DEFAULT_VIDEO_PROMPT = "Describe what happens in this video in detail." # Default model. Override with ``--model /local/path`` to use a local checkpoint. DEFAULT_MODEL = "lmms-lab-encoder/LLaVA-OneVision2-8B-Instruct" def load_image(source: str): """Load a PIL image from a local path or an http(s) URL.""" from PIL import Image if source.startswith(("http://", "https://")): import requests resp = requests.get(source, stream=True, timeout=30) resp.raise_for_status() img = Image.open(io.BytesIO(resp.content)) else: img = Image.open(source) return img.convert("RGB") def run_image(model, processor, media: str, prompt: str, max_new_tokens: int, device: str) -> str: """Caption a single image.""" image = load_image(media) messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], images=[image], return_tensors="pt", padding=True, ) inputs = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()} tok = processor.tokenizer pad_id = tok.pad_token_id or tok.eos_token_id with torch.inference_mode(): out_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, num_beams=1, use_cache=True, eos_token_id=tok.eos_token_id, pad_token_id=pad_id, ) prompt_len = inputs["input_ids"].shape[-1] new_ids = out_ids[:, prompt_len:] return tok.batch_decode(new_ids, skip_special_tokens=True)[0].strip() def run_video( model, processor, media: str, prompt: str, max_new_tokens: int, device: str, num_frames: int, max_pixels: int, ) -> str: """Caption an mp4/avi/... video file. Key processor knobs (all passed through ``__call__``): * ``num_frames`` : force exactly N uniformly-sampled frames. * ``max_frames`` : cap on auto-selected frame count (used when num_frames is None). * ``target_fps`` : sample at this FPS, capped by ``max_frames``. For memory control, lower the per-frame resolution by overriding ``processor.video_processor.max_pixels`` before calling the processor. """ if not os.path.exists(media): raise FileNotFoundError( f"Video file not found: {media!r}. Pass --media ." ) # Constrain per-frame pixel budget (memory-friendly default for a single ~80GB GPU). processor.video_processor.max_pixels = max_pixels messages = [ { "role": "user", "content": [ {"type": "video"}, {"type": "text", "text": prompt}, ], } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], videos=[media], return_tensors="pt", padding=True, num_frames=num_frames, # force exactly N frames ) inputs = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()} tok = processor.tokenizer pad_id = tok.pad_token_id or tok.eos_token_id with torch.inference_mode(): out_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, num_beams=1, use_cache=True, eos_token_id=tok.eos_token_id, pad_token_id=pad_id, ) prompt_len = inputs["input_ids"].shape[-1] new_ids = out_ids[:, prompt_len:] return tok.batch_decode(new_ids, skip_special_tokens=True)[0].strip() def main(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument( "--model", default=DEFAULT_MODEL, help=f"HF repo id or local path to the model checkpoint (default: {DEFAULT_MODEL}).", ) parser.add_argument( "--mode", choices=["image", "video"], default="image", help="Inference mode (default: image).", ) parser.add_argument( "--media", default=None, help=( "Image path/URL (image mode) or video path (video mode). " f"Defaults: image={DEFAULT_IMAGE_URL!r}, video={DEFAULT_VIDEO_PATH!r}." ), ) parser.add_argument("--prompt", default=None, help="User prompt sent alongside the media.") parser.add_argument("--max-new-tokens", type=int, default=256) parser.add_argument( "--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to load the model on.", ) parser.add_argument( "--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"], help="Model dtype.", ) # Video-only knobs (ignored in image mode). parser.add_argument( "--num-frames", type=int, default=16, help="[video] Number of frames to sample (default: 16).", ) parser.add_argument( "--max-pixels", type=int, default=200704, help="[video] Per-frame max pixel count (default: 200704 = 448*448).", ) args = parser.parse_args() # Defaults that depend on mode. if args.media is None: args.media = DEFAULT_IMAGE_URL if args.mode == "image" else DEFAULT_VIDEO_PATH if args.prompt is None: args.prompt = DEFAULT_IMAGE_PROMPT if args.mode == "image" else DEFAULT_VIDEO_PROMPT dtype = getattr(torch, args.dtype) from transformers import AutoModelForImageTextToText, AutoProcessor print(f"[demo_inference] Loading processor from: {args.model}", flush=True) processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) print(f"[demo_inference] Loading model on {args.device} ({args.dtype})...", flush=True) model = AutoModelForImageTextToText.from_pretrained( args.model, trust_remote_code=True, dtype=dtype, device_map=args.device, ) model.eval() print(f"[demo_inference] Mode={args.mode} media={args.media}", flush=True) if args.mode == "image": caption = run_image( model, processor, args.media, args.prompt, args.max_new_tokens, args.device, ) else: caption = run_video( model, processor, args.media, args.prompt, args.max_new_tokens, args.device, num_frames=args.num_frames, max_pixels=args.max_pixels, ) print("\n========== OUTPUT ==========") print(caption) print("============================") return 0 if __name__ == "__main__": sys.exit(main())