LLaVA-OneVision-2-8B-Instruct / demo_inference.py
yiyexy's picture
Upload folder using huggingface_hub
0379b48 verified
"""End-to-end inference demo for LlavaOnevision2 (image + video).
This script shows the two canonical inference paths supported by the model:
* Image captioning (``--mode image``, default)
* Video captioning (``--mode video``)
Both modes share the same loading pattern:
from transformers import AutoProcessor, AutoModelForImageTextToText
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
model_dir, trust_remote_code=True, dtype=torch.bfloat16, device_map="cuda",
)
Examples
--------
# Image (default sample image from the web)
python demo_inference.py
# Image with a local file and a custom prompt
python demo_inference.py --mode image --media /path/to/cat.jpg \
--prompt "What is the cat doing?"
# Video
# - ``--num-frames`` selects exactly N frames (uniform sampling).
# - ``--max-pixels`` caps each frame's pixel budget. Lower it to fit smaller
# GPUs; 200704 (=448*448) is a safe default for a single ~80GB card.
python demo_inference.py --mode video --media /path/to/clip.mp4 \
--num-frames 16 --max-pixels 200704 \
--prompt "Describe what happens in this video."
Tested with:
transformers == 5.7.0
torch >= 2.4
decord, Pillow, requests
"""
from __future__ import annotations
import argparse
import io
import os
import sys
import torch
# Placeholder constants so the user can swap their own media in easily.
# (Public sample image from the transformers project; no auth required.)
DEFAULT_IMAGE_URL = "https://www.ilankelman.org/stopsigns/australia.jpg"
DEFAULT_VIDEO_PATH = "/path/to/your/video.mp4" # <-- replace me
DEFAULT_IMAGE_PROMPT = "Describe this image in detail."
DEFAULT_VIDEO_PROMPT = "Describe what happens in this video in detail."
# Default model. Override with ``--model /local/path`` to use a local checkpoint.
DEFAULT_MODEL = "lmms-lab-encoder/LLaVA-OneVision2-8B-Instruct"
def load_image(source: str):
"""Load a PIL image from a local path or an http(s) URL."""
from PIL import Image
if source.startswith(("http://", "https://")):
import requests
resp = requests.get(source, stream=True, timeout=30)
resp.raise_for_status()
img = Image.open(io.BytesIO(resp.content))
else:
img = Image.open(source)
return img.convert("RGB")
def run_image(model, processor, media: str, prompt: str, max_new_tokens: int, device: str) -> str:
"""Caption a single image."""
image = load_image(media)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[image],
return_tensors="pt",
padding=True,
)
inputs = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}
tok = processor.tokenizer
pad_id = tok.pad_token_id or tok.eos_token_id
with torch.inference_mode():
out_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=1,
use_cache=True,
eos_token_id=tok.eos_token_id,
pad_token_id=pad_id,
)
prompt_len = inputs["input_ids"].shape[-1]
new_ids = out_ids[:, prompt_len:]
return tok.batch_decode(new_ids, skip_special_tokens=True)[0].strip()
def run_video(
model,
processor,
media: str,
prompt: str,
max_new_tokens: int,
device: str,
num_frames: int,
max_pixels: int,
) -> str:
"""Caption an mp4/avi/... video file.
Key processor knobs (all passed through ``__call__``):
* ``num_frames`` : force exactly N uniformly-sampled frames.
* ``max_frames`` : cap on auto-selected frame count (used when num_frames is None).
* ``target_fps`` : sample at this FPS, capped by ``max_frames``.
For memory control, lower the per-frame resolution by overriding
``processor.video_processor.max_pixels`` before calling the processor.
"""
if not os.path.exists(media):
raise FileNotFoundError(
f"Video file not found: {media!r}. Pass --media <path/to/video.mp4>."
)
# Constrain per-frame pixel budget (memory-friendly default for a single ~80GB GPU).
processor.video_processor.max_pixels = max_pixels
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
videos=[media],
return_tensors="pt",
padding=True,
num_frames=num_frames, # force exactly N frames
)
inputs = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}
tok = processor.tokenizer
pad_id = tok.pad_token_id or tok.eos_token_id
with torch.inference_mode():
out_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=1,
use_cache=True,
eos_token_id=tok.eos_token_id,
pad_token_id=pad_id,
)
prompt_len = inputs["input_ids"].shape[-1]
new_ids = out_ids[:, prompt_len:]
return tok.batch_decode(new_ids, skip_special_tokens=True)[0].strip()
def main():
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"--model",
default=DEFAULT_MODEL,
help=f"HF repo id or local path to the model checkpoint (default: {DEFAULT_MODEL}).",
)
parser.add_argument(
"--mode",
choices=["image", "video"],
default="image",
help="Inference mode (default: image).",
)
parser.add_argument(
"--media",
default=None,
help=(
"Image path/URL (image mode) or video path (video mode). "
f"Defaults: image={DEFAULT_IMAGE_URL!r}, video={DEFAULT_VIDEO_PATH!r}."
),
)
parser.add_argument("--prompt", default=None, help="User prompt sent alongside the media.")
parser.add_argument("--max-new-tokens", type=int, default=256)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to load the model on.",
)
parser.add_argument(
"--dtype",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Model dtype.",
)
# Video-only knobs (ignored in image mode).
parser.add_argument(
"--num-frames",
type=int,
default=16,
help="[video] Number of frames to sample (default: 16).",
)
parser.add_argument(
"--max-pixels",
type=int,
default=200704,
help="[video] Per-frame max pixel count (default: 200704 = 448*448).",
)
args = parser.parse_args()
# Defaults that depend on mode.
if args.media is None:
args.media = DEFAULT_IMAGE_URL if args.mode == "image" else DEFAULT_VIDEO_PATH
if args.prompt is None:
args.prompt = DEFAULT_IMAGE_PROMPT if args.mode == "image" else DEFAULT_VIDEO_PROMPT
dtype = getattr(torch, args.dtype)
from transformers import AutoModelForImageTextToText, AutoProcessor
print(f"[demo_inference] Loading processor from: {args.model}", flush=True)
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
print(f"[demo_inference] Loading model on {args.device} ({args.dtype})...", flush=True)
model = AutoModelForImageTextToText.from_pretrained(
args.model,
trust_remote_code=True,
dtype=dtype,
device_map=args.device,
)
model.eval()
print(f"[demo_inference] Mode={args.mode} media={args.media}", flush=True)
if args.mode == "image":
caption = run_image(
model, processor, args.media, args.prompt, args.max_new_tokens, args.device,
)
else:
caption = run_video(
model, processor, args.media, args.prompt, args.max_new_tokens, args.device,
num_frames=args.num_frames, max_pixels=args.max_pixels,
)
print("\n========== OUTPUT ==========")
print(caption)
print("============================")
return 0
if __name__ == "__main__":
sys.exit(main())