Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/vllm/assets/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/audio.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/image.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/video.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/assets/audio.py +33 -0
- .venv/lib/python3.11/site-packages/vllm/assets/base.py +40 -0
- .venv/lib/python3.11/site-packages/vllm/assets/image.py +31 -0
- .venv/lib/python3.11/site-packages/vllm/assets/video.py +84 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/__init__.py +33 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/hasher.py +102 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/image.py +139 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/inputs.py +741 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/parse.py +368 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/processing.py +1295 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/profiling.py +209 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/registry.py +458 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/utils.py +518 -0
- .venv/lib/python3.11/site-packages/vllm/multimodal/video.py +191 -0
- .venv/lib/python3.11/site-packages/vllm/triton_utils/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/custom_cache_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/importing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cache_engine.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_enc_dec_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_pooling_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/enc_dec_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner_base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_tpu_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/pooling_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker_base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_model_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_worker.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/vllm/assets/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/audio.cpython-311.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/image.cpython-311.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/video.cpython-311.pyc
ADDED
|
Binary file (4.79 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/assets/audio.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Literal
|
| 5 |
+
from urllib.parse import urljoin
|
| 6 |
+
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
|
| 9 |
+
from vllm.utils import PlaceholderModule
|
| 10 |
+
|
| 11 |
+
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import librosa
|
| 15 |
+
except ImportError:
|
| 16 |
+
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
| 17 |
+
|
| 18 |
+
ASSET_DIR = "multimodal_asset"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class AudioAsset:
|
| 23 |
+
name: Literal["winning_call", "mary_had_lamb"]
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
| 27 |
+
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
|
| 28 |
+
s3_prefix=ASSET_DIR)
|
| 29 |
+
return librosa.load(audio_path, sr=None)
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def url(self) -> str:
|
| 33 |
+
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
.venv/lib/python3.11/site-packages/vllm/assets/base.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import vllm.envs as envs
|
| 8 |
+
from vllm.connections import global_http_connection
|
| 9 |
+
|
| 10 |
+
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_cache_dir() -> Path:
|
| 14 |
+
"""Get the path to the cache for storing downloaded assets."""
|
| 15 |
+
path = Path(envs.VLLM_ASSETS_CACHE)
|
| 16 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
return path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@lru_cache
|
| 22 |
+
def get_vllm_public_assets(filename: str,
|
| 23 |
+
s3_prefix: Optional[str] = None) -> Path:
|
| 24 |
+
"""
|
| 25 |
+
Download an asset file from ``s3://vllm-public-assets``
|
| 26 |
+
and return the path to the downloaded file.
|
| 27 |
+
"""
|
| 28 |
+
asset_directory = get_cache_dir() / "vllm_public_assets"
|
| 29 |
+
asset_directory.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
asset_path = asset_directory / filename
|
| 32 |
+
if not asset_path.exists():
|
| 33 |
+
if s3_prefix is not None:
|
| 34 |
+
filename = s3_prefix + "/" + filename
|
| 35 |
+
global_http_connection.download_file(
|
| 36 |
+
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
| 37 |
+
asset_path,
|
| 38 |
+
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
|
| 39 |
+
|
| 40 |
+
return asset_path
|
.venv/lib/python3.11/site-packages/vllm/assets/image.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from .base import get_vllm_public_assets
|
| 10 |
+
|
| 11 |
+
VLM_IMAGES_DIR = "vision_model_images"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class ImageAsset:
|
| 16 |
+
name: Literal["stop_sign", "cherry_blossom"]
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def pil_image(self) -> Image.Image:
|
| 20 |
+
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
|
| 21 |
+
s3_prefix=VLM_IMAGES_DIR)
|
| 22 |
+
return Image.open(image_path)
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def image_embeds(self) -> torch.Tensor:
|
| 26 |
+
"""
|
| 27 |
+
Image embeddings, only used for testing purposes with llava 1.5.
|
| 28 |
+
"""
|
| 29 |
+
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
|
| 30 |
+
s3_prefix=VLM_IMAGES_DIR)
|
| 31 |
+
return torch.load(image_path, map_location="cpu", weights_only=True)
|
.venv/lib/python3.11/site-packages/vllm/assets/video.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from typing import List, Literal
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import numpy.typing as npt
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from vllm.multimodal.video import sample_frames_from_video
|
| 14 |
+
|
| 15 |
+
from .base import get_cache_dir
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@lru_cache
|
| 19 |
+
def download_video_asset(filename: str) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Download and open an image from huggingface
|
| 22 |
+
repo: raushan-testing-hf/videos-test
|
| 23 |
+
"""
|
| 24 |
+
video_directory = get_cache_dir() / "video-example-data"
|
| 25 |
+
video_directory.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
|
| 27 |
+
video_path = video_directory / filename
|
| 28 |
+
video_path_str = str(video_path)
|
| 29 |
+
if not video_path.exists():
|
| 30 |
+
video_path_str = hf_hub_download(
|
| 31 |
+
repo_id="raushan-testing-hf/videos-test",
|
| 32 |
+
filename=filename,
|
| 33 |
+
repo_type="dataset",
|
| 34 |
+
cache_dir=video_directory,
|
| 35 |
+
)
|
| 36 |
+
return video_path_str
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
| 40 |
+
cap = cv2.VideoCapture(path)
|
| 41 |
+
if not cap.isOpened():
|
| 42 |
+
raise ValueError(f"Could not open video file {path}")
|
| 43 |
+
|
| 44 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 45 |
+
frames = []
|
| 46 |
+
for i in range(total_frames):
|
| 47 |
+
ret, frame = cap.read()
|
| 48 |
+
if ret:
|
| 49 |
+
frames.append(frame)
|
| 50 |
+
cap.release()
|
| 51 |
+
|
| 52 |
+
frames = np.stack(frames)
|
| 53 |
+
frames = sample_frames_from_video(frames, num_frames)
|
| 54 |
+
if len(frames) < num_frames:
|
| 55 |
+
raise ValueError(f"Could not read enough frames from video file {path}"
|
| 56 |
+
f" (expected {num_frames} frames, got {len(frames)})")
|
| 57 |
+
return frames
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def video_to_pil_images_list(path: str,
|
| 61 |
+
num_frames: int = -1) -> List[Image.Image]:
|
| 62 |
+
frames = video_to_ndarrays(path, num_frames)
|
| 63 |
+
return [
|
| 64 |
+
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 65 |
+
for frame in frames
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass(frozen=True)
|
| 70 |
+
class VideoAsset:
|
| 71 |
+
name: Literal["sample_demo_1.mp4"]
|
| 72 |
+
num_frames: int = -1
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def pil_images(self) -> List[Image.Image]:
|
| 76 |
+
video_path = download_video_asset(self.name)
|
| 77 |
+
ret = video_to_pil_images_list(video_path, self.num_frames)
|
| 78 |
+
return ret
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def np_ndarrays(self) -> npt.NDArray:
|
| 82 |
+
video_path = download_video_asset(self.name)
|
| 83 |
+
ret = video_to_ndarrays(video_path, self.num_frames)
|
| 84 |
+
return ret
|
.venv/lib/python3.11/site-packages/vllm/multimodal/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from .base import MultiModalPlaceholderMap, MultiModalPlugin
|
| 4 |
+
from .hasher import MultiModalHashDict, MultiModalHasher
|
| 5 |
+
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
| 6 |
+
MultiModalDataDict, MultiModalKwargs,
|
| 7 |
+
MultiModalPlaceholderDict, NestedTensors)
|
| 8 |
+
from .registry import MultiModalRegistry
|
| 9 |
+
|
| 10 |
+
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
| 11 |
+
"""
|
| 12 |
+
The global :class:`~MultiModalRegistry` is used by model runners to
|
| 13 |
+
dispatch data processing according to the target model.
|
| 14 |
+
|
| 15 |
+
See also:
|
| 16 |
+
:ref:`mm-processing`
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"BatchedTensorInputs",
|
| 21 |
+
"ModalityData",
|
| 22 |
+
"MultiModalDataBuiltins",
|
| 23 |
+
"MultiModalDataDict",
|
| 24 |
+
"MultiModalHashDict",
|
| 25 |
+
"MultiModalHasher",
|
| 26 |
+
"MultiModalKwargs",
|
| 27 |
+
"MultiModalPlaceholderDict",
|
| 28 |
+
"MultiModalPlaceholderMap",
|
| 29 |
+
"MultiModalPlugin",
|
| 30 |
+
"NestedTensors",
|
| 31 |
+
"MULTIMODAL_REGISTRY",
|
| 32 |
+
"MultiModalRegistry",
|
| 33 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/multimodal/hasher.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import pickle
|
| 4 |
+
from typing import TYPE_CHECKING, Iterable, Mapping, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from blake3 import blake3
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from vllm.logger import init_logger
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from vllm.inputs import TokensPrompt
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
MultiModalHashDict = Mapping[str, list[str]]
|
| 19 |
+
"""
|
| 20 |
+
A dictionary containing hashes for items in each modality.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MultiModalHasher:
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def serialize_item(cls, obj: object) -> bytes:
|
| 28 |
+
# Simple cases
|
| 29 |
+
if isinstance(obj, str):
|
| 30 |
+
return obj.encode("utf-8")
|
| 31 |
+
if isinstance(obj, bytes):
|
| 32 |
+
return obj
|
| 33 |
+
if isinstance(obj, Image.Image):
|
| 34 |
+
return obj.tobytes()
|
| 35 |
+
|
| 36 |
+
# Convertible to NumPy arrays
|
| 37 |
+
if isinstance(obj, torch.Tensor):
|
| 38 |
+
obj = obj.numpy()
|
| 39 |
+
if isinstance(obj, (int, float)):
|
| 40 |
+
obj = np.array(obj)
|
| 41 |
+
if isinstance(obj, np.ndarray):
|
| 42 |
+
return obj.tobytes()
|
| 43 |
+
|
| 44 |
+
logger.warning(
|
| 45 |
+
"No serialization method found for %s. "
|
| 46 |
+
"Falling back to pickle.", type(obj))
|
| 47 |
+
|
| 48 |
+
return pickle.dumps(obj)
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def item_to_bytes(
|
| 52 |
+
cls,
|
| 53 |
+
key: str,
|
| 54 |
+
obj: object,
|
| 55 |
+
) -> Iterable[tuple[bytes, bytes]]:
|
| 56 |
+
# Recursive cases
|
| 57 |
+
if isinstance(obj, (list, tuple)):
|
| 58 |
+
for i, elem in enumerate(obj):
|
| 59 |
+
yield from cls.item_to_bytes(f"{key}.{i}", elem)
|
| 60 |
+
elif isinstance(obj, dict):
|
| 61 |
+
for k, v in obj.items():
|
| 62 |
+
yield from cls.item_to_bytes(f"{key}.{k}", v)
|
| 63 |
+
else:
|
| 64 |
+
key_bytes = cls.serialize_item(key)
|
| 65 |
+
value_bytes = cls.serialize_item(obj)
|
| 66 |
+
yield key_bytes, value_bytes
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def hash_kwargs(cls, **kwargs: object) -> str:
|
| 70 |
+
hasher = blake3()
|
| 71 |
+
|
| 72 |
+
for k, v in kwargs.items():
|
| 73 |
+
for k_bytes, v_bytes in cls.item_to_bytes(k, v):
|
| 74 |
+
hasher.update(k_bytes)
|
| 75 |
+
hasher.update(v_bytes)
|
| 76 |
+
|
| 77 |
+
return hasher.hexdigest()
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def hash_prompt_mm_data(
|
| 81 |
+
cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
|
| 82 |
+
"""Hash multimodal data in the user input prompt if they exist."""
|
| 83 |
+
|
| 84 |
+
if "multi_modal_data" not in prompt:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
mm_data = prompt["multi_modal_data"]
|
| 88 |
+
if not mm_data:
|
| 89 |
+
# mm_data can be None or an empty dict.
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
mm_items = {
|
| 93 |
+
modality: items if isinstance(items, list) else [items]
|
| 94 |
+
for modality, items in mm_data.items()
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
mm_hashes = {
|
| 98 |
+
modality: [cls.hash_kwargs(**{modality: item}) for item in items]
|
| 99 |
+
for modality, items in mm_items.items()
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
return mm_hashes
|
.venv/lib/python3.11/site-packages/vllm/multimodal/image.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from vllm.inputs.registry import InputContext
|
| 13 |
+
from vllm.logger import init_logger
|
| 14 |
+
from vllm.transformers_utils.processor import get_image_processor
|
| 15 |
+
from vllm.utils import is_list_of
|
| 16 |
+
|
| 17 |
+
from .base import MediaIO, MultiModalPlugin
|
| 18 |
+
from .inputs import ImageItem, ModalityData, MultiModalKwargs
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from vllm.config import ModelConfig
|
| 22 |
+
|
| 23 |
+
logger = init_logger(__name__)
|
| 24 |
+
|
| 25 |
+
cached_get_image_processor = lru_cache(get_image_processor)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ImagePlugin(MultiModalPlugin):
|
| 29 |
+
"""Plugin for image data."""
|
| 30 |
+
|
| 31 |
+
def get_data_key(self) -> str:
|
| 32 |
+
return "image"
|
| 33 |
+
|
| 34 |
+
def _get_hf_image_processor(
|
| 35 |
+
self,
|
| 36 |
+
model_config: "ModelConfig",
|
| 37 |
+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
| 38 |
+
):
|
| 39 |
+
if mm_processor_kwargs is None:
|
| 40 |
+
mm_processor_kwargs = {}
|
| 41 |
+
return cached_get_image_processor(
|
| 42 |
+
model_config.model,
|
| 43 |
+
trust_remote_code=model_config.trust_remote_code,
|
| 44 |
+
**mm_processor_kwargs)
|
| 45 |
+
|
| 46 |
+
def _default_input_mapper(
|
| 47 |
+
self,
|
| 48 |
+
ctx: InputContext,
|
| 49 |
+
data: ModalityData[ImageItem],
|
| 50 |
+
**mm_processor_kwargs,
|
| 51 |
+
) -> MultiModalKwargs:
|
| 52 |
+
model_config = ctx.model_config
|
| 53 |
+
|
| 54 |
+
# PIL image
|
| 55 |
+
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
| 56 |
+
image_processor = self._get_hf_image_processor(
|
| 57 |
+
model_config,
|
| 58 |
+
mm_processor_kwargs,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if image_processor is None:
|
| 62 |
+
raise RuntimeError("No HuggingFace processor is available "
|
| 63 |
+
"to process the image object")
|
| 64 |
+
try:
|
| 65 |
+
# NOTE: It may make sense to forward the mm_processor_kwargs
|
| 66 |
+
# here too. For now, to keep it simple, we only allow it be
|
| 67 |
+
# used for the initialization call though, just in case the
|
| 68 |
+
# signatures of the preprocessor initializer don't match
|
| 69 |
+
# preprocess()
|
| 70 |
+
batch_data = image_processor \
|
| 71 |
+
.preprocess(data, return_tensors="pt") \
|
| 72 |
+
.data
|
| 73 |
+
except Exception:
|
| 74 |
+
logger.error(
|
| 75 |
+
"Failed to process image (%s) with the default mapper. "
|
| 76 |
+
"This is most likely an edge-case with this model's image "
|
| 77 |
+
"processor in transformers (type: %s), and not vLLM.",
|
| 78 |
+
data,
|
| 79 |
+
type(image_processor).__name__)
|
| 80 |
+
raise
|
| 81 |
+
|
| 82 |
+
return MultiModalKwargs(batch_data)
|
| 83 |
+
|
| 84 |
+
# Image embedding
|
| 85 |
+
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
|
| 86 |
+
return MultiModalKwargs({"image_embeds": data})
|
| 87 |
+
|
| 88 |
+
raise TypeError(f"Invalid image type: {type(data)}")
|
| 89 |
+
|
| 90 |
+
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
| 91 |
+
return 3000
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def rescale_image_size(image: Image.Image,
|
| 95 |
+
size_factor: float,
|
| 96 |
+
transpose: int = -1) -> Image.Image:
|
| 97 |
+
"""Rescale the dimensions of an image by a constant factor."""
|
| 98 |
+
new_width = int(image.width * size_factor)
|
| 99 |
+
new_height = int(image.height * size_factor)
|
| 100 |
+
image = image.resize((new_width, new_height))
|
| 101 |
+
if transpose >= 0:
|
| 102 |
+
image = image.transpose(Image.Transpose(transpose))
|
| 103 |
+
return image
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ImageMediaIO(MediaIO[Image.Image]):
|
| 107 |
+
|
| 108 |
+
def __init__(self, *, image_mode: str = "RGB") -> None:
|
| 109 |
+
super().__init__()
|
| 110 |
+
|
| 111 |
+
self.image_mode = image_mode
|
| 112 |
+
|
| 113 |
+
def load_bytes(self, data: bytes) -> Image.Image:
|
| 114 |
+
image = Image.open(BytesIO(data))
|
| 115 |
+
image.load()
|
| 116 |
+
return image.convert(self.image_mode)
|
| 117 |
+
|
| 118 |
+
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
| 119 |
+
return self.load_bytes(base64.b64decode(data))
|
| 120 |
+
|
| 121 |
+
def load_file(self, filepath: Path) -> Image.Image:
|
| 122 |
+
image = Image.open(filepath)
|
| 123 |
+
image.load()
|
| 124 |
+
return image.convert(self.image_mode)
|
| 125 |
+
|
| 126 |
+
def encode_base64(
|
| 127 |
+
self,
|
| 128 |
+
media: Image.Image,
|
| 129 |
+
*,
|
| 130 |
+
image_format: str = "JPEG",
|
| 131 |
+
) -> str:
|
| 132 |
+
image = media
|
| 133 |
+
|
| 134 |
+
with BytesIO() as buffer:
|
| 135 |
+
image = image.convert(self.image_mode)
|
| 136 |
+
image.save(buffer, image_format)
|
| 137 |
+
data = buffer.getvalue()
|
| 138 |
+
|
| 139 |
+
return base64.b64encode(data).decode('utf-8')
|
.venv/lib/python3.11/site-packages/vllm/multimodal/inputs.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from collections import UserDict, defaultdict
|
| 5 |
+
from collections.abc import Mapping, Sequence
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from functools import partial
|
| 8 |
+
from itertools import accumulate
|
| 9 |
+
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
| 10 |
+
Union, cast, final)
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.types
|
| 15 |
+
from PIL.Image import Image
|
| 16 |
+
from transformers import BatchFeature
|
| 17 |
+
from typing_extensions import NotRequired, TypeAlias
|
| 18 |
+
|
| 19 |
+
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from .hasher import MultiModalHashDict
|
| 23 |
+
|
| 24 |
+
_T = TypeVar("_T")
|
| 25 |
+
|
| 26 |
+
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
|
| 27 |
+
"""
|
| 28 |
+
A :class:`transformers.image_utils.ImageInput` representing a single image
|
| 29 |
+
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
|
| 33 |
+
list[np.ndarray], list[torch.Tensor]]
|
| 34 |
+
"""
|
| 35 |
+
A :class:`transformers.image_utils.VideoInput` representing a single video
|
| 36 |
+
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
|
| 40 |
+
"""
|
| 41 |
+
Represents a single audio
|
| 42 |
+
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
|
| 46 |
+
"""
|
| 47 |
+
A :class:`transformers.image_utils.ImageInput` representing a single image
|
| 48 |
+
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
|
| 49 |
+
|
| 50 |
+
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
| 51 |
+
which are treated as image embeddings;
|
| 52 |
+
these are directly passed to the model without HF processing.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
|
| 56 |
+
"""
|
| 57 |
+
A :class:`transformers.image_utils.VideoInput` representing a single video
|
| 58 |
+
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
|
| 59 |
+
|
| 60 |
+
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
| 61 |
+
which are treated as video embeddings;
|
| 62 |
+
these are directly passed to the model without HF processing.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
|
| 66 |
+
torch.Tensor]
|
| 67 |
+
"""
|
| 68 |
+
Represents a single audio
|
| 69 |
+
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
|
| 70 |
+
|
| 71 |
+
Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
|
| 72 |
+
is different from that expected by the model;
|
| 73 |
+
these are resampled to the model's sampling rate before being processed by HF.
|
| 74 |
+
|
| 75 |
+
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
| 76 |
+
which are treated as audio embeddings;
|
| 77 |
+
these are directly passed to the model without HF processing.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
ModalityData: TypeAlias = Union[_T, list[_T]]
|
| 81 |
+
"""
|
| 82 |
+
Either a single data item, or a list of data items.
|
| 83 |
+
|
| 84 |
+
The number of data items allowed per modality is restricted by
|
| 85 |
+
:code:`--limit-mm-per-prompt`.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@final
|
| 90 |
+
class MultiModalDataBuiltins(TypedDict, total=False):
|
| 91 |
+
"""Type annotations for modality types predefined by vLLM."""
|
| 92 |
+
|
| 93 |
+
image: ModalityData[ImageItem]
|
| 94 |
+
"""The input image(s)."""
|
| 95 |
+
|
| 96 |
+
video: ModalityData[VideoItem]
|
| 97 |
+
"""The input video(s)."""
|
| 98 |
+
|
| 99 |
+
audio: ModalityData[AudioItem]
|
| 100 |
+
"""The input audio(s)."""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
|
| 104 |
+
"""
|
| 105 |
+
A dictionary containing an entry for each modality type to input.
|
| 106 |
+
|
| 107 |
+
The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class PlaceholderRange(TypedDict):
|
| 112 |
+
"""
|
| 113 |
+
Placeholder location information for multi-modal data.
|
| 114 |
+
|
| 115 |
+
Example:
|
| 116 |
+
|
| 117 |
+
Prompt: :code:`AAAA BBBB What is in these images?`
|
| 118 |
+
|
| 119 |
+
Images A and B will have:
|
| 120 |
+
|
| 121 |
+
.. code-block::
|
| 122 |
+
|
| 123 |
+
A: { "offset": 0, "length": 4 }
|
| 124 |
+
B: { "offset": 5, "length": 4 }
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
offset: int
|
| 128 |
+
"""The start index of the placeholder in the prompt."""
|
| 129 |
+
|
| 130 |
+
length: int
|
| 131 |
+
"""The length of the placeholder."""
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
|
| 135 |
+
tuple[torch.Tensor, ...]]
|
| 136 |
+
"""
|
| 137 |
+
Uses a list instead of a tensor if the dimensions of each element do not match.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
| 142 |
+
"""Equality check between :data:`NestedTensors` objects."""
|
| 143 |
+
if isinstance(a, torch.Tensor):
|
| 144 |
+
return isinstance(b, torch.Tensor) and torch.equal(a, b)
|
| 145 |
+
elif isinstance(b, torch.Tensor):
|
| 146 |
+
return isinstance(a, torch.Tensor) and torch.equal(b, a)
|
| 147 |
+
|
| 148 |
+
if isinstance(a, list):
|
| 149 |
+
return (isinstance(b, list)
|
| 150 |
+
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
|
| 151 |
+
if isinstance(b, list):
|
| 152 |
+
return (isinstance(a, list)
|
| 153 |
+
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))
|
| 154 |
+
|
| 155 |
+
# Both a and b are scalars
|
| 156 |
+
return a == b
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
|
| 160 |
+
"""
|
| 161 |
+
A dictionary containing nested tensors which have been batched via
|
| 162 |
+
:meth:`MultiModalKwargs.batch`.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass(frozen=True)
|
| 167 |
+
class MultiModalFieldElem:
|
| 168 |
+
"""
|
| 169 |
+
Represents a keyword argument corresponding to a multi-modal item
|
| 170 |
+
in :class:`MultiModalKwargs`.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
modality: str
|
| 174 |
+
"""
|
| 175 |
+
The modality of the corresponding multi-modal item.
|
| 176 |
+
Each multi-modal item can consist of multiple keyword arguments.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
key: str
|
| 180 |
+
"""
|
| 181 |
+
The key of this field in :class:`MultiModalKwargs`,
|
| 182 |
+
i.e. the name of the keyword argument to be passed to the model.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
data: NestedTensors
|
| 186 |
+
"""
|
| 187 |
+
The tensor data of this field in :class:`MultiModalKwargs`,
|
| 188 |
+
i.e. the value of the keyword argument to be passed to the model.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
field: "BaseMultiModalField"
|
| 192 |
+
"""
|
| 193 |
+
Defines how to combine the tensor data of this field with others
|
| 194 |
+
in order to batch multi-modal items together for model inference.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __eq__(self, other: object) -> bool:
|
| 198 |
+
if not isinstance(other, self.__class__):
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
return ((self.modality, self.key) == (other.modality, other.key)
|
| 202 |
+
and nested_tensors_equal(self.data, other.data)
|
| 203 |
+
and type(self.field) == type(other.field)) # noqa: E721
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@dataclass(frozen=True)
|
| 207 |
+
class BaseMultiModalField(ABC):
|
| 208 |
+
"""
|
| 209 |
+
Defines how to interpret tensor data belonging to a keyword argument in
|
| 210 |
+
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def _field_factory(self, *, modality: str, key: str):
|
| 214 |
+
f = partial(
|
| 215 |
+
MultiModalFieldElem,
|
| 216 |
+
modality=modality,
|
| 217 |
+
key=key,
|
| 218 |
+
field=self,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Allow passing data as positional argument
|
| 222 |
+
def factory(data: NestedTensors) -> MultiModalFieldElem:
|
| 223 |
+
return f(data=data)
|
| 224 |
+
|
| 225 |
+
return factory
|
| 226 |
+
|
| 227 |
+
@abstractmethod
|
| 228 |
+
def build_elems(
|
| 229 |
+
self,
|
| 230 |
+
modality: str,
|
| 231 |
+
key: str,
|
| 232 |
+
data: NestedTensors,
|
| 233 |
+
) -> Sequence[MultiModalFieldElem]:
|
| 234 |
+
"""
|
| 235 |
+
Construct :class:`MultiModalFieldElem` instances to represent
|
| 236 |
+
the provided data.
|
| 237 |
+
|
| 238 |
+
This is the inverse of :meth:`reduce_data`.
|
| 239 |
+
"""
|
| 240 |
+
raise NotImplementedError
|
| 241 |
+
|
| 242 |
+
@abstractmethod
|
| 243 |
+
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
| 244 |
+
raise NotImplementedError
|
| 245 |
+
|
| 246 |
+
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
|
| 247 |
+
"""
|
| 248 |
+
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
|
| 249 |
+
|
| 250 |
+
This is the inverse of :meth:`build_elems`.
|
| 251 |
+
"""
|
| 252 |
+
field_types = [type(item.field) for item in elems]
|
| 253 |
+
if len(set(field_types)) > 1:
|
| 254 |
+
raise ValueError(f"Cannot merge different {field_types=}")
|
| 255 |
+
|
| 256 |
+
return self._reduce_data([item.data for item in elems])
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@dataclass(frozen=True)
|
| 260 |
+
class MultiModalBatchedField(BaseMultiModalField):
|
| 261 |
+
"""
|
| 262 |
+
See also:
|
| 263 |
+
:func:`MultiModalFieldConfig.batched`
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def build_elems(
|
| 267 |
+
self,
|
| 268 |
+
modality: str,
|
| 269 |
+
key: str,
|
| 270 |
+
data: NestedTensors,
|
| 271 |
+
) -> Sequence[MultiModalFieldElem]:
|
| 272 |
+
field_factory = self._field_factory(modality=modality, key=key)
|
| 273 |
+
return [field_factory(item) for item in data]
|
| 274 |
+
|
| 275 |
+
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
| 276 |
+
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
| 277 |
+
if len(batch) == 1:
|
| 278 |
+
# An optimization when `batch` contains only one tensor:
|
| 279 |
+
# - produce exactly same result as `torch.stack(batch)`
|
| 280 |
+
# - will achieve zero-copy if the tensor is contiguous
|
| 281 |
+
return batch[0].unsqueeze(0).contiguous()
|
| 282 |
+
first_shape = batch[0].shape
|
| 283 |
+
if all(elem.shape == first_shape for elem in batch):
|
| 284 |
+
return torch.stack(batch)
|
| 285 |
+
|
| 286 |
+
return batch
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@dataclass(frozen=True)
|
| 290 |
+
class MultiModalFlatField(BaseMultiModalField):
|
| 291 |
+
"""
|
| 292 |
+
See also:
|
| 293 |
+
:func:`MultiModalFieldConfig.flat`
|
| 294 |
+
:func:`MultiModalFieldConfig.flat_from_sizes`
|
| 295 |
+
"""
|
| 296 |
+
slices: Sequence[slice]
|
| 297 |
+
|
| 298 |
+
def build_elems(
|
| 299 |
+
self,
|
| 300 |
+
modality: str,
|
| 301 |
+
key: str,
|
| 302 |
+
data: NestedTensors,
|
| 303 |
+
) -> Sequence[MultiModalFieldElem]:
|
| 304 |
+
field_factory = self._field_factory(modality=modality, key=key)
|
| 305 |
+
return [field_factory(data[s]) for s in self.slices]
|
| 306 |
+
|
| 307 |
+
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
| 308 |
+
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
| 309 |
+
if len(batch) == 1:
|
| 310 |
+
# An optimization when `batch` contains only one tensor:
|
| 311 |
+
# - produce exactly same result as `torch.concat(batch)`
|
| 312 |
+
# - will achieve zero-copy if the tensor is contiguous
|
| 313 |
+
return batch[0].contiguous()
|
| 314 |
+
first_shape = batch[0].shape
|
| 315 |
+
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
|
| 316 |
+
return torch.concat(batch)
|
| 317 |
+
|
| 318 |
+
return [e for elem in batch for e in elem]
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
@dataclass(frozen=True)
|
| 322 |
+
class MultiModalSharedField(BaseMultiModalField):
|
| 323 |
+
"""
|
| 324 |
+
See also:
|
| 325 |
+
:func:`MultiModalFieldConfig.shared`
|
| 326 |
+
"""
|
| 327 |
+
batch_size: int
|
| 328 |
+
|
| 329 |
+
def build_elems(
|
| 330 |
+
self,
|
| 331 |
+
modality: str,
|
| 332 |
+
key: str,
|
| 333 |
+
data: NestedTensors,
|
| 334 |
+
) -> Sequence[MultiModalFieldElem]:
|
| 335 |
+
field_factory = self._field_factory(modality=modality, key=key)
|
| 336 |
+
return [field_factory(data)] * self.batch_size
|
| 337 |
+
|
| 338 |
+
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
| 339 |
+
return batch[0]
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class MultiModalFieldConfig:
|
| 343 |
+
|
| 344 |
+
@staticmethod
|
| 345 |
+
def batched(modality: str):
|
| 346 |
+
"""
|
| 347 |
+
Defines a field where an element in the batch is obtained by
|
| 348 |
+
indexing into the first dimension of the underlying data.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
modality: The modality of the multi-modal item that uses this
|
| 352 |
+
keyword argument.
|
| 353 |
+
|
| 354 |
+
Example:
|
| 355 |
+
|
| 356 |
+
.. code-block::
|
| 357 |
+
|
| 358 |
+
Input:
|
| 359 |
+
Data: [[AAAA]
|
| 360 |
+
[BBBB]
|
| 361 |
+
[CCCC]]
|
| 362 |
+
|
| 363 |
+
Output:
|
| 364 |
+
Element 1: [AAAA]
|
| 365 |
+
Element 2: [BBBB]
|
| 366 |
+
Element 3: [CCCC]
|
| 367 |
+
"""
|
| 368 |
+
return MultiModalFieldConfig(
|
| 369 |
+
field=MultiModalBatchedField(),
|
| 370 |
+
modality=modality,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
@staticmethod
|
| 374 |
+
def flat(modality: str, slices: Sequence[slice]):
|
| 375 |
+
"""
|
| 376 |
+
Defines a field where an element in the batch is obtained by
|
| 377 |
+
slicing along the first dimension of the underlying data.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
modality: The modality of the multi-modal item that uses this
|
| 381 |
+
keyword argument.
|
| 382 |
+
slices: For each multi-modal item, a slice that is used to extract
|
| 383 |
+
the data corresponding to it.
|
| 384 |
+
|
| 385 |
+
Example:
|
| 386 |
+
|
| 387 |
+
.. code-block::
|
| 388 |
+
|
| 389 |
+
Given:
|
| 390 |
+
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
|
| 391 |
+
|
| 392 |
+
Input:
|
| 393 |
+
Data: [AAABBBBCC]
|
| 394 |
+
|
| 395 |
+
Output:
|
| 396 |
+
Element 1: [AAA]
|
| 397 |
+
Element 2: [BBBB]
|
| 398 |
+
Element 3: [CC]
|
| 399 |
+
"""
|
| 400 |
+
return MultiModalFieldConfig(
|
| 401 |
+
field=MultiModalFlatField(slices=slices),
|
| 402 |
+
modality=modality,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
@staticmethod
|
| 406 |
+
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
|
| 407 |
+
"""
|
| 408 |
+
Defines a field where an element in the batch is obtained by
|
| 409 |
+
slicing along the first dimension of the underlying data.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
modality: The modality of the multi-modal item that uses this
|
| 413 |
+
keyword argument.
|
| 414 |
+
slices: For each multi-modal item, the size of the slice that
|
| 415 |
+
is used to extract the data corresponding to it.
|
| 416 |
+
|
| 417 |
+
Example:
|
| 418 |
+
|
| 419 |
+
.. code-block::
|
| 420 |
+
|
| 421 |
+
Given:
|
| 422 |
+
size_per_item: [3, 4, 2]
|
| 423 |
+
|
| 424 |
+
Input:
|
| 425 |
+
Data: [AAABBBBCC]
|
| 426 |
+
|
| 427 |
+
Output:
|
| 428 |
+
Element 1: [AAA]
|
| 429 |
+
Element 2: [BBBB]
|
| 430 |
+
Element 3: [CC]
|
| 431 |
+
|
| 432 |
+
See also:
|
| 433 |
+
:func:`MultiModalFieldConfig.flat`
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
slice_idxs = [0, *accumulate(size_per_item)]
|
| 437 |
+
slices = [
|
| 438 |
+
slice(slice_idxs[i], slice_idxs[i + 1])
|
| 439 |
+
for i in range(len(size_per_item))
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
return MultiModalFieldConfig.flat(modality, slices)
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def shared(modality: str, batch_size: int):
|
| 446 |
+
"""
|
| 447 |
+
Defines a field where an element in the batch is obtained by
|
| 448 |
+
taking the entirety of the underlying data.
|
| 449 |
+
|
| 450 |
+
This means that the data is the same for each element in the batch.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
modality: The modality of the multi-modal item that uses this
|
| 454 |
+
keyword argument.
|
| 455 |
+
batch_size: The number of multi-modal items which share this data.
|
| 456 |
+
|
| 457 |
+
Example:
|
| 458 |
+
|
| 459 |
+
.. code-block::
|
| 460 |
+
|
| 461 |
+
Given:
|
| 462 |
+
batch_size: 4
|
| 463 |
+
|
| 464 |
+
Input:
|
| 465 |
+
Data: [XYZ]
|
| 466 |
+
|
| 467 |
+
Output:
|
| 468 |
+
Element 1: [XYZ]
|
| 469 |
+
Element 2: [XYZ]
|
| 470 |
+
Element 3: [XYZ]
|
| 471 |
+
Element 4: [XYZ]
|
| 472 |
+
"""
|
| 473 |
+
return MultiModalFieldConfig(
|
| 474 |
+
field=MultiModalSharedField(batch_size),
|
| 475 |
+
modality=modality,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
|
| 479 |
+
super().__init__()
|
| 480 |
+
|
| 481 |
+
self.field = field
|
| 482 |
+
self.modality = modality
|
| 483 |
+
|
| 484 |
+
def build_elems(
|
| 485 |
+
self,
|
| 486 |
+
key: str,
|
| 487 |
+
batch: NestedTensors,
|
| 488 |
+
) -> Sequence[MultiModalFieldElem]:
|
| 489 |
+
return self.field.build_elems(self.modality, key, batch)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
| 493 |
+
"""
|
| 494 |
+
A collection of :class:`MultiModalFieldElem`
|
| 495 |
+
corresponding to a data item in :class:`MultiModalDataItems`.
|
| 496 |
+
"""
|
| 497 |
+
|
| 498 |
+
@staticmethod
|
| 499 |
+
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
| 500 |
+
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
| 501 |
+
|
| 502 |
+
@property
|
| 503 |
+
def modality(self) -> str:
|
| 504 |
+
modalities = {elem.modality for elem in self.data.values()}
|
| 505 |
+
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
| 506 |
+
return next(iter(modalities))
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# NOTE: UserDict is for V0 compatibility.
|
| 510 |
+
# V1 should access individual items via `get_item`.
|
| 511 |
+
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
| 512 |
+
"""
|
| 513 |
+
A dictionary that represents the keyword arguments to
|
| 514 |
+
:meth:`~torch.nn.Module.forward`.
|
| 515 |
+
|
| 516 |
+
The metadata :code:`items` enables us to obtain the keyword arguments
|
| 517 |
+
corresponding to each data item in :class:`MultiModalDataItems`, via
|
| 518 |
+
:meth:`get_item` and :meth:`get_items`.
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
@staticmethod
|
| 522 |
+
def from_hf_inputs(
|
| 523 |
+
hf_inputs: BatchFeature,
|
| 524 |
+
config_by_key: Mapping[str, MultiModalFieldConfig],
|
| 525 |
+
):
|
| 526 |
+
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
| 527 |
+
# We assume that those fields are not used in vLLM
|
| 528 |
+
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
|
| 529 |
+
keys_by_modality = defaultdict[str, set[str]](set)
|
| 530 |
+
for key, config in config_by_key.items():
|
| 531 |
+
batch = hf_inputs.get(key)
|
| 532 |
+
if batch is not None:
|
| 533 |
+
elems = config.build_elems(key, batch)
|
| 534 |
+
if len(elems) > 0:
|
| 535 |
+
elems_by_key[key] = elems
|
| 536 |
+
keys_by_modality[config.modality].add(key)
|
| 537 |
+
|
| 538 |
+
items = list[MultiModalKwargsItem]()
|
| 539 |
+
for modality, keys in keys_by_modality.items():
|
| 540 |
+
elems_in_modality = {k: elems_by_key[k] for k in keys}
|
| 541 |
+
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
|
| 542 |
+
|
| 543 |
+
if len(set(batch_sizes.values())) > 1:
|
| 544 |
+
raise ValueError(
|
| 545 |
+
f"Cannot merge different batch sizes for {modality=}! "
|
| 546 |
+
f"Found: {batch_sizes=}")
|
| 547 |
+
|
| 548 |
+
batch_size = next(iter(batch_sizes.values()))
|
| 549 |
+
for item_idx in range(batch_size):
|
| 550 |
+
elems = [v[item_idx] for v in elems_in_modality.values()]
|
| 551 |
+
items.append(MultiModalKwargsItem.from_elems(elems))
|
| 552 |
+
|
| 553 |
+
return MultiModalKwargs.from_items(items)
|
| 554 |
+
|
| 555 |
+
@staticmethod
|
| 556 |
+
def from_items(items: Sequence[MultiModalKwargsItem]):
|
| 557 |
+
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
|
| 558 |
+
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
| 559 |
+
for item in items:
|
| 560 |
+
for key, elem in item.items():
|
| 561 |
+
elems_by_key[key].append(elem)
|
| 562 |
+
|
| 563 |
+
data = {
|
| 564 |
+
key: elems[0].field.reduce_data(elems)
|
| 565 |
+
for key, elems in elems_by_key.items() if len(elems) > 0
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
return MultiModalKwargs(data, items=items)
|
| 569 |
+
|
| 570 |
+
def __init__(
|
| 571 |
+
self,
|
| 572 |
+
data: Mapping[str, NestedTensors],
|
| 573 |
+
*,
|
| 574 |
+
items: Optional[Sequence[MultiModalKwargsItem]] = None,
|
| 575 |
+
) -> None:
|
| 576 |
+
super().__init__(data)
|
| 577 |
+
|
| 578 |
+
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
|
| 579 |
+
self._items_by_modality = dict(items_by_modality)
|
| 580 |
+
|
| 581 |
+
@property
|
| 582 |
+
def modalities(self):
|
| 583 |
+
return self._items_by_modality.keys()
|
| 584 |
+
|
| 585 |
+
@staticmethod
|
| 586 |
+
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
| 587 |
+
"""
|
| 588 |
+
Stack the inner dimensions that have the same shape in
|
| 589 |
+
a nested list of tensors.
|
| 590 |
+
|
| 591 |
+
Thus, a dimension represented by a list means that the inner
|
| 592 |
+
dimensions are different for each element along that dimension.
|
| 593 |
+
"""
|
| 594 |
+
if isinstance(nested_tensors, torch.Tensor):
|
| 595 |
+
return nested_tensors
|
| 596 |
+
|
| 597 |
+
# TODO: Remove these once all models have been migrated
|
| 598 |
+
if isinstance(nested_tensors, np.ndarray):
|
| 599 |
+
return torch.from_numpy(nested_tensors)
|
| 600 |
+
if isinstance(nested_tensors, (int, float)):
|
| 601 |
+
return torch.tensor(nested_tensors)
|
| 602 |
+
|
| 603 |
+
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
|
| 604 |
+
if not is_list_of(stacked, torch.Tensor, check="all"):
|
| 605 |
+
# Only tensors (not lists) can be stacked.
|
| 606 |
+
return stacked
|
| 607 |
+
|
| 608 |
+
tensors_ = cast(list[torch.Tensor], stacked)
|
| 609 |
+
if len(tensors_) == 1:
|
| 610 |
+
# An optimization when `tensors_` contains only one tensor:
|
| 611 |
+
# - produce exactly same result as `torch.stack(tensors_)`
|
| 612 |
+
# - will achieve zero-copy if the tensor is contiguous
|
| 613 |
+
return tensors_[0].unsqueeze(0).contiguous()
|
| 614 |
+
|
| 615 |
+
if any(t.shape != tensors_[0].shape for t in tensors_):
|
| 616 |
+
# The tensors have incompatible shapes and can't be stacked.
|
| 617 |
+
return tensors_
|
| 618 |
+
|
| 619 |
+
return torch.stack(tensors_)
|
| 620 |
+
|
| 621 |
+
@staticmethod
|
| 622 |
+
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
|
| 623 |
+
"""
|
| 624 |
+
Batch multiple inputs together into a dictionary.
|
| 625 |
+
|
| 626 |
+
The resulting dictionary has the same keys as the inputs.
|
| 627 |
+
If the corresponding value from each input is a tensor and they all
|
| 628 |
+
share the same shape, the output value is a single batched tensor;
|
| 629 |
+
otherwise, the output value is a list containing the original value
|
| 630 |
+
from each input.
|
| 631 |
+
"""
|
| 632 |
+
if len(inputs_list) == 0:
|
| 633 |
+
return {}
|
| 634 |
+
|
| 635 |
+
# We need to consider the case where each item in the batch
|
| 636 |
+
# contains different modalities (i.e. different keys).
|
| 637 |
+
item_lists = defaultdict[str, list[NestedTensors]](list)
|
| 638 |
+
|
| 639 |
+
for inputs in inputs_list:
|
| 640 |
+
for k, v in inputs.items():
|
| 641 |
+
item_lists[k].append(v)
|
| 642 |
+
|
| 643 |
+
return {
|
| 644 |
+
k: MultiModalKwargs._try_stack(item_list)
|
| 645 |
+
for k, item_list in item_lists.items()
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
@staticmethod
|
| 649 |
+
def as_kwargs(
|
| 650 |
+
batched_inputs: BatchedTensorInputs,
|
| 651 |
+
*,
|
| 652 |
+
device: torch.types.Device,
|
| 653 |
+
) -> BatchedTensorInputs:
|
| 654 |
+
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
| 655 |
+
|
| 656 |
+
json_mapped = json_map_leaves(
|
| 657 |
+
lambda x: x.to(device, non_blocking=True),
|
| 658 |
+
json_inputs,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
return cast(BatchedTensorInputs, json_mapped)
|
| 662 |
+
|
| 663 |
+
def __eq__(self, other: object) -> bool:
|
| 664 |
+
if not isinstance(other, self.__class__):
|
| 665 |
+
return False
|
| 666 |
+
if self._items_by_modality != other._items_by_modality:
|
| 667 |
+
return False
|
| 668 |
+
|
| 669 |
+
ks = self.keys()
|
| 670 |
+
return (ks == other.keys()
|
| 671 |
+
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
| 672 |
+
|
| 673 |
+
def _validate_modality(self, method_name: str, modality: str) -> None:
|
| 674 |
+
if not self._items_by_modality:
|
| 675 |
+
raise RuntimeError(
|
| 676 |
+
f"`{method_name}` is not supported when "
|
| 677 |
+
"MultiModalKwargs is not initialized with `items`")
|
| 678 |
+
|
| 679 |
+
if modality not in self._items_by_modality:
|
| 680 |
+
available_modalities = set(self._items_by_modality.keys())
|
| 681 |
+
raise KeyError(f"Modality {modality!r} not found. "
|
| 682 |
+
f"Available modalities: {available_modalities}")
|
| 683 |
+
|
| 684 |
+
def get_item_count(self, modality: str) -> int:
|
| 685 |
+
"""Get the number of items belonging to a modality."""
|
| 686 |
+
self._validate_modality("get_item_count", modality)
|
| 687 |
+
return len(self._items_by_modality[modality])
|
| 688 |
+
|
| 689 |
+
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
|
| 690 |
+
"""
|
| 691 |
+
Get the keyword arguments corresponding to an item identified by
|
| 692 |
+
its modality and index.
|
| 693 |
+
"""
|
| 694 |
+
self._validate_modality("get_item", modality)
|
| 695 |
+
return self._items_by_modality[modality][item_index]
|
| 696 |
+
|
| 697 |
+
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
|
| 698 |
+
"""
|
| 699 |
+
Get the keyword arguments corresponding to each item belonging to
|
| 700 |
+
a modality.
|
| 701 |
+
"""
|
| 702 |
+
self._validate_modality("get_items", modality)
|
| 703 |
+
return self._items_by_modality[modality]
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
|
| 707 |
+
"""
|
| 708 |
+
A dictionary containing placeholder ranges for each modality.
|
| 709 |
+
"""
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class MultiModalInputs(TypedDict):
|
| 713 |
+
"""
|
| 714 |
+
Represents the outputs of
|
| 715 |
+
:class:`vllm.multimodal.processing.BaseMultiModalProcessor`,
|
| 716 |
+
ready to be passed to vLLM internals.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
type: Literal["multimodal"]
|
| 720 |
+
"""The type of inputs."""
|
| 721 |
+
|
| 722 |
+
prompt: str
|
| 723 |
+
"""The processed prompt text."""
|
| 724 |
+
|
| 725 |
+
prompt_token_ids: list[int]
|
| 726 |
+
"""The processed token IDs which includes placeholder tokens."""
|
| 727 |
+
|
| 728 |
+
token_type_ids: NotRequired[list[int]]
|
| 729 |
+
"""The token type IDs of the prompt."""
|
| 730 |
+
|
| 731 |
+
mm_kwargs: MultiModalKwargs
|
| 732 |
+
"""Keyword arguments to be directly passed to the model after batching."""
|
| 733 |
+
|
| 734 |
+
mm_hashes: NotRequired[Optional["MultiModalHashDict"]]
|
| 735 |
+
"""The hashes of the multi-modal data."""
|
| 736 |
+
|
| 737 |
+
mm_placeholders: MultiModalPlaceholderDict
|
| 738 |
+
"""
|
| 739 |
+
For each modality, information about the placeholder tokens in
|
| 740 |
+
:code:`prompt_token_ids`.
|
| 741 |
+
"""
|
.venv/lib/python3.11/site-packages/vllm/multimodal/parse.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from collections import UserDict
|
| 5 |
+
from collections.abc import Callable, Iterator, Mapping, Sequence
|
| 6 |
+
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
|
| 7 |
+
Union)
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL.Image import Image
|
| 12 |
+
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
| 13 |
+
|
| 14 |
+
from vllm.utils import is_list_of
|
| 15 |
+
|
| 16 |
+
from .audio import resample_audio
|
| 17 |
+
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
| 18 |
+
ImageItem, ModalityData, MultiModalDataDict, VideoItem)
|
| 19 |
+
|
| 20 |
+
_T = TypeVar("_T")
|
| 21 |
+
_I = TypeVar("_I")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ModalityDataItems(ABC, Generic[_T, _I]):
|
| 25 |
+
"""
|
| 26 |
+
Represents data items for a modality in :class:`MultiModalDataItems`.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, data: _T, modality: str) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.data = data
|
| 33 |
+
self.modality = modality
|
| 34 |
+
|
| 35 |
+
def __repr__(self) -> str:
|
| 36 |
+
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
| 37 |
+
f"len={len(self)})")
|
| 38 |
+
|
| 39 |
+
def __len__(self) -> int:
|
| 40 |
+
return self.get_count()
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, index: int) -> _I:
|
| 43 |
+
return self.get(index)
|
| 44 |
+
|
| 45 |
+
if TYPE_CHECKING:
|
| 46 |
+
# Auto-generated
|
| 47 |
+
def __iter__(self) -> Iterator[_I]:
|
| 48 |
+
...
|
| 49 |
+
|
| 50 |
+
@abstractmethod
|
| 51 |
+
def get_count(self) -> int:
|
| 52 |
+
"""Get the number of data items."""
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def get(self, index: int) -> _I:
|
| 57 |
+
"""Get a data item by its index."""
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
def get_all(self) -> list[_I]:
|
| 61 |
+
"""Get all data items."""
|
| 62 |
+
return [self.get(idx) for idx in range(self.get_count())]
|
| 63 |
+
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def get_processor_data(self) -> Mapping[str, object]:
|
| 66 |
+
"""Get the data to pass to the HF processor."""
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def get_passthrough_data(self) -> Mapping[str, object]:
|
| 71 |
+
"""Get the data to pass directly to the model."""
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
| 76 |
+
"""Base class for data items that are arranged in a list."""
|
| 77 |
+
|
| 78 |
+
def get_count(self) -> int:
|
| 79 |
+
return len(self.data)
|
| 80 |
+
|
| 81 |
+
def get(self, index: int) -> _T:
|
| 82 |
+
return self.data[index]
|
| 83 |
+
|
| 84 |
+
def get_processor_data(self) -> Mapping[str, object]:
|
| 85 |
+
return {f"{self.modality}s": self.data}
|
| 86 |
+
|
| 87 |
+
def get_passthrough_data(self) -> Mapping[str, object]:
|
| 88 |
+
return {}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
|
| 92 |
+
torch.Tensor]):
|
| 93 |
+
"""
|
| 94 |
+
Base class for data items that are expressed as a batched embedding tensor,
|
| 95 |
+
or a list of embedding tensors (one per item).
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def get_count(self) -> int:
|
| 99 |
+
return len(self.data)
|
| 100 |
+
|
| 101 |
+
def get(self, index: int) -> torch.Tensor:
|
| 102 |
+
return self.data[index]
|
| 103 |
+
|
| 104 |
+
def get_processor_data(self) -> Mapping[str, object]:
|
| 105 |
+
return {}
|
| 106 |
+
|
| 107 |
+
def get_passthrough_data(self) -> Mapping[str, object]:
|
| 108 |
+
return {f"{self.modality}_embeds": self.data}
|
| 109 |
+
|
| 110 |
+
def get_feature_size(self, item_idx: int) -> int:
|
| 111 |
+
return len(self.get(item_idx))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
| 115 |
+
|
| 116 |
+
def __init__(self, data: Sequence[HfAudioItem]) -> None:
|
| 117 |
+
super().__init__(data, "audio")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class AudioEmbeddingItems(EmbeddingItems):
|
| 121 |
+
|
| 122 |
+
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
| 123 |
+
super().__init__(data, "audio")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ImageSize(NamedTuple):
|
| 127 |
+
width: int
|
| 128 |
+
height: int
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
| 132 |
+
|
| 133 |
+
def __init__(self, data: Sequence[HfImageItem]) -> None:
|
| 134 |
+
super().__init__(data, "image")
|
| 135 |
+
|
| 136 |
+
def get_image_size(self, item_idx: int) -> ImageSize:
|
| 137 |
+
image = self.get(item_idx)
|
| 138 |
+
|
| 139 |
+
if isinstance(image, Image):
|
| 140 |
+
return ImageSize(*image.size)
|
| 141 |
+
if isinstance(image, (np.ndarray, torch.Tensor)):
|
| 142 |
+
_, h, w = image.shape
|
| 143 |
+
return ImageSize(w, h)
|
| 144 |
+
|
| 145 |
+
assert_never(image)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class ImageEmbeddingItems(EmbeddingItems):
|
| 149 |
+
|
| 150 |
+
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
| 151 |
+
super().__init__(data, "image")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
| 155 |
+
|
| 156 |
+
def __init__(self, data: Sequence[HfVideoItem]) -> None:
|
| 157 |
+
super().__init__(data, "video")
|
| 158 |
+
|
| 159 |
+
def get_num_frames(self, item_idx: int) -> int:
|
| 160 |
+
return len(self.get(item_idx))
|
| 161 |
+
|
| 162 |
+
def get_frame_size(self, item_idx: int) -> ImageSize:
|
| 163 |
+
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
| 164 |
+
|
| 165 |
+
if isinstance(image, Image):
|
| 166 |
+
return ImageSize(*image.size)
|
| 167 |
+
if isinstance(image, (np.ndarray, torch.Tensor)):
|
| 168 |
+
_, h, w = image.shape
|
| 169 |
+
return ImageSize(w, h)
|
| 170 |
+
|
| 171 |
+
assert_never(image)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class VideoEmbeddingItems(EmbeddingItems):
|
| 175 |
+
|
| 176 |
+
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
| 177 |
+
super().__init__(data, "video")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
| 184 |
+
"""
|
| 185 |
+
As :data:`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized
|
| 186 |
+
such that each entry corresponds to a list.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
| 190 |
+
"""
|
| 191 |
+
Get the number of data items belonging to a modality.
|
| 192 |
+
|
| 193 |
+
If `strict=False`, return `0` instead of raising :exc:`KeyError`
|
| 194 |
+
even if the modality is not found.
|
| 195 |
+
"""
|
| 196 |
+
if modality not in self:
|
| 197 |
+
if strict:
|
| 198 |
+
available_modalities = set(self.keys())
|
| 199 |
+
raise KeyError(f"Modality {modality!r} not found. "
|
| 200 |
+
f"Available modalities: {available_modalities}")
|
| 201 |
+
|
| 202 |
+
return 0
|
| 203 |
+
|
| 204 |
+
return self[modality].get_count()
|
| 205 |
+
|
| 206 |
+
def get_all_counts(self) -> Mapping[str, int]:
|
| 207 |
+
"""Get the number of items belonging to each modality."""
|
| 208 |
+
return {m: items.get_count() for m, items in self.items()}
|
| 209 |
+
|
| 210 |
+
def get_items(
|
| 211 |
+
self,
|
| 212 |
+
modality: str,
|
| 213 |
+
typ: Union[type[_D], tuple[type[_D], ...]],
|
| 214 |
+
) -> _D:
|
| 215 |
+
"""
|
| 216 |
+
Get the data items belonging to a modality,
|
| 217 |
+
requiring that they belong to a certain type.
|
| 218 |
+
"""
|
| 219 |
+
if modality not in self:
|
| 220 |
+
available_modalities = set(self.keys())
|
| 221 |
+
raise KeyError(f"Modality {modality!r} not found. "
|
| 222 |
+
f"Available modalities: {available_modalities}")
|
| 223 |
+
|
| 224 |
+
items = self[modality]
|
| 225 |
+
if not isinstance(items, typ):
|
| 226 |
+
raise TypeError(f"Invalid type of data items for {modality=}. "
|
| 227 |
+
f"Expected type: {typ}, but "
|
| 228 |
+
f"found type: {type(items)}")
|
| 229 |
+
|
| 230 |
+
return items # type: ignore[return-value]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
|
| 234 |
+
ModalityDataItems[Any, Any]]
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class MultiModalDataParser:
|
| 238 |
+
"""
|
| 239 |
+
Parses :data:`~vllm.multimodal.inputs.MultiModalDataDict` into
|
| 240 |
+
:class:`MultiModalDataItems`.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
target_sr (float, optional): Enables automatic resampling of audio
|
| 244 |
+
items to the model's expected sampling rate.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self, *, target_sr: Optional[float] = None) -> None:
|
| 248 |
+
super().__init__()
|
| 249 |
+
|
| 250 |
+
self.target_sr = target_sr
|
| 251 |
+
|
| 252 |
+
def _is_embeddings(
|
| 253 |
+
self, data: object
|
| 254 |
+
) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]:
|
| 255 |
+
if isinstance(data, torch.Tensor):
|
| 256 |
+
return data.ndim == 3
|
| 257 |
+
if is_list_of(data, torch.Tensor):
|
| 258 |
+
return len(data) == 0 or data[0].ndim == 2
|
| 259 |
+
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
def _get_audio_with_sr(
|
| 263 |
+
self,
|
| 264 |
+
audio: AudioItem,
|
| 265 |
+
) -> tuple[np.ndarray, Optional[float]]:
|
| 266 |
+
if isinstance(audio, tuple):
|
| 267 |
+
return audio
|
| 268 |
+
if isinstance(audio, list):
|
| 269 |
+
return np.array(audio), None
|
| 270 |
+
if isinstance(audio, np.ndarray):
|
| 271 |
+
return audio, None
|
| 272 |
+
if isinstance(audio, torch.Tensor):
|
| 273 |
+
return audio.numpy(), None
|
| 274 |
+
|
| 275 |
+
assert_never(audio)
|
| 276 |
+
|
| 277 |
+
def _parse_audio_data(
|
| 278 |
+
self,
|
| 279 |
+
data: ModalityData[AudioItem],
|
| 280 |
+
) -> ModalityDataItems[Any, Any]:
|
| 281 |
+
if self._is_embeddings(data):
|
| 282 |
+
return AudioEmbeddingItems(data)
|
| 283 |
+
|
| 284 |
+
if (is_list_of(data, float)
|
| 285 |
+
or isinstance(data,
|
| 286 |
+
(np.ndarray, torch.Tensor)) and data.ndim == 1
|
| 287 |
+
or isinstance(data, tuple)):
|
| 288 |
+
data_items = [data]
|
| 289 |
+
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
| 290 |
+
data_items = [elem for elem in data]
|
| 291 |
+
else:
|
| 292 |
+
data_items = data
|
| 293 |
+
|
| 294 |
+
new_audios = list[np.ndarray]()
|
| 295 |
+
for data_item in data_items:
|
| 296 |
+
audio, orig_sr = self._get_audio_with_sr(data_item)
|
| 297 |
+
if orig_sr is None:
|
| 298 |
+
new_audio = audio
|
| 299 |
+
else:
|
| 300 |
+
target_sr = self.target_sr
|
| 301 |
+
if target_sr is None:
|
| 302 |
+
raise RuntimeError(
|
| 303 |
+
"Audio resampling is not supported when "
|
| 304 |
+
"`target_sr` is not provided")
|
| 305 |
+
|
| 306 |
+
new_audio = resample_audio(audio,
|
| 307 |
+
orig_sr=orig_sr,
|
| 308 |
+
target_sr=target_sr)
|
| 309 |
+
|
| 310 |
+
new_audios.append(new_audio)
|
| 311 |
+
|
| 312 |
+
return AudioProcessorItems(new_audios)
|
| 313 |
+
|
| 314 |
+
def _parse_image_data(
|
| 315 |
+
self,
|
| 316 |
+
data: ModalityData[ImageItem],
|
| 317 |
+
) -> ModalityDataItems[Any, Any]:
|
| 318 |
+
if self._is_embeddings(data):
|
| 319 |
+
return ImageEmbeddingItems(data)
|
| 320 |
+
|
| 321 |
+
if (isinstance(data, Image)
|
| 322 |
+
or isinstance(data,
|
| 323 |
+
(np.ndarray, torch.Tensor)) and data.ndim == 3):
|
| 324 |
+
data_items = [data]
|
| 325 |
+
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
| 326 |
+
data_items = [elem for elem in data]
|
| 327 |
+
else:
|
| 328 |
+
data_items = data
|
| 329 |
+
|
| 330 |
+
return ImageProcessorItems(data_items)
|
| 331 |
+
|
| 332 |
+
def _parse_video_data(
|
| 333 |
+
self,
|
| 334 |
+
data: ModalityData[VideoItem],
|
| 335 |
+
) -> ModalityDataItems[Any, Any]:
|
| 336 |
+
if self._is_embeddings(data):
|
| 337 |
+
return VideoEmbeddingItems(data)
|
| 338 |
+
|
| 339 |
+
if (is_list_of(data, Image)
|
| 340 |
+
or isinstance(data,
|
| 341 |
+
(np.ndarray, torch.Tensor)) and data.ndim == 4):
|
| 342 |
+
data_items = [data]
|
| 343 |
+
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
| 344 |
+
data_items = [elem for elem in data]
|
| 345 |
+
else:
|
| 346 |
+
data_items = data
|
| 347 |
+
|
| 348 |
+
return VideoProcessorItems(data_items)
|
| 349 |
+
|
| 350 |
+
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
| 351 |
+
return {
|
| 352 |
+
"audio": self._parse_audio_data,
|
| 353 |
+
"image": self._parse_image_data,
|
| 354 |
+
"video": self._parse_video_data,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
def parse_mm_data(self,
|
| 358 |
+
mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
| 359 |
+
subparsers = self._get_subparsers()
|
| 360 |
+
|
| 361 |
+
mm_items = MultiModalDataItems()
|
| 362 |
+
for k, v in mm_data.items():
|
| 363 |
+
if k not in subparsers:
|
| 364 |
+
raise ValueError(f"Unsupported modality: {k}")
|
| 365 |
+
|
| 366 |
+
mm_items[k] = subparsers[k](v)
|
| 367 |
+
|
| 368 |
+
return mm_items
|
.venv/lib/python3.11/site-packages/vllm/multimodal/processing.py
ADDED
|
@@ -0,0 +1,1295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
|
| 7 |
+
Sequence)
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
| 11 |
+
TypeVar, Union)
|
| 12 |
+
|
| 13 |
+
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
| 14 |
+
|
| 15 |
+
import vllm.envs as envs
|
| 16 |
+
from vllm.inputs import InputProcessingContext
|
| 17 |
+
from vllm.logger import init_logger
|
| 18 |
+
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
| 19 |
+
encode_tokens)
|
| 20 |
+
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
| 21 |
+
|
| 22 |
+
from .hasher import MultiModalHasher
|
| 23 |
+
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
| 24 |
+
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
|
| 25 |
+
PlaceholderRange)
|
| 26 |
+
from .parse import MultiModalDataItems, MultiModalDataParser
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from .profiling import BaseDummyInputsBuilder
|
| 30 |
+
|
| 31 |
+
logger = init_logger(__name__)
|
| 32 |
+
|
| 33 |
+
_S = TypeVar("_S", str, list[int])
|
| 34 |
+
|
| 35 |
+
PromptSeq = Union[str, list[int]]
|
| 36 |
+
"""A token sequence (list of token IDs) or text."""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class PromptReplacementDetails:
|
| 41 |
+
"""Details about the replacement token sequence or text."""
|
| 42 |
+
|
| 43 |
+
full: PromptSeq
|
| 44 |
+
"""The full replacement."""
|
| 45 |
+
|
| 46 |
+
features: PromptSeq
|
| 47 |
+
"""
|
| 48 |
+
The part of the replacement that corresponds to feature placeholders;
|
| 49 |
+
this will be replaced by the output of the vision encoder during model
|
| 50 |
+
inference.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
|
| 55 |
+
return PromptReplacementDetails(full=seq, features=seq)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
|
| 59 |
+
"""
|
| 60 |
+
The replacement token sequence or text.
|
| 61 |
+
|
| 62 |
+
If only part of the replacement corresponds to feature placeholders, you can
|
| 63 |
+
use :class:`PromptReplacementDetails` to specify which part.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class PromptReplacement:
|
| 69 |
+
"""
|
| 70 |
+
Defines how to replace portions of an input prompt with placeholder tokens.
|
| 71 |
+
|
| 72 |
+
Example:
|
| 73 |
+
|
| 74 |
+
For each image, replace one ``<image>`` input placeholder in the prompt
|
| 75 |
+
with a number of ``<image>`` feature placeholders
|
| 76 |
+
equal to the feature size of the vision encoder:
|
| 77 |
+
|
| 78 |
+
.. code-block:: python
|
| 79 |
+
|
| 80 |
+
PromptReplacement(
|
| 81 |
+
modality="image",
|
| 82 |
+
target="<image>",
|
| 83 |
+
replacement="<image>" * image_feature_size,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
As above, but further pad the feature placeholders with ``<image_bos>``
|
| 87 |
+
and `<image_eos>``, which are not supposed to be passed to the vision
|
| 88 |
+
encoder:
|
| 89 |
+
|
| 90 |
+
.. code-block:: python
|
| 91 |
+
|
| 92 |
+
PromptReplacement(
|
| 93 |
+
modality="image",
|
| 94 |
+
target="<image>",
|
| 95 |
+
replacement=PromptReplacementDetails(
|
| 96 |
+
full="".join([
|
| 97 |
+
"<image_bos>",
|
| 98 |
+
"<image>" * image_feature_size,
|
| 99 |
+
"<image_eos>",
|
| 100 |
+
]),
|
| 101 |
+
features="<image>" * image_feature_size,
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
To avoid unnecessary tokenization during prompt replacement,
|
| 106 |
+
we recommended passing token sequences instead of text:
|
| 107 |
+
|
| 108 |
+
.. code-block:: python
|
| 109 |
+
|
| 110 |
+
PromptReplacement(
|
| 111 |
+
modality="image",
|
| 112 |
+
target=[image_token_id],
|
| 113 |
+
replacement=PromptReplacementDetails(
|
| 114 |
+
full=([image_bos_id] + [image_token_id] * image_feature_size
|
| 115 |
+
+ [image_eos_id]),
|
| 116 |
+
features=[image_token_id] * image_feature_size,
|
| 117 |
+
),
|
| 118 |
+
)
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
modality: str
|
| 122 |
+
"""The modality for which the replacement is made."""
|
| 123 |
+
|
| 124 |
+
target: PromptSeq
|
| 125 |
+
"""The token sequence (or text) to find and replace."""
|
| 126 |
+
|
| 127 |
+
replacement: Union[Callable[[int], PromptRepl],
|
| 128 |
+
PromptRepl] = field(repr=False)
|
| 129 |
+
"""
|
| 130 |
+
Given the index of the processed item within :attr:`modality`,
|
| 131 |
+
output the replacement token sequence (or text).
|
| 132 |
+
|
| 133 |
+
For convenience, you can directly pass in the replacement token sequence
|
| 134 |
+
(or text) instead of a function if it does not depend on the input.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
|
| 138 |
+
return BoundPromptReplacement(
|
| 139 |
+
tokenizer=tokenizer,
|
| 140 |
+
modality=self.modality,
|
| 141 |
+
_target=self.target,
|
| 142 |
+
_replacement=self.replacement,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@lru_cache(maxsize=2048)
|
| 147 |
+
def _cached_encode(
|
| 148 |
+
tokenizer: AnyTokenizer,
|
| 149 |
+
text: str,
|
| 150 |
+
*,
|
| 151 |
+
add_special_tokens: bool = False,
|
| 152 |
+
) -> list[int]:
|
| 153 |
+
return encode_tokens(tokenizer,
|
| 154 |
+
text,
|
| 155 |
+
add_special_tokens=add_special_tokens)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@lru_cache(maxsize=2048)
|
| 159 |
+
def _cached_decode(
|
| 160 |
+
tokenizer: AnyTokenizer,
|
| 161 |
+
token_ids: tuple[int, ...],
|
| 162 |
+
*,
|
| 163 |
+
skip_special_tokens: bool = False,
|
| 164 |
+
) -> str:
|
| 165 |
+
return decode_tokens(tokenizer,
|
| 166 |
+
list(token_ids),
|
| 167 |
+
skip_special_tokens=skip_special_tokens)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class _HasModalityAttr(Protocol):
|
| 171 |
+
modality: str
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class _HasModalityProp(Protocol):
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def modality(self) -> str:
|
| 178 |
+
...
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
| 185 |
+
"""Convenience function to apply :func:`full_groupby` based on modality."""
|
| 186 |
+
return full_groupby(values, key=lambda x: x.modality)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@dataclass
|
| 190 |
+
class _BoundPromptSequence:
|
| 191 |
+
"""
|
| 192 |
+
A :data:`_PromptSeq` bound to a tokenizer to automatically
|
| 193 |
+
convert between token sequence and text representations.
|
| 194 |
+
"""
|
| 195 |
+
tokenizer: AnyTokenizer = field(repr=False)
|
| 196 |
+
|
| 197 |
+
_text: Optional[str]
|
| 198 |
+
_token_ids: Optional[list[int]]
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
def from_seq(
|
| 202 |
+
tokenizer: AnyTokenizer,
|
| 203 |
+
seq: PromptSeq,
|
| 204 |
+
) -> "_BoundPromptSequence":
|
| 205 |
+
return _BoundPromptSequence(
|
| 206 |
+
tokenizer=tokenizer,
|
| 207 |
+
_text=seq if isinstance(seq, str) else None,
|
| 208 |
+
_token_ids=seq if isinstance(seq, list) else None,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def __post_init__(self) -> None:
|
| 212 |
+
if self._text is None and self._token_ids is None:
|
| 213 |
+
raise ValueError("At least one of 'text' and 'token_ids' must be "
|
| 214 |
+
"specified")
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def text(self) -> str:
|
| 218 |
+
if self._text is None:
|
| 219 |
+
assert self._token_ids is not None
|
| 220 |
+
self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
|
| 221 |
+
|
| 222 |
+
return self._text
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def token_ids(self) -> list[int]:
|
| 226 |
+
if self._token_ids is None:
|
| 227 |
+
assert self._text is not None
|
| 228 |
+
self._token_ids = _cached_encode(self.tokenizer, self._text)
|
| 229 |
+
|
| 230 |
+
return self._token_ids
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@dataclass
|
| 234 |
+
class _BoundPromptReplacementGroup:
|
| 235 |
+
full: _BoundPromptSequence
|
| 236 |
+
features: _BoundPromptSequence
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@dataclass
|
| 240 |
+
class BoundPromptReplacement:
|
| 241 |
+
"""
|
| 242 |
+
A :class:`PromptReplacement` bound to a tokenizer to automatically
|
| 243 |
+
convert :attr:`target` and the result of :meth:`get_replacement` between
|
| 244 |
+
token sequence and text representations.
|
| 245 |
+
"""
|
| 246 |
+
tokenizer: AnyTokenizer = field(repr=False)
|
| 247 |
+
modality: str
|
| 248 |
+
|
| 249 |
+
_target: PromptSeq
|
| 250 |
+
_replacement: Union[Callable[[int], PromptRepl],
|
| 251 |
+
PromptRepl] = field(repr=False)
|
| 252 |
+
|
| 253 |
+
def __post_init__(self) -> None:
|
| 254 |
+
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
|
| 255 |
+
|
| 256 |
+
@property
|
| 257 |
+
def target(self) -> _BoundPromptSequence:
|
| 258 |
+
"""The token sequence (or text) to find and replace."""
|
| 259 |
+
return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
|
| 260 |
+
|
| 261 |
+
def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
|
| 262 |
+
"""
|
| 263 |
+
Given the index of the processed item within :attr:`modality`,
|
| 264 |
+
output the replacement token sequence (or text).
|
| 265 |
+
"""
|
| 266 |
+
replacement = self._replacement
|
| 267 |
+
if callable(replacement):
|
| 268 |
+
cache_key = item_idx
|
| 269 |
+
if cache_key in self._replacement_cache:
|
| 270 |
+
return self._replacement_cache[cache_key]
|
| 271 |
+
|
| 272 |
+
replacement = replacement(item_idx)
|
| 273 |
+
else:
|
| 274 |
+
cache_key = None
|
| 275 |
+
|
| 276 |
+
if not isinstance(replacement, PromptReplacementDetails):
|
| 277 |
+
replacement = PromptReplacementDetails.from_seq(replacement)
|
| 278 |
+
|
| 279 |
+
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
|
| 280 |
+
replacement.full)
|
| 281 |
+
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
|
| 282 |
+
replacement.features)
|
| 283 |
+
bound_replacement = _BoundPromptReplacementGroup(
|
| 284 |
+
full=bound_full,
|
| 285 |
+
features=bound_features,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if cache_key is not None:
|
| 289 |
+
self._replacement_cache[cache_key] = bound_replacement
|
| 290 |
+
|
| 291 |
+
return bound_replacement
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class _TokenMatch(NamedTuple):
|
| 295 |
+
start_idx: int
|
| 296 |
+
end_idx: int
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def iter_token_matches(
|
| 300 |
+
token_ids: list[int],
|
| 301 |
+
match_ids: list[int],
|
| 302 |
+
) -> Generator[_TokenMatch]:
|
| 303 |
+
"""
|
| 304 |
+
Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
|
| 305 |
+
|
| 306 |
+
Note that empty matches are ignored.
|
| 307 |
+
"""
|
| 308 |
+
prompt_len = len(token_ids)
|
| 309 |
+
match_len = len(match_ids)
|
| 310 |
+
|
| 311 |
+
if match_len == 0:
|
| 312 |
+
return
|
| 313 |
+
|
| 314 |
+
start_idx = 0
|
| 315 |
+
while start_idx < prompt_len - match_len + 1:
|
| 316 |
+
end_idx = start_idx + match_len
|
| 317 |
+
|
| 318 |
+
if token_ids[start_idx:end_idx] == match_ids:
|
| 319 |
+
yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
|
| 320 |
+
|
| 321 |
+
# Exclude overlapping matches
|
| 322 |
+
start_idx = end_idx
|
| 323 |
+
else:
|
| 324 |
+
start_idx += 1
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@dataclass(repr=False)
|
| 328 |
+
class _PromptReplacementMatch(ABC):
|
| 329 |
+
prompt_repl: BoundPromptReplacement
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def modality(self) -> str:
|
| 333 |
+
return self.prompt_repl.modality
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
@abstractmethod
|
| 337 |
+
def start_idx(self) -> int:
|
| 338 |
+
raise NotImplementedError
|
| 339 |
+
|
| 340 |
+
@property
|
| 341 |
+
@abstractmethod
|
| 342 |
+
def end_idx(self) -> int:
|
| 343 |
+
raise NotImplementedError
|
| 344 |
+
|
| 345 |
+
def __repr__(self) -> str:
|
| 346 |
+
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
| 347 |
+
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
@dataclass(repr=False)
|
| 351 |
+
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
|
| 352 |
+
match: _TokenMatch
|
| 353 |
+
|
| 354 |
+
@property
|
| 355 |
+
def start_idx(self) -> int:
|
| 356 |
+
return self.match.start_idx
|
| 357 |
+
|
| 358 |
+
@property
|
| 359 |
+
def end_idx(self) -> int:
|
| 360 |
+
return self.match.end_idx
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@dataclass(repr=False)
|
| 364 |
+
class _PromptReplacementTextMatch(_PromptReplacementMatch):
|
| 365 |
+
match: re.Match[str]
|
| 366 |
+
|
| 367 |
+
@property
|
| 368 |
+
def start_idx(self) -> int:
|
| 369 |
+
return self.match.start()
|
| 370 |
+
|
| 371 |
+
@property
|
| 372 |
+
def end_idx(self) -> int:
|
| 373 |
+
return self.match.end()
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@dataclass
|
| 377 |
+
class PlaceholderFeaturesInfo:
|
| 378 |
+
modality: str
|
| 379 |
+
item_idx: int
|
| 380 |
+
start_idx: int
|
| 381 |
+
tokens: list[int]
|
| 382 |
+
|
| 383 |
+
@property
|
| 384 |
+
def length(self) -> int:
|
| 385 |
+
return len(self.tokens)
|
| 386 |
+
|
| 387 |
+
def to_range(self) -> PlaceholderRange:
|
| 388 |
+
return PlaceholderRange(
|
| 389 |
+
offset=self.start_idx,
|
| 390 |
+
length=self.length,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def find_token_matches(
|
| 395 |
+
prompt: list[int],
|
| 396 |
+
prompt_repls: Sequence[BoundPromptReplacement],
|
| 397 |
+
) -> list[_PromptReplacementTokenMatch]:
|
| 398 |
+
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
| 399 |
+
return [
|
| 400 |
+
_PromptReplacementTokenMatch(prompt_repl, match)
|
| 401 |
+
for prompt_repl in prompt_repls
|
| 402 |
+
for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
|
| 403 |
+
]
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def find_text_matches(
|
| 407 |
+
prompt: str,
|
| 408 |
+
prompt_repls: Sequence[BoundPromptReplacement],
|
| 409 |
+
) -> list[_PromptReplacementTextMatch]:
|
| 410 |
+
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
| 411 |
+
return [
|
| 412 |
+
_PromptReplacementTextMatch(prompt_repl, match)
|
| 413 |
+
for prompt_repl in prompt_repls
|
| 414 |
+
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
|
| 415 |
+
]
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _resolve_matches(
|
| 419 |
+
prompt: PromptSeq,
|
| 420 |
+
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
| 421 |
+
) -> list[_PromptReplacementMatch]:
|
| 422 |
+
"""
|
| 423 |
+
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
|
| 424 |
+
and sort them such that earlier matches take priority over later ones.
|
| 425 |
+
"""
|
| 426 |
+
matches = [m for matches in mm_matches.values() for m in matches]
|
| 427 |
+
|
| 428 |
+
seen_matches: list[Optional[_PromptReplacementMatch]] = [None
|
| 429 |
+
] * len(prompt)
|
| 430 |
+
|
| 431 |
+
for match in matches:
|
| 432 |
+
for idx in range(match.start_idx, match.end_idx):
|
| 433 |
+
if seen_matches[idx] is not None:
|
| 434 |
+
raise ValueError("Found overlapping matches "
|
| 435 |
+
f"({seen_matches[idx]} and {match}) "
|
| 436 |
+
f"at index={idx} of prompt={prompt}")
|
| 437 |
+
|
| 438 |
+
seen_matches[idx] = match
|
| 439 |
+
|
| 440 |
+
return sorted(matches, key=lambda x: x.start_idx)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _replace_matches(
|
| 444 |
+
prompt: _S,
|
| 445 |
+
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
| 446 |
+
mm_item_counts: Mapping[str, int],
|
| 447 |
+
) -> list[_S]:
|
| 448 |
+
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
| 449 |
+
out_seqs = list[_S]()
|
| 450 |
+
prev_end_idx = 0
|
| 451 |
+
next_idx_by_modality = defaultdict[str, int](lambda: 0)
|
| 452 |
+
|
| 453 |
+
for match in _resolve_matches(prompt, mm_matches):
|
| 454 |
+
modality = match.modality
|
| 455 |
+
|
| 456 |
+
item_idx = next_idx_by_modality[modality]
|
| 457 |
+
if item_idx >= mm_item_counts.get(modality, 0):
|
| 458 |
+
continue
|
| 459 |
+
|
| 460 |
+
start_idx = match.start_idx
|
| 461 |
+
end_idx = match.end_idx
|
| 462 |
+
|
| 463 |
+
repl_info = match.prompt_repl
|
| 464 |
+
replacement = repl_info.get_replacement(item_idx)
|
| 465 |
+
|
| 466 |
+
if isinstance(prompt, str):
|
| 467 |
+
repl_seq = replacement.full.text
|
| 468 |
+
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
| 469 |
+
else:
|
| 470 |
+
repl_seq = replacement.full.token_ids
|
| 471 |
+
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
| 472 |
+
|
| 473 |
+
prev_end_idx = end_idx
|
| 474 |
+
next_idx_by_modality[modality] += 1
|
| 475 |
+
|
| 476 |
+
out_seqs.append(prompt[prev_end_idx:])
|
| 477 |
+
|
| 478 |
+
return out_seqs
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def replace_token_matches(
|
| 482 |
+
prompt: list[int],
|
| 483 |
+
mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
|
| 484 |
+
mm_item_counts: Mapping[str, int],
|
| 485 |
+
) -> list[int]:
|
| 486 |
+
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
| 487 |
+
if not mm_matches:
|
| 488 |
+
return prompt
|
| 489 |
+
|
| 490 |
+
token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
|
| 491 |
+
|
| 492 |
+
return flatten_2d_lists(token_id_seqs)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def replace_text_matches(
|
| 496 |
+
prompt: str,
|
| 497 |
+
mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
|
| 498 |
+
mm_item_counts: Mapping[str, int],
|
| 499 |
+
) -> str:
|
| 500 |
+
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
| 501 |
+
if not mm_matches:
|
| 502 |
+
return prompt
|
| 503 |
+
|
| 504 |
+
texts = _replace_matches(prompt, mm_matches, mm_item_counts)
|
| 505 |
+
|
| 506 |
+
return "".join(texts)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _iter_placeholders(
|
| 510 |
+
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
| 511 |
+
prompt: list[int],
|
| 512 |
+
mm_item_counts: Mapping[str, int],
|
| 513 |
+
) -> Iterable[PlaceholderFeaturesInfo]:
|
| 514 |
+
"""
|
| 515 |
+
Yield each set of placeholder tokens found in :code:`prompt`.
|
| 516 |
+
|
| 517 |
+
Matches are exclusive even when multiple modalities share
|
| 518 |
+
the same placeholder tokens. In that case, the modality that
|
| 519 |
+
appears earlier in `mm_prompt_repls` takes priority.
|
| 520 |
+
|
| 521 |
+
Note that empty matches are ignored.
|
| 522 |
+
"""
|
| 523 |
+
prompt_len = len(prompt)
|
| 524 |
+
item_idx_by_modality = defaultdict[str, int](lambda: 0)
|
| 525 |
+
|
| 526 |
+
start_idx = 0
|
| 527 |
+
while start_idx < prompt_len:
|
| 528 |
+
found = False
|
| 529 |
+
|
| 530 |
+
for modality, modality_repls in mm_prompt_repls.items():
|
| 531 |
+
item_idx = item_idx_by_modality[modality]
|
| 532 |
+
if item_idx >= mm_item_counts.get(modality, 0):
|
| 533 |
+
continue
|
| 534 |
+
|
| 535 |
+
for repl_info in modality_repls:
|
| 536 |
+
replacement = repl_info.get_replacement(item_idx)
|
| 537 |
+
repl_tokens_full = replacement.full.token_ids
|
| 538 |
+
repl_len_full = len(repl_tokens_full)
|
| 539 |
+
end_idx_full = start_idx + repl_len_full
|
| 540 |
+
|
| 541 |
+
if repl_len_full == 0 or end_idx_full > prompt_len:
|
| 542 |
+
continue
|
| 543 |
+
|
| 544 |
+
if prompt[start_idx:end_idx_full] == repl_tokens_full:
|
| 545 |
+
repl_tokens_feat = replacement.features.token_ids
|
| 546 |
+
|
| 547 |
+
try:
|
| 548 |
+
match = next(
|
| 549 |
+
iter_token_matches(repl_tokens_full,
|
| 550 |
+
repl_tokens_feat))
|
| 551 |
+
yield PlaceholderFeaturesInfo(
|
| 552 |
+
modality=modality,
|
| 553 |
+
item_idx=item_idx,
|
| 554 |
+
start_idx=start_idx + match.start_idx,
|
| 555 |
+
tokens=repl_tokens_feat,
|
| 556 |
+
)
|
| 557 |
+
except StopIteration:
|
| 558 |
+
raise AssertionError(
|
| 559 |
+
f"{repl_tokens_feat=} should be a "
|
| 560 |
+
f"subsequence of {repl_tokens_full=}") from None
|
| 561 |
+
|
| 562 |
+
# Exclude overlapping matches
|
| 563 |
+
start_idx = end_idx_full
|
| 564 |
+
item_idx_by_modality[modality] += 1
|
| 565 |
+
found = True
|
| 566 |
+
break
|
| 567 |
+
|
| 568 |
+
if found:
|
| 569 |
+
break # Go back to the outer while loop
|
| 570 |
+
|
| 571 |
+
if not found:
|
| 572 |
+
start_idx += 1
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def find_mm_placeholders(
|
| 576 |
+
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
| 577 |
+
prompt: list[int],
|
| 578 |
+
mm_item_counts: Mapping[str, int],
|
| 579 |
+
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
| 580 |
+
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
|
| 581 |
+
return dict(full_groupby_modality(it))
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class ProcessingCache:
|
| 585 |
+
|
| 586 |
+
def __init__(self, capacity: int) -> None:
|
| 587 |
+
super().__init__()
|
| 588 |
+
|
| 589 |
+
# DEBUG: Set to None to disable
|
| 590 |
+
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
| 591 |
+
|
| 592 |
+
self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
|
| 593 |
+
|
| 594 |
+
def _maybe_log_cache_stats(self) -> None:
|
| 595 |
+
steps = self.debug_cache_hit_ratio_steps
|
| 596 |
+
if not steps:
|
| 597 |
+
return
|
| 598 |
+
|
| 599 |
+
cache_stats = self._cache.stat()
|
| 600 |
+
if cache_stats.total % steps == 0:
|
| 601 |
+
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
| 602 |
+
cache_stats.hit_ratio)
|
| 603 |
+
|
| 604 |
+
def get(
|
| 605 |
+
self,
|
| 606 |
+
model_id: str,
|
| 607 |
+
modality: str,
|
| 608 |
+
input_item: object,
|
| 609 |
+
input_kwargs: Mapping[str, object],
|
| 610 |
+
) -> Optional[MultiModalKwargsItem]:
|
| 611 |
+
"""
|
| 612 |
+
Get a processed multi-modal item from the cache
|
| 613 |
+
according to its dependencies, including:
|
| 614 |
+
|
| 615 |
+
- The model ID
|
| 616 |
+
- The modality of the item
|
| 617 |
+
- The original data item passed to the HF processor
|
| 618 |
+
- The configuration options of the HF processor
|
| 619 |
+
"""
|
| 620 |
+
self._maybe_log_cache_stats()
|
| 621 |
+
|
| 622 |
+
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
| 623 |
+
**{modality: input_item},
|
| 624 |
+
**input_kwargs)
|
| 625 |
+
return self._cache.get(cache_key)
|
| 626 |
+
|
| 627 |
+
def put(
|
| 628 |
+
self,
|
| 629 |
+
model_id: str,
|
| 630 |
+
modality: str,
|
| 631 |
+
input_item: object,
|
| 632 |
+
input_kwargs: Mapping[str, object],
|
| 633 |
+
output_kwargs: MultiModalKwargsItem,
|
| 634 |
+
) -> None:
|
| 635 |
+
"""
|
| 636 |
+
Put a processed multi-modal item into the cache
|
| 637 |
+
according to its dependencies (see :meth:`get`).
|
| 638 |
+
"""
|
| 639 |
+
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
| 640 |
+
**{modality: input_item},
|
| 641 |
+
**input_kwargs)
|
| 642 |
+
self._cache.put(cache_key, output_kwargs)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class BaseProcessingInfo:
|
| 646 |
+
"""Base class to provide the information necessary for data processing."""
|
| 647 |
+
|
| 648 |
+
def __init__(self, ctx: InputProcessingContext) -> None:
|
| 649 |
+
super().__init__()
|
| 650 |
+
|
| 651 |
+
self.ctx = ctx
|
| 652 |
+
|
| 653 |
+
@property
|
| 654 |
+
def model_id(self) -> str:
|
| 655 |
+
return self.ctx.model_config.model
|
| 656 |
+
|
| 657 |
+
def get_tokenizer(self) -> AnyTokenizer:
|
| 658 |
+
return self.ctx.tokenizer
|
| 659 |
+
|
| 660 |
+
def get_hf_config(self) -> PretrainedConfig:
|
| 661 |
+
return self.ctx.get_hf_config()
|
| 662 |
+
|
| 663 |
+
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
| 664 |
+
"""
|
| 665 |
+
Subclasses can override this method to handle
|
| 666 |
+
specific kwargs from model config or user inputs.
|
| 667 |
+
"""
|
| 668 |
+
return self.ctx.get_hf_processor(**kwargs)
|
| 669 |
+
|
| 670 |
+
@abstractmethod
|
| 671 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
| 672 |
+
"""
|
| 673 |
+
Return the maximum supported number of items for each modality.
|
| 674 |
+
|
| 675 |
+
A value of `None` means unlimited number of items.
|
| 676 |
+
|
| 677 |
+
Omitting a modality from the returned dictionary means that
|
| 678 |
+
it is not supported at all.
|
| 679 |
+
"""
|
| 680 |
+
raise NotImplementedError
|
| 681 |
+
|
| 682 |
+
@abstractmethod
|
| 683 |
+
def get_mm_max_tokens_per_item(
|
| 684 |
+
self,
|
| 685 |
+
seq_len: int,
|
| 686 |
+
mm_counts: Mapping[str, int],
|
| 687 |
+
) -> Mapping[str, int]:
|
| 688 |
+
"""
|
| 689 |
+
Get the maximum possible number of tokens per data item
|
| 690 |
+
for each modality.
|
| 691 |
+
|
| 692 |
+
The dictionary returned by this method should have the same
|
| 693 |
+
keys as that returned by :meth:`get_supported_mm_limits`.
|
| 694 |
+
"""
|
| 695 |
+
raise NotImplementedError
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
class BaseMultiModalProcessor(ABC, Generic[_I]):
|
| 702 |
+
"""
|
| 703 |
+
Abstract base class to process multi-modal inputs to be used in vLLM.
|
| 704 |
+
|
| 705 |
+
Not to be confused with :class:`transformers.ProcessorMixin`.
|
| 706 |
+
"""
|
| 707 |
+
|
| 708 |
+
def __init__(self,
|
| 709 |
+
info: _I,
|
| 710 |
+
dummy_inputs: "BaseDummyInputsBuilder[_I]",
|
| 711 |
+
*,
|
| 712 |
+
cache: Optional[ProcessingCache] = None,
|
| 713 |
+
enable_sanity_checks: bool = True) -> None:
|
| 714 |
+
super().__init__()
|
| 715 |
+
|
| 716 |
+
self.info = info
|
| 717 |
+
self.dummy_inputs = dummy_inputs
|
| 718 |
+
self.cache = cache
|
| 719 |
+
self.enable_sanity_checks = enable_sanity_checks
|
| 720 |
+
|
| 721 |
+
self.data_parser = self._get_data_parser()
|
| 722 |
+
|
| 723 |
+
def __call__(
|
| 724 |
+
self,
|
| 725 |
+
prompt: str,
|
| 726 |
+
mm_data: MultiModalDataDict,
|
| 727 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 728 |
+
) -> MultiModalInputs:
|
| 729 |
+
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
| 730 |
+
|
| 731 |
+
def _get_data_parser(self) -> MultiModalDataParser:
|
| 732 |
+
"""
|
| 733 |
+
Construct a parser to preprocess multi-modal data items
|
| 734 |
+
before passing them to :meth:`_get_hf_mm_data`.
|
| 735 |
+
|
| 736 |
+
You can support additional modalities by creating a subclass
|
| 737 |
+
of :class:`MultiModalDataParser` that has additional subparsers.
|
| 738 |
+
"""
|
| 739 |
+
return MultiModalDataParser()
|
| 740 |
+
|
| 741 |
+
def _to_mm_items(
|
| 742 |
+
self,
|
| 743 |
+
mm_data: MultiModalDataDict,
|
| 744 |
+
) -> MultiModalDataItems:
|
| 745 |
+
"""
|
| 746 |
+
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
|
| 747 |
+
before passing them to :meth:`_get_hf_mm_data`.
|
| 748 |
+
"""
|
| 749 |
+
mm_items = self.data_parser.parse_mm_data(mm_data)
|
| 750 |
+
|
| 751 |
+
mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
|
| 752 |
+
for modality, items in mm_items.items():
|
| 753 |
+
limit = mm_limits.get(modality, 1)
|
| 754 |
+
if len(items) > limit:
|
| 755 |
+
raise ValueError(
|
| 756 |
+
f"You set {modality}={limit} (or defaulted to 1) in "
|
| 757 |
+
f"`--limit-mm-per-prompt`, but passed {len(items)} "
|
| 758 |
+
f"{modality} items in the same prompt.")
|
| 759 |
+
|
| 760 |
+
return mm_items
|
| 761 |
+
|
| 762 |
+
@abstractmethod
|
| 763 |
+
def _get_mm_fields_config(
|
| 764 |
+
self,
|
| 765 |
+
hf_inputs: BatchFeature,
|
| 766 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 767 |
+
) -> Mapping[str, MultiModalFieldConfig]:
|
| 768 |
+
"""Given the HF-processed data, output the metadata of each field."""
|
| 769 |
+
raise NotImplementedError
|
| 770 |
+
|
| 771 |
+
@abstractmethod
|
| 772 |
+
def _get_prompt_replacements(
|
| 773 |
+
self,
|
| 774 |
+
mm_items: MultiModalDataItems,
|
| 775 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 776 |
+
out_mm_kwargs: MultiModalKwargs,
|
| 777 |
+
) -> list[PromptReplacement]:
|
| 778 |
+
"""
|
| 779 |
+
Given the original multi-modal items for this modality
|
| 780 |
+
and HF-processed data, output the replacements to perform.
|
| 781 |
+
|
| 782 |
+
Notes:
|
| 783 |
+
- You should not assume that HF processor always performs prompt
|
| 784 |
+
replacement: in :meth:`_apply_hf_processor_missing`, this method
|
| 785 |
+
is called on text-only and multimodal-only inputs separately,
|
| 786 |
+
instead of passing them in the same call.
|
| 787 |
+
- The replacement information returned by this method is also used
|
| 788 |
+
to determine the placeholder token positions for each multi-modal
|
| 789 |
+
item.
|
| 790 |
+
"""
|
| 791 |
+
raise NotImplementedError
|
| 792 |
+
|
| 793 |
+
def _find_mm_placeholders(
|
| 794 |
+
self,
|
| 795 |
+
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
| 796 |
+
new_token_ids: list[int],
|
| 797 |
+
mm_item_counts: Mapping[str, int],
|
| 798 |
+
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
| 799 |
+
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
|
| 800 |
+
mm_item_counts)
|
| 801 |
+
|
| 802 |
+
def _get_hf_mm_data(
|
| 803 |
+
self,
|
| 804 |
+
mm_items: MultiModalDataItems,
|
| 805 |
+
) -> tuple[Mapping[str, object], Mapping[str, object]]:
|
| 806 |
+
processor_data = dict[str, object]()
|
| 807 |
+
passthrough_data = dict[str, object]()
|
| 808 |
+
|
| 809 |
+
for items in mm_items.values():
|
| 810 |
+
processor_data.update(items.get_processor_data())
|
| 811 |
+
passthrough_data.update(items.get_passthrough_data())
|
| 812 |
+
|
| 813 |
+
return processor_data, passthrough_data
|
| 814 |
+
|
| 815 |
+
def _call_hf_processor(
|
| 816 |
+
self,
|
| 817 |
+
prompt: str,
|
| 818 |
+
# Not to be confused with `mm_data` in `self.apply`.
|
| 819 |
+
# This refers to the data to be passed to HF processor.
|
| 820 |
+
mm_data: Mapping[str, object],
|
| 821 |
+
mm_kwargs: Mapping[str, object],
|
| 822 |
+
) -> BatchFeature:
|
| 823 |
+
"""
|
| 824 |
+
Call the HF processor on the prompt text and
|
| 825 |
+
associated multi-modal data.
|
| 826 |
+
"""
|
| 827 |
+
return self.info.ctx.call_hf_processor(
|
| 828 |
+
self.info.get_hf_processor(**mm_kwargs),
|
| 829 |
+
dict(text=prompt, **mm_data),
|
| 830 |
+
mm_kwargs,
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
def _apply_hf_processor_text_mm(
|
| 834 |
+
self,
|
| 835 |
+
prompt_text: str,
|
| 836 |
+
mm_items: MultiModalDataItems,
|
| 837 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 838 |
+
) -> tuple[list[int], MultiModalKwargs]:
|
| 839 |
+
"""
|
| 840 |
+
Apply the HF processor on the prompt text and multi-modal data
|
| 841 |
+
together.
|
| 842 |
+
"""
|
| 843 |
+
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
|
| 844 |
+
|
| 845 |
+
processed_data = self._call_hf_processor(
|
| 846 |
+
prompt=prompt_text,
|
| 847 |
+
mm_data=processor_data,
|
| 848 |
+
mm_kwargs=hf_processor_mm_kwargs,
|
| 849 |
+
)
|
| 850 |
+
processed_data.update(passthrough_data)
|
| 851 |
+
|
| 852 |
+
prompt_ids, = processed_data.pop("input_ids").tolist()
|
| 853 |
+
|
| 854 |
+
mm_kwargs = MultiModalKwargs.from_hf_inputs(
|
| 855 |
+
processed_data,
|
| 856 |
+
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
return prompt_ids, mm_kwargs
|
| 860 |
+
|
| 861 |
+
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
|
| 862 |
+
"""
|
| 863 |
+
Apply the HF processor on the prompt text only.
|
| 864 |
+
|
| 865 |
+
Since HF processor requires that text and multi-modal items
|
| 866 |
+
correspond to each other, we create dummy multi-modal items
|
| 867 |
+
to go along with the text.
|
| 868 |
+
"""
|
| 869 |
+
prompt_ids, _ = self._apply_hf_processor_text_mm(
|
| 870 |
+
prompt_text=prompt_text,
|
| 871 |
+
mm_items=MultiModalDataItems({}),
|
| 872 |
+
hf_processor_mm_kwargs={},
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
return prompt_ids
|
| 876 |
+
|
| 877 |
+
def _apply_hf_processor_tokens_only(
|
| 878 |
+
self,
|
| 879 |
+
prompt_tokens: list[int],
|
| 880 |
+
) -> list[int]:
|
| 881 |
+
"""
|
| 882 |
+
Apply the HF processor on the prompt tokens only.
|
| 883 |
+
|
| 884 |
+
Most HF processors accept prompt text but not prompt tokens.
|
| 885 |
+
If the HF processor adds or removes tokens that are not related to
|
| 886 |
+
multi-modal data, you should override this method so it is consistent
|
| 887 |
+
with the output of :meth:`_apply_hf_processor_text_only` on the
|
| 888 |
+
corresponding text.
|
| 889 |
+
"""
|
| 890 |
+
return prompt_tokens
|
| 891 |
+
|
| 892 |
+
def _apply_hf_processor_mm_only(
|
| 893 |
+
self,
|
| 894 |
+
mm_items: MultiModalDataItems,
|
| 895 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 896 |
+
) -> MultiModalKwargs:
|
| 897 |
+
"""
|
| 898 |
+
Apply the HF processor on the multi-modal data only.
|
| 899 |
+
|
| 900 |
+
Since HF processor requires that text and multi-modal items
|
| 901 |
+
correspond to each other, we generate dummy text using
|
| 902 |
+
:class:`DummyInputsBuilder` to go along with the multi-modal data.
|
| 903 |
+
"""
|
| 904 |
+
mm_counts = mm_items.get_all_counts()
|
| 905 |
+
|
| 906 |
+
dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
|
| 907 |
+
self.info.ctx.model_config.max_model_len,
|
| 908 |
+
mm_counts,
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
_, mm_kwargs = self._apply_hf_processor_text_mm(
|
| 912 |
+
prompt_text=dummy_inputs.prompt_text,
|
| 913 |
+
mm_items=mm_items,
|
| 914 |
+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
return mm_kwargs
|
| 918 |
+
|
| 919 |
+
def _apply_hf_processor_main(
|
| 920 |
+
self,
|
| 921 |
+
prompt: Union[str, list[int]],
|
| 922 |
+
mm_items: MultiModalDataItems,
|
| 923 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 924 |
+
*,
|
| 925 |
+
enable_hf_prompt_replacement: bool,
|
| 926 |
+
) -> tuple[list[int], MultiModalKwargs]:
|
| 927 |
+
"""
|
| 928 |
+
Apply the HF processor on the prompt text and multi-modal data.
|
| 929 |
+
|
| 930 |
+
Note:
|
| 931 |
+
If :code:`enable_hf_prompt_replacement=False`, the prompt should
|
| 932 |
+
correspond to the multi-modal items.
|
| 933 |
+
"""
|
| 934 |
+
if isinstance(prompt, str):
|
| 935 |
+
if enable_hf_prompt_replacement:
|
| 936 |
+
return self._apply_hf_processor_text_mm(
|
| 937 |
+
prompt_text=prompt,
|
| 938 |
+
mm_items=mm_items,
|
| 939 |
+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
prompt_ids = self._apply_hf_processor_text_only(prompt)
|
| 943 |
+
else:
|
| 944 |
+
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
|
| 945 |
+
|
| 946 |
+
mm_missing_kwargs = self._apply_hf_processor_mm_only(
|
| 947 |
+
mm_items=mm_items,
|
| 948 |
+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
return prompt_ids, mm_missing_kwargs
|
| 952 |
+
|
| 953 |
+
def _cached_apply_hf_processor(
|
| 954 |
+
self,
|
| 955 |
+
prompt: Union[str, list[int]],
|
| 956 |
+
mm_data_items: MultiModalDataItems,
|
| 957 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 958 |
+
) -> tuple[list[int], MultiModalKwargs]:
|
| 959 |
+
"""
|
| 960 |
+
Apply the HF processor on the full prompt text,
|
| 961 |
+
caching the results and reusing cached results.
|
| 962 |
+
"""
|
| 963 |
+
cache = self.cache
|
| 964 |
+
model_id = self.info.model_id
|
| 965 |
+
|
| 966 |
+
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
|
| 967 |
+
if cache is None or passthrough_data:
|
| 968 |
+
return self._apply_hf_processor_main(
|
| 969 |
+
prompt=prompt,
|
| 970 |
+
mm_items=mm_data_items,
|
| 971 |
+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
| 972 |
+
enable_hf_prompt_replacement=True,
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
mm_maybe_cached_kw_items = {
|
| 976 |
+
modality: [
|
| 977 |
+
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
|
| 978 |
+
for item in items
|
| 979 |
+
]
|
| 980 |
+
for modality, items in mm_data_items.items()
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
mm_missing_idxs = {
|
| 984 |
+
modality:
|
| 985 |
+
[idx for idx, item in enumerate(kw_items) if item is None]
|
| 986 |
+
for modality, kw_items in mm_maybe_cached_kw_items.items()
|
| 987 |
+
}
|
| 988 |
+
mm_missing_data = {
|
| 989 |
+
modality: [mm_data_items[modality][idx] for idx in idxs]
|
| 990 |
+
for modality, idxs in mm_missing_idxs.items()
|
| 991 |
+
}
|
| 992 |
+
mm_missing_data_items = self._to_mm_items(mm_missing_data)
|
| 993 |
+
|
| 994 |
+
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
|
| 995 |
+
# so we need to pass `enable_hf_prompt_replacement=False`
|
| 996 |
+
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
|
| 997 |
+
prompt=prompt,
|
| 998 |
+
mm_items=mm_missing_data_items,
|
| 999 |
+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
| 1000 |
+
enable_hf_prompt_replacement=False,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
mm_missing_next_idx = {
|
| 1004 |
+
modality: 0
|
| 1005 |
+
for modality in mm_missing_data_items
|
| 1006 |
+
}
|
| 1007 |
+
|
| 1008 |
+
merged_kw_items = list[MultiModalKwargsItem]()
|
| 1009 |
+
for modality, kw_items in mm_maybe_cached_kw_items.items():
|
| 1010 |
+
for idx, kw_item in enumerate(kw_items):
|
| 1011 |
+
if kw_item is None:
|
| 1012 |
+
kw_item = mm_missing_kwargs.get_item(
|
| 1013 |
+
modality,
|
| 1014 |
+
mm_missing_next_idx[modality],
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
cache.put(
|
| 1018 |
+
model_id,
|
| 1019 |
+
modality,
|
| 1020 |
+
mm_data_items[modality][idx],
|
| 1021 |
+
hf_processor_mm_kwargs,
|
| 1022 |
+
kw_item,
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
mm_missing_next_idx[modality] += 1
|
| 1026 |
+
|
| 1027 |
+
merged_kw_items.append(kw_item)
|
| 1028 |
+
|
| 1029 |
+
if self.enable_sanity_checks:
|
| 1030 |
+
mm_missing_counts = mm_missing_data_items.get_all_counts()
|
| 1031 |
+
assert all(
|
| 1032 |
+
item_count == mm_missing_counts[modality]
|
| 1033 |
+
for modality, item_count in mm_missing_next_idx.items()), dict(
|
| 1034 |
+
mm_missing_next_idx=mm_missing_next_idx,
|
| 1035 |
+
mm_missing_counts=mm_missing_counts)
|
| 1036 |
+
|
| 1037 |
+
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
|
| 1038 |
+
|
| 1039 |
+
return prompt_ids, mm_kwargs
|
| 1040 |
+
|
| 1041 |
+
def _bind_and_group_repls(
|
| 1042 |
+
self,
|
| 1043 |
+
prompt_repls: list[PromptReplacement],
|
| 1044 |
+
) -> dict[str, list[BoundPromptReplacement]]:
|
| 1045 |
+
tokenizer = self.info.get_tokenizer()
|
| 1046 |
+
|
| 1047 |
+
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
|
| 1048 |
+
return dict(full_groupby_modality(it))
|
| 1049 |
+
|
| 1050 |
+
def _always_apply_prompt_replacements(self) -> bool:
|
| 1051 |
+
"""
|
| 1052 |
+
A flag which can be overridden so that
|
| 1053 |
+
:meth:`_apply_prompt_replacements` is always called even if we
|
| 1054 |
+
detect that HF has performed processing via
|
| 1055 |
+
:meth:`_find_placeholders_by_modality`.
|
| 1056 |
+
|
| 1057 |
+
This is useful in cases where :meth:`_find_placeholders_by_modality`
|
| 1058 |
+
cannot be reliably used to detect whether HF has performed processing.
|
| 1059 |
+
"""
|
| 1060 |
+
return False
|
| 1061 |
+
|
| 1062 |
+
def _apply_prompt_replacements(
|
| 1063 |
+
self,
|
| 1064 |
+
token_ids: list[int],
|
| 1065 |
+
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
| 1066 |
+
mm_item_counts: Mapping[str, int],
|
| 1067 |
+
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
| 1068 |
+
tokenizer = self.info.get_tokenizer()
|
| 1069 |
+
|
| 1070 |
+
mm_token_matches = {
|
| 1071 |
+
modality: find_token_matches(token_ids, prompt_repls)
|
| 1072 |
+
for modality, prompt_repls in mm_prompt_repls.items()
|
| 1073 |
+
}
|
| 1074 |
+
mm_match_counts = {
|
| 1075 |
+
modality: len(matches)
|
| 1076 |
+
for modality, matches in mm_token_matches.items()
|
| 1077 |
+
}
|
| 1078 |
+
|
| 1079 |
+
# If the search text does not represent a special token,
|
| 1080 |
+
# it may have different token IDs in the prompt, because
|
| 1081 |
+
# the tokens may go across the boundaries of the search text.
|
| 1082 |
+
# ----
|
| 1083 |
+
# e.g. when searching for "foo" in "food", if "food" itself makes
|
| 1084 |
+
# up a token, then the token ID of "foo" will not appear at all
|
| 1085 |
+
# ----
|
| 1086 |
+
# Since it is inefficient to search for all possible tokenizations
|
| 1087 |
+
# of the search text in the prompt, we instead perform string
|
| 1088 |
+
# replacement on the decoded token IDs, then encode them back.
|
| 1089 |
+
if all(
|
| 1090 |
+
mm_match_counts.get(modality, 0) >= item_count
|
| 1091 |
+
for modality, item_count in mm_item_counts.items()
|
| 1092 |
+
): # yapf: disable
|
| 1093 |
+
token_ids = replace_token_matches(
|
| 1094 |
+
token_ids,
|
| 1095 |
+
mm_token_matches,
|
| 1096 |
+
mm_item_counts,
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
text = decode_tokens(tokenizer, token_ids)
|
| 1100 |
+
matched_repls = {
|
| 1101 |
+
modality: [match.prompt_repl for match in token_matches]
|
| 1102 |
+
for modality, token_matches in mm_token_matches.items()
|
| 1103 |
+
}
|
| 1104 |
+
else:
|
| 1105 |
+
text = decode_tokens(tokenizer, token_ids)
|
| 1106 |
+
|
| 1107 |
+
mm_text_matches = {
|
| 1108 |
+
modality: find_text_matches(text, prompt_repls)
|
| 1109 |
+
for modality, prompt_repls in mm_prompt_repls.items()
|
| 1110 |
+
}
|
| 1111 |
+
text = replace_text_matches(
|
| 1112 |
+
text,
|
| 1113 |
+
mm_text_matches,
|
| 1114 |
+
mm_item_counts,
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
token_ids = encode_tokens(tokenizer,
|
| 1118 |
+
text,
|
| 1119 |
+
add_special_tokens=False)
|
| 1120 |
+
matched_repls = {
|
| 1121 |
+
modality: [match.prompt_repl for match in token_matches]
|
| 1122 |
+
for modality, token_matches in mm_text_matches.items()
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
placeholders = self._find_mm_placeholders(
|
| 1126 |
+
matched_repls,
|
| 1127 |
+
token_ids,
|
| 1128 |
+
mm_item_counts,
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
return token_ids, text, placeholders
|
| 1132 |
+
|
| 1133 |
+
def _validate_mm_kwargs(
|
| 1134 |
+
self,
|
| 1135 |
+
mm_kwargs: MultiModalKwargs,
|
| 1136 |
+
mm_item_counts: Mapping[str, int],
|
| 1137 |
+
) -> None:
|
| 1138 |
+
for modality, item_count in mm_item_counts.items():
|
| 1139 |
+
if modality in mm_kwargs.modalities:
|
| 1140 |
+
items = mm_kwargs.get_items(modality)
|
| 1141 |
+
else:
|
| 1142 |
+
items = []
|
| 1143 |
+
|
| 1144 |
+
if len(items) != item_count:
|
| 1145 |
+
raise RuntimeError(
|
| 1146 |
+
f"Expected there to be {item_count} {modality} items in "
|
| 1147 |
+
f"keyword arguments corresponding to {item_count} "
|
| 1148 |
+
f"{modality} data items, but only found {len(items)}! "
|
| 1149 |
+
"There is likely a problem with your "
|
| 1150 |
+
"implementation of merged multi-modal processor for this "
|
| 1151 |
+
"model (usually arising from an inconsistency between "
|
| 1152 |
+
"`_call_hf_processor` and `_get_mm_fields_config`).")
|
| 1153 |
+
|
| 1154 |
+
def _validate_mm_placeholders(
|
| 1155 |
+
self,
|
| 1156 |
+
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
| 1157 |
+
mm_item_counts: Mapping[str, int],
|
| 1158 |
+
*,
|
| 1159 |
+
allow_missing: bool = False,
|
| 1160 |
+
) -> Mapping[str, int]:
|
| 1161 |
+
missing_repl_counts = dict[str, int]()
|
| 1162 |
+
|
| 1163 |
+
for modality, item_count in mm_item_counts.items():
|
| 1164 |
+
placeholders = mm_placeholders.get(modality, [])
|
| 1165 |
+
|
| 1166 |
+
if len(placeholders) != item_count and not allow_missing:
|
| 1167 |
+
raise RuntimeError(
|
| 1168 |
+
f"Expected there to be {item_count} prompt replacements "
|
| 1169 |
+
f"corresponding to {item_count} {modality} items, but only "
|
| 1170 |
+
f"found {len(placeholders)} prompt replacements! Either "
|
| 1171 |
+
"the prompt text has missing/incorrect tokens for "
|
| 1172 |
+
"multi-modal inputs, or there is a problem with your "
|
| 1173 |
+
"implementation of merged multi-modal processor for this "
|
| 1174 |
+
"model (usually arising from an inconsistency between "
|
| 1175 |
+
"`_call_hf_processor` and `_get_prompt_replacements`).")
|
| 1176 |
+
|
| 1177 |
+
missing_repl_counts[modality] = item_count - len(placeholders)
|
| 1178 |
+
|
| 1179 |
+
return missing_repl_counts
|
| 1180 |
+
|
| 1181 |
+
def apply(
|
| 1182 |
+
self,
|
| 1183 |
+
prompt: Union[str, list[int]],
|
| 1184 |
+
mm_data: MultiModalDataDict,
|
| 1185 |
+
hf_processor_mm_kwargs: Mapping[str, object],
|
| 1186 |
+
) -> MultiModalInputs:
|
| 1187 |
+
"""
|
| 1188 |
+
Process multi-modal inputs to be used in vLLM.
|
| 1189 |
+
|
| 1190 |
+
The main steps are:
|
| 1191 |
+
|
| 1192 |
+
1. Apply HF Processor on prompt text and multi-modal data together,
|
| 1193 |
+
outputting token IDs and processed tensors.
|
| 1194 |
+
2. Find and replace sequences in the token IDs with placeholder tokens.
|
| 1195 |
+
The number of placeholder tokens equals the feature size of the
|
| 1196 |
+
multi-modal data outputted by the multi-modal encoder.
|
| 1197 |
+
3. Extract information about the placeholder tokens from the
|
| 1198 |
+
processed token IDs.
|
| 1199 |
+
"""
|
| 1200 |
+
mm_items = self._to_mm_items(mm_data)
|
| 1201 |
+
|
| 1202 |
+
# Create MM hashes (only used in V1)
|
| 1203 |
+
# TODO: Use these hash keys for caching operations in apply_hf_processor
|
| 1204 |
+
# instead of rehashing.
|
| 1205 |
+
|
| 1206 |
+
if envs.VLLM_USE_V1:
|
| 1207 |
+
model_id = self.info.model_id
|
| 1208 |
+
mm_hashes = {
|
| 1209 |
+
modality: [
|
| 1210 |
+
MultiModalHasher.hash_kwargs(model_id=model_id,
|
| 1211 |
+
**{modality: item},
|
| 1212 |
+
**hf_processor_mm_kwargs)
|
| 1213 |
+
for item in items
|
| 1214 |
+
]
|
| 1215 |
+
for modality, items in mm_items.items()
|
| 1216 |
+
}
|
| 1217 |
+
else:
|
| 1218 |
+
mm_hashes = None
|
| 1219 |
+
|
| 1220 |
+
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
|
| 1221 |
+
prompt,
|
| 1222 |
+
mm_items,
|
| 1223 |
+
hf_processor_mm_kwargs,
|
| 1224 |
+
)
|
| 1225 |
+
|
| 1226 |
+
unbound_prompt_repls = self._get_prompt_replacements(
|
| 1227 |
+
mm_items,
|
| 1228 |
+
hf_processor_mm_kwargs,
|
| 1229 |
+
mm_kwargs,
|
| 1230 |
+
)
|
| 1231 |
+
mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
|
| 1232 |
+
|
| 1233 |
+
mm_item_counts = mm_items.get_all_counts()
|
| 1234 |
+
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
| 1235 |
+
|
| 1236 |
+
hf_mm_placeholders = self._find_mm_placeholders(
|
| 1237 |
+
mm_prompt_repls,
|
| 1238 |
+
prompt_ids,
|
| 1239 |
+
mm_item_counts,
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
if self._always_apply_prompt_replacements():
|
| 1243 |
+
mm_missing_repl_counts = mm_item_counts
|
| 1244 |
+
mm_missing_repls = dict(mm_prompt_repls)
|
| 1245 |
+
else:
|
| 1246 |
+
mm_missing_repl_counts = self._validate_mm_placeholders(
|
| 1247 |
+
hf_mm_placeholders,
|
| 1248 |
+
mm_item_counts,
|
| 1249 |
+
allow_missing=True,
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
|
| 1253 |
+
for modality, missing_repl_count in mm_missing_repl_counts.items():
|
| 1254 |
+
if missing_repl_count == 0:
|
| 1255 |
+
mm_missing_repls[modality] = []
|
| 1256 |
+
elif missing_repl_count == mm_item_counts.get(modality, 0):
|
| 1257 |
+
mm_missing_repls[modality] = mm_prompt_repls[modality]
|
| 1258 |
+
else:
|
| 1259 |
+
raise ValueError("Partial prompt replacement within "
|
| 1260 |
+
f"{modality=} is not supported")
|
| 1261 |
+
|
| 1262 |
+
# If HF processor already inserts placeholder tokens,
|
| 1263 |
+
# there is no need for us to insert them
|
| 1264 |
+
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
|
| 1265 |
+
tokenizer = self.info.get_tokenizer()
|
| 1266 |
+
prompt = decode_tokens(tokenizer, prompt_ids)
|
| 1267 |
+
mm_placeholders = hf_mm_placeholders
|
| 1268 |
+
else:
|
| 1269 |
+
(
|
| 1270 |
+
prompt_ids,
|
| 1271 |
+
prompt,
|
| 1272 |
+
missing_mm_placeholders,
|
| 1273 |
+
) = self._apply_prompt_replacements(
|
| 1274 |
+
prompt_ids,
|
| 1275 |
+
mm_missing_repls,
|
| 1276 |
+
mm_missing_repl_counts,
|
| 1277 |
+
)
|
| 1278 |
+
|
| 1279 |
+
mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
|
| 1280 |
+
|
| 1281 |
+
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
| 1282 |
+
|
| 1283 |
+
mm_placeholder_ranges = {
|
| 1284 |
+
modality: [item.to_range() for item in placeholders]
|
| 1285 |
+
for modality, placeholders in mm_placeholders.items()
|
| 1286 |
+
}
|
| 1287 |
+
|
| 1288 |
+
return MultiModalInputs(
|
| 1289 |
+
type="multimodal",
|
| 1290 |
+
prompt=prompt,
|
| 1291 |
+
prompt_token_ids=prompt_ids,
|
| 1292 |
+
mm_kwargs=mm_kwargs,
|
| 1293 |
+
mm_hashes=mm_hashes,
|
| 1294 |
+
mm_placeholders=mm_placeholder_ranges,
|
| 1295 |
+
)
|
.venv/lib/python3.11/site-packages/vllm/multimodal/profiling.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from collections.abc import Mapping
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Generic, TypeVar
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import numpy.typing as npt
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
import vllm.envs as envs
|
| 13 |
+
from vllm.inputs import DummyData
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
|
| 16 |
+
from .inputs import MultiModalDataDict, MultiModalInputs
|
| 17 |
+
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
|
| 18 |
+
|
| 19 |
+
logger = init_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ProcessorInputs:
|
| 24 |
+
"""
|
| 25 |
+
Represents the keyword arguments to
|
| 26 |
+
:meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
|
| 27 |
+
"""
|
| 28 |
+
prompt_text: str
|
| 29 |
+
mm_data: MultiModalDataDict
|
| 30 |
+
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
| 37 |
+
"""
|
| 38 |
+
Abstract base class that constructs the dummy data to profile
|
| 39 |
+
multi-modal models.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, info: _I) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.info = info
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def get_dummy_processor_inputs(
|
| 49 |
+
self,
|
| 50 |
+
seq_len: int,
|
| 51 |
+
mm_counts: Mapping[str, int],
|
| 52 |
+
) -> ProcessorInputs:
|
| 53 |
+
"""
|
| 54 |
+
Build the input which, after processing, results in
|
| 55 |
+
:code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
|
| 56 |
+
"""
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
def _get_dummy_audios(
|
| 60 |
+
self,
|
| 61 |
+
*,
|
| 62 |
+
length: int,
|
| 63 |
+
num_audios: int,
|
| 64 |
+
) -> list[npt.NDArray]:
|
| 65 |
+
audio = np.zeros((length, ))
|
| 66 |
+
return [audio] * num_audios
|
| 67 |
+
|
| 68 |
+
def _get_dummy_images(
|
| 69 |
+
self,
|
| 70 |
+
*,
|
| 71 |
+
width: int,
|
| 72 |
+
height: int,
|
| 73 |
+
num_images: int,
|
| 74 |
+
) -> list[Image.Image]:
|
| 75 |
+
image = Image.new("RGB", (width, height), color=0)
|
| 76 |
+
return [image] * num_images
|
| 77 |
+
|
| 78 |
+
def _get_dummy_videos(
|
| 79 |
+
self,
|
| 80 |
+
*,
|
| 81 |
+
width: int,
|
| 82 |
+
height: int,
|
| 83 |
+
num_frames: int,
|
| 84 |
+
num_videos: int,
|
| 85 |
+
) -> list[npt.NDArray]:
|
| 86 |
+
video = np.zeros((num_frames, width, height, 3))
|
| 87 |
+
return [video] * num_videos
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class MultiModalProfiler(Generic[_I]):
|
| 91 |
+
"""
|
| 92 |
+
Contains code for running memory profiling for multi-modal models.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
processor: BaseMultiModalProcessor[_I],
|
| 98 |
+
) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
self.processor = processor
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def processing_info(self) -> BaseProcessingInfo:
|
| 105 |
+
return self.processor.info
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
|
| 109 |
+
return self.processor.dummy_inputs
|
| 110 |
+
|
| 111 |
+
def get_mm_limits(self) -> Mapping[str, int]:
|
| 112 |
+
mm_config = self.processing_info.ctx.get_mm_config()
|
| 113 |
+
mm_limit_per_prompt = mm_config.limit_per_prompt
|
| 114 |
+
|
| 115 |
+
supported_mm_limits = self.processing_info.get_supported_mm_limits()
|
| 116 |
+
|
| 117 |
+
mm_limits = {
|
| 118 |
+
modality: mm_limit_per_prompt.get(modality, 1)
|
| 119 |
+
for modality in supported_mm_limits
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
for modality, supported_limit in supported_mm_limits.items():
|
| 123 |
+
limit = mm_limits[modality]
|
| 124 |
+
if supported_limit is not None and supported_limit < limit:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"You set {modality}={limit} (or defaulted to 1) in "
|
| 127 |
+
f"`--limit-mm-per-prompt`, but this model only supports "
|
| 128 |
+
f"at most {supported_limit} {modality} items.")
|
| 129 |
+
|
| 130 |
+
return mm_limits
|
| 131 |
+
|
| 132 |
+
def _get_dummy_mm_inputs(
|
| 133 |
+
self,
|
| 134 |
+
seq_len: int,
|
| 135 |
+
mm_counts: Mapping[str, int],
|
| 136 |
+
) -> MultiModalInputs:
|
| 137 |
+
factory = self.dummy_inputs
|
| 138 |
+
processor_inputs = factory.get_dummy_processor_inputs(
|
| 139 |
+
seq_len, mm_counts)
|
| 140 |
+
|
| 141 |
+
return self.processor.apply(
|
| 142 |
+
prompt=processor_inputs.prompt_text,
|
| 143 |
+
mm_data=processor_inputs.mm_data,
|
| 144 |
+
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def get_dummy_data(self, seq_len: int) -> DummyData:
|
| 148 |
+
# Avoid circular import
|
| 149 |
+
from vllm.sequence import SequenceData
|
| 150 |
+
|
| 151 |
+
mm_counts = self.get_mm_limits()
|
| 152 |
+
|
| 153 |
+
info = self.processing_info
|
| 154 |
+
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
|
| 155 |
+
seq_len, mm_counts)
|
| 156 |
+
|
| 157 |
+
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
| 158 |
+
raise AssertionError(
|
| 159 |
+
"The keys returned by `get_supported_mm_limits`"
|
| 160 |
+
f"({set(mm_counts.keys())}) should be the same as those "
|
| 161 |
+
"returned by `get_mm_max_tokens_per_item` "
|
| 162 |
+
f"({set(mm_max_tokens_per_item.keys())})")
|
| 163 |
+
|
| 164 |
+
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
| 165 |
+
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
| 166 |
+
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
| 167 |
+
|
| 168 |
+
total_placeholders_by_modality = {
|
| 169 |
+
modality: sum(item["length"] for item in placeholders)
|
| 170 |
+
for modality, placeholders in placeholders_by_modality.items()
|
| 171 |
+
}
|
| 172 |
+
expected_placeholders_by_modality = {
|
| 173 |
+
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
|
| 174 |
+
for modality in placeholders_by_modality
|
| 175 |
+
}
|
| 176 |
+
if total_placeholders_by_modality != expected_placeholders_by_modality:
|
| 177 |
+
raise AssertionError(
|
| 178 |
+
f"The processed dummy data has a total of "
|
| 179 |
+
f"{total_placeholders_by_modality} placeholder tokens, which "
|
| 180 |
+
f"is not the expected {expected_placeholders_by_modality} "
|
| 181 |
+
"tokens.")
|
| 182 |
+
|
| 183 |
+
total_len = len(prompt_token_ids)
|
| 184 |
+
|
| 185 |
+
# V0 does not support chunked prefill.
|
| 186 |
+
if total_len > seq_len and not envs.VLLM_USE_V1:
|
| 187 |
+
logger.warning(
|
| 188 |
+
"The context length (%d) of the model is too short "
|
| 189 |
+
"to hold the multi-modal embeddings in the worst case "
|
| 190 |
+
"(%d tokens in total, out of which %s are reserved for "
|
| 191 |
+
"multi-modal embeddings). This may cause certain multi-modal "
|
| 192 |
+
"inputs to fail during inference, even when the input text is "
|
| 193 |
+
"short. To avoid this, you should increase `max_model_len`, "
|
| 194 |
+
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
|
| 195 |
+
total_len, total_placeholders_by_modality)
|
| 196 |
+
|
| 197 |
+
return DummyData(
|
| 198 |
+
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
| 199 |
+
multi_modal_data=None,
|
| 200 |
+
multi_modal_placeholders=None,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
|
| 204 |
+
|
| 205 |
+
return DummyData(
|
| 206 |
+
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
| 207 |
+
multi_modal_data=mm_inputs["mm_kwargs"],
|
| 208 |
+
multi_modal_placeholders=placeholders_by_modality,
|
| 209 |
+
)
|
.venv/lib/python3.11/site-packages/vllm/multimodal/registry.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
from collections import UserDict
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
|
| 7 |
+
Protocol, Sequence, Type, TypeVar)
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from vllm.inputs import InputProcessingContext
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 14 |
+
from vllm.utils import ClassRegistry
|
| 15 |
+
|
| 16 |
+
from .audio import AudioPlugin
|
| 17 |
+
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
| 18 |
+
from .image import ImagePlugin
|
| 19 |
+
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
| 20 |
+
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
| 21 |
+
ProcessingCache)
|
| 22 |
+
from .profiling import BaseDummyInputsBuilder, MultiModalProfiler
|
| 23 |
+
from .utils import cached_get_tokenizer
|
| 24 |
+
from .video import VideoPlugin
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from vllm.config import ModelConfig
|
| 28 |
+
|
| 29 |
+
logger = init_logger(__name__)
|
| 30 |
+
|
| 31 |
+
# TODO: Tune the MM cache size
|
| 32 |
+
MM_CACHE_SIZE = 256
|
| 33 |
+
|
| 34 |
+
N = TypeVar("N", bound=Type[nn.Module])
|
| 35 |
+
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
| 36 |
+
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ProcessingInfoFactory(Protocol[_I_co]):
|
| 40 |
+
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
|
| 41 |
+
|
| 42 |
+
def __call__(
|
| 43 |
+
self,
|
| 44 |
+
ctx: InputProcessingContext,
|
| 45 |
+
) -> _I_co:
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DummyInputsBuilderFactory(Protocol[_I]):
|
| 50 |
+
"""
|
| 51 |
+
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
|
| 55 |
+
...
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MultiModalProcessorFactory(Protocol[_I]):
|
| 59 |
+
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
|
| 60 |
+
|
| 61 |
+
def __call__(
|
| 62 |
+
self,
|
| 63 |
+
info: _I,
|
| 64 |
+
dummy_inputs: BaseDummyInputsBuilder[_I],
|
| 65 |
+
*,
|
| 66 |
+
cache: Optional[ProcessingCache] = None,
|
| 67 |
+
) -> BaseMultiModalProcessor[_I]:
|
| 68 |
+
...
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass(frozen=True)
|
| 72 |
+
class _ProcessorFactories(Generic[_I]):
|
| 73 |
+
info: ProcessingInfoFactory[_I]
|
| 74 |
+
processor: MultiModalProcessorFactory[_I]
|
| 75 |
+
dummy_inputs: DummyInputsBuilderFactory[_I]
|
| 76 |
+
|
| 77 |
+
def build_processor(
|
| 78 |
+
self,
|
| 79 |
+
ctx: InputProcessingContext,
|
| 80 |
+
*,
|
| 81 |
+
cache: Optional[ProcessingCache] = None,
|
| 82 |
+
):
|
| 83 |
+
info = self.info(ctx)
|
| 84 |
+
dummy_inputs_builder = self.dummy_inputs(info)
|
| 85 |
+
return self.processor(info, dummy_inputs_builder, cache=cache)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
|
| 89 |
+
"""
|
| 90 |
+
Wraps `_limits_by_model` for a more informative error message
|
| 91 |
+
when attempting to access a model that does not exist.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
|
| 95 |
+
try:
|
| 96 |
+
return super().__getitem__(key)
|
| 97 |
+
except KeyError as exc:
|
| 98 |
+
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
|
| 99 |
+
"forget to call `init_mm_limits_per_prompt`?")
|
| 100 |
+
raise KeyError(msg) from exc
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MultiModalRegistry:
|
| 104 |
+
"""
|
| 105 |
+
A registry that dispatches data processing according to the model.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
*,
|
| 113 |
+
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
|
| 114 |
+
self._plugins = {p.get_data_key(): p for p in plugins}
|
| 115 |
+
|
| 116 |
+
self._processor_factories = ClassRegistry[nn.Module,
|
| 117 |
+
_ProcessorFactories]()
|
| 118 |
+
|
| 119 |
+
# This is used for non-multimodal models
|
| 120 |
+
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
|
| 121 |
+
|
| 122 |
+
self._limits_by_model = _MultiModalLimits()
|
| 123 |
+
|
| 124 |
+
self._processing_cache = ProcessingCache(MM_CACHE_SIZE)
|
| 125 |
+
|
| 126 |
+
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
| 127 |
+
"""
|
| 128 |
+
Register a multi-modal plugin so it can be recognized by vLLM.
|
| 129 |
+
"""
|
| 130 |
+
data_type_key = plugin.get_data_key()
|
| 131 |
+
|
| 132 |
+
if data_type_key in self._plugins:
|
| 133 |
+
logger.warning(
|
| 134 |
+
"A plugin is already registered for data type %s, "
|
| 135 |
+
"and will be overwritten by the new plugin %s.", data_type_key,
|
| 136 |
+
plugin)
|
| 137 |
+
|
| 138 |
+
self._plugins[data_type_key] = plugin
|
| 139 |
+
|
| 140 |
+
def _get_plugin(self, data_type_key: str):
|
| 141 |
+
plugin = self._plugins.get(data_type_key)
|
| 142 |
+
if plugin is not None:
|
| 143 |
+
return plugin
|
| 144 |
+
|
| 145 |
+
msg = f"Unknown multi-modal data type: {data_type_key}"
|
| 146 |
+
raise NotImplementedError(msg)
|
| 147 |
+
|
| 148 |
+
def register_input_mapper(
|
| 149 |
+
self,
|
| 150 |
+
data_type_key: str,
|
| 151 |
+
mapper: Optional[MultiModalInputMapper] = None,
|
| 152 |
+
):
|
| 153 |
+
"""
|
| 154 |
+
Register an input mapper for a specific modality to a model class.
|
| 155 |
+
|
| 156 |
+
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
| 157 |
+
"""
|
| 158 |
+
return self._get_plugin(data_type_key).register_input_mapper(mapper)
|
| 159 |
+
|
| 160 |
+
def register_image_input_mapper(
|
| 161 |
+
self,
|
| 162 |
+
mapper: Optional[MultiModalInputMapper] = None,
|
| 163 |
+
):
|
| 164 |
+
"""
|
| 165 |
+
Register an input mapper for image data to a model class.
|
| 166 |
+
|
| 167 |
+
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
| 168 |
+
"""
|
| 169 |
+
return self.register_input_mapper("image", mapper)
|
| 170 |
+
|
| 171 |
+
def map_input(
|
| 172 |
+
self,
|
| 173 |
+
model_config: "ModelConfig",
|
| 174 |
+
data: MultiModalDataDict,
|
| 175 |
+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
| 176 |
+
) -> MultiModalKwargs:
|
| 177 |
+
"""
|
| 178 |
+
Apply an input mapper to the data passed to the model.
|
| 179 |
+
|
| 180 |
+
The data belonging to each modality is passed to the corresponding
|
| 181 |
+
plugin which in turn converts the data into into keyword arguments
|
| 182 |
+
via the input mapper registered for that model.
|
| 183 |
+
|
| 184 |
+
See :meth:`MultiModalPlugin.map_input` for more details.
|
| 185 |
+
|
| 186 |
+
Note:
|
| 187 |
+
This should be called after :meth:`init_mm_limits_per_prompt`.
|
| 188 |
+
"""
|
| 189 |
+
merged_dict: Dict[str, NestedTensors] = {}
|
| 190 |
+
|
| 191 |
+
for data_key, data_value in data.items():
|
| 192 |
+
plugin = self._get_plugin(data_key)
|
| 193 |
+
|
| 194 |
+
num_items = len(data_value) if isinstance(data_value, list) else 1
|
| 195 |
+
max_items = self._limits_by_model[model_config][data_key]
|
| 196 |
+
if num_items > max_items:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"You set {data_key}={max_items} (or defaulted to 1) in "
|
| 199 |
+
f"`--limit-mm-per-prompt`, but found {num_items} items "
|
| 200 |
+
"in the same prompt.")
|
| 201 |
+
|
| 202 |
+
input_dict = plugin.map_input(model_config, data_value,
|
| 203 |
+
mm_processor_kwargs)
|
| 204 |
+
for input_key, input_tensor in input_dict.items():
|
| 205 |
+
if input_key in merged_dict:
|
| 206 |
+
raise ValueError(f"The input mappers (keys={set(data)}) "
|
| 207 |
+
f"resulted in a conflicting keyword "
|
| 208 |
+
f"argument to `forward()`: {input_key}")
|
| 209 |
+
|
| 210 |
+
merged_dict[input_key] = input_tensor
|
| 211 |
+
|
| 212 |
+
return MultiModalKwargs(merged_dict)
|
| 213 |
+
|
| 214 |
+
def create_input_mapper(self, model_config: "ModelConfig"):
|
| 215 |
+
"""
|
| 216 |
+
Create an input mapper (see :meth:`map_input`) for a specific model.
|
| 217 |
+
"""
|
| 218 |
+
# NOTE - we currently make the assumption that if a model has multiple
|
| 219 |
+
# supported modalities, they take the same kwargs. For the default,
|
| 220 |
+
# this could be an issue in the future if it falls back to two HF
|
| 221 |
+
# resources and we can't inspect the signature easily since it's
|
| 222 |
+
# getting initialized through the autoclass.
|
| 223 |
+
#
|
| 224 |
+
# If this is a problem in the future, we should revisit it, but since
|
| 225 |
+
# it potentially introduces a lot of complexity for a currently
|
| 226 |
+
# uncommon case, we do not for simplicity of both use & implementation
|
| 227 |
+
return functools.partial(self.map_input, model_config)
|
| 228 |
+
|
| 229 |
+
def register_max_multimodal_tokens(
|
| 230 |
+
self,
|
| 231 |
+
data_type_key: str,
|
| 232 |
+
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Register the maximum number of tokens, corresponding to a single
|
| 236 |
+
instance of multimodal data belonging to a specific modality, that are
|
| 237 |
+
passed to the language model for a model class.
|
| 238 |
+
"""
|
| 239 |
+
return self._get_plugin(data_type_key) \
|
| 240 |
+
.register_max_multimodal_tokens(max_mm_tokens)
|
| 241 |
+
|
| 242 |
+
def register_max_image_tokens(
|
| 243 |
+
self,
|
| 244 |
+
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
| 245 |
+
):
|
| 246 |
+
"""
|
| 247 |
+
Register the maximum number of image tokens, corresponding to a single
|
| 248 |
+
image, that are passed to the language model for a model class.
|
| 249 |
+
"""
|
| 250 |
+
return self.register_max_multimodal_tokens("image", max_mm_tokens)
|
| 251 |
+
|
| 252 |
+
def get_max_tokens_per_item_by_modality(
|
| 253 |
+
self,
|
| 254 |
+
model_config: "ModelConfig",
|
| 255 |
+
) -> Mapping[str, int]:
|
| 256 |
+
"""
|
| 257 |
+
Get the maximum number of tokens per data item from each modality based
|
| 258 |
+
on underlying model configuration.
|
| 259 |
+
"""
|
| 260 |
+
if self.has_processor(model_config):
|
| 261 |
+
tokenizer = cached_get_tokenizer(
|
| 262 |
+
model_config.tokenizer,
|
| 263 |
+
trust_remote_code=model_config.trust_remote_code,
|
| 264 |
+
)
|
| 265 |
+
processor = self.create_processor(model_config, tokenizer)
|
| 266 |
+
seq_len = model_config.max_model_len
|
| 267 |
+
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
| 268 |
+
return processor.info.get_mm_max_tokens_per_item(
|
| 269 |
+
seq_len, mm_limits)
|
| 270 |
+
|
| 271 |
+
return {
|
| 272 |
+
key: plugin.get_max_multimodal_tokens(model_config)
|
| 273 |
+
for key, plugin in self._plugins.items()
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
def get_max_tokens_per_item_by_nonzero_modality(
|
| 277 |
+
self,
|
| 278 |
+
model_config: "ModelConfig",
|
| 279 |
+
) -> Mapping[str, int]:
|
| 280 |
+
"""
|
| 281 |
+
Get the maximum number of tokens per data item from each modality based
|
| 282 |
+
on underlying model configuration, excluding modalities that user
|
| 283 |
+
explicitly disabled via `limit_mm_per_prompt`.
|
| 284 |
+
|
| 285 |
+
Note:
|
| 286 |
+
This is currently directly used only in V1 for profiling the memory
|
| 287 |
+
usage of a model.
|
| 288 |
+
"""
|
| 289 |
+
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
| 290 |
+
|
| 291 |
+
return {
|
| 292 |
+
key: max_tokens_per_mm_item
|
| 293 |
+
for key, max_tokens_per_mm_item in
|
| 294 |
+
self.get_max_tokens_per_item_by_modality(model_config).items()
|
| 295 |
+
if mm_limits[key] > 0
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
def get_max_tokens_by_modality(
|
| 299 |
+
self,
|
| 300 |
+
model_config: "ModelConfig",
|
| 301 |
+
) -> Mapping[str, int]:
|
| 302 |
+
"""
|
| 303 |
+
Get the maximum number of tokens from each modality
|
| 304 |
+
for profiling the memory usage of a model.
|
| 305 |
+
|
| 306 |
+
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
|
| 307 |
+
|
| 308 |
+
Note:
|
| 309 |
+
This should be called after :meth:`init_mm_limits_per_prompt`.
|
| 310 |
+
"""
|
| 311 |
+
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
| 312 |
+
|
| 313 |
+
return {
|
| 314 |
+
key: mm_limits[key] * max_tokens_per_mm_item
|
| 315 |
+
for key, max_tokens_per_mm_item in
|
| 316 |
+
self.get_max_tokens_per_item_by_modality(model_config).items()
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
| 320 |
+
"""
|
| 321 |
+
Get the maximum number of multi-modal tokens
|
| 322 |
+
for profiling the memory usage of a model.
|
| 323 |
+
|
| 324 |
+
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
|
| 325 |
+
|
| 326 |
+
Note:
|
| 327 |
+
This should be called after :meth:`init_mm_limits_per_prompt`.
|
| 328 |
+
"""
|
| 329 |
+
return sum(self.get_max_tokens_by_modality(model_config).values())
|
| 330 |
+
|
| 331 |
+
def init_mm_limits_per_prompt(
|
| 332 |
+
self,
|
| 333 |
+
model_config: "ModelConfig",
|
| 334 |
+
) -> None:
|
| 335 |
+
"""
|
| 336 |
+
Initialize the maximum number of multi-modal input instances for each
|
| 337 |
+
modality that are allowed per prompt for a model class.
|
| 338 |
+
"""
|
| 339 |
+
if model_config in self._limits_by_model:
|
| 340 |
+
logger.warning(
|
| 341 |
+
"`mm_limits` has already been set for model=%s, and will "
|
| 342 |
+
"be overwritten by the new values.", model_config.model)
|
| 343 |
+
|
| 344 |
+
multimodal_config = model_config.multimodal_config
|
| 345 |
+
if multimodal_config is None:
|
| 346 |
+
limits_per_plugin = self._disabled_limits_per_plugin
|
| 347 |
+
else:
|
| 348 |
+
config_limits_per_plugin = multimodal_config.limit_per_prompt
|
| 349 |
+
|
| 350 |
+
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
|
| 351 |
+
if extra_keys:
|
| 352 |
+
logger.warning(
|
| 353 |
+
"Detected extra keys in `--limit-mm-per-prompt` which "
|
| 354 |
+
"are not registered as multi-modal plugins: %s. "
|
| 355 |
+
"They will be ignored.", extra_keys)
|
| 356 |
+
|
| 357 |
+
# NOTE: Currently the default is set to 1 for each plugin
|
| 358 |
+
# TODO: Automatically determine the limits based on budget
|
| 359 |
+
# once more models support multi-image inputs
|
| 360 |
+
limits_per_plugin = {
|
| 361 |
+
key: config_limits_per_plugin.get(key, 1)
|
| 362 |
+
for key in self._plugins
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
self._limits_by_model[model_config] = limits_per_plugin
|
| 366 |
+
|
| 367 |
+
def get_mm_limits_per_prompt(
|
| 368 |
+
self,
|
| 369 |
+
model_config: "ModelConfig",
|
| 370 |
+
) -> Mapping[str, int]:
|
| 371 |
+
"""
|
| 372 |
+
Get the maximum number of multi-modal input instances for each modality
|
| 373 |
+
that are allowed per prompt for a model class.
|
| 374 |
+
|
| 375 |
+
Note:
|
| 376 |
+
This should be called after :meth:`init_mm_limits_per_prompt`.
|
| 377 |
+
"""
|
| 378 |
+
if self.has_processor(model_config):
|
| 379 |
+
tokenizer = cached_get_tokenizer(
|
| 380 |
+
model_config.tokenizer,
|
| 381 |
+
trust_remote_code=model_config.trust_remote_code,
|
| 382 |
+
)
|
| 383 |
+
processor = self.create_processor(model_config, tokenizer)
|
| 384 |
+
profiler = MultiModalProfiler(processor)
|
| 385 |
+
return profiler.get_mm_limits()
|
| 386 |
+
|
| 387 |
+
return self._limits_by_model[model_config]
|
| 388 |
+
|
| 389 |
+
def register_processor(
|
| 390 |
+
self,
|
| 391 |
+
processor: MultiModalProcessorFactory[_I],
|
| 392 |
+
*,
|
| 393 |
+
info: ProcessingInfoFactory[_I],
|
| 394 |
+
dummy_inputs: DummyInputsBuilderFactory[_I],
|
| 395 |
+
):
|
| 396 |
+
"""
|
| 397 |
+
Register a multi-modal processor to a model class. The processor
|
| 398 |
+
is constructed lazily, hence a factory method should be passed.
|
| 399 |
+
|
| 400 |
+
When the model receives multi-modal data, the provided function is
|
| 401 |
+
invoked to transform the data into a dictionary of model inputs.
|
| 402 |
+
|
| 403 |
+
See also:
|
| 404 |
+
:ref:`mm-processing`
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def wrapper(model_cls: N) -> N:
|
| 408 |
+
if self._processor_factories.contains(model_cls, strict=True):
|
| 409 |
+
logger.warning(
|
| 410 |
+
"Model class %s already has a multi-modal processor "
|
| 411 |
+
"registered to %s. It is overwritten by the new one.",
|
| 412 |
+
model_cls, self)
|
| 413 |
+
|
| 414 |
+
self._processor_factories[model_cls] = _ProcessorFactories(
|
| 415 |
+
info=info,
|
| 416 |
+
dummy_inputs=dummy_inputs,
|
| 417 |
+
processor=processor,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
return model_cls
|
| 421 |
+
|
| 422 |
+
return wrapper
|
| 423 |
+
|
| 424 |
+
def _get_model_cls(self, model_config: "ModelConfig"):
|
| 425 |
+
# Avoid circular import
|
| 426 |
+
from vllm.model_executor.model_loader import get_model_architecture
|
| 427 |
+
|
| 428 |
+
model_cls, _ = get_model_architecture(model_config)
|
| 429 |
+
return model_cls
|
| 430 |
+
|
| 431 |
+
def has_processor(self, model_config: "ModelConfig") -> bool:
|
| 432 |
+
"""
|
| 433 |
+
Test whether a multi-modal processor is defined for a specific model.
|
| 434 |
+
|
| 435 |
+
See also:
|
| 436 |
+
:ref:`mm-processing`
|
| 437 |
+
"""
|
| 438 |
+
return self._get_model_cls(model_config) in self._processor_factories
|
| 439 |
+
|
| 440 |
+
def create_processor(
|
| 441 |
+
self,
|
| 442 |
+
model_config: "ModelConfig",
|
| 443 |
+
tokenizer: AnyTokenizer,
|
| 444 |
+
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
| 445 |
+
"""
|
| 446 |
+
Create a multi-modal processor for a specific model and tokenizer.
|
| 447 |
+
|
| 448 |
+
See also:
|
| 449 |
+
:ref:`mm-processing`
|
| 450 |
+
"""
|
| 451 |
+
model_cls = self._get_model_cls(model_config)
|
| 452 |
+
factories = self._processor_factories[model_cls]
|
| 453 |
+
|
| 454 |
+
ctx = InputProcessingContext(model_config, tokenizer)
|
| 455 |
+
cache = (None if model_config.disable_mm_preprocessor_cache else
|
| 456 |
+
self._processing_cache)
|
| 457 |
+
|
| 458 |
+
return factories.build_processor(ctx, cache=cache)
|
.venv/lib/python3.11/site-packages/vllm/multimodal/utils.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from itertools import groupby
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import TYPE_CHECKING, Optional, TypeVar, Union
|
| 7 |
+
from urllib.parse import ParseResult, urlparse
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import numpy.typing as npt
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
import vllm.envs as envs
|
| 14 |
+
from vllm.connections import HTTPConnection, global_http_connection
|
| 15 |
+
from vllm.logger import init_logger
|
| 16 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
| 17 |
+
|
| 18 |
+
from .audio import AudioMediaIO
|
| 19 |
+
from .base import MediaIO
|
| 20 |
+
from .image import ImageMediaIO
|
| 21 |
+
from .inputs import PlaceholderRange
|
| 22 |
+
from .video import VideoMediaIO
|
| 23 |
+
|
| 24 |
+
logger = init_logger(__name__)
|
| 25 |
+
|
| 26 |
+
cached_get_tokenizer = lru_cache(get_tokenizer)
|
| 27 |
+
|
| 28 |
+
_M = TypeVar("_M")
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
from .hasher import MultiModalHashDict
|
| 32 |
+
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MediaConnector:
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
connection: HTTPConnection = global_http_connection,
|
| 40 |
+
*,
|
| 41 |
+
allowed_local_media_path: str = "",
|
| 42 |
+
) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.connection = connection
|
| 46 |
+
|
| 47 |
+
if allowed_local_media_path:
|
| 48 |
+
allowed_local_media_path_ = Path(allowed_local_media_path)
|
| 49 |
+
|
| 50 |
+
if not allowed_local_media_path_.exists():
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"Invalid `--allowed-local-media-path`: The path "
|
| 53 |
+
f"{allowed_local_media_path_} does not exist.")
|
| 54 |
+
if not allowed_local_media_path_.is_dir():
|
| 55 |
+
raise ValueError(
|
| 56 |
+
"Invalid `--allowed-local-media-path`: The path "
|
| 57 |
+
f"{allowed_local_media_path_} must be a directory.")
|
| 58 |
+
else:
|
| 59 |
+
allowed_local_media_path_ = None
|
| 60 |
+
|
| 61 |
+
self.allowed_local_media_path = allowed_local_media_path_
|
| 62 |
+
|
| 63 |
+
def _load_data_url(
|
| 64 |
+
self,
|
| 65 |
+
url_spec: ParseResult,
|
| 66 |
+
media_io: MediaIO[_M],
|
| 67 |
+
) -> _M:
|
| 68 |
+
data_spec, data = url_spec.path.split(",", 1)
|
| 69 |
+
media_type, data_type = data_spec.split(";", 1)
|
| 70 |
+
|
| 71 |
+
if data_type != "base64":
|
| 72 |
+
msg = "Only base64 data URLs are supported for now."
|
| 73 |
+
raise NotImplementedError(msg)
|
| 74 |
+
|
| 75 |
+
return media_io.load_base64(media_type, data)
|
| 76 |
+
|
| 77 |
+
def _load_file_url(
|
| 78 |
+
self,
|
| 79 |
+
url_spec: ParseResult,
|
| 80 |
+
media_io: MediaIO[_M],
|
| 81 |
+
) -> _M:
|
| 82 |
+
allowed_local_media_path = self.allowed_local_media_path
|
| 83 |
+
if allowed_local_media_path is None:
|
| 84 |
+
raise RuntimeError("Cannot load local files without "
|
| 85 |
+
"`--allowed-local-media-path`.")
|
| 86 |
+
|
| 87 |
+
filepath = Path(url_spec.path)
|
| 88 |
+
if allowed_local_media_path not in filepath.resolve().parents:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"The file path {filepath} must be a subpath "
|
| 91 |
+
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
|
| 92 |
+
|
| 93 |
+
return media_io.load_file(filepath)
|
| 94 |
+
|
| 95 |
+
def load_from_url(
|
| 96 |
+
self,
|
| 97 |
+
url: str,
|
| 98 |
+
media_io: MediaIO[_M],
|
| 99 |
+
*,
|
| 100 |
+
fetch_timeout: Optional[int] = None,
|
| 101 |
+
) -> _M:
|
| 102 |
+
url_spec = urlparse(url)
|
| 103 |
+
|
| 104 |
+
if url_spec.scheme.startswith("http"):
|
| 105 |
+
connection = self.connection
|
| 106 |
+
data = connection.get_bytes(url, timeout=fetch_timeout)
|
| 107 |
+
|
| 108 |
+
return media_io.load_bytes(data)
|
| 109 |
+
|
| 110 |
+
if url_spec.scheme == "data":
|
| 111 |
+
return self._load_data_url(url_spec, media_io)
|
| 112 |
+
|
| 113 |
+
if url_spec.scheme == "file":
|
| 114 |
+
return self._load_file_url(url_spec, media_io)
|
| 115 |
+
|
| 116 |
+
msg = "The URL must be either a HTTP, data or file URL."
|
| 117 |
+
raise ValueError(msg)
|
| 118 |
+
|
| 119 |
+
async def load_from_url_async(
|
| 120 |
+
self,
|
| 121 |
+
url: str,
|
| 122 |
+
media_io: MediaIO[_M],
|
| 123 |
+
*,
|
| 124 |
+
fetch_timeout: Optional[int] = None,
|
| 125 |
+
) -> _M:
|
| 126 |
+
url_spec = urlparse(url)
|
| 127 |
+
|
| 128 |
+
if url_spec.scheme.startswith("http"):
|
| 129 |
+
connection = self.connection
|
| 130 |
+
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
| 131 |
+
|
| 132 |
+
return media_io.load_bytes(data)
|
| 133 |
+
|
| 134 |
+
if url_spec.scheme == "data":
|
| 135 |
+
return self._load_data_url(url_spec, media_io)
|
| 136 |
+
|
| 137 |
+
if url_spec.scheme == "file":
|
| 138 |
+
return self._load_file_url(url_spec, media_io)
|
| 139 |
+
|
| 140 |
+
msg = "The URL must be either a HTTP, data or file URL."
|
| 141 |
+
raise ValueError(msg)
|
| 142 |
+
|
| 143 |
+
def fetch_audio(
|
| 144 |
+
self,
|
| 145 |
+
audio_url: str,
|
| 146 |
+
) -> tuple[np.ndarray, Union[int, float]]:
|
| 147 |
+
"""
|
| 148 |
+
Load audio from a URL.
|
| 149 |
+
"""
|
| 150 |
+
audio_io = AudioMediaIO()
|
| 151 |
+
|
| 152 |
+
return self.load_from_url(
|
| 153 |
+
audio_url,
|
| 154 |
+
audio_io,
|
| 155 |
+
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
async def fetch_audio_async(
|
| 159 |
+
self,
|
| 160 |
+
audio_url: str,
|
| 161 |
+
) -> tuple[np.ndarray, Union[int, float]]:
|
| 162 |
+
"""
|
| 163 |
+
Asynchronously fetch audio from a URL.
|
| 164 |
+
"""
|
| 165 |
+
audio_io = AudioMediaIO()
|
| 166 |
+
|
| 167 |
+
return await self.load_from_url_async(
|
| 168 |
+
audio_url,
|
| 169 |
+
audio_io,
|
| 170 |
+
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def fetch_image(
|
| 174 |
+
self,
|
| 175 |
+
image_url: str,
|
| 176 |
+
*,
|
| 177 |
+
image_mode: str = "RGB",
|
| 178 |
+
) -> Image.Image:
|
| 179 |
+
"""
|
| 180 |
+
Load a PIL image from a HTTP or base64 data URL.
|
| 181 |
+
|
| 182 |
+
By default, the image is converted into RGB format.
|
| 183 |
+
"""
|
| 184 |
+
image_io = ImageMediaIO(image_mode=image_mode)
|
| 185 |
+
|
| 186 |
+
return self.load_from_url(
|
| 187 |
+
image_url,
|
| 188 |
+
image_io,
|
| 189 |
+
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
async def fetch_image_async(
|
| 193 |
+
self,
|
| 194 |
+
image_url: str,
|
| 195 |
+
*,
|
| 196 |
+
image_mode: str = "RGB",
|
| 197 |
+
) -> Image.Image:
|
| 198 |
+
"""
|
| 199 |
+
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
| 200 |
+
|
| 201 |
+
By default, the image is converted into RGB format.
|
| 202 |
+
"""
|
| 203 |
+
image_io = ImageMediaIO(image_mode=image_mode)
|
| 204 |
+
|
| 205 |
+
return await self.load_from_url_async(
|
| 206 |
+
image_url,
|
| 207 |
+
image_io,
|
| 208 |
+
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def fetch_video(
|
| 212 |
+
self,
|
| 213 |
+
video_url: str,
|
| 214 |
+
*,
|
| 215 |
+
image_mode: str = "RGB",
|
| 216 |
+
num_frames: int = 32,
|
| 217 |
+
) -> npt.NDArray:
|
| 218 |
+
"""
|
| 219 |
+
Load video from a HTTP or base64 data URL.
|
| 220 |
+
"""
|
| 221 |
+
image_io = ImageMediaIO(image_mode=image_mode)
|
| 222 |
+
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
| 223 |
+
|
| 224 |
+
return self.load_from_url(
|
| 225 |
+
video_url,
|
| 226 |
+
video_io,
|
| 227 |
+
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
async def fetch_video_async(
|
| 231 |
+
self,
|
| 232 |
+
video_url: str,
|
| 233 |
+
*,
|
| 234 |
+
image_mode: str = "RGB",
|
| 235 |
+
num_frames: int = 32,
|
| 236 |
+
) -> npt.NDArray:
|
| 237 |
+
"""
|
| 238 |
+
Asynchronously load video from a HTTP or base64 data URL.
|
| 239 |
+
|
| 240 |
+
By default, the image is converted into RGB format.
|
| 241 |
+
"""
|
| 242 |
+
image_io = ImageMediaIO(image_mode=image_mode)
|
| 243 |
+
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
| 244 |
+
|
| 245 |
+
return await self.load_from_url_async(
|
| 246 |
+
video_url,
|
| 247 |
+
video_io,
|
| 248 |
+
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
global_media_connector = MediaConnector()
|
| 253 |
+
"""The global :class:`MediaConnector` instance used by vLLM."""
|
| 254 |
+
|
| 255 |
+
fetch_audio = global_media_connector.fetch_audio
|
| 256 |
+
fetch_image = global_media_connector.fetch_image
|
| 257 |
+
fetch_video = global_media_connector.fetch_video
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def encode_audio_base64(
|
| 261 |
+
audio: np.ndarray,
|
| 262 |
+
sampling_rate: int,
|
| 263 |
+
) -> str:
|
| 264 |
+
"""Encode audio as base64."""
|
| 265 |
+
audio_io = AudioMediaIO()
|
| 266 |
+
return audio_io.encode_base64((audio, sampling_rate))
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def encode_image_base64(
|
| 270 |
+
image: Image.Image,
|
| 271 |
+
*,
|
| 272 |
+
image_mode: str = "RGB",
|
| 273 |
+
format: str = "JPEG",
|
| 274 |
+
) -> str:
|
| 275 |
+
"""
|
| 276 |
+
Encode a pillow image to base64 format.
|
| 277 |
+
|
| 278 |
+
By default, the image is converted into RGB format before being encoded.
|
| 279 |
+
"""
|
| 280 |
+
image_io = ImageMediaIO(image_mode=image_mode)
|
| 281 |
+
return image_io.encode_base64(image, image_format=format)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def encode_video_base64(frames: npt.NDArray) -> str:
|
| 285 |
+
image_io = ImageMediaIO()
|
| 286 |
+
video_io = VideoMediaIO(image_io)
|
| 287 |
+
return video_io.encode_base64(frames)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Utilities for input processors
|
| 291 |
+
_T = TypeVar("_T", str, int)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def repeat_and_pad_token(
|
| 295 |
+
token: _T,
|
| 296 |
+
*,
|
| 297 |
+
repeat_count: int = 1,
|
| 298 |
+
pad_token_left: Optional[_T] = None,
|
| 299 |
+
pad_token_right: Optional[_T] = None,
|
| 300 |
+
) -> list[_T]:
|
| 301 |
+
replacement = [token] * repeat_count
|
| 302 |
+
if pad_token_left is not None:
|
| 303 |
+
replacement = [pad_token_left] + replacement
|
| 304 |
+
if pad_token_right is not None:
|
| 305 |
+
replacement = replacement + [pad_token_right]
|
| 306 |
+
|
| 307 |
+
return replacement
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def repeat_and_pad_placeholder_tokens(
|
| 311 |
+
tokenizer: AnyTokenizer,
|
| 312 |
+
prompt: Optional[str],
|
| 313 |
+
prompt_token_ids: list[int],
|
| 314 |
+
*,
|
| 315 |
+
placeholder_token_id: int,
|
| 316 |
+
repeat_count: Union[int, list[int]],
|
| 317 |
+
pad_token_left: Optional[int] = None,
|
| 318 |
+
pad_token_right: Optional[int] = None,
|
| 319 |
+
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
|
| 320 |
+
if isinstance(repeat_count, int):
|
| 321 |
+
repeat_count = [repeat_count]
|
| 322 |
+
|
| 323 |
+
if prompt is None:
|
| 324 |
+
new_prompt = None
|
| 325 |
+
else:
|
| 326 |
+
placeholder_token_str = tokenizer.decode(placeholder_token_id)
|
| 327 |
+
pad_token_str_left = (None if pad_token_left is None else
|
| 328 |
+
tokenizer.decode(pad_token_left))
|
| 329 |
+
pad_token_str_right = (None if pad_token_right is None else
|
| 330 |
+
tokenizer.decode(pad_token_right))
|
| 331 |
+
|
| 332 |
+
placeholder_token_count = prompt.count(placeholder_token_str)
|
| 333 |
+
# This is an arbitrary number to distinguish between the two cases
|
| 334 |
+
if placeholder_token_count > 16:
|
| 335 |
+
logger.warning(
|
| 336 |
+
"Please follow the prompt format that is "
|
| 337 |
+
"documented on HuggingFace which does not involve "
|
| 338 |
+
"repeating %s tokens.", placeholder_token_str)
|
| 339 |
+
if placeholder_token_count < len(repeat_count):
|
| 340 |
+
logger.warning(
|
| 341 |
+
"The number of multi-modal placeholder tokens in the prompt "
|
| 342 |
+
"is less than the number of multi-modal inputs. Extra "
|
| 343 |
+
"placeholder tokens will be treated as plain text")
|
| 344 |
+
repeat_count = repeat_count[:placeholder_token_count]
|
| 345 |
+
|
| 346 |
+
prompt_parts = prompt.split(placeholder_token_str,
|
| 347 |
+
maxsplit=len(repeat_count))
|
| 348 |
+
new_prompt = ""
|
| 349 |
+
for i, repeat_count_item in enumerate(repeat_count):
|
| 350 |
+
replacement_str = "".join(
|
| 351 |
+
repeat_and_pad_token(
|
| 352 |
+
placeholder_token_str,
|
| 353 |
+
repeat_count=repeat_count_item,
|
| 354 |
+
pad_token_left=pad_token_str_left,
|
| 355 |
+
pad_token_right=pad_token_str_right,
|
| 356 |
+
))
|
| 357 |
+
# The image tokens are removed to be consistent with HuggingFace
|
| 358 |
+
new_prompt += prompt_parts[i] + replacement_str
|
| 359 |
+
new_prompt += prompt_parts[-1]
|
| 360 |
+
|
| 361 |
+
new_token_ids = list[int]()
|
| 362 |
+
placeholder_ranges = list[PlaceholderRange]()
|
| 363 |
+
placeholder_token_idx = 0
|
| 364 |
+
for i, token in enumerate(prompt_token_ids):
|
| 365 |
+
if token == placeholder_token_id:
|
| 366 |
+
curr_repeat_count = repeat_count[placeholder_token_idx]
|
| 367 |
+
replacement_ids = repeat_and_pad_token(
|
| 368 |
+
placeholder_token_id,
|
| 369 |
+
repeat_count=curr_repeat_count,
|
| 370 |
+
pad_token_left=pad_token_left,
|
| 371 |
+
pad_token_right=pad_token_right,
|
| 372 |
+
)
|
| 373 |
+
offset = len(new_token_ids)
|
| 374 |
+
if pad_token_left is not None:
|
| 375 |
+
offset += 1
|
| 376 |
+
placeholder_ranges.append({
|
| 377 |
+
"offset": offset,
|
| 378 |
+
"length": curr_repeat_count,
|
| 379 |
+
})
|
| 380 |
+
new_token_ids.extend(replacement_ids)
|
| 381 |
+
placeholder_token_idx += 1
|
| 382 |
+
|
| 383 |
+
# No need to further scan the list since we replaced all tokens
|
| 384 |
+
if placeholder_token_idx >= len(repeat_count):
|
| 385 |
+
new_token_ids.extend(prompt_token_ids[i + 1:])
|
| 386 |
+
break
|
| 387 |
+
else:
|
| 388 |
+
new_token_ids.append(token)
|
| 389 |
+
|
| 390 |
+
return new_prompt, new_token_ids, placeholder_ranges
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def consecutive_placeholder_ranges(
|
| 394 |
+
num_items: int,
|
| 395 |
+
item_size: int,
|
| 396 |
+
initial_offset: int = 0) -> list[PlaceholderRange]:
|
| 397 |
+
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
|
| 398 |
+
|
| 399 |
+
return [
|
| 400 |
+
PlaceholderRange(offset=initial_offset + i * item_size,
|
| 401 |
+
length=item_size) for i in range(num_items)
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def merge_and_sort_multimodal_metadata(
|
| 406 |
+
mm_positions: "MultiModalPlaceholderDict",
|
| 407 |
+
mm_hashes: Optional["MultiModalHashDict"],
|
| 408 |
+
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
|
| 409 |
+
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
|
| 410 |
+
objects from all available modalities into a single list of
|
| 411 |
+
PlaceholderRange, sorted by their offset (starting index in the input
|
| 412 |
+
sequence) in the ascending order.
|
| 413 |
+
|
| 414 |
+
Optionally if a MultiModalHashDict is given, same operation will be
|
| 415 |
+
applied to the object and the sorted list of hashes will be returned.
|
| 416 |
+
|
| 417 |
+
Raises:
|
| 418 |
+
ValueError: If the input prompt has interleaved placeholders from
|
| 419 |
+
different modalities (e.g, "<image><audio><image> Describe the
|
| 420 |
+
content.")
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
list[str]: Sorted list of involved modalities.
|
| 424 |
+
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
|
| 425 |
+
mm_positions.
|
| 426 |
+
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
|
| 427 |
+
given, None otherwise.
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
modalities = list(mm_positions.keys())
|
| 431 |
+
|
| 432 |
+
assert len(modalities) > 0, "No modalities found in the mm_positions."
|
| 433 |
+
|
| 434 |
+
# For single modality, placeholder ranges and hashes are already sorted
|
| 435 |
+
# so we can return the list directly.
|
| 436 |
+
if len(modalities) == 1:
|
| 437 |
+
if mm_hashes is None:
|
| 438 |
+
return modalities, list(mm_positions[modalities[0]]), None
|
| 439 |
+
else:
|
| 440 |
+
return modalities, list(mm_positions[modalities[0]]), list(
|
| 441 |
+
mm_hashes[modalities[0]])
|
| 442 |
+
|
| 443 |
+
placeholder_lists_with_modality = [(modality, mm_positions[modality])
|
| 444 |
+
for modality in modalities]
|
| 445 |
+
|
| 446 |
+
if mm_hashes is None:
|
| 447 |
+
sorted_placeholder_lists = sorted(placeholder_lists_with_modality,
|
| 448 |
+
key=lambda x: x[1][0]['offset'])
|
| 449 |
+
sorted_hash_lists = None
|
| 450 |
+
else:
|
| 451 |
+
hashes_lists = [
|
| 452 |
+
mm_hashes[modality] for modality in modalities
|
| 453 |
+
if modality in mm_hashes
|
| 454 |
+
]
|
| 455 |
+
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
|
| 456 |
+
hashes_lists),
|
| 457 |
+
key=lambda x: x[0][1][0]['offset'])
|
| 458 |
+
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
|
| 459 |
+
sorted_placeholder_lists = list(sorted_placeholder_tuple)
|
| 460 |
+
sorted_hash_lists = list(sorted_hash_tuple)
|
| 461 |
+
|
| 462 |
+
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists]
|
| 463 |
+
|
| 464 |
+
# Flatten sorted list of lists to a single list and verify there is no
|
| 465 |
+
# interleaving of placeholders from different modalities.
|
| 466 |
+
merged_placeholders: list[PlaceholderRange] = []
|
| 467 |
+
for modality, placeholder_list in sorted_placeholder_lists:
|
| 468 |
+
if merged_placeholders and placeholder_list[0][
|
| 469 |
+
'offset'] < merged_placeholders[-1]['offset']:
|
| 470 |
+
raise ValueError(
|
| 471 |
+
"Interleaved mixed-modality inference is currently not "
|
| 472 |
+
"supported.")
|
| 473 |
+
merged_placeholders.extend(placeholder_list)
|
| 474 |
+
|
| 475 |
+
if sorted_hash_lists is not None:
|
| 476 |
+
merged_hashes = []
|
| 477 |
+
for hash_list in sorted_hash_lists:
|
| 478 |
+
merged_hashes.extend(hash_list)
|
| 479 |
+
else:
|
| 480 |
+
merged_hashes = None
|
| 481 |
+
|
| 482 |
+
return sorted_modalities, merged_placeholders, merged_hashes
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def group_mm_inputs_by_modality(
|
| 486 |
+
mm_inputs: list["MultiModalKwargs"]) -> list[list["MultiModalKwargs"]]:
|
| 487 |
+
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
|
| 488 |
+
together into the same list for batching purpose. For MultiModalKwargs with
|
| 489 |
+
multiple modalities, put them into their own list.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
mm_inputs: List of MultiModalKwargs.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
|
| 496 |
+
inner list contains consecutive MultiModalKwargs with same modality, or
|
| 497 |
+
one with multimodal modalities.
|
| 498 |
+
"""
|
| 499 |
+
if not mm_inputs:
|
| 500 |
+
return []
|
| 501 |
+
|
| 502 |
+
def modality_group_func(mm_input: "MultiModalKwargs") -> Union[str, int]:
|
| 503 |
+
# If the input has multiple modalities, return a id as the unique key
|
| 504 |
+
# for the mm_input input.
|
| 505 |
+
if len(mm_input.modalities) > 1:
|
| 506 |
+
return id(mm_input)
|
| 507 |
+
|
| 508 |
+
elif len(mm_input.modalities) == 1:
|
| 509 |
+
return list(mm_input.modalities)[0]
|
| 510 |
+
|
| 511 |
+
# FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
|
| 512 |
+
# this is used to make InternVL with legacy pipeline still work with v1.
|
| 513 |
+
else:
|
| 514 |
+
return ""
|
| 515 |
+
|
| 516 |
+
return [
|
| 517 |
+
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
| 518 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/multimodal/video.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
from functools import lru_cache, partial
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import numpy.typing as npt
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from vllm.inputs.registry import InputContext
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.transformers_utils.processor import get_video_processor
|
| 16 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
| 17 |
+
from vllm.utils import PlaceholderModule, is_list_of
|
| 18 |
+
|
| 19 |
+
from .base import MediaIO, ModalityData
|
| 20 |
+
from .image import ImageMediaIO, ImagePlugin
|
| 21 |
+
from .inputs import MultiModalKwargs, VideoItem
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from vllm.config import ModelConfig
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import decord
|
| 28 |
+
except ImportError:
|
| 29 |
+
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
| 30 |
+
|
| 31 |
+
logger = init_logger(__name__)
|
| 32 |
+
|
| 33 |
+
cached_get_video_processor = lru_cache(get_video_processor)
|
| 34 |
+
cached_get_tokenizer = lru_cache(get_tokenizer)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class VideoPlugin(ImagePlugin):
|
| 38 |
+
"""Plugin for video data."""
|
| 39 |
+
|
| 40 |
+
def get_data_key(self) -> str:
|
| 41 |
+
return "video"
|
| 42 |
+
|
| 43 |
+
def _get_hf_video_processor(
|
| 44 |
+
self,
|
| 45 |
+
model_config: "ModelConfig",
|
| 46 |
+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
| 47 |
+
):
|
| 48 |
+
if mm_processor_kwargs is None:
|
| 49 |
+
mm_processor_kwargs = {}
|
| 50 |
+
return cached_get_video_processor(
|
| 51 |
+
model_config.model,
|
| 52 |
+
trust_remote_code=model_config.trust_remote_code,
|
| 53 |
+
**mm_processor_kwargs)
|
| 54 |
+
|
| 55 |
+
def _default_input_mapper(
|
| 56 |
+
self,
|
| 57 |
+
ctx: InputContext,
|
| 58 |
+
data: ModalityData[VideoItem],
|
| 59 |
+
**mm_processor_kwargs,
|
| 60 |
+
) -> MultiModalKwargs:
|
| 61 |
+
model_config = ctx.model_config
|
| 62 |
+
|
| 63 |
+
if isinstance(data, list) and len(data) == 1:
|
| 64 |
+
data = data[0] # type: ignore
|
| 65 |
+
|
| 66 |
+
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
|
| 67 |
+
video_processor = self._get_hf_video_processor(
|
| 68 |
+
model_config,
|
| 69 |
+
mm_processor_kwargs,
|
| 70 |
+
)
|
| 71 |
+
if video_processor is None:
|
| 72 |
+
raise RuntimeError("No HuggingFace processor is available "
|
| 73 |
+
"to process the video object")
|
| 74 |
+
try:
|
| 75 |
+
# NOTE: Similar to image; it may be a good idea to filter and
|
| 76 |
+
# pass mm_processor_kwargs here too, but for now we don't to
|
| 77 |
+
# avoid extra complexity if the initializer and preprocess
|
| 78 |
+
# signatures of the processor don't align
|
| 79 |
+
batch_data = video_processor(data, return_tensors="pt").data
|
| 80 |
+
except Exception:
|
| 81 |
+
logger.error("Failed to process video (%s)", data)
|
| 82 |
+
raise
|
| 83 |
+
|
| 84 |
+
return MultiModalKwargs(batch_data)
|
| 85 |
+
|
| 86 |
+
raise TypeError(f"Invalid video type: {type(data)}")
|
| 87 |
+
|
| 88 |
+
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
| 89 |
+
return 4096
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
| 93 |
+
num_frames, _, _, channels = frames.shape
|
| 94 |
+
new_height, new_width = size
|
| 95 |
+
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
| 96 |
+
dtype=frames.dtype)
|
| 97 |
+
# lazy import cv2 to avoid bothering users who only use text models
|
| 98 |
+
import cv2
|
| 99 |
+
for i, frame in enumerate(frames):
|
| 100 |
+
resized_frame = cv2.resize(frame, (new_width, new_height))
|
| 101 |
+
resized_frames[i] = resized_frame
|
| 102 |
+
return resized_frames
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
| 106 |
+
_, height, width, _ = frames.shape
|
| 107 |
+
new_height = int(height * size_factor)
|
| 108 |
+
new_width = int(width * size_factor)
|
| 109 |
+
|
| 110 |
+
return resize_video(frames, (new_height, new_width))
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def sample_frames_from_video(frames: npt.NDArray,
|
| 114 |
+
num_frames: int) -> npt.NDArray:
|
| 115 |
+
total_frames = frames.shape[0]
|
| 116 |
+
if num_frames == -1:
|
| 117 |
+
return frames
|
| 118 |
+
|
| 119 |
+
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
| 120 |
+
sampled_frames = frames[frame_indices, ...]
|
| 121 |
+
return sampled_frames
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class VideoMediaIO(MediaIO[npt.NDArray]):
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
image_io: ImageMediaIO,
|
| 129 |
+
*,
|
| 130 |
+
num_frames: int = 32,
|
| 131 |
+
) -> None:
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.image_io = image_io
|
| 135 |
+
self.num_frames = num_frames
|
| 136 |
+
|
| 137 |
+
def load_bytes(self, data: bytes) -> npt.NDArray:
|
| 138 |
+
vr = decord.VideoReader(BytesIO(data), num_threads=1)
|
| 139 |
+
total_frame_num = len(vr)
|
| 140 |
+
|
| 141 |
+
num_frames = self.num_frames
|
| 142 |
+
if total_frame_num > num_frames:
|
| 143 |
+
uniform_sampled_frames = np.linspace(0,
|
| 144 |
+
total_frame_num - 1,
|
| 145 |
+
num_frames,
|
| 146 |
+
dtype=int)
|
| 147 |
+
frame_idx = uniform_sampled_frames.tolist()
|
| 148 |
+
else:
|
| 149 |
+
frame_idx = list(range(0, total_frame_num))
|
| 150 |
+
|
| 151 |
+
return vr.get_batch(frame_idx).asnumpy()
|
| 152 |
+
|
| 153 |
+
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
| 154 |
+
if media_type.lower() == "video/jpeg":
|
| 155 |
+
load_frame = partial(
|
| 156 |
+
self.image_io.load_base64,
|
| 157 |
+
"image/jpeg",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return np.stack([
|
| 161 |
+
np.array(load_frame(frame_data))
|
| 162 |
+
for frame_data in data.split(",")
|
| 163 |
+
])
|
| 164 |
+
|
| 165 |
+
return self.load_bytes(base64.b64decode(data))
|
| 166 |
+
|
| 167 |
+
def load_file(self, filepath: Path) -> npt.NDArray:
|
| 168 |
+
with filepath.open("rb") as f:
|
| 169 |
+
data = f.read()
|
| 170 |
+
|
| 171 |
+
return self.load_bytes(data)
|
| 172 |
+
|
| 173 |
+
def encode_base64(
|
| 174 |
+
self,
|
| 175 |
+
media: npt.NDArray,
|
| 176 |
+
*,
|
| 177 |
+
video_format: str = "JPEG",
|
| 178 |
+
) -> str:
|
| 179 |
+
video = media
|
| 180 |
+
|
| 181 |
+
if video_format == "JPEG":
|
| 182 |
+
encode_frame = partial(
|
| 183 |
+
self.image_io.encode_base64,
|
| 184 |
+
image_format=video_format,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return ",".join(
|
| 188 |
+
encode_frame(Image.fromarray(frame)) for frame in video)
|
| 189 |
+
|
| 190 |
+
msg = "Only JPEG format is supported for now."
|
| 191 |
+
raise NotImplementedError(msg)
|
.venv/lib/python3.11/site-packages/vllm/triton_utils/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from vllm.triton_utils.importing import HAS_TRITON
|
| 4 |
+
|
| 5 |
+
__all__ = ["HAS_TRITON"]
|
| 6 |
+
|
| 7 |
+
if HAS_TRITON:
|
| 8 |
+
|
| 9 |
+
from vllm.triton_utils.custom_cache_manager import (
|
| 10 |
+
maybe_set_triton_cache_manager)
|
| 11 |
+
|
| 12 |
+
__all__ += ["maybe_set_triton_cache_manager"]
|
.venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (473 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/custom_cache_manager.cpython-311.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/importing.cpython-311.pyc
ADDED
|
Binary file (779 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cache_engine.cpython-311.pyc
ADDED
|
Binary file (7.42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_enc_dec_model_runner.cpython-311.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_model_runner.cpython-311.pyc
ADDED
|
Binary file (33.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_pooling_model_runner.cpython-311.pyc
ADDED
|
Binary file (6.72 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_worker.cpython-311.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/enc_dec_model_runner.cpython-311.pyc
ADDED
|
Binary file (22.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_worker.cpython-311.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner.cpython-311.pyc
ADDED
|
Binary file (87.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner_base.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_model_runner.cpython-311.pyc
ADDED
|
Binary file (36.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_tpu_worker.cpython-311.pyc
ADDED
|
Binary file (4.61 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_worker.cpython-311.pyc
ADDED
|
Binary file (7.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_model_runner.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_worker.cpython-311.pyc
ADDED
|
Binary file (6.76 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_model_runner.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_worker.cpython-311.pyc
ADDED
|
Binary file (27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/pooling_model_runner.cpython-311.pyc
ADDED
|
Binary file (9.95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_model_runner.cpython-311.pyc
ADDED
|
Binary file (40.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_worker.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker.cpython-311.pyc
ADDED
|
Binary file (29.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker_base.cpython-311.pyc
ADDED
|
Binary file (30.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_model_runner.cpython-311.pyc
ADDED
|
Binary file (26.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_worker.cpython-311.pyc
ADDED
|
Binary file (9.11 kB). View file
|
|
|