koichi12 commited on
Commit
a1956b1
·
verified ·
1 Parent(s): 2bdf65d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/vllm/assets/__init__.py +0 -0
  2. .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/audio.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/base.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/image.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/vllm/assets/__pycache__/video.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/vllm/assets/audio.py +33 -0
  8. .venv/lib/python3.11/site-packages/vllm/assets/base.py +40 -0
  9. .venv/lib/python3.11/site-packages/vllm/assets/image.py +31 -0
  10. .venv/lib/python3.11/site-packages/vllm/assets/video.py +84 -0
  11. .venv/lib/python3.11/site-packages/vllm/multimodal/__init__.py +33 -0
  12. .venv/lib/python3.11/site-packages/vllm/multimodal/hasher.py +102 -0
  13. .venv/lib/python3.11/site-packages/vllm/multimodal/image.py +139 -0
  14. .venv/lib/python3.11/site-packages/vllm/multimodal/inputs.py +741 -0
  15. .venv/lib/python3.11/site-packages/vllm/multimodal/parse.py +368 -0
  16. .venv/lib/python3.11/site-packages/vllm/multimodal/processing.py +1295 -0
  17. .venv/lib/python3.11/site-packages/vllm/multimodal/profiling.py +209 -0
  18. .venv/lib/python3.11/site-packages/vllm/multimodal/registry.py +458 -0
  19. .venv/lib/python3.11/site-packages/vllm/multimodal/utils.py +518 -0
  20. .venv/lib/python3.11/site-packages/vllm/multimodal/video.py +191 -0
  21. .venv/lib/python3.11/site-packages/vllm/triton_utils/__init__.py +12 -0
  22. .venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/__init__.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/custom_cache_manager.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/importing.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/vllm/worker/__init__.py +0 -0
  26. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/__init__.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cache_engine.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_enc_dec_model_runner.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_model_runner.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_pooling_model_runner.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_worker.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/enc_dec_model_runner.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_worker.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner_base.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_model_runner.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_tpu_worker.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_worker.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_model_runner.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_worker.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_model_runner.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_worker.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/pooling_model_runner.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_model_runner.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_worker.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/utils.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker_base.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_model_runner.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_worker.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/vllm/assets/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (184 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/audio.cpython-311.pyc ADDED
Binary file (2.03 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/base.cpython-311.pyc ADDED
Binary file (1.87 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/image.cpython-311.pyc ADDED
Binary file (1.87 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/assets/__pycache__/video.cpython-311.pyc ADDED
Binary file (4.79 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/assets/audio.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+ from urllib.parse import urljoin
6
+
7
+ import numpy.typing as npt
8
+
9
+ from vllm.utils import PlaceholderModule
10
+
11
+ from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
12
+
13
+ try:
14
+ import librosa
15
+ except ImportError:
16
+ librosa = PlaceholderModule("librosa") # type: ignore[assignment]
17
+
18
+ ASSET_DIR = "multimodal_asset"
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class AudioAsset:
23
+ name: Literal["winning_call", "mary_had_lamb"]
24
+
25
+ @property
26
+ def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
27
+ audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
28
+ s3_prefix=ASSET_DIR)
29
+ return librosa.load(audio_path, sr=None)
30
+
31
+ @property
32
+ def url(self) -> str:
33
+ return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
.venv/lib/python3.11/site-packages/vllm/assets/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from functools import lru_cache
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import vllm.envs as envs
8
+ from vllm.connections import global_http_connection
9
+
10
+ VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
11
+
12
+
13
+ def get_cache_dir() -> Path:
14
+ """Get the path to the cache for storing downloaded assets."""
15
+ path = Path(envs.VLLM_ASSETS_CACHE)
16
+ path.mkdir(parents=True, exist_ok=True)
17
+
18
+ return path
19
+
20
+
21
+ @lru_cache
22
+ def get_vllm_public_assets(filename: str,
23
+ s3_prefix: Optional[str] = None) -> Path:
24
+ """
25
+ Download an asset file from ``s3://vllm-public-assets``
26
+ and return the path to the downloaded file.
27
+ """
28
+ asset_directory = get_cache_dir() / "vllm_public_assets"
29
+ asset_directory.mkdir(parents=True, exist_ok=True)
30
+
31
+ asset_path = asset_directory / filename
32
+ if not asset_path.exists():
33
+ if s3_prefix is not None:
34
+ filename = s3_prefix + "/" + filename
35
+ global_http_connection.download_file(
36
+ f"{VLLM_S3_BUCKET_URL}/{filename}",
37
+ asset_path,
38
+ timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
39
+
40
+ return asset_path
.venv/lib/python3.11/site-packages/vllm/assets/image.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from .base import get_vllm_public_assets
10
+
11
+ VLM_IMAGES_DIR = "vision_model_images"
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ImageAsset:
16
+ name: Literal["stop_sign", "cherry_blossom"]
17
+
18
+ @property
19
+ def pil_image(self) -> Image.Image:
20
+ image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
21
+ s3_prefix=VLM_IMAGES_DIR)
22
+ return Image.open(image_path)
23
+
24
+ @property
25
+ def image_embeds(self) -> torch.Tensor:
26
+ """
27
+ Image embeddings, only used for testing purposes with llava 1.5.
28
+ """
29
+ image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
30
+ s3_prefix=VLM_IMAGES_DIR)
31
+ return torch.load(image_path, map_location="cpu", weights_only=True)
.venv/lib/python3.11/site-packages/vllm/assets/video.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+ from typing import List, Literal
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ from huggingface_hub import hf_hub_download
11
+ from PIL import Image
12
+
13
+ from vllm.multimodal.video import sample_frames_from_video
14
+
15
+ from .base import get_cache_dir
16
+
17
+
18
+ @lru_cache
19
+ def download_video_asset(filename: str) -> str:
20
+ """
21
+ Download and open an image from huggingface
22
+ repo: raushan-testing-hf/videos-test
23
+ """
24
+ video_directory = get_cache_dir() / "video-example-data"
25
+ video_directory.mkdir(parents=True, exist_ok=True)
26
+
27
+ video_path = video_directory / filename
28
+ video_path_str = str(video_path)
29
+ if not video_path.exists():
30
+ video_path_str = hf_hub_download(
31
+ repo_id="raushan-testing-hf/videos-test",
32
+ filename=filename,
33
+ repo_type="dataset",
34
+ cache_dir=video_directory,
35
+ )
36
+ return video_path_str
37
+
38
+
39
+ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
40
+ cap = cv2.VideoCapture(path)
41
+ if not cap.isOpened():
42
+ raise ValueError(f"Could not open video file {path}")
43
+
44
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
+ frames = []
46
+ for i in range(total_frames):
47
+ ret, frame = cap.read()
48
+ if ret:
49
+ frames.append(frame)
50
+ cap.release()
51
+
52
+ frames = np.stack(frames)
53
+ frames = sample_frames_from_video(frames, num_frames)
54
+ if len(frames) < num_frames:
55
+ raise ValueError(f"Could not read enough frames from video file {path}"
56
+ f" (expected {num_frames} frames, got {len(frames)})")
57
+ return frames
58
+
59
+
60
+ def video_to_pil_images_list(path: str,
61
+ num_frames: int = -1) -> List[Image.Image]:
62
+ frames = video_to_ndarrays(path, num_frames)
63
+ return [
64
+ Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
65
+ for frame in frames
66
+ ]
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class VideoAsset:
71
+ name: Literal["sample_demo_1.mp4"]
72
+ num_frames: int = -1
73
+
74
+ @property
75
+ def pil_images(self) -> List[Image.Image]:
76
+ video_path = download_video_asset(self.name)
77
+ ret = video_to_pil_images_list(video_path, self.num_frames)
78
+ return ret
79
+
80
+ @property
81
+ def np_ndarrays(self) -> npt.NDArray:
82
+ video_path = download_video_asset(self.name)
83
+ ret = video_to_ndarrays(video_path, self.num_frames)
84
+ return ret
.venv/lib/python3.11/site-packages/vllm/multimodal/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from .base import MultiModalPlaceholderMap, MultiModalPlugin
4
+ from .hasher import MultiModalHashDict, MultiModalHasher
5
+ from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
6
+ MultiModalDataDict, MultiModalKwargs,
7
+ MultiModalPlaceholderDict, NestedTensors)
8
+ from .registry import MultiModalRegistry
9
+
10
+ MULTIMODAL_REGISTRY = MultiModalRegistry()
11
+ """
12
+ The global :class:`~MultiModalRegistry` is used by model runners to
13
+ dispatch data processing according to the target model.
14
+
15
+ See also:
16
+ :ref:`mm-processing`
17
+ """
18
+
19
+ __all__ = [
20
+ "BatchedTensorInputs",
21
+ "ModalityData",
22
+ "MultiModalDataBuiltins",
23
+ "MultiModalDataDict",
24
+ "MultiModalHashDict",
25
+ "MultiModalHasher",
26
+ "MultiModalKwargs",
27
+ "MultiModalPlaceholderDict",
28
+ "MultiModalPlaceholderMap",
29
+ "MultiModalPlugin",
30
+ "NestedTensors",
31
+ "MULTIMODAL_REGISTRY",
32
+ "MultiModalRegistry",
33
+ ]
.venv/lib/python3.11/site-packages/vllm/multimodal/hasher.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import pickle
4
+ from typing import TYPE_CHECKING, Iterable, Mapping, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ from blake3 import blake3
9
+ from PIL import Image
10
+
11
+ from vllm.logger import init_logger
12
+
13
+ if TYPE_CHECKING:
14
+ from vllm.inputs import TokensPrompt
15
+
16
+ logger = init_logger(__name__)
17
+
18
+ MultiModalHashDict = Mapping[str, list[str]]
19
+ """
20
+ A dictionary containing hashes for items in each modality.
21
+ """
22
+
23
+
24
+ class MultiModalHasher:
25
+
26
+ @classmethod
27
+ def serialize_item(cls, obj: object) -> bytes:
28
+ # Simple cases
29
+ if isinstance(obj, str):
30
+ return obj.encode("utf-8")
31
+ if isinstance(obj, bytes):
32
+ return obj
33
+ if isinstance(obj, Image.Image):
34
+ return obj.tobytes()
35
+
36
+ # Convertible to NumPy arrays
37
+ if isinstance(obj, torch.Tensor):
38
+ obj = obj.numpy()
39
+ if isinstance(obj, (int, float)):
40
+ obj = np.array(obj)
41
+ if isinstance(obj, np.ndarray):
42
+ return obj.tobytes()
43
+
44
+ logger.warning(
45
+ "No serialization method found for %s. "
46
+ "Falling back to pickle.", type(obj))
47
+
48
+ return pickle.dumps(obj)
49
+
50
+ @classmethod
51
+ def item_to_bytes(
52
+ cls,
53
+ key: str,
54
+ obj: object,
55
+ ) -> Iterable[tuple[bytes, bytes]]:
56
+ # Recursive cases
57
+ if isinstance(obj, (list, tuple)):
58
+ for i, elem in enumerate(obj):
59
+ yield from cls.item_to_bytes(f"{key}.{i}", elem)
60
+ elif isinstance(obj, dict):
61
+ for k, v in obj.items():
62
+ yield from cls.item_to_bytes(f"{key}.{k}", v)
63
+ else:
64
+ key_bytes = cls.serialize_item(key)
65
+ value_bytes = cls.serialize_item(obj)
66
+ yield key_bytes, value_bytes
67
+
68
+ @classmethod
69
+ def hash_kwargs(cls, **kwargs: object) -> str:
70
+ hasher = blake3()
71
+
72
+ for k, v in kwargs.items():
73
+ for k_bytes, v_bytes in cls.item_to_bytes(k, v):
74
+ hasher.update(k_bytes)
75
+ hasher.update(v_bytes)
76
+
77
+ return hasher.hexdigest()
78
+
79
+ @classmethod
80
+ def hash_prompt_mm_data(
81
+ cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
82
+ """Hash multimodal data in the user input prompt if they exist."""
83
+
84
+ if "multi_modal_data" not in prompt:
85
+ return None
86
+
87
+ mm_data = prompt["multi_modal_data"]
88
+ if not mm_data:
89
+ # mm_data can be None or an empty dict.
90
+ return None
91
+
92
+ mm_items = {
93
+ modality: items if isinstance(items, list) else [items]
94
+ for modality, items in mm_data.items()
95
+ }
96
+
97
+ mm_hashes = {
98
+ modality: [cls.hash_kwargs(**{modality: item}) for item in items]
99
+ for modality, items in mm_items.items()
100
+ }
101
+
102
+ return mm_hashes
.venv/lib/python3.11/site-packages/vllm/multimodal/image.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import base64
4
+ from functools import lru_cache
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Dict, Optional
8
+
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from vllm.inputs.registry import InputContext
13
+ from vllm.logger import init_logger
14
+ from vllm.transformers_utils.processor import get_image_processor
15
+ from vllm.utils import is_list_of
16
+
17
+ from .base import MediaIO, MultiModalPlugin
18
+ from .inputs import ImageItem, ModalityData, MultiModalKwargs
19
+
20
+ if TYPE_CHECKING:
21
+ from vllm.config import ModelConfig
22
+
23
+ logger = init_logger(__name__)
24
+
25
+ cached_get_image_processor = lru_cache(get_image_processor)
26
+
27
+
28
+ class ImagePlugin(MultiModalPlugin):
29
+ """Plugin for image data."""
30
+
31
+ def get_data_key(self) -> str:
32
+ return "image"
33
+
34
+ def _get_hf_image_processor(
35
+ self,
36
+ model_config: "ModelConfig",
37
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
38
+ ):
39
+ if mm_processor_kwargs is None:
40
+ mm_processor_kwargs = {}
41
+ return cached_get_image_processor(
42
+ model_config.model,
43
+ trust_remote_code=model_config.trust_remote_code,
44
+ **mm_processor_kwargs)
45
+
46
+ def _default_input_mapper(
47
+ self,
48
+ ctx: InputContext,
49
+ data: ModalityData[ImageItem],
50
+ **mm_processor_kwargs,
51
+ ) -> MultiModalKwargs:
52
+ model_config = ctx.model_config
53
+
54
+ # PIL image
55
+ if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
56
+ image_processor = self._get_hf_image_processor(
57
+ model_config,
58
+ mm_processor_kwargs,
59
+ )
60
+
61
+ if image_processor is None:
62
+ raise RuntimeError("No HuggingFace processor is available "
63
+ "to process the image object")
64
+ try:
65
+ # NOTE: It may make sense to forward the mm_processor_kwargs
66
+ # here too. For now, to keep it simple, we only allow it be
67
+ # used for the initialization call though, just in case the
68
+ # signatures of the preprocessor initializer don't match
69
+ # preprocess()
70
+ batch_data = image_processor \
71
+ .preprocess(data, return_tensors="pt") \
72
+ .data
73
+ except Exception:
74
+ logger.error(
75
+ "Failed to process image (%s) with the default mapper. "
76
+ "This is most likely an edge-case with this model's image "
77
+ "processor in transformers (type: %s), and not vLLM.",
78
+ data,
79
+ type(image_processor).__name__)
80
+ raise
81
+
82
+ return MultiModalKwargs(batch_data)
83
+
84
+ # Image embedding
85
+ elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
86
+ return MultiModalKwargs({"image_embeds": data})
87
+
88
+ raise TypeError(f"Invalid image type: {type(data)}")
89
+
90
+ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
91
+ return 3000
92
+
93
+
94
+ def rescale_image_size(image: Image.Image,
95
+ size_factor: float,
96
+ transpose: int = -1) -> Image.Image:
97
+ """Rescale the dimensions of an image by a constant factor."""
98
+ new_width = int(image.width * size_factor)
99
+ new_height = int(image.height * size_factor)
100
+ image = image.resize((new_width, new_height))
101
+ if transpose >= 0:
102
+ image = image.transpose(Image.Transpose(transpose))
103
+ return image
104
+
105
+
106
+ class ImageMediaIO(MediaIO[Image.Image]):
107
+
108
+ def __init__(self, *, image_mode: str = "RGB") -> None:
109
+ super().__init__()
110
+
111
+ self.image_mode = image_mode
112
+
113
+ def load_bytes(self, data: bytes) -> Image.Image:
114
+ image = Image.open(BytesIO(data))
115
+ image.load()
116
+ return image.convert(self.image_mode)
117
+
118
+ def load_base64(self, media_type: str, data: str) -> Image.Image:
119
+ return self.load_bytes(base64.b64decode(data))
120
+
121
+ def load_file(self, filepath: Path) -> Image.Image:
122
+ image = Image.open(filepath)
123
+ image.load()
124
+ return image.convert(self.image_mode)
125
+
126
+ def encode_base64(
127
+ self,
128
+ media: Image.Image,
129
+ *,
130
+ image_format: str = "JPEG",
131
+ ) -> str:
132
+ image = media
133
+
134
+ with BytesIO() as buffer:
135
+ image = image.convert(self.image_mode)
136
+ image.save(buffer, image_format)
137
+ data = buffer.getvalue()
138
+
139
+ return base64.b64encode(data).decode('utf-8')
.venv/lib/python3.11/site-packages/vllm/multimodal/inputs.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections import UserDict, defaultdict
5
+ from collections.abc import Mapping, Sequence
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ from itertools import accumulate
9
+ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
10
+ Union, cast, final)
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.types
15
+ from PIL.Image import Image
16
+ from transformers import BatchFeature
17
+ from typing_extensions import NotRequired, TypeAlias
18
+
19
+ from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
20
+
21
+ if TYPE_CHECKING:
22
+ from .hasher import MultiModalHashDict
23
+
24
+ _T = TypeVar("_T")
25
+
26
+ HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
27
+ """
28
+ A :class:`transformers.image_utils.ImageInput` representing a single image
29
+ item, which can be passed to a HuggingFace :code:`ImageProcessor`.
30
+ """
31
+
32
+ HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
33
+ list[np.ndarray], list[torch.Tensor]]
34
+ """
35
+ A :class:`transformers.image_utils.VideoInput` representing a single video
36
+ item, which can be passed to a HuggingFace :code:`VideoProcessor`.
37
+ """
38
+
39
+ HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
40
+ """
41
+ Represents a single audio
42
+ item, which can be passed to a HuggingFace :code:`AudioProcessor`.
43
+ """
44
+
45
+ ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
46
+ """
47
+ A :class:`transformers.image_utils.ImageInput` representing a single image
48
+ item, which can be passed to a HuggingFace :code:`ImageProcessor`.
49
+
50
+ Alternatively, a 3-D tensor or batch of 2-D tensors,
51
+ which are treated as image embeddings;
52
+ these are directly passed to the model without HF processing.
53
+ """
54
+
55
+ VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
56
+ """
57
+ A :class:`transformers.image_utils.VideoInput` representing a single video
58
+ item, which can be passed to a HuggingFace :code:`VideoProcessor`.
59
+
60
+ Alternatively, a 3-D tensor or batch of 2-D tensors,
61
+ which are treated as video embeddings;
62
+ these are directly passed to the model without HF processing.
63
+ """
64
+
65
+ AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
66
+ torch.Tensor]
67
+ """
68
+ Represents a single audio
69
+ item, which can be passed to a HuggingFace :code:`AudioProcessor`.
70
+
71
+ Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
72
+ is different from that expected by the model;
73
+ these are resampled to the model's sampling rate before being processed by HF.
74
+
75
+ Alternatively, a 3-D tensor or batch of 2-D tensors,
76
+ which are treated as audio embeddings;
77
+ these are directly passed to the model without HF processing.
78
+ """
79
+
80
+ ModalityData: TypeAlias = Union[_T, list[_T]]
81
+ """
82
+ Either a single data item, or a list of data items.
83
+
84
+ The number of data items allowed per modality is restricted by
85
+ :code:`--limit-mm-per-prompt`.
86
+ """
87
+
88
+
89
+ @final
90
+ class MultiModalDataBuiltins(TypedDict, total=False):
91
+ """Type annotations for modality types predefined by vLLM."""
92
+
93
+ image: ModalityData[ImageItem]
94
+ """The input image(s)."""
95
+
96
+ video: ModalityData[VideoItem]
97
+ """The input video(s)."""
98
+
99
+ audio: ModalityData[AudioItem]
100
+ """The input audio(s)."""
101
+
102
+
103
+ MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
104
+ """
105
+ A dictionary containing an entry for each modality type to input.
106
+
107
+ The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
108
+ """
109
+
110
+
111
+ class PlaceholderRange(TypedDict):
112
+ """
113
+ Placeholder location information for multi-modal data.
114
+
115
+ Example:
116
+
117
+ Prompt: :code:`AAAA BBBB What is in these images?`
118
+
119
+ Images A and B will have:
120
+
121
+ .. code-block::
122
+
123
+ A: { "offset": 0, "length": 4 }
124
+ B: { "offset": 5, "length": 4 }
125
+ """
126
+
127
+ offset: int
128
+ """The start index of the placeholder in the prompt."""
129
+
130
+ length: int
131
+ """The length of the placeholder."""
132
+
133
+
134
+ NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
135
+ tuple[torch.Tensor, ...]]
136
+ """
137
+ Uses a list instead of a tensor if the dimensions of each element do not match.
138
+ """
139
+
140
+
141
+ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
142
+ """Equality check between :data:`NestedTensors` objects."""
143
+ if isinstance(a, torch.Tensor):
144
+ return isinstance(b, torch.Tensor) and torch.equal(a, b)
145
+ elif isinstance(b, torch.Tensor):
146
+ return isinstance(a, torch.Tensor) and torch.equal(b, a)
147
+
148
+ if isinstance(a, list):
149
+ return (isinstance(b, list)
150
+ and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
151
+ if isinstance(b, list):
152
+ return (isinstance(a, list)
153
+ and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))
154
+
155
+ # Both a and b are scalars
156
+ return a == b
157
+
158
+
159
+ BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
160
+ """
161
+ A dictionary containing nested tensors which have been batched via
162
+ :meth:`MultiModalKwargs.batch`.
163
+ """
164
+
165
+
166
+ @dataclass(frozen=True)
167
+ class MultiModalFieldElem:
168
+ """
169
+ Represents a keyword argument corresponding to a multi-modal item
170
+ in :class:`MultiModalKwargs`.
171
+ """
172
+
173
+ modality: str
174
+ """
175
+ The modality of the corresponding multi-modal item.
176
+ Each multi-modal item can consist of multiple keyword arguments.
177
+ """
178
+
179
+ key: str
180
+ """
181
+ The key of this field in :class:`MultiModalKwargs`,
182
+ i.e. the name of the keyword argument to be passed to the model.
183
+ """
184
+
185
+ data: NestedTensors
186
+ """
187
+ The tensor data of this field in :class:`MultiModalKwargs`,
188
+ i.e. the value of the keyword argument to be passed to the model.
189
+ """
190
+
191
+ field: "BaseMultiModalField"
192
+ """
193
+ Defines how to combine the tensor data of this field with others
194
+ in order to batch multi-modal items together for model inference.
195
+ """
196
+
197
+ def __eq__(self, other: object) -> bool:
198
+ if not isinstance(other, self.__class__):
199
+ return False
200
+
201
+ return ((self.modality, self.key) == (other.modality, other.key)
202
+ and nested_tensors_equal(self.data, other.data)
203
+ and type(self.field) == type(other.field)) # noqa: E721
204
+
205
+
206
+ @dataclass(frozen=True)
207
+ class BaseMultiModalField(ABC):
208
+ """
209
+ Defines how to interpret tensor data belonging to a keyword argument in
210
+ :class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
211
+ """
212
+
213
+ def _field_factory(self, *, modality: str, key: str):
214
+ f = partial(
215
+ MultiModalFieldElem,
216
+ modality=modality,
217
+ key=key,
218
+ field=self,
219
+ )
220
+
221
+ # Allow passing data as positional argument
222
+ def factory(data: NestedTensors) -> MultiModalFieldElem:
223
+ return f(data=data)
224
+
225
+ return factory
226
+
227
+ @abstractmethod
228
+ def build_elems(
229
+ self,
230
+ modality: str,
231
+ key: str,
232
+ data: NestedTensors,
233
+ ) -> Sequence[MultiModalFieldElem]:
234
+ """
235
+ Construct :class:`MultiModalFieldElem` instances to represent
236
+ the provided data.
237
+
238
+ This is the inverse of :meth:`reduce_data`.
239
+ """
240
+ raise NotImplementedError
241
+
242
+ @abstractmethod
243
+ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
244
+ raise NotImplementedError
245
+
246
+ def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
247
+ """
248
+ Merge the data from multiple instances of :class:`MultiModalFieldElem`.
249
+
250
+ This is the inverse of :meth:`build_elems`.
251
+ """
252
+ field_types = [type(item.field) for item in elems]
253
+ if len(set(field_types)) > 1:
254
+ raise ValueError(f"Cannot merge different {field_types=}")
255
+
256
+ return self._reduce_data([item.data for item in elems])
257
+
258
+
259
+ @dataclass(frozen=True)
260
+ class MultiModalBatchedField(BaseMultiModalField):
261
+ """
262
+ See also:
263
+ :func:`MultiModalFieldConfig.batched`
264
+ """
265
+
266
+ def build_elems(
267
+ self,
268
+ modality: str,
269
+ key: str,
270
+ data: NestedTensors,
271
+ ) -> Sequence[MultiModalFieldElem]:
272
+ field_factory = self._field_factory(modality=modality, key=key)
273
+ return [field_factory(item) for item in data]
274
+
275
+ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
276
+ if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
277
+ if len(batch) == 1:
278
+ # An optimization when `batch` contains only one tensor:
279
+ # - produce exactly same result as `torch.stack(batch)`
280
+ # - will achieve zero-copy if the tensor is contiguous
281
+ return batch[0].unsqueeze(0).contiguous()
282
+ first_shape = batch[0].shape
283
+ if all(elem.shape == first_shape for elem in batch):
284
+ return torch.stack(batch)
285
+
286
+ return batch
287
+
288
+
289
+ @dataclass(frozen=True)
290
+ class MultiModalFlatField(BaseMultiModalField):
291
+ """
292
+ See also:
293
+ :func:`MultiModalFieldConfig.flat`
294
+ :func:`MultiModalFieldConfig.flat_from_sizes`
295
+ """
296
+ slices: Sequence[slice]
297
+
298
+ def build_elems(
299
+ self,
300
+ modality: str,
301
+ key: str,
302
+ data: NestedTensors,
303
+ ) -> Sequence[MultiModalFieldElem]:
304
+ field_factory = self._field_factory(modality=modality, key=key)
305
+ return [field_factory(data[s]) for s in self.slices]
306
+
307
+ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
308
+ if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
309
+ if len(batch) == 1:
310
+ # An optimization when `batch` contains only one tensor:
311
+ # - produce exactly same result as `torch.concat(batch)`
312
+ # - will achieve zero-copy if the tensor is contiguous
313
+ return batch[0].contiguous()
314
+ first_shape = batch[0].shape
315
+ if all(elem.shape[1:] == first_shape[1:] for elem in batch):
316
+ return torch.concat(batch)
317
+
318
+ return [e for elem in batch for e in elem]
319
+
320
+
321
+ @dataclass(frozen=True)
322
+ class MultiModalSharedField(BaseMultiModalField):
323
+ """
324
+ See also:
325
+ :func:`MultiModalFieldConfig.shared`
326
+ """
327
+ batch_size: int
328
+
329
+ def build_elems(
330
+ self,
331
+ modality: str,
332
+ key: str,
333
+ data: NestedTensors,
334
+ ) -> Sequence[MultiModalFieldElem]:
335
+ field_factory = self._field_factory(modality=modality, key=key)
336
+ return [field_factory(data)] * self.batch_size
337
+
338
+ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
339
+ return batch[0]
340
+
341
+
342
+ class MultiModalFieldConfig:
343
+
344
+ @staticmethod
345
+ def batched(modality: str):
346
+ """
347
+ Defines a field where an element in the batch is obtained by
348
+ indexing into the first dimension of the underlying data.
349
+
350
+ Args:
351
+ modality: The modality of the multi-modal item that uses this
352
+ keyword argument.
353
+
354
+ Example:
355
+
356
+ .. code-block::
357
+
358
+ Input:
359
+ Data: [[AAAA]
360
+ [BBBB]
361
+ [CCCC]]
362
+
363
+ Output:
364
+ Element 1: [AAAA]
365
+ Element 2: [BBBB]
366
+ Element 3: [CCCC]
367
+ """
368
+ return MultiModalFieldConfig(
369
+ field=MultiModalBatchedField(),
370
+ modality=modality,
371
+ )
372
+
373
+ @staticmethod
374
+ def flat(modality: str, slices: Sequence[slice]):
375
+ """
376
+ Defines a field where an element in the batch is obtained by
377
+ slicing along the first dimension of the underlying data.
378
+
379
+ Args:
380
+ modality: The modality of the multi-modal item that uses this
381
+ keyword argument.
382
+ slices: For each multi-modal item, a slice that is used to extract
383
+ the data corresponding to it.
384
+
385
+ Example:
386
+
387
+ .. code-block::
388
+
389
+ Given:
390
+ slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
391
+
392
+ Input:
393
+ Data: [AAABBBBCC]
394
+
395
+ Output:
396
+ Element 1: [AAA]
397
+ Element 2: [BBBB]
398
+ Element 3: [CC]
399
+ """
400
+ return MultiModalFieldConfig(
401
+ field=MultiModalFlatField(slices=slices),
402
+ modality=modality,
403
+ )
404
+
405
+ @staticmethod
406
+ def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
407
+ """
408
+ Defines a field where an element in the batch is obtained by
409
+ slicing along the first dimension of the underlying data.
410
+
411
+ Args:
412
+ modality: The modality of the multi-modal item that uses this
413
+ keyword argument.
414
+ slices: For each multi-modal item, the size of the slice that
415
+ is used to extract the data corresponding to it.
416
+
417
+ Example:
418
+
419
+ .. code-block::
420
+
421
+ Given:
422
+ size_per_item: [3, 4, 2]
423
+
424
+ Input:
425
+ Data: [AAABBBBCC]
426
+
427
+ Output:
428
+ Element 1: [AAA]
429
+ Element 2: [BBBB]
430
+ Element 3: [CC]
431
+
432
+ See also:
433
+ :func:`MultiModalFieldConfig.flat`
434
+ """
435
+
436
+ slice_idxs = [0, *accumulate(size_per_item)]
437
+ slices = [
438
+ slice(slice_idxs[i], slice_idxs[i + 1])
439
+ for i in range(len(size_per_item))
440
+ ]
441
+
442
+ return MultiModalFieldConfig.flat(modality, slices)
443
+
444
+ @staticmethod
445
+ def shared(modality: str, batch_size: int):
446
+ """
447
+ Defines a field where an element in the batch is obtained by
448
+ taking the entirety of the underlying data.
449
+
450
+ This means that the data is the same for each element in the batch.
451
+
452
+ Args:
453
+ modality: The modality of the multi-modal item that uses this
454
+ keyword argument.
455
+ batch_size: The number of multi-modal items which share this data.
456
+
457
+ Example:
458
+
459
+ .. code-block::
460
+
461
+ Given:
462
+ batch_size: 4
463
+
464
+ Input:
465
+ Data: [XYZ]
466
+
467
+ Output:
468
+ Element 1: [XYZ]
469
+ Element 2: [XYZ]
470
+ Element 3: [XYZ]
471
+ Element 4: [XYZ]
472
+ """
473
+ return MultiModalFieldConfig(
474
+ field=MultiModalSharedField(batch_size),
475
+ modality=modality,
476
+ )
477
+
478
+ def __init__(self, field: BaseMultiModalField, modality: str) -> None:
479
+ super().__init__()
480
+
481
+ self.field = field
482
+ self.modality = modality
483
+
484
+ def build_elems(
485
+ self,
486
+ key: str,
487
+ batch: NestedTensors,
488
+ ) -> Sequence[MultiModalFieldElem]:
489
+ return self.field.build_elems(self.modality, key, batch)
490
+
491
+
492
+ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
493
+ """
494
+ A collection of :class:`MultiModalFieldElem`
495
+ corresponding to a data item in :class:`MultiModalDataItems`.
496
+ """
497
+
498
+ @staticmethod
499
+ def from_elems(elems: Sequence[MultiModalFieldElem]):
500
+ return MultiModalKwargsItem({elem.key: elem for elem in elems})
501
+
502
+ @property
503
+ def modality(self) -> str:
504
+ modalities = {elem.modality for elem in self.data.values()}
505
+ assert len(modalities) == 1, f"Found different modalities={modalities}"
506
+ return next(iter(modalities))
507
+
508
+
509
+ # NOTE: UserDict is for V0 compatibility.
510
+ # V1 should access individual items via `get_item`.
511
+ class MultiModalKwargs(UserDict[str, NestedTensors]):
512
+ """
513
+ A dictionary that represents the keyword arguments to
514
+ :meth:`~torch.nn.Module.forward`.
515
+
516
+ The metadata :code:`items` enables us to obtain the keyword arguments
517
+ corresponding to each data item in :class:`MultiModalDataItems`, via
518
+ :meth:`get_item` and :meth:`get_items`.
519
+ """
520
+
521
+ @staticmethod
522
+ def from_hf_inputs(
523
+ hf_inputs: BatchFeature,
524
+ config_by_key: Mapping[str, MultiModalFieldConfig],
525
+ ):
526
+ # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
527
+ # We assume that those fields are not used in vLLM
528
+ elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
529
+ keys_by_modality = defaultdict[str, set[str]](set)
530
+ for key, config in config_by_key.items():
531
+ batch = hf_inputs.get(key)
532
+ if batch is not None:
533
+ elems = config.build_elems(key, batch)
534
+ if len(elems) > 0:
535
+ elems_by_key[key] = elems
536
+ keys_by_modality[config.modality].add(key)
537
+
538
+ items = list[MultiModalKwargsItem]()
539
+ for modality, keys in keys_by_modality.items():
540
+ elems_in_modality = {k: elems_by_key[k] for k in keys}
541
+ batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
542
+
543
+ if len(set(batch_sizes.values())) > 1:
544
+ raise ValueError(
545
+ f"Cannot merge different batch sizes for {modality=}! "
546
+ f"Found: {batch_sizes=}")
547
+
548
+ batch_size = next(iter(batch_sizes.values()))
549
+ for item_idx in range(batch_size):
550
+ elems = [v[item_idx] for v in elems_in_modality.values()]
551
+ items.append(MultiModalKwargsItem.from_elems(elems))
552
+
553
+ return MultiModalKwargs.from_items(items)
554
+
555
+ @staticmethod
556
+ def from_items(items: Sequence[MultiModalKwargsItem]):
557
+ """Construct a new :class:`MultiModalKwargs` from multiple items."""
558
+ elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
559
+ for item in items:
560
+ for key, elem in item.items():
561
+ elems_by_key[key].append(elem)
562
+
563
+ data = {
564
+ key: elems[0].field.reduce_data(elems)
565
+ for key, elems in elems_by_key.items() if len(elems) > 0
566
+ }
567
+
568
+ return MultiModalKwargs(data, items=items)
569
+
570
+ def __init__(
571
+ self,
572
+ data: Mapping[str, NestedTensors],
573
+ *,
574
+ items: Optional[Sequence[MultiModalKwargsItem]] = None,
575
+ ) -> None:
576
+ super().__init__(data)
577
+
578
+ items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
579
+ self._items_by_modality = dict(items_by_modality)
580
+
581
+ @property
582
+ def modalities(self):
583
+ return self._items_by_modality.keys()
584
+
585
+ @staticmethod
586
+ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
587
+ """
588
+ Stack the inner dimensions that have the same shape in
589
+ a nested list of tensors.
590
+
591
+ Thus, a dimension represented by a list means that the inner
592
+ dimensions are different for each element along that dimension.
593
+ """
594
+ if isinstance(nested_tensors, torch.Tensor):
595
+ return nested_tensors
596
+
597
+ # TODO: Remove these once all models have been migrated
598
+ if isinstance(nested_tensors, np.ndarray):
599
+ return torch.from_numpy(nested_tensors)
600
+ if isinstance(nested_tensors, (int, float)):
601
+ return torch.tensor(nested_tensors)
602
+
603
+ stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
604
+ if not is_list_of(stacked, torch.Tensor, check="all"):
605
+ # Only tensors (not lists) can be stacked.
606
+ return stacked
607
+
608
+ tensors_ = cast(list[torch.Tensor], stacked)
609
+ if len(tensors_) == 1:
610
+ # An optimization when `tensors_` contains only one tensor:
611
+ # - produce exactly same result as `torch.stack(tensors_)`
612
+ # - will achieve zero-copy if the tensor is contiguous
613
+ return tensors_[0].unsqueeze(0).contiguous()
614
+
615
+ if any(t.shape != tensors_[0].shape for t in tensors_):
616
+ # The tensors have incompatible shapes and can't be stacked.
617
+ return tensors_
618
+
619
+ return torch.stack(tensors_)
620
+
621
+ @staticmethod
622
+ def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
623
+ """
624
+ Batch multiple inputs together into a dictionary.
625
+
626
+ The resulting dictionary has the same keys as the inputs.
627
+ If the corresponding value from each input is a tensor and they all
628
+ share the same shape, the output value is a single batched tensor;
629
+ otherwise, the output value is a list containing the original value
630
+ from each input.
631
+ """
632
+ if len(inputs_list) == 0:
633
+ return {}
634
+
635
+ # We need to consider the case where each item in the batch
636
+ # contains different modalities (i.e. different keys).
637
+ item_lists = defaultdict[str, list[NestedTensors]](list)
638
+
639
+ for inputs in inputs_list:
640
+ for k, v in inputs.items():
641
+ item_lists[k].append(v)
642
+
643
+ return {
644
+ k: MultiModalKwargs._try_stack(item_list)
645
+ for k, item_list in item_lists.items()
646
+ }
647
+
648
+ @staticmethod
649
+ def as_kwargs(
650
+ batched_inputs: BatchedTensorInputs,
651
+ *,
652
+ device: torch.types.Device,
653
+ ) -> BatchedTensorInputs:
654
+ json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
655
+
656
+ json_mapped = json_map_leaves(
657
+ lambda x: x.to(device, non_blocking=True),
658
+ json_inputs,
659
+ )
660
+
661
+ return cast(BatchedTensorInputs, json_mapped)
662
+
663
+ def __eq__(self, other: object) -> bool:
664
+ if not isinstance(other, self.__class__):
665
+ return False
666
+ if self._items_by_modality != other._items_by_modality:
667
+ return False
668
+
669
+ ks = self.keys()
670
+ return (ks == other.keys()
671
+ and all(nested_tensors_equal(self[k], other[k]) for k in ks))
672
+
673
+ def _validate_modality(self, method_name: str, modality: str) -> None:
674
+ if not self._items_by_modality:
675
+ raise RuntimeError(
676
+ f"`{method_name}` is not supported when "
677
+ "MultiModalKwargs is not initialized with `items`")
678
+
679
+ if modality not in self._items_by_modality:
680
+ available_modalities = set(self._items_by_modality.keys())
681
+ raise KeyError(f"Modality {modality!r} not found. "
682
+ f"Available modalities: {available_modalities}")
683
+
684
+ def get_item_count(self, modality: str) -> int:
685
+ """Get the number of items belonging to a modality."""
686
+ self._validate_modality("get_item_count", modality)
687
+ return len(self._items_by_modality[modality])
688
+
689
+ def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
690
+ """
691
+ Get the keyword arguments corresponding to an item identified by
692
+ its modality and index.
693
+ """
694
+ self._validate_modality("get_item", modality)
695
+ return self._items_by_modality[modality][item_index]
696
+
697
+ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
698
+ """
699
+ Get the keyword arguments corresponding to each item belonging to
700
+ a modality.
701
+ """
702
+ self._validate_modality("get_items", modality)
703
+ return self._items_by_modality[modality]
704
+
705
+
706
+ MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
707
+ """
708
+ A dictionary containing placeholder ranges for each modality.
709
+ """
710
+
711
+
712
+ class MultiModalInputs(TypedDict):
713
+ """
714
+ Represents the outputs of
715
+ :class:`vllm.multimodal.processing.BaseMultiModalProcessor`,
716
+ ready to be passed to vLLM internals.
717
+ """
718
+
719
+ type: Literal["multimodal"]
720
+ """The type of inputs."""
721
+
722
+ prompt: str
723
+ """The processed prompt text."""
724
+
725
+ prompt_token_ids: list[int]
726
+ """The processed token IDs which includes placeholder tokens."""
727
+
728
+ token_type_ids: NotRequired[list[int]]
729
+ """The token type IDs of the prompt."""
730
+
731
+ mm_kwargs: MultiModalKwargs
732
+ """Keyword arguments to be directly passed to the model after batching."""
733
+
734
+ mm_hashes: NotRequired[Optional["MultiModalHashDict"]]
735
+ """The hashes of the multi-modal data."""
736
+
737
+ mm_placeholders: MultiModalPlaceholderDict
738
+ """
739
+ For each modality, information about the placeholder tokens in
740
+ :code:`prompt_token_ids`.
741
+ """
.venv/lib/python3.11/site-packages/vllm/multimodal/parse.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections import UserDict
5
+ from collections.abc import Callable, Iterator, Mapping, Sequence
6
+ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
7
+ Union)
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL.Image import Image
12
+ from typing_extensions import TypeAlias, TypeGuard, assert_never
13
+
14
+ from vllm.utils import is_list_of
15
+
16
+ from .audio import resample_audio
17
+ from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
18
+ ImageItem, ModalityData, MultiModalDataDict, VideoItem)
19
+
20
+ _T = TypeVar("_T")
21
+ _I = TypeVar("_I")
22
+
23
+
24
+ class ModalityDataItems(ABC, Generic[_T, _I]):
25
+ """
26
+ Represents data items for a modality in :class:`MultiModalDataItems`.
27
+ """
28
+
29
+ def __init__(self, data: _T, modality: str) -> None:
30
+ super().__init__()
31
+
32
+ self.data = data
33
+ self.modality = modality
34
+
35
+ def __repr__(self) -> str:
36
+ return (f"{type(self).__name__}(modality={self.modality!r}, "
37
+ f"len={len(self)})")
38
+
39
+ def __len__(self) -> int:
40
+ return self.get_count()
41
+
42
+ def __getitem__(self, index: int) -> _I:
43
+ return self.get(index)
44
+
45
+ if TYPE_CHECKING:
46
+ # Auto-generated
47
+ def __iter__(self) -> Iterator[_I]:
48
+ ...
49
+
50
+ @abstractmethod
51
+ def get_count(self) -> int:
52
+ """Get the number of data items."""
53
+ raise NotImplementedError
54
+
55
+ @abstractmethod
56
+ def get(self, index: int) -> _I:
57
+ """Get a data item by its index."""
58
+ raise NotImplementedError
59
+
60
+ def get_all(self) -> list[_I]:
61
+ """Get all data items."""
62
+ return [self.get(idx) for idx in range(self.get_count())]
63
+
64
+ @abstractmethod
65
+ def get_processor_data(self) -> Mapping[str, object]:
66
+ """Get the data to pass to the HF processor."""
67
+ raise NotImplementedError
68
+
69
+ @abstractmethod
70
+ def get_passthrough_data(self) -> Mapping[str, object]:
71
+ """Get the data to pass directly to the model."""
72
+ raise NotImplementedError
73
+
74
+
75
+ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
76
+ """Base class for data items that are arranged in a list."""
77
+
78
+ def get_count(self) -> int:
79
+ return len(self.data)
80
+
81
+ def get(self, index: int) -> _T:
82
+ return self.data[index]
83
+
84
+ def get_processor_data(self) -> Mapping[str, object]:
85
+ return {f"{self.modality}s": self.data}
86
+
87
+ def get_passthrough_data(self) -> Mapping[str, object]:
88
+ return {}
89
+
90
+
91
+ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
92
+ torch.Tensor]):
93
+ """
94
+ Base class for data items that are expressed as a batched embedding tensor,
95
+ or a list of embedding tensors (one per item).
96
+ """
97
+
98
+ def get_count(self) -> int:
99
+ return len(self.data)
100
+
101
+ def get(self, index: int) -> torch.Tensor:
102
+ return self.data[index]
103
+
104
+ def get_processor_data(self) -> Mapping[str, object]:
105
+ return {}
106
+
107
+ def get_passthrough_data(self) -> Mapping[str, object]:
108
+ return {f"{self.modality}_embeds": self.data}
109
+
110
+ def get_feature_size(self, item_idx: int) -> int:
111
+ return len(self.get(item_idx))
112
+
113
+
114
+ class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
115
+
116
+ def __init__(self, data: Sequence[HfAudioItem]) -> None:
117
+ super().__init__(data, "audio")
118
+
119
+
120
+ class AudioEmbeddingItems(EmbeddingItems):
121
+
122
+ def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
123
+ super().__init__(data, "audio")
124
+
125
+
126
+ class ImageSize(NamedTuple):
127
+ width: int
128
+ height: int
129
+
130
+
131
+ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
132
+
133
+ def __init__(self, data: Sequence[HfImageItem]) -> None:
134
+ super().__init__(data, "image")
135
+
136
+ def get_image_size(self, item_idx: int) -> ImageSize:
137
+ image = self.get(item_idx)
138
+
139
+ if isinstance(image, Image):
140
+ return ImageSize(*image.size)
141
+ if isinstance(image, (np.ndarray, torch.Tensor)):
142
+ _, h, w = image.shape
143
+ return ImageSize(w, h)
144
+
145
+ assert_never(image)
146
+
147
+
148
+ class ImageEmbeddingItems(EmbeddingItems):
149
+
150
+ def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
151
+ super().__init__(data, "image")
152
+
153
+
154
+ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
155
+
156
+ def __init__(self, data: Sequence[HfVideoItem]) -> None:
157
+ super().__init__(data, "video")
158
+
159
+ def get_num_frames(self, item_idx: int) -> int:
160
+ return len(self.get(item_idx))
161
+
162
+ def get_frame_size(self, item_idx: int) -> ImageSize:
163
+ image = self.get(item_idx)[0] # Assume that the video isn't empty
164
+
165
+ if isinstance(image, Image):
166
+ return ImageSize(*image.size)
167
+ if isinstance(image, (np.ndarray, torch.Tensor)):
168
+ _, h, w = image.shape
169
+ return ImageSize(w, h)
170
+
171
+ assert_never(image)
172
+
173
+
174
+ class VideoEmbeddingItems(EmbeddingItems):
175
+
176
+ def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
177
+ super().__init__(data, "video")
178
+
179
+
180
+ _D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
181
+
182
+
183
+ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
184
+ """
185
+ As :data:`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized
186
+ such that each entry corresponds to a list.
187
+ """
188
+
189
+ def get_count(self, modality: str, *, strict: bool = True) -> int:
190
+ """
191
+ Get the number of data items belonging to a modality.
192
+
193
+ If `strict=False`, return `0` instead of raising :exc:`KeyError`
194
+ even if the modality is not found.
195
+ """
196
+ if modality not in self:
197
+ if strict:
198
+ available_modalities = set(self.keys())
199
+ raise KeyError(f"Modality {modality!r} not found. "
200
+ f"Available modalities: {available_modalities}")
201
+
202
+ return 0
203
+
204
+ return self[modality].get_count()
205
+
206
+ def get_all_counts(self) -> Mapping[str, int]:
207
+ """Get the number of items belonging to each modality."""
208
+ return {m: items.get_count() for m, items in self.items()}
209
+
210
+ def get_items(
211
+ self,
212
+ modality: str,
213
+ typ: Union[type[_D], tuple[type[_D], ...]],
214
+ ) -> _D:
215
+ """
216
+ Get the data items belonging to a modality,
217
+ requiring that they belong to a certain type.
218
+ """
219
+ if modality not in self:
220
+ available_modalities = set(self.keys())
221
+ raise KeyError(f"Modality {modality!r} not found. "
222
+ f"Available modalities: {available_modalities}")
223
+
224
+ items = self[modality]
225
+ if not isinstance(items, typ):
226
+ raise TypeError(f"Invalid type of data items for {modality=}. "
227
+ f"Expected type: {typ}, but "
228
+ f"found type: {type(items)}")
229
+
230
+ return items # type: ignore[return-value]
231
+
232
+
233
+ ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
234
+ ModalityDataItems[Any, Any]]
235
+
236
+
237
+ class MultiModalDataParser:
238
+ """
239
+ Parses :data:`~vllm.multimodal.inputs.MultiModalDataDict` into
240
+ :class:`MultiModalDataItems`.
241
+
242
+ Args:
243
+ target_sr (float, optional): Enables automatic resampling of audio
244
+ items to the model's expected sampling rate.
245
+ """
246
+
247
+ def __init__(self, *, target_sr: Optional[float] = None) -> None:
248
+ super().__init__()
249
+
250
+ self.target_sr = target_sr
251
+
252
+ def _is_embeddings(
253
+ self, data: object
254
+ ) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]:
255
+ if isinstance(data, torch.Tensor):
256
+ return data.ndim == 3
257
+ if is_list_of(data, torch.Tensor):
258
+ return len(data) == 0 or data[0].ndim == 2
259
+
260
+ return False
261
+
262
+ def _get_audio_with_sr(
263
+ self,
264
+ audio: AudioItem,
265
+ ) -> tuple[np.ndarray, Optional[float]]:
266
+ if isinstance(audio, tuple):
267
+ return audio
268
+ if isinstance(audio, list):
269
+ return np.array(audio), None
270
+ if isinstance(audio, np.ndarray):
271
+ return audio, None
272
+ if isinstance(audio, torch.Tensor):
273
+ return audio.numpy(), None
274
+
275
+ assert_never(audio)
276
+
277
+ def _parse_audio_data(
278
+ self,
279
+ data: ModalityData[AudioItem],
280
+ ) -> ModalityDataItems[Any, Any]:
281
+ if self._is_embeddings(data):
282
+ return AudioEmbeddingItems(data)
283
+
284
+ if (is_list_of(data, float)
285
+ or isinstance(data,
286
+ (np.ndarray, torch.Tensor)) and data.ndim == 1
287
+ or isinstance(data, tuple)):
288
+ data_items = [data]
289
+ elif isinstance(data, (np.ndarray, torch.Tensor)):
290
+ data_items = [elem for elem in data]
291
+ else:
292
+ data_items = data
293
+
294
+ new_audios = list[np.ndarray]()
295
+ for data_item in data_items:
296
+ audio, orig_sr = self._get_audio_with_sr(data_item)
297
+ if orig_sr is None:
298
+ new_audio = audio
299
+ else:
300
+ target_sr = self.target_sr
301
+ if target_sr is None:
302
+ raise RuntimeError(
303
+ "Audio resampling is not supported when "
304
+ "`target_sr` is not provided")
305
+
306
+ new_audio = resample_audio(audio,
307
+ orig_sr=orig_sr,
308
+ target_sr=target_sr)
309
+
310
+ new_audios.append(new_audio)
311
+
312
+ return AudioProcessorItems(new_audios)
313
+
314
+ def _parse_image_data(
315
+ self,
316
+ data: ModalityData[ImageItem],
317
+ ) -> ModalityDataItems[Any, Any]:
318
+ if self._is_embeddings(data):
319
+ return ImageEmbeddingItems(data)
320
+
321
+ if (isinstance(data, Image)
322
+ or isinstance(data,
323
+ (np.ndarray, torch.Tensor)) and data.ndim == 3):
324
+ data_items = [data]
325
+ elif isinstance(data, (np.ndarray, torch.Tensor)):
326
+ data_items = [elem for elem in data]
327
+ else:
328
+ data_items = data
329
+
330
+ return ImageProcessorItems(data_items)
331
+
332
+ def _parse_video_data(
333
+ self,
334
+ data: ModalityData[VideoItem],
335
+ ) -> ModalityDataItems[Any, Any]:
336
+ if self._is_embeddings(data):
337
+ return VideoEmbeddingItems(data)
338
+
339
+ if (is_list_of(data, Image)
340
+ or isinstance(data,
341
+ (np.ndarray, torch.Tensor)) and data.ndim == 4):
342
+ data_items = [data]
343
+ elif isinstance(data, (np.ndarray, torch.Tensor)):
344
+ data_items = [elem for elem in data]
345
+ else:
346
+ data_items = data
347
+
348
+ return VideoProcessorItems(data_items)
349
+
350
+ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
351
+ return {
352
+ "audio": self._parse_audio_data,
353
+ "image": self._parse_image_data,
354
+ "video": self._parse_video_data,
355
+ }
356
+
357
+ def parse_mm_data(self,
358
+ mm_data: MultiModalDataDict) -> MultiModalDataItems:
359
+ subparsers = self._get_subparsers()
360
+
361
+ mm_items = MultiModalDataItems()
362
+ for k, v in mm_data.items():
363
+ if k not in subparsers:
364
+ raise ValueError(f"Unsupported modality: {k}")
365
+
366
+ mm_items[k] = subparsers[k](v)
367
+
368
+ return mm_items
.venv/lib/python3.11/site-packages/vllm/multimodal/processing.py ADDED
@@ -0,0 +1,1295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import re
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict
6
+ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
7
+ Sequence)
8
+ from dataclasses import dataclass, field
9
+ from functools import lru_cache
10
+ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
11
+ TypeVar, Union)
12
+
13
+ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
14
+
15
+ import vllm.envs as envs
16
+ from vllm.inputs import InputProcessingContext
17
+ from vllm.logger import init_logger
18
+ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
19
+ encode_tokens)
20
+ from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
21
+
22
+ from .hasher import MultiModalHasher
23
+ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
24
+ MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
25
+ PlaceholderRange)
26
+ from .parse import MultiModalDataItems, MultiModalDataParser
27
+
28
+ if TYPE_CHECKING:
29
+ from .profiling import BaseDummyInputsBuilder
30
+
31
+ logger = init_logger(__name__)
32
+
33
+ _S = TypeVar("_S", str, list[int])
34
+
35
+ PromptSeq = Union[str, list[int]]
36
+ """A token sequence (list of token IDs) or text."""
37
+
38
+
39
+ @dataclass
40
+ class PromptReplacementDetails:
41
+ """Details about the replacement token sequence or text."""
42
+
43
+ full: PromptSeq
44
+ """The full replacement."""
45
+
46
+ features: PromptSeq
47
+ """
48
+ The part of the replacement that corresponds to feature placeholders;
49
+ this will be replaced by the output of the vision encoder during model
50
+ inference.
51
+ """
52
+
53
+ @staticmethod
54
+ def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
55
+ return PromptReplacementDetails(full=seq, features=seq)
56
+
57
+
58
+ PromptRepl = Union[PromptSeq, PromptReplacementDetails]
59
+ """
60
+ The replacement token sequence or text.
61
+
62
+ If only part of the replacement corresponds to feature placeholders, you can
63
+ use :class:`PromptReplacementDetails` to specify which part.
64
+ """
65
+
66
+
67
+ @dataclass
68
+ class PromptReplacement:
69
+ """
70
+ Defines how to replace portions of an input prompt with placeholder tokens.
71
+
72
+ Example:
73
+
74
+ For each image, replace one ``<image>`` input placeholder in the prompt
75
+ with a number of ``<image>`` feature placeholders
76
+ equal to the feature size of the vision encoder:
77
+
78
+ .. code-block:: python
79
+
80
+ PromptReplacement(
81
+ modality="image",
82
+ target="<image>",
83
+ replacement="<image>" * image_feature_size,
84
+ )
85
+
86
+ As above, but further pad the feature placeholders with ``<image_bos>``
87
+ and `<image_eos>``, which are not supposed to be passed to the vision
88
+ encoder:
89
+
90
+ .. code-block:: python
91
+
92
+ PromptReplacement(
93
+ modality="image",
94
+ target="<image>",
95
+ replacement=PromptReplacementDetails(
96
+ full="".join([
97
+ "<image_bos>",
98
+ "<image>" * image_feature_size,
99
+ "<image_eos>",
100
+ ]),
101
+ features="<image>" * image_feature_size,
102
+ ),
103
+ )
104
+
105
+ To avoid unnecessary tokenization during prompt replacement,
106
+ we recommended passing token sequences instead of text:
107
+
108
+ .. code-block:: python
109
+
110
+ PromptReplacement(
111
+ modality="image",
112
+ target=[image_token_id],
113
+ replacement=PromptReplacementDetails(
114
+ full=([image_bos_id] + [image_token_id] * image_feature_size
115
+ + [image_eos_id]),
116
+ features=[image_token_id] * image_feature_size,
117
+ ),
118
+ )
119
+ """
120
+
121
+ modality: str
122
+ """The modality for which the replacement is made."""
123
+
124
+ target: PromptSeq
125
+ """The token sequence (or text) to find and replace."""
126
+
127
+ replacement: Union[Callable[[int], PromptRepl],
128
+ PromptRepl] = field(repr=False)
129
+ """
130
+ Given the index of the processed item within :attr:`modality`,
131
+ output the replacement token sequence (or text).
132
+
133
+ For convenience, you can directly pass in the replacement token sequence
134
+ (or text) instead of a function if it does not depend on the input.
135
+ """
136
+
137
+ def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
138
+ return BoundPromptReplacement(
139
+ tokenizer=tokenizer,
140
+ modality=self.modality,
141
+ _target=self.target,
142
+ _replacement=self.replacement,
143
+ )
144
+
145
+
146
+ @lru_cache(maxsize=2048)
147
+ def _cached_encode(
148
+ tokenizer: AnyTokenizer,
149
+ text: str,
150
+ *,
151
+ add_special_tokens: bool = False,
152
+ ) -> list[int]:
153
+ return encode_tokens(tokenizer,
154
+ text,
155
+ add_special_tokens=add_special_tokens)
156
+
157
+
158
+ @lru_cache(maxsize=2048)
159
+ def _cached_decode(
160
+ tokenizer: AnyTokenizer,
161
+ token_ids: tuple[int, ...],
162
+ *,
163
+ skip_special_tokens: bool = False,
164
+ ) -> str:
165
+ return decode_tokens(tokenizer,
166
+ list(token_ids),
167
+ skip_special_tokens=skip_special_tokens)
168
+
169
+
170
+ class _HasModalityAttr(Protocol):
171
+ modality: str
172
+
173
+
174
+ class _HasModalityProp(Protocol):
175
+
176
+ @property
177
+ def modality(self) -> str:
178
+ ...
179
+
180
+
181
+ _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
182
+
183
+
184
+ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
185
+ """Convenience function to apply :func:`full_groupby` based on modality."""
186
+ return full_groupby(values, key=lambda x: x.modality)
187
+
188
+
189
+ @dataclass
190
+ class _BoundPromptSequence:
191
+ """
192
+ A :data:`_PromptSeq` bound to a tokenizer to automatically
193
+ convert between token sequence and text representations.
194
+ """
195
+ tokenizer: AnyTokenizer = field(repr=False)
196
+
197
+ _text: Optional[str]
198
+ _token_ids: Optional[list[int]]
199
+
200
+ @staticmethod
201
+ def from_seq(
202
+ tokenizer: AnyTokenizer,
203
+ seq: PromptSeq,
204
+ ) -> "_BoundPromptSequence":
205
+ return _BoundPromptSequence(
206
+ tokenizer=tokenizer,
207
+ _text=seq if isinstance(seq, str) else None,
208
+ _token_ids=seq if isinstance(seq, list) else None,
209
+ )
210
+
211
+ def __post_init__(self) -> None:
212
+ if self._text is None and self._token_ids is None:
213
+ raise ValueError("At least one of 'text' and 'token_ids' must be "
214
+ "specified")
215
+
216
+ @property
217
+ def text(self) -> str:
218
+ if self._text is None:
219
+ assert self._token_ids is not None
220
+ self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
221
+
222
+ return self._text
223
+
224
+ @property
225
+ def token_ids(self) -> list[int]:
226
+ if self._token_ids is None:
227
+ assert self._text is not None
228
+ self._token_ids = _cached_encode(self.tokenizer, self._text)
229
+
230
+ return self._token_ids
231
+
232
+
233
+ @dataclass
234
+ class _BoundPromptReplacementGroup:
235
+ full: _BoundPromptSequence
236
+ features: _BoundPromptSequence
237
+
238
+
239
+ @dataclass
240
+ class BoundPromptReplacement:
241
+ """
242
+ A :class:`PromptReplacement` bound to a tokenizer to automatically
243
+ convert :attr:`target` and the result of :meth:`get_replacement` between
244
+ token sequence and text representations.
245
+ """
246
+ tokenizer: AnyTokenizer = field(repr=False)
247
+ modality: str
248
+
249
+ _target: PromptSeq
250
+ _replacement: Union[Callable[[int], PromptRepl],
251
+ PromptRepl] = field(repr=False)
252
+
253
+ def __post_init__(self) -> None:
254
+ self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
255
+
256
+ @property
257
+ def target(self) -> _BoundPromptSequence:
258
+ """The token sequence (or text) to find and replace."""
259
+ return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
260
+
261
+ def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
262
+ """
263
+ Given the index of the processed item within :attr:`modality`,
264
+ output the replacement token sequence (or text).
265
+ """
266
+ replacement = self._replacement
267
+ if callable(replacement):
268
+ cache_key = item_idx
269
+ if cache_key in self._replacement_cache:
270
+ return self._replacement_cache[cache_key]
271
+
272
+ replacement = replacement(item_idx)
273
+ else:
274
+ cache_key = None
275
+
276
+ if not isinstance(replacement, PromptReplacementDetails):
277
+ replacement = PromptReplacementDetails.from_seq(replacement)
278
+
279
+ bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
280
+ replacement.full)
281
+ bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
282
+ replacement.features)
283
+ bound_replacement = _BoundPromptReplacementGroup(
284
+ full=bound_full,
285
+ features=bound_features,
286
+ )
287
+
288
+ if cache_key is not None:
289
+ self._replacement_cache[cache_key] = bound_replacement
290
+
291
+ return bound_replacement
292
+
293
+
294
+ class _TokenMatch(NamedTuple):
295
+ start_idx: int
296
+ end_idx: int
297
+
298
+
299
+ def iter_token_matches(
300
+ token_ids: list[int],
301
+ match_ids: list[int],
302
+ ) -> Generator[_TokenMatch]:
303
+ """
304
+ Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
305
+
306
+ Note that empty matches are ignored.
307
+ """
308
+ prompt_len = len(token_ids)
309
+ match_len = len(match_ids)
310
+
311
+ if match_len == 0:
312
+ return
313
+
314
+ start_idx = 0
315
+ while start_idx < prompt_len - match_len + 1:
316
+ end_idx = start_idx + match_len
317
+
318
+ if token_ids[start_idx:end_idx] == match_ids:
319
+ yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
320
+
321
+ # Exclude overlapping matches
322
+ start_idx = end_idx
323
+ else:
324
+ start_idx += 1
325
+
326
+
327
+ @dataclass(repr=False)
328
+ class _PromptReplacementMatch(ABC):
329
+ prompt_repl: BoundPromptReplacement
330
+
331
+ @property
332
+ def modality(self) -> str:
333
+ return self.prompt_repl.modality
334
+
335
+ @property
336
+ @abstractmethod
337
+ def start_idx(self) -> int:
338
+ raise NotImplementedError
339
+
340
+ @property
341
+ @abstractmethod
342
+ def end_idx(self) -> int:
343
+ raise NotImplementedError
344
+
345
+ def __repr__(self) -> str:
346
+ return (f"{type(self).__name__}(modality={self.modality!r}, "
347
+ f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
348
+
349
+
350
+ @dataclass(repr=False)
351
+ class _PromptReplacementTokenMatch(_PromptReplacementMatch):
352
+ match: _TokenMatch
353
+
354
+ @property
355
+ def start_idx(self) -> int:
356
+ return self.match.start_idx
357
+
358
+ @property
359
+ def end_idx(self) -> int:
360
+ return self.match.end_idx
361
+
362
+
363
+ @dataclass(repr=False)
364
+ class _PromptReplacementTextMatch(_PromptReplacementMatch):
365
+ match: re.Match[str]
366
+
367
+ @property
368
+ def start_idx(self) -> int:
369
+ return self.match.start()
370
+
371
+ @property
372
+ def end_idx(self) -> int:
373
+ return self.match.end()
374
+
375
+
376
+ @dataclass
377
+ class PlaceholderFeaturesInfo:
378
+ modality: str
379
+ item_idx: int
380
+ start_idx: int
381
+ tokens: list[int]
382
+
383
+ @property
384
+ def length(self) -> int:
385
+ return len(self.tokens)
386
+
387
+ def to_range(self) -> PlaceholderRange:
388
+ return PlaceholderRange(
389
+ offset=self.start_idx,
390
+ length=self.length,
391
+ )
392
+
393
+
394
+ def find_token_matches(
395
+ prompt: list[int],
396
+ prompt_repls: Sequence[BoundPromptReplacement],
397
+ ) -> list[_PromptReplacementTokenMatch]:
398
+ """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
399
+ return [
400
+ _PromptReplacementTokenMatch(prompt_repl, match)
401
+ for prompt_repl in prompt_repls
402
+ for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
403
+ ]
404
+
405
+
406
+ def find_text_matches(
407
+ prompt: str,
408
+ prompt_repls: Sequence[BoundPromptReplacement],
409
+ ) -> list[_PromptReplacementTextMatch]:
410
+ """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
411
+ return [
412
+ _PromptReplacementTextMatch(prompt_repl, match)
413
+ for prompt_repl in prompt_repls
414
+ for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
415
+ ]
416
+
417
+
418
+ def _resolve_matches(
419
+ prompt: PromptSeq,
420
+ mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
421
+ ) -> list[_PromptReplacementMatch]:
422
+ """
423
+ Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
424
+ and sort them such that earlier matches take priority over later ones.
425
+ """
426
+ matches = [m for matches in mm_matches.values() for m in matches]
427
+
428
+ seen_matches: list[Optional[_PromptReplacementMatch]] = [None
429
+ ] * len(prompt)
430
+
431
+ for match in matches:
432
+ for idx in range(match.start_idx, match.end_idx):
433
+ if seen_matches[idx] is not None:
434
+ raise ValueError("Found overlapping matches "
435
+ f"({seen_matches[idx]} and {match}) "
436
+ f"at index={idx} of prompt={prompt}")
437
+
438
+ seen_matches[idx] = match
439
+
440
+ return sorted(matches, key=lambda x: x.start_idx)
441
+
442
+
443
+ def _replace_matches(
444
+ prompt: _S,
445
+ mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
446
+ mm_item_counts: Mapping[str, int],
447
+ ) -> list[_S]:
448
+ """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
449
+ out_seqs = list[_S]()
450
+ prev_end_idx = 0
451
+ next_idx_by_modality = defaultdict[str, int](lambda: 0)
452
+
453
+ for match in _resolve_matches(prompt, mm_matches):
454
+ modality = match.modality
455
+
456
+ item_idx = next_idx_by_modality[modality]
457
+ if item_idx >= mm_item_counts.get(modality, 0):
458
+ continue
459
+
460
+ start_idx = match.start_idx
461
+ end_idx = match.end_idx
462
+
463
+ repl_info = match.prompt_repl
464
+ replacement = repl_info.get_replacement(item_idx)
465
+
466
+ if isinstance(prompt, str):
467
+ repl_seq = replacement.full.text
468
+ out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
469
+ else:
470
+ repl_seq = replacement.full.token_ids
471
+ out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
472
+
473
+ prev_end_idx = end_idx
474
+ next_idx_by_modality[modality] += 1
475
+
476
+ out_seqs.append(prompt[prev_end_idx:])
477
+
478
+ return out_seqs
479
+
480
+
481
+ def replace_token_matches(
482
+ prompt: list[int],
483
+ mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
484
+ mm_item_counts: Mapping[str, int],
485
+ ) -> list[int]:
486
+ """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
487
+ if not mm_matches:
488
+ return prompt
489
+
490
+ token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
491
+
492
+ return flatten_2d_lists(token_id_seqs)
493
+
494
+
495
+ def replace_text_matches(
496
+ prompt: str,
497
+ mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
498
+ mm_item_counts: Mapping[str, int],
499
+ ) -> str:
500
+ """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
501
+ if not mm_matches:
502
+ return prompt
503
+
504
+ texts = _replace_matches(prompt, mm_matches, mm_item_counts)
505
+
506
+ return "".join(texts)
507
+
508
+
509
+ def _iter_placeholders(
510
+ mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
511
+ prompt: list[int],
512
+ mm_item_counts: Mapping[str, int],
513
+ ) -> Iterable[PlaceholderFeaturesInfo]:
514
+ """
515
+ Yield each set of placeholder tokens found in :code:`prompt`.
516
+
517
+ Matches are exclusive even when multiple modalities share
518
+ the same placeholder tokens. In that case, the modality that
519
+ appears earlier in `mm_prompt_repls` takes priority.
520
+
521
+ Note that empty matches are ignored.
522
+ """
523
+ prompt_len = len(prompt)
524
+ item_idx_by_modality = defaultdict[str, int](lambda: 0)
525
+
526
+ start_idx = 0
527
+ while start_idx < prompt_len:
528
+ found = False
529
+
530
+ for modality, modality_repls in mm_prompt_repls.items():
531
+ item_idx = item_idx_by_modality[modality]
532
+ if item_idx >= mm_item_counts.get(modality, 0):
533
+ continue
534
+
535
+ for repl_info in modality_repls:
536
+ replacement = repl_info.get_replacement(item_idx)
537
+ repl_tokens_full = replacement.full.token_ids
538
+ repl_len_full = len(repl_tokens_full)
539
+ end_idx_full = start_idx + repl_len_full
540
+
541
+ if repl_len_full == 0 or end_idx_full > prompt_len:
542
+ continue
543
+
544
+ if prompt[start_idx:end_idx_full] == repl_tokens_full:
545
+ repl_tokens_feat = replacement.features.token_ids
546
+
547
+ try:
548
+ match = next(
549
+ iter_token_matches(repl_tokens_full,
550
+ repl_tokens_feat))
551
+ yield PlaceholderFeaturesInfo(
552
+ modality=modality,
553
+ item_idx=item_idx,
554
+ start_idx=start_idx + match.start_idx,
555
+ tokens=repl_tokens_feat,
556
+ )
557
+ except StopIteration:
558
+ raise AssertionError(
559
+ f"{repl_tokens_feat=} should be a "
560
+ f"subsequence of {repl_tokens_full=}") from None
561
+
562
+ # Exclude overlapping matches
563
+ start_idx = end_idx_full
564
+ item_idx_by_modality[modality] += 1
565
+ found = True
566
+ break
567
+
568
+ if found:
569
+ break # Go back to the outer while loop
570
+
571
+ if not found:
572
+ start_idx += 1
573
+
574
+
575
+ def find_mm_placeholders(
576
+ mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
577
+ prompt: list[int],
578
+ mm_item_counts: Mapping[str, int],
579
+ ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
580
+ it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
581
+ return dict(full_groupby_modality(it))
582
+
583
+
584
+ class ProcessingCache:
585
+
586
+ def __init__(self, capacity: int) -> None:
587
+ super().__init__()
588
+
589
+ # DEBUG: Set to None to disable
590
+ self.debug_cache_hit_ratio_steps: Optional[int] = None
591
+
592
+ self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
593
+
594
+ def _maybe_log_cache_stats(self) -> None:
595
+ steps = self.debug_cache_hit_ratio_steps
596
+ if not steps:
597
+ return
598
+
599
+ cache_stats = self._cache.stat()
600
+ if cache_stats.total % steps == 0:
601
+ logger.debug("ProcessingCache: hit_ratio = %.2f",
602
+ cache_stats.hit_ratio)
603
+
604
+ def get(
605
+ self,
606
+ model_id: str,
607
+ modality: str,
608
+ input_item: object,
609
+ input_kwargs: Mapping[str, object],
610
+ ) -> Optional[MultiModalKwargsItem]:
611
+ """
612
+ Get a processed multi-modal item from the cache
613
+ according to its dependencies, including:
614
+
615
+ - The model ID
616
+ - The modality of the item
617
+ - The original data item passed to the HF processor
618
+ - The configuration options of the HF processor
619
+ """
620
+ self._maybe_log_cache_stats()
621
+
622
+ cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
623
+ **{modality: input_item},
624
+ **input_kwargs)
625
+ return self._cache.get(cache_key)
626
+
627
+ def put(
628
+ self,
629
+ model_id: str,
630
+ modality: str,
631
+ input_item: object,
632
+ input_kwargs: Mapping[str, object],
633
+ output_kwargs: MultiModalKwargsItem,
634
+ ) -> None:
635
+ """
636
+ Put a processed multi-modal item into the cache
637
+ according to its dependencies (see :meth:`get`).
638
+ """
639
+ cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
640
+ **{modality: input_item},
641
+ **input_kwargs)
642
+ self._cache.put(cache_key, output_kwargs)
643
+
644
+
645
+ class BaseProcessingInfo:
646
+ """Base class to provide the information necessary for data processing."""
647
+
648
+ def __init__(self, ctx: InputProcessingContext) -> None:
649
+ super().__init__()
650
+
651
+ self.ctx = ctx
652
+
653
+ @property
654
+ def model_id(self) -> str:
655
+ return self.ctx.model_config.model
656
+
657
+ def get_tokenizer(self) -> AnyTokenizer:
658
+ return self.ctx.tokenizer
659
+
660
+ def get_hf_config(self) -> PretrainedConfig:
661
+ return self.ctx.get_hf_config()
662
+
663
+ def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
664
+ """
665
+ Subclasses can override this method to handle
666
+ specific kwargs from model config or user inputs.
667
+ """
668
+ return self.ctx.get_hf_processor(**kwargs)
669
+
670
+ @abstractmethod
671
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
672
+ """
673
+ Return the maximum supported number of items for each modality.
674
+
675
+ A value of `None` means unlimited number of items.
676
+
677
+ Omitting a modality from the returned dictionary means that
678
+ it is not supported at all.
679
+ """
680
+ raise NotImplementedError
681
+
682
+ @abstractmethod
683
+ def get_mm_max_tokens_per_item(
684
+ self,
685
+ seq_len: int,
686
+ mm_counts: Mapping[str, int],
687
+ ) -> Mapping[str, int]:
688
+ """
689
+ Get the maximum possible number of tokens per data item
690
+ for each modality.
691
+
692
+ The dictionary returned by this method should have the same
693
+ keys as that returned by :meth:`get_supported_mm_limits`.
694
+ """
695
+ raise NotImplementedError
696
+
697
+
698
+ _I = TypeVar("_I", bound=BaseProcessingInfo)
699
+
700
+
701
+ class BaseMultiModalProcessor(ABC, Generic[_I]):
702
+ """
703
+ Abstract base class to process multi-modal inputs to be used in vLLM.
704
+
705
+ Not to be confused with :class:`transformers.ProcessorMixin`.
706
+ """
707
+
708
+ def __init__(self,
709
+ info: _I,
710
+ dummy_inputs: "BaseDummyInputsBuilder[_I]",
711
+ *,
712
+ cache: Optional[ProcessingCache] = None,
713
+ enable_sanity_checks: bool = True) -> None:
714
+ super().__init__()
715
+
716
+ self.info = info
717
+ self.dummy_inputs = dummy_inputs
718
+ self.cache = cache
719
+ self.enable_sanity_checks = enable_sanity_checks
720
+
721
+ self.data_parser = self._get_data_parser()
722
+
723
+ def __call__(
724
+ self,
725
+ prompt: str,
726
+ mm_data: MultiModalDataDict,
727
+ hf_processor_mm_kwargs: Mapping[str, object],
728
+ ) -> MultiModalInputs:
729
+ return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
730
+
731
+ def _get_data_parser(self) -> MultiModalDataParser:
732
+ """
733
+ Construct a parser to preprocess multi-modal data items
734
+ before passing them to :meth:`_get_hf_mm_data`.
735
+
736
+ You can support additional modalities by creating a subclass
737
+ of :class:`MultiModalDataParser` that has additional subparsers.
738
+ """
739
+ return MultiModalDataParser()
740
+
741
+ def _to_mm_items(
742
+ self,
743
+ mm_data: MultiModalDataDict,
744
+ ) -> MultiModalDataItems:
745
+ """
746
+ Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
747
+ before passing them to :meth:`_get_hf_mm_data`.
748
+ """
749
+ mm_items = self.data_parser.parse_mm_data(mm_data)
750
+
751
+ mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
752
+ for modality, items in mm_items.items():
753
+ limit = mm_limits.get(modality, 1)
754
+ if len(items) > limit:
755
+ raise ValueError(
756
+ f"You set {modality}={limit} (or defaulted to 1) in "
757
+ f"`--limit-mm-per-prompt`, but passed {len(items)} "
758
+ f"{modality} items in the same prompt.")
759
+
760
+ return mm_items
761
+
762
+ @abstractmethod
763
+ def _get_mm_fields_config(
764
+ self,
765
+ hf_inputs: BatchFeature,
766
+ hf_processor_mm_kwargs: Mapping[str, object],
767
+ ) -> Mapping[str, MultiModalFieldConfig]:
768
+ """Given the HF-processed data, output the metadata of each field."""
769
+ raise NotImplementedError
770
+
771
+ @abstractmethod
772
+ def _get_prompt_replacements(
773
+ self,
774
+ mm_items: MultiModalDataItems,
775
+ hf_processor_mm_kwargs: Mapping[str, object],
776
+ out_mm_kwargs: MultiModalKwargs,
777
+ ) -> list[PromptReplacement]:
778
+ """
779
+ Given the original multi-modal items for this modality
780
+ and HF-processed data, output the replacements to perform.
781
+
782
+ Notes:
783
+ - You should not assume that HF processor always performs prompt
784
+ replacement: in :meth:`_apply_hf_processor_missing`, this method
785
+ is called on text-only and multimodal-only inputs separately,
786
+ instead of passing them in the same call.
787
+ - The replacement information returned by this method is also used
788
+ to determine the placeholder token positions for each multi-modal
789
+ item.
790
+ """
791
+ raise NotImplementedError
792
+
793
+ def _find_mm_placeholders(
794
+ self,
795
+ mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
796
+ new_token_ids: list[int],
797
+ mm_item_counts: Mapping[str, int],
798
+ ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
799
+ return find_mm_placeholders(mm_prompt_repls, new_token_ids,
800
+ mm_item_counts)
801
+
802
+ def _get_hf_mm_data(
803
+ self,
804
+ mm_items: MultiModalDataItems,
805
+ ) -> tuple[Mapping[str, object], Mapping[str, object]]:
806
+ processor_data = dict[str, object]()
807
+ passthrough_data = dict[str, object]()
808
+
809
+ for items in mm_items.values():
810
+ processor_data.update(items.get_processor_data())
811
+ passthrough_data.update(items.get_passthrough_data())
812
+
813
+ return processor_data, passthrough_data
814
+
815
+ def _call_hf_processor(
816
+ self,
817
+ prompt: str,
818
+ # Not to be confused with `mm_data` in `self.apply`.
819
+ # This refers to the data to be passed to HF processor.
820
+ mm_data: Mapping[str, object],
821
+ mm_kwargs: Mapping[str, object],
822
+ ) -> BatchFeature:
823
+ """
824
+ Call the HF processor on the prompt text and
825
+ associated multi-modal data.
826
+ """
827
+ return self.info.ctx.call_hf_processor(
828
+ self.info.get_hf_processor(**mm_kwargs),
829
+ dict(text=prompt, **mm_data),
830
+ mm_kwargs,
831
+ )
832
+
833
+ def _apply_hf_processor_text_mm(
834
+ self,
835
+ prompt_text: str,
836
+ mm_items: MultiModalDataItems,
837
+ hf_processor_mm_kwargs: Mapping[str, object],
838
+ ) -> tuple[list[int], MultiModalKwargs]:
839
+ """
840
+ Apply the HF processor on the prompt text and multi-modal data
841
+ together.
842
+ """
843
+ processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
844
+
845
+ processed_data = self._call_hf_processor(
846
+ prompt=prompt_text,
847
+ mm_data=processor_data,
848
+ mm_kwargs=hf_processor_mm_kwargs,
849
+ )
850
+ processed_data.update(passthrough_data)
851
+
852
+ prompt_ids, = processed_data.pop("input_ids").tolist()
853
+
854
+ mm_kwargs = MultiModalKwargs.from_hf_inputs(
855
+ processed_data,
856
+ self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
857
+ )
858
+
859
+ return prompt_ids, mm_kwargs
860
+
861
+ def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
862
+ """
863
+ Apply the HF processor on the prompt text only.
864
+
865
+ Since HF processor requires that text and multi-modal items
866
+ correspond to each other, we create dummy multi-modal items
867
+ to go along with the text.
868
+ """
869
+ prompt_ids, _ = self._apply_hf_processor_text_mm(
870
+ prompt_text=prompt_text,
871
+ mm_items=MultiModalDataItems({}),
872
+ hf_processor_mm_kwargs={},
873
+ )
874
+
875
+ return prompt_ids
876
+
877
+ def _apply_hf_processor_tokens_only(
878
+ self,
879
+ prompt_tokens: list[int],
880
+ ) -> list[int]:
881
+ """
882
+ Apply the HF processor on the prompt tokens only.
883
+
884
+ Most HF processors accept prompt text but not prompt tokens.
885
+ If the HF processor adds or removes tokens that are not related to
886
+ multi-modal data, you should override this method so it is consistent
887
+ with the output of :meth:`_apply_hf_processor_text_only` on the
888
+ corresponding text.
889
+ """
890
+ return prompt_tokens
891
+
892
+ def _apply_hf_processor_mm_only(
893
+ self,
894
+ mm_items: MultiModalDataItems,
895
+ hf_processor_mm_kwargs: Mapping[str, object],
896
+ ) -> MultiModalKwargs:
897
+ """
898
+ Apply the HF processor on the multi-modal data only.
899
+
900
+ Since HF processor requires that text and multi-modal items
901
+ correspond to each other, we generate dummy text using
902
+ :class:`DummyInputsBuilder` to go along with the multi-modal data.
903
+ """
904
+ mm_counts = mm_items.get_all_counts()
905
+
906
+ dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
907
+ self.info.ctx.model_config.max_model_len,
908
+ mm_counts,
909
+ )
910
+
911
+ _, mm_kwargs = self._apply_hf_processor_text_mm(
912
+ prompt_text=dummy_inputs.prompt_text,
913
+ mm_items=mm_items,
914
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
915
+ )
916
+
917
+ return mm_kwargs
918
+
919
+ def _apply_hf_processor_main(
920
+ self,
921
+ prompt: Union[str, list[int]],
922
+ mm_items: MultiModalDataItems,
923
+ hf_processor_mm_kwargs: Mapping[str, object],
924
+ *,
925
+ enable_hf_prompt_replacement: bool,
926
+ ) -> tuple[list[int], MultiModalKwargs]:
927
+ """
928
+ Apply the HF processor on the prompt text and multi-modal data.
929
+
930
+ Note:
931
+ If :code:`enable_hf_prompt_replacement=False`, the prompt should
932
+ correspond to the multi-modal items.
933
+ """
934
+ if isinstance(prompt, str):
935
+ if enable_hf_prompt_replacement:
936
+ return self._apply_hf_processor_text_mm(
937
+ prompt_text=prompt,
938
+ mm_items=mm_items,
939
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
940
+ )
941
+
942
+ prompt_ids = self._apply_hf_processor_text_only(prompt)
943
+ else:
944
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt)
945
+
946
+ mm_missing_kwargs = self._apply_hf_processor_mm_only(
947
+ mm_items=mm_items,
948
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
949
+ )
950
+
951
+ return prompt_ids, mm_missing_kwargs
952
+
953
+ def _cached_apply_hf_processor(
954
+ self,
955
+ prompt: Union[str, list[int]],
956
+ mm_data_items: MultiModalDataItems,
957
+ hf_processor_mm_kwargs: Mapping[str, object],
958
+ ) -> tuple[list[int], MultiModalKwargs]:
959
+ """
960
+ Apply the HF processor on the full prompt text,
961
+ caching the results and reusing cached results.
962
+ """
963
+ cache = self.cache
964
+ model_id = self.info.model_id
965
+
966
+ _, passthrough_data = self._get_hf_mm_data(mm_data_items)
967
+ if cache is None or passthrough_data:
968
+ return self._apply_hf_processor_main(
969
+ prompt=prompt,
970
+ mm_items=mm_data_items,
971
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
972
+ enable_hf_prompt_replacement=True,
973
+ )
974
+
975
+ mm_maybe_cached_kw_items = {
976
+ modality: [
977
+ cache.get(model_id, modality, item, hf_processor_mm_kwargs)
978
+ for item in items
979
+ ]
980
+ for modality, items in mm_data_items.items()
981
+ }
982
+
983
+ mm_missing_idxs = {
984
+ modality:
985
+ [idx for idx, item in enumerate(kw_items) if item is None]
986
+ for modality, kw_items in mm_maybe_cached_kw_items.items()
987
+ }
988
+ mm_missing_data = {
989
+ modality: [mm_data_items[modality][idx] for idx in idxs]
990
+ for modality, idxs in mm_missing_idxs.items()
991
+ }
992
+ mm_missing_data_items = self._to_mm_items(mm_missing_data)
993
+
994
+ # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
995
+ # so we need to pass `enable_hf_prompt_replacement=False`
996
+ prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
997
+ prompt=prompt,
998
+ mm_items=mm_missing_data_items,
999
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1000
+ enable_hf_prompt_replacement=False,
1001
+ )
1002
+
1003
+ mm_missing_next_idx = {
1004
+ modality: 0
1005
+ for modality in mm_missing_data_items
1006
+ }
1007
+
1008
+ merged_kw_items = list[MultiModalKwargsItem]()
1009
+ for modality, kw_items in mm_maybe_cached_kw_items.items():
1010
+ for idx, kw_item in enumerate(kw_items):
1011
+ if kw_item is None:
1012
+ kw_item = mm_missing_kwargs.get_item(
1013
+ modality,
1014
+ mm_missing_next_idx[modality],
1015
+ )
1016
+
1017
+ cache.put(
1018
+ model_id,
1019
+ modality,
1020
+ mm_data_items[modality][idx],
1021
+ hf_processor_mm_kwargs,
1022
+ kw_item,
1023
+ )
1024
+
1025
+ mm_missing_next_idx[modality] += 1
1026
+
1027
+ merged_kw_items.append(kw_item)
1028
+
1029
+ if self.enable_sanity_checks:
1030
+ mm_missing_counts = mm_missing_data_items.get_all_counts()
1031
+ assert all(
1032
+ item_count == mm_missing_counts[modality]
1033
+ for modality, item_count in mm_missing_next_idx.items()), dict(
1034
+ mm_missing_next_idx=mm_missing_next_idx,
1035
+ mm_missing_counts=mm_missing_counts)
1036
+
1037
+ mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1038
+
1039
+ return prompt_ids, mm_kwargs
1040
+
1041
+ def _bind_and_group_repls(
1042
+ self,
1043
+ prompt_repls: list[PromptReplacement],
1044
+ ) -> dict[str, list[BoundPromptReplacement]]:
1045
+ tokenizer = self.info.get_tokenizer()
1046
+
1047
+ it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
1048
+ return dict(full_groupby_modality(it))
1049
+
1050
+ def _always_apply_prompt_replacements(self) -> bool:
1051
+ """
1052
+ A flag which can be overridden so that
1053
+ :meth:`_apply_prompt_replacements` is always called even if we
1054
+ detect that HF has performed processing via
1055
+ :meth:`_find_placeholders_by_modality`.
1056
+
1057
+ This is useful in cases where :meth:`_find_placeholders_by_modality`
1058
+ cannot be reliably used to detect whether HF has performed processing.
1059
+ """
1060
+ return False
1061
+
1062
+ def _apply_prompt_replacements(
1063
+ self,
1064
+ token_ids: list[int],
1065
+ mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
1066
+ mm_item_counts: Mapping[str, int],
1067
+ ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1068
+ tokenizer = self.info.get_tokenizer()
1069
+
1070
+ mm_token_matches = {
1071
+ modality: find_token_matches(token_ids, prompt_repls)
1072
+ for modality, prompt_repls in mm_prompt_repls.items()
1073
+ }
1074
+ mm_match_counts = {
1075
+ modality: len(matches)
1076
+ for modality, matches in mm_token_matches.items()
1077
+ }
1078
+
1079
+ # If the search text does not represent a special token,
1080
+ # it may have different token IDs in the prompt, because
1081
+ # the tokens may go across the boundaries of the search text.
1082
+ # ----
1083
+ # e.g. when searching for "foo" in "food", if "food" itself makes
1084
+ # up a token, then the token ID of "foo" will not appear at all
1085
+ # ----
1086
+ # Since it is inefficient to search for all possible tokenizations
1087
+ # of the search text in the prompt, we instead perform string
1088
+ # replacement on the decoded token IDs, then encode them back.
1089
+ if all(
1090
+ mm_match_counts.get(modality, 0) >= item_count
1091
+ for modality, item_count in mm_item_counts.items()
1092
+ ): # yapf: disable
1093
+ token_ids = replace_token_matches(
1094
+ token_ids,
1095
+ mm_token_matches,
1096
+ mm_item_counts,
1097
+ )
1098
+
1099
+ text = decode_tokens(tokenizer, token_ids)
1100
+ matched_repls = {
1101
+ modality: [match.prompt_repl for match in token_matches]
1102
+ for modality, token_matches in mm_token_matches.items()
1103
+ }
1104
+ else:
1105
+ text = decode_tokens(tokenizer, token_ids)
1106
+
1107
+ mm_text_matches = {
1108
+ modality: find_text_matches(text, prompt_repls)
1109
+ for modality, prompt_repls in mm_prompt_repls.items()
1110
+ }
1111
+ text = replace_text_matches(
1112
+ text,
1113
+ mm_text_matches,
1114
+ mm_item_counts,
1115
+ )
1116
+
1117
+ token_ids = encode_tokens(tokenizer,
1118
+ text,
1119
+ add_special_tokens=False)
1120
+ matched_repls = {
1121
+ modality: [match.prompt_repl for match in token_matches]
1122
+ for modality, token_matches in mm_text_matches.items()
1123
+ }
1124
+
1125
+ placeholders = self._find_mm_placeholders(
1126
+ matched_repls,
1127
+ token_ids,
1128
+ mm_item_counts,
1129
+ )
1130
+
1131
+ return token_ids, text, placeholders
1132
+
1133
+ def _validate_mm_kwargs(
1134
+ self,
1135
+ mm_kwargs: MultiModalKwargs,
1136
+ mm_item_counts: Mapping[str, int],
1137
+ ) -> None:
1138
+ for modality, item_count in mm_item_counts.items():
1139
+ if modality in mm_kwargs.modalities:
1140
+ items = mm_kwargs.get_items(modality)
1141
+ else:
1142
+ items = []
1143
+
1144
+ if len(items) != item_count:
1145
+ raise RuntimeError(
1146
+ f"Expected there to be {item_count} {modality} items in "
1147
+ f"keyword arguments corresponding to {item_count} "
1148
+ f"{modality} data items, but only found {len(items)}! "
1149
+ "There is likely a problem with your "
1150
+ "implementation of merged multi-modal processor for this "
1151
+ "model (usually arising from an inconsistency between "
1152
+ "`_call_hf_processor` and `_get_mm_fields_config`).")
1153
+
1154
+ def _validate_mm_placeholders(
1155
+ self,
1156
+ mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1157
+ mm_item_counts: Mapping[str, int],
1158
+ *,
1159
+ allow_missing: bool = False,
1160
+ ) -> Mapping[str, int]:
1161
+ missing_repl_counts = dict[str, int]()
1162
+
1163
+ for modality, item_count in mm_item_counts.items():
1164
+ placeholders = mm_placeholders.get(modality, [])
1165
+
1166
+ if len(placeholders) != item_count and not allow_missing:
1167
+ raise RuntimeError(
1168
+ f"Expected there to be {item_count} prompt replacements "
1169
+ f"corresponding to {item_count} {modality} items, but only "
1170
+ f"found {len(placeholders)} prompt replacements! Either "
1171
+ "the prompt text has missing/incorrect tokens for "
1172
+ "multi-modal inputs, or there is a problem with your "
1173
+ "implementation of merged multi-modal processor for this "
1174
+ "model (usually arising from an inconsistency between "
1175
+ "`_call_hf_processor` and `_get_prompt_replacements`).")
1176
+
1177
+ missing_repl_counts[modality] = item_count - len(placeholders)
1178
+
1179
+ return missing_repl_counts
1180
+
1181
+ def apply(
1182
+ self,
1183
+ prompt: Union[str, list[int]],
1184
+ mm_data: MultiModalDataDict,
1185
+ hf_processor_mm_kwargs: Mapping[str, object],
1186
+ ) -> MultiModalInputs:
1187
+ """
1188
+ Process multi-modal inputs to be used in vLLM.
1189
+
1190
+ The main steps are:
1191
+
1192
+ 1. Apply HF Processor on prompt text and multi-modal data together,
1193
+ outputting token IDs and processed tensors.
1194
+ 2. Find and replace sequences in the token IDs with placeholder tokens.
1195
+ The number of placeholder tokens equals the feature size of the
1196
+ multi-modal data outputted by the multi-modal encoder.
1197
+ 3. Extract information about the placeholder tokens from the
1198
+ processed token IDs.
1199
+ """
1200
+ mm_items = self._to_mm_items(mm_data)
1201
+
1202
+ # Create MM hashes (only used in V1)
1203
+ # TODO: Use these hash keys for caching operations in apply_hf_processor
1204
+ # instead of rehashing.
1205
+
1206
+ if envs.VLLM_USE_V1:
1207
+ model_id = self.info.model_id
1208
+ mm_hashes = {
1209
+ modality: [
1210
+ MultiModalHasher.hash_kwargs(model_id=model_id,
1211
+ **{modality: item},
1212
+ **hf_processor_mm_kwargs)
1213
+ for item in items
1214
+ ]
1215
+ for modality, items in mm_items.items()
1216
+ }
1217
+ else:
1218
+ mm_hashes = None
1219
+
1220
+ prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
1221
+ prompt,
1222
+ mm_items,
1223
+ hf_processor_mm_kwargs,
1224
+ )
1225
+
1226
+ unbound_prompt_repls = self._get_prompt_replacements(
1227
+ mm_items,
1228
+ hf_processor_mm_kwargs,
1229
+ mm_kwargs,
1230
+ )
1231
+ mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
1232
+
1233
+ mm_item_counts = mm_items.get_all_counts()
1234
+ self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
1235
+
1236
+ hf_mm_placeholders = self._find_mm_placeholders(
1237
+ mm_prompt_repls,
1238
+ prompt_ids,
1239
+ mm_item_counts,
1240
+ )
1241
+
1242
+ if self._always_apply_prompt_replacements():
1243
+ mm_missing_repl_counts = mm_item_counts
1244
+ mm_missing_repls = dict(mm_prompt_repls)
1245
+ else:
1246
+ mm_missing_repl_counts = self._validate_mm_placeholders(
1247
+ hf_mm_placeholders,
1248
+ mm_item_counts,
1249
+ allow_missing=True,
1250
+ )
1251
+
1252
+ mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
1253
+ for modality, missing_repl_count in mm_missing_repl_counts.items():
1254
+ if missing_repl_count == 0:
1255
+ mm_missing_repls[modality] = []
1256
+ elif missing_repl_count == mm_item_counts.get(modality, 0):
1257
+ mm_missing_repls[modality] = mm_prompt_repls[modality]
1258
+ else:
1259
+ raise ValueError("Partial prompt replacement within "
1260
+ f"{modality=} is not supported")
1261
+
1262
+ # If HF processor already inserts placeholder tokens,
1263
+ # there is no need for us to insert them
1264
+ if all(len(repls) == 0 for repls in mm_missing_repls.values()):
1265
+ tokenizer = self.info.get_tokenizer()
1266
+ prompt = decode_tokens(tokenizer, prompt_ids)
1267
+ mm_placeholders = hf_mm_placeholders
1268
+ else:
1269
+ (
1270
+ prompt_ids,
1271
+ prompt,
1272
+ missing_mm_placeholders,
1273
+ ) = self._apply_prompt_replacements(
1274
+ prompt_ids,
1275
+ mm_missing_repls,
1276
+ mm_missing_repl_counts,
1277
+ )
1278
+
1279
+ mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}
1280
+
1281
+ self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1282
+
1283
+ mm_placeholder_ranges = {
1284
+ modality: [item.to_range() for item in placeholders]
1285
+ for modality, placeholders in mm_placeholders.items()
1286
+ }
1287
+
1288
+ return MultiModalInputs(
1289
+ type="multimodal",
1290
+ prompt=prompt,
1291
+ prompt_token_ids=prompt_ids,
1292
+ mm_kwargs=mm_kwargs,
1293
+ mm_hashes=mm_hashes,
1294
+ mm_placeholders=mm_placeholder_ranges,
1295
+ )
.venv/lib/python3.11/site-packages/vllm/multimodal/profiling.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+ from dataclasses import dataclass, field
6
+ from typing import Generic, TypeVar
7
+
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ from PIL import Image
11
+
12
+ import vllm.envs as envs
13
+ from vllm.inputs import DummyData
14
+ from vllm.logger import init_logger
15
+
16
+ from .inputs import MultiModalDataDict, MultiModalInputs
17
+ from .processing import BaseMultiModalProcessor, BaseProcessingInfo
18
+
19
+ logger = init_logger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class ProcessorInputs:
24
+ """
25
+ Represents the keyword arguments to
26
+ :meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
27
+ """
28
+ prompt_text: str
29
+ mm_data: MultiModalDataDict
30
+ hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
31
+
32
+
33
+ _I = TypeVar("_I", bound=BaseProcessingInfo)
34
+
35
+
36
+ class BaseDummyInputsBuilder(ABC, Generic[_I]):
37
+ """
38
+ Abstract base class that constructs the dummy data to profile
39
+ multi-modal models.
40
+ """
41
+
42
+ def __init__(self, info: _I) -> None:
43
+ super().__init__()
44
+
45
+ self.info = info
46
+
47
+ @abstractmethod
48
+ def get_dummy_processor_inputs(
49
+ self,
50
+ seq_len: int,
51
+ mm_counts: Mapping[str, int],
52
+ ) -> ProcessorInputs:
53
+ """
54
+ Build the input which, after processing, results in
55
+ :code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
56
+ """
57
+ raise NotImplementedError
58
+
59
+ def _get_dummy_audios(
60
+ self,
61
+ *,
62
+ length: int,
63
+ num_audios: int,
64
+ ) -> list[npt.NDArray]:
65
+ audio = np.zeros((length, ))
66
+ return [audio] * num_audios
67
+
68
+ def _get_dummy_images(
69
+ self,
70
+ *,
71
+ width: int,
72
+ height: int,
73
+ num_images: int,
74
+ ) -> list[Image.Image]:
75
+ image = Image.new("RGB", (width, height), color=0)
76
+ return [image] * num_images
77
+
78
+ def _get_dummy_videos(
79
+ self,
80
+ *,
81
+ width: int,
82
+ height: int,
83
+ num_frames: int,
84
+ num_videos: int,
85
+ ) -> list[npt.NDArray]:
86
+ video = np.zeros((num_frames, width, height, 3))
87
+ return [video] * num_videos
88
+
89
+
90
+ class MultiModalProfiler(Generic[_I]):
91
+ """
92
+ Contains code for running memory profiling for multi-modal models.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ processor: BaseMultiModalProcessor[_I],
98
+ ) -> None:
99
+ super().__init__()
100
+
101
+ self.processor = processor
102
+
103
+ @property
104
+ def processing_info(self) -> BaseProcessingInfo:
105
+ return self.processor.info
106
+
107
+ @property
108
+ def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
109
+ return self.processor.dummy_inputs
110
+
111
+ def get_mm_limits(self) -> Mapping[str, int]:
112
+ mm_config = self.processing_info.ctx.get_mm_config()
113
+ mm_limit_per_prompt = mm_config.limit_per_prompt
114
+
115
+ supported_mm_limits = self.processing_info.get_supported_mm_limits()
116
+
117
+ mm_limits = {
118
+ modality: mm_limit_per_prompt.get(modality, 1)
119
+ for modality in supported_mm_limits
120
+ }
121
+
122
+ for modality, supported_limit in supported_mm_limits.items():
123
+ limit = mm_limits[modality]
124
+ if supported_limit is not None and supported_limit < limit:
125
+ raise ValueError(
126
+ f"You set {modality}={limit} (or defaulted to 1) in "
127
+ f"`--limit-mm-per-prompt`, but this model only supports "
128
+ f"at most {supported_limit} {modality} items.")
129
+
130
+ return mm_limits
131
+
132
+ def _get_dummy_mm_inputs(
133
+ self,
134
+ seq_len: int,
135
+ mm_counts: Mapping[str, int],
136
+ ) -> MultiModalInputs:
137
+ factory = self.dummy_inputs
138
+ processor_inputs = factory.get_dummy_processor_inputs(
139
+ seq_len, mm_counts)
140
+
141
+ return self.processor.apply(
142
+ prompt=processor_inputs.prompt_text,
143
+ mm_data=processor_inputs.mm_data,
144
+ hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
145
+ )
146
+
147
+ def get_dummy_data(self, seq_len: int) -> DummyData:
148
+ # Avoid circular import
149
+ from vllm.sequence import SequenceData
150
+
151
+ mm_counts = self.get_mm_limits()
152
+
153
+ info = self.processing_info
154
+ mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
155
+ seq_len, mm_counts)
156
+
157
+ if mm_counts.keys() != mm_max_tokens_per_item.keys():
158
+ raise AssertionError(
159
+ "The keys returned by `get_supported_mm_limits`"
160
+ f"({set(mm_counts.keys())}) should be the same as those "
161
+ "returned by `get_mm_max_tokens_per_item` "
162
+ f"({set(mm_max_tokens_per_item.keys())})")
163
+
164
+ mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
165
+ prompt_token_ids = mm_inputs["prompt_token_ids"]
166
+ placeholders_by_modality = mm_inputs["mm_placeholders"]
167
+
168
+ total_placeholders_by_modality = {
169
+ modality: sum(item["length"] for item in placeholders)
170
+ for modality, placeholders in placeholders_by_modality.items()
171
+ }
172
+ expected_placeholders_by_modality = {
173
+ modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
174
+ for modality in placeholders_by_modality
175
+ }
176
+ if total_placeholders_by_modality != expected_placeholders_by_modality:
177
+ raise AssertionError(
178
+ f"The processed dummy data has a total of "
179
+ f"{total_placeholders_by_modality} placeholder tokens, which "
180
+ f"is not the expected {expected_placeholders_by_modality} "
181
+ "tokens.")
182
+
183
+ total_len = len(prompt_token_ids)
184
+
185
+ # V0 does not support chunked prefill.
186
+ if total_len > seq_len and not envs.VLLM_USE_V1:
187
+ logger.warning(
188
+ "The context length (%d) of the model is too short "
189
+ "to hold the multi-modal embeddings in the worst case "
190
+ "(%d tokens in total, out of which %s are reserved for "
191
+ "multi-modal embeddings). This may cause certain multi-modal "
192
+ "inputs to fail during inference, even when the input text is "
193
+ "short. To avoid this, you should increase `max_model_len`, "
194
+ "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
195
+ total_len, total_placeholders_by_modality)
196
+
197
+ return DummyData(
198
+ seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
199
+ multi_modal_data=None,
200
+ multi_modal_placeholders=None,
201
+ )
202
+
203
+ prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
204
+
205
+ return DummyData(
206
+ seq_data=SequenceData.from_seqs(prompt_token_ids),
207
+ multi_modal_data=mm_inputs["mm_kwargs"],
208
+ multi_modal_placeholders=placeholders_by_modality,
209
+ )
.venv/lib/python3.11/site-packages/vllm/multimodal/registry.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import functools
4
+ from collections import UserDict
5
+ from dataclasses import dataclass
6
+ from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
7
+ Protocol, Sequence, Type, TypeVar)
8
+
9
+ import torch.nn as nn
10
+
11
+ from vllm.inputs import InputProcessingContext
12
+ from vllm.logger import init_logger
13
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
14
+ from vllm.utils import ClassRegistry
15
+
16
+ from .audio import AudioPlugin
17
+ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
18
+ from .image import ImagePlugin
19
+ from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
20
+ from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
21
+ ProcessingCache)
22
+ from .profiling import BaseDummyInputsBuilder, MultiModalProfiler
23
+ from .utils import cached_get_tokenizer
24
+ from .video import VideoPlugin
25
+
26
+ if TYPE_CHECKING:
27
+ from vllm.config import ModelConfig
28
+
29
+ logger = init_logger(__name__)
30
+
31
+ # TODO: Tune the MM cache size
32
+ MM_CACHE_SIZE = 256
33
+
34
+ N = TypeVar("N", bound=Type[nn.Module])
35
+ _I = TypeVar("_I", bound=BaseProcessingInfo)
36
+ _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
37
+
38
+
39
+ class ProcessingInfoFactory(Protocol[_I_co]):
40
+ """Constructs a :class:`MultiModalProcessor` instance from the context."""
41
+
42
+ def __call__(
43
+ self,
44
+ ctx: InputProcessingContext,
45
+ ) -> _I_co:
46
+ ...
47
+
48
+
49
+ class DummyInputsBuilderFactory(Protocol[_I]):
50
+ """
51
+ Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
52
+ """
53
+
54
+ def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
55
+ ...
56
+
57
+
58
+ class MultiModalProcessorFactory(Protocol[_I]):
59
+ """Constructs a :class:`MultiModalProcessor` instance from the context."""
60
+
61
+ def __call__(
62
+ self,
63
+ info: _I,
64
+ dummy_inputs: BaseDummyInputsBuilder[_I],
65
+ *,
66
+ cache: Optional[ProcessingCache] = None,
67
+ ) -> BaseMultiModalProcessor[_I]:
68
+ ...
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class _ProcessorFactories(Generic[_I]):
73
+ info: ProcessingInfoFactory[_I]
74
+ processor: MultiModalProcessorFactory[_I]
75
+ dummy_inputs: DummyInputsBuilderFactory[_I]
76
+
77
+ def build_processor(
78
+ self,
79
+ ctx: InputProcessingContext,
80
+ *,
81
+ cache: Optional[ProcessingCache] = None,
82
+ ):
83
+ info = self.info(ctx)
84
+ dummy_inputs_builder = self.dummy_inputs(info)
85
+ return self.processor(info, dummy_inputs_builder, cache=cache)
86
+
87
+
88
+ class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
89
+ """
90
+ Wraps `_limits_by_model` for a more informative error message
91
+ when attempting to access a model that does not exist.
92
+ """
93
+
94
+ def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
95
+ try:
96
+ return super().__getitem__(key)
97
+ except KeyError as exc:
98
+ msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
99
+ "forget to call `init_mm_limits_per_prompt`?")
100
+ raise KeyError(msg) from exc
101
+
102
+
103
+ class MultiModalRegistry:
104
+ """
105
+ A registry that dispatches data processing according to the model.
106
+ """
107
+
108
+ DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
109
+
110
+ def __init__(
111
+ self,
112
+ *,
113
+ plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
114
+ self._plugins = {p.get_data_key(): p for p in plugins}
115
+
116
+ self._processor_factories = ClassRegistry[nn.Module,
117
+ _ProcessorFactories]()
118
+
119
+ # This is used for non-multimodal models
120
+ self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
121
+
122
+ self._limits_by_model = _MultiModalLimits()
123
+
124
+ self._processing_cache = ProcessingCache(MM_CACHE_SIZE)
125
+
126
+ def register_plugin(self, plugin: MultiModalPlugin) -> None:
127
+ """
128
+ Register a multi-modal plugin so it can be recognized by vLLM.
129
+ """
130
+ data_type_key = plugin.get_data_key()
131
+
132
+ if data_type_key in self._plugins:
133
+ logger.warning(
134
+ "A plugin is already registered for data type %s, "
135
+ "and will be overwritten by the new plugin %s.", data_type_key,
136
+ plugin)
137
+
138
+ self._plugins[data_type_key] = plugin
139
+
140
+ def _get_plugin(self, data_type_key: str):
141
+ plugin = self._plugins.get(data_type_key)
142
+ if plugin is not None:
143
+ return plugin
144
+
145
+ msg = f"Unknown multi-modal data type: {data_type_key}"
146
+ raise NotImplementedError(msg)
147
+
148
+ def register_input_mapper(
149
+ self,
150
+ data_type_key: str,
151
+ mapper: Optional[MultiModalInputMapper] = None,
152
+ ):
153
+ """
154
+ Register an input mapper for a specific modality to a model class.
155
+
156
+ See :meth:`MultiModalPlugin.register_input_mapper` for more details.
157
+ """
158
+ return self._get_plugin(data_type_key).register_input_mapper(mapper)
159
+
160
+ def register_image_input_mapper(
161
+ self,
162
+ mapper: Optional[MultiModalInputMapper] = None,
163
+ ):
164
+ """
165
+ Register an input mapper for image data to a model class.
166
+
167
+ See :meth:`MultiModalPlugin.register_input_mapper` for more details.
168
+ """
169
+ return self.register_input_mapper("image", mapper)
170
+
171
+ def map_input(
172
+ self,
173
+ model_config: "ModelConfig",
174
+ data: MultiModalDataDict,
175
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
176
+ ) -> MultiModalKwargs:
177
+ """
178
+ Apply an input mapper to the data passed to the model.
179
+
180
+ The data belonging to each modality is passed to the corresponding
181
+ plugin which in turn converts the data into into keyword arguments
182
+ via the input mapper registered for that model.
183
+
184
+ See :meth:`MultiModalPlugin.map_input` for more details.
185
+
186
+ Note:
187
+ This should be called after :meth:`init_mm_limits_per_prompt`.
188
+ """
189
+ merged_dict: Dict[str, NestedTensors] = {}
190
+
191
+ for data_key, data_value in data.items():
192
+ plugin = self._get_plugin(data_key)
193
+
194
+ num_items = len(data_value) if isinstance(data_value, list) else 1
195
+ max_items = self._limits_by_model[model_config][data_key]
196
+ if num_items > max_items:
197
+ raise ValueError(
198
+ f"You set {data_key}={max_items} (or defaulted to 1) in "
199
+ f"`--limit-mm-per-prompt`, but found {num_items} items "
200
+ "in the same prompt.")
201
+
202
+ input_dict = plugin.map_input(model_config, data_value,
203
+ mm_processor_kwargs)
204
+ for input_key, input_tensor in input_dict.items():
205
+ if input_key in merged_dict:
206
+ raise ValueError(f"The input mappers (keys={set(data)}) "
207
+ f"resulted in a conflicting keyword "
208
+ f"argument to `forward()`: {input_key}")
209
+
210
+ merged_dict[input_key] = input_tensor
211
+
212
+ return MultiModalKwargs(merged_dict)
213
+
214
+ def create_input_mapper(self, model_config: "ModelConfig"):
215
+ """
216
+ Create an input mapper (see :meth:`map_input`) for a specific model.
217
+ """
218
+ # NOTE - we currently make the assumption that if a model has multiple
219
+ # supported modalities, they take the same kwargs. For the default,
220
+ # this could be an issue in the future if it falls back to two HF
221
+ # resources and we can't inspect the signature easily since it's
222
+ # getting initialized through the autoclass.
223
+ #
224
+ # If this is a problem in the future, we should revisit it, but since
225
+ # it potentially introduces a lot of complexity for a currently
226
+ # uncommon case, we do not for simplicity of both use & implementation
227
+ return functools.partial(self.map_input, model_config)
228
+
229
+ def register_max_multimodal_tokens(
230
+ self,
231
+ data_type_key: str,
232
+ max_mm_tokens: Optional[MultiModalTokensCalc] = None,
233
+ ):
234
+ """
235
+ Register the maximum number of tokens, corresponding to a single
236
+ instance of multimodal data belonging to a specific modality, that are
237
+ passed to the language model for a model class.
238
+ """
239
+ return self._get_plugin(data_type_key) \
240
+ .register_max_multimodal_tokens(max_mm_tokens)
241
+
242
+ def register_max_image_tokens(
243
+ self,
244
+ max_mm_tokens: Optional[MultiModalTokensCalc] = None,
245
+ ):
246
+ """
247
+ Register the maximum number of image tokens, corresponding to a single
248
+ image, that are passed to the language model for a model class.
249
+ """
250
+ return self.register_max_multimodal_tokens("image", max_mm_tokens)
251
+
252
+ def get_max_tokens_per_item_by_modality(
253
+ self,
254
+ model_config: "ModelConfig",
255
+ ) -> Mapping[str, int]:
256
+ """
257
+ Get the maximum number of tokens per data item from each modality based
258
+ on underlying model configuration.
259
+ """
260
+ if self.has_processor(model_config):
261
+ tokenizer = cached_get_tokenizer(
262
+ model_config.tokenizer,
263
+ trust_remote_code=model_config.trust_remote_code,
264
+ )
265
+ processor = self.create_processor(model_config, tokenizer)
266
+ seq_len = model_config.max_model_len
267
+ mm_limits = self.get_mm_limits_per_prompt(model_config)
268
+ return processor.info.get_mm_max_tokens_per_item(
269
+ seq_len, mm_limits)
270
+
271
+ return {
272
+ key: plugin.get_max_multimodal_tokens(model_config)
273
+ for key, plugin in self._plugins.items()
274
+ }
275
+
276
+ def get_max_tokens_per_item_by_nonzero_modality(
277
+ self,
278
+ model_config: "ModelConfig",
279
+ ) -> Mapping[str, int]:
280
+ """
281
+ Get the maximum number of tokens per data item from each modality based
282
+ on underlying model configuration, excluding modalities that user
283
+ explicitly disabled via `limit_mm_per_prompt`.
284
+
285
+ Note:
286
+ This is currently directly used only in V1 for profiling the memory
287
+ usage of a model.
288
+ """
289
+ mm_limits = self.get_mm_limits_per_prompt(model_config)
290
+
291
+ return {
292
+ key: max_tokens_per_mm_item
293
+ for key, max_tokens_per_mm_item in
294
+ self.get_max_tokens_per_item_by_modality(model_config).items()
295
+ if mm_limits[key] > 0
296
+ }
297
+
298
+ def get_max_tokens_by_modality(
299
+ self,
300
+ model_config: "ModelConfig",
301
+ ) -> Mapping[str, int]:
302
+ """
303
+ Get the maximum number of tokens from each modality
304
+ for profiling the memory usage of a model.
305
+
306
+ See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
307
+
308
+ Note:
309
+ This should be called after :meth:`init_mm_limits_per_prompt`.
310
+ """
311
+ mm_limits = self.get_mm_limits_per_prompt(model_config)
312
+
313
+ return {
314
+ key: mm_limits[key] * max_tokens_per_mm_item
315
+ for key, max_tokens_per_mm_item in
316
+ self.get_max_tokens_per_item_by_modality(model_config).items()
317
+ }
318
+
319
+ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
320
+ """
321
+ Get the maximum number of multi-modal tokens
322
+ for profiling the memory usage of a model.
323
+
324
+ See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
325
+
326
+ Note:
327
+ This should be called after :meth:`init_mm_limits_per_prompt`.
328
+ """
329
+ return sum(self.get_max_tokens_by_modality(model_config).values())
330
+
331
+ def init_mm_limits_per_prompt(
332
+ self,
333
+ model_config: "ModelConfig",
334
+ ) -> None:
335
+ """
336
+ Initialize the maximum number of multi-modal input instances for each
337
+ modality that are allowed per prompt for a model class.
338
+ """
339
+ if model_config in self._limits_by_model:
340
+ logger.warning(
341
+ "`mm_limits` has already been set for model=%s, and will "
342
+ "be overwritten by the new values.", model_config.model)
343
+
344
+ multimodal_config = model_config.multimodal_config
345
+ if multimodal_config is None:
346
+ limits_per_plugin = self._disabled_limits_per_plugin
347
+ else:
348
+ config_limits_per_plugin = multimodal_config.limit_per_prompt
349
+
350
+ extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
351
+ if extra_keys:
352
+ logger.warning(
353
+ "Detected extra keys in `--limit-mm-per-prompt` which "
354
+ "are not registered as multi-modal plugins: %s. "
355
+ "They will be ignored.", extra_keys)
356
+
357
+ # NOTE: Currently the default is set to 1 for each plugin
358
+ # TODO: Automatically determine the limits based on budget
359
+ # once more models support multi-image inputs
360
+ limits_per_plugin = {
361
+ key: config_limits_per_plugin.get(key, 1)
362
+ for key in self._plugins
363
+ }
364
+
365
+ self._limits_by_model[model_config] = limits_per_plugin
366
+
367
+ def get_mm_limits_per_prompt(
368
+ self,
369
+ model_config: "ModelConfig",
370
+ ) -> Mapping[str, int]:
371
+ """
372
+ Get the maximum number of multi-modal input instances for each modality
373
+ that are allowed per prompt for a model class.
374
+
375
+ Note:
376
+ This should be called after :meth:`init_mm_limits_per_prompt`.
377
+ """
378
+ if self.has_processor(model_config):
379
+ tokenizer = cached_get_tokenizer(
380
+ model_config.tokenizer,
381
+ trust_remote_code=model_config.trust_remote_code,
382
+ )
383
+ processor = self.create_processor(model_config, tokenizer)
384
+ profiler = MultiModalProfiler(processor)
385
+ return profiler.get_mm_limits()
386
+
387
+ return self._limits_by_model[model_config]
388
+
389
+ def register_processor(
390
+ self,
391
+ processor: MultiModalProcessorFactory[_I],
392
+ *,
393
+ info: ProcessingInfoFactory[_I],
394
+ dummy_inputs: DummyInputsBuilderFactory[_I],
395
+ ):
396
+ """
397
+ Register a multi-modal processor to a model class. The processor
398
+ is constructed lazily, hence a factory method should be passed.
399
+
400
+ When the model receives multi-modal data, the provided function is
401
+ invoked to transform the data into a dictionary of model inputs.
402
+
403
+ See also:
404
+ :ref:`mm-processing`
405
+ """
406
+
407
+ def wrapper(model_cls: N) -> N:
408
+ if self._processor_factories.contains(model_cls, strict=True):
409
+ logger.warning(
410
+ "Model class %s already has a multi-modal processor "
411
+ "registered to %s. It is overwritten by the new one.",
412
+ model_cls, self)
413
+
414
+ self._processor_factories[model_cls] = _ProcessorFactories(
415
+ info=info,
416
+ dummy_inputs=dummy_inputs,
417
+ processor=processor,
418
+ )
419
+
420
+ return model_cls
421
+
422
+ return wrapper
423
+
424
+ def _get_model_cls(self, model_config: "ModelConfig"):
425
+ # Avoid circular import
426
+ from vllm.model_executor.model_loader import get_model_architecture
427
+
428
+ model_cls, _ = get_model_architecture(model_config)
429
+ return model_cls
430
+
431
+ def has_processor(self, model_config: "ModelConfig") -> bool:
432
+ """
433
+ Test whether a multi-modal processor is defined for a specific model.
434
+
435
+ See also:
436
+ :ref:`mm-processing`
437
+ """
438
+ return self._get_model_cls(model_config) in self._processor_factories
439
+
440
+ def create_processor(
441
+ self,
442
+ model_config: "ModelConfig",
443
+ tokenizer: AnyTokenizer,
444
+ ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
445
+ """
446
+ Create a multi-modal processor for a specific model and tokenizer.
447
+
448
+ See also:
449
+ :ref:`mm-processing`
450
+ """
451
+ model_cls = self._get_model_cls(model_config)
452
+ factories = self._processor_factories[model_cls]
453
+
454
+ ctx = InputProcessingContext(model_config, tokenizer)
455
+ cache = (None if model_config.disable_mm_preprocessor_cache else
456
+ self._processing_cache)
457
+
458
+ return factories.build_processor(ctx, cache=cache)
.venv/lib/python3.11/site-packages/vllm/multimodal/utils.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from functools import lru_cache
4
+ from itertools import groupby
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Optional, TypeVar, Union
7
+ from urllib.parse import ParseResult, urlparse
8
+
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ from PIL import Image
12
+
13
+ import vllm.envs as envs
14
+ from vllm.connections import HTTPConnection, global_http_connection
15
+ from vllm.logger import init_logger
16
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
17
+
18
+ from .audio import AudioMediaIO
19
+ from .base import MediaIO
20
+ from .image import ImageMediaIO
21
+ from .inputs import PlaceholderRange
22
+ from .video import VideoMediaIO
23
+
24
+ logger = init_logger(__name__)
25
+
26
+ cached_get_tokenizer = lru_cache(get_tokenizer)
27
+
28
+ _M = TypeVar("_M")
29
+
30
+ if TYPE_CHECKING:
31
+ from .hasher import MultiModalHashDict
32
+ from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
33
+
34
+
35
+ class MediaConnector:
36
+
37
+ def __init__(
38
+ self,
39
+ connection: HTTPConnection = global_http_connection,
40
+ *,
41
+ allowed_local_media_path: str = "",
42
+ ) -> None:
43
+ super().__init__()
44
+
45
+ self.connection = connection
46
+
47
+ if allowed_local_media_path:
48
+ allowed_local_media_path_ = Path(allowed_local_media_path)
49
+
50
+ if not allowed_local_media_path_.exists():
51
+ raise ValueError(
52
+ "Invalid `--allowed-local-media-path`: The path "
53
+ f"{allowed_local_media_path_} does not exist.")
54
+ if not allowed_local_media_path_.is_dir():
55
+ raise ValueError(
56
+ "Invalid `--allowed-local-media-path`: The path "
57
+ f"{allowed_local_media_path_} must be a directory.")
58
+ else:
59
+ allowed_local_media_path_ = None
60
+
61
+ self.allowed_local_media_path = allowed_local_media_path_
62
+
63
+ def _load_data_url(
64
+ self,
65
+ url_spec: ParseResult,
66
+ media_io: MediaIO[_M],
67
+ ) -> _M:
68
+ data_spec, data = url_spec.path.split(",", 1)
69
+ media_type, data_type = data_spec.split(";", 1)
70
+
71
+ if data_type != "base64":
72
+ msg = "Only base64 data URLs are supported for now."
73
+ raise NotImplementedError(msg)
74
+
75
+ return media_io.load_base64(media_type, data)
76
+
77
+ def _load_file_url(
78
+ self,
79
+ url_spec: ParseResult,
80
+ media_io: MediaIO[_M],
81
+ ) -> _M:
82
+ allowed_local_media_path = self.allowed_local_media_path
83
+ if allowed_local_media_path is None:
84
+ raise RuntimeError("Cannot load local files without "
85
+ "`--allowed-local-media-path`.")
86
+
87
+ filepath = Path(url_spec.path)
88
+ if allowed_local_media_path not in filepath.resolve().parents:
89
+ raise ValueError(
90
+ f"The file path {filepath} must be a subpath "
91
+ f"of `--allowed-local-media-path` {allowed_local_media_path}.")
92
+
93
+ return media_io.load_file(filepath)
94
+
95
+ def load_from_url(
96
+ self,
97
+ url: str,
98
+ media_io: MediaIO[_M],
99
+ *,
100
+ fetch_timeout: Optional[int] = None,
101
+ ) -> _M:
102
+ url_spec = urlparse(url)
103
+
104
+ if url_spec.scheme.startswith("http"):
105
+ connection = self.connection
106
+ data = connection.get_bytes(url, timeout=fetch_timeout)
107
+
108
+ return media_io.load_bytes(data)
109
+
110
+ if url_spec.scheme == "data":
111
+ return self._load_data_url(url_spec, media_io)
112
+
113
+ if url_spec.scheme == "file":
114
+ return self._load_file_url(url_spec, media_io)
115
+
116
+ msg = "The URL must be either a HTTP, data or file URL."
117
+ raise ValueError(msg)
118
+
119
+ async def load_from_url_async(
120
+ self,
121
+ url: str,
122
+ media_io: MediaIO[_M],
123
+ *,
124
+ fetch_timeout: Optional[int] = None,
125
+ ) -> _M:
126
+ url_spec = urlparse(url)
127
+
128
+ if url_spec.scheme.startswith("http"):
129
+ connection = self.connection
130
+ data = await connection.async_get_bytes(url, timeout=fetch_timeout)
131
+
132
+ return media_io.load_bytes(data)
133
+
134
+ if url_spec.scheme == "data":
135
+ return self._load_data_url(url_spec, media_io)
136
+
137
+ if url_spec.scheme == "file":
138
+ return self._load_file_url(url_spec, media_io)
139
+
140
+ msg = "The URL must be either a HTTP, data or file URL."
141
+ raise ValueError(msg)
142
+
143
+ def fetch_audio(
144
+ self,
145
+ audio_url: str,
146
+ ) -> tuple[np.ndarray, Union[int, float]]:
147
+ """
148
+ Load audio from a URL.
149
+ """
150
+ audio_io = AudioMediaIO()
151
+
152
+ return self.load_from_url(
153
+ audio_url,
154
+ audio_io,
155
+ fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
156
+ )
157
+
158
+ async def fetch_audio_async(
159
+ self,
160
+ audio_url: str,
161
+ ) -> tuple[np.ndarray, Union[int, float]]:
162
+ """
163
+ Asynchronously fetch audio from a URL.
164
+ """
165
+ audio_io = AudioMediaIO()
166
+
167
+ return await self.load_from_url_async(
168
+ audio_url,
169
+ audio_io,
170
+ fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
171
+ )
172
+
173
+ def fetch_image(
174
+ self,
175
+ image_url: str,
176
+ *,
177
+ image_mode: str = "RGB",
178
+ ) -> Image.Image:
179
+ """
180
+ Load a PIL image from a HTTP or base64 data URL.
181
+
182
+ By default, the image is converted into RGB format.
183
+ """
184
+ image_io = ImageMediaIO(image_mode=image_mode)
185
+
186
+ return self.load_from_url(
187
+ image_url,
188
+ image_io,
189
+ fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
190
+ )
191
+
192
+ async def fetch_image_async(
193
+ self,
194
+ image_url: str,
195
+ *,
196
+ image_mode: str = "RGB",
197
+ ) -> Image.Image:
198
+ """
199
+ Asynchronously load a PIL image from a HTTP or base64 data URL.
200
+
201
+ By default, the image is converted into RGB format.
202
+ """
203
+ image_io = ImageMediaIO(image_mode=image_mode)
204
+
205
+ return await self.load_from_url_async(
206
+ image_url,
207
+ image_io,
208
+ fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
209
+ )
210
+
211
+ def fetch_video(
212
+ self,
213
+ video_url: str,
214
+ *,
215
+ image_mode: str = "RGB",
216
+ num_frames: int = 32,
217
+ ) -> npt.NDArray:
218
+ """
219
+ Load video from a HTTP or base64 data URL.
220
+ """
221
+ image_io = ImageMediaIO(image_mode=image_mode)
222
+ video_io = VideoMediaIO(image_io, num_frames=num_frames)
223
+
224
+ return self.load_from_url(
225
+ video_url,
226
+ video_io,
227
+ fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
228
+ )
229
+
230
+ async def fetch_video_async(
231
+ self,
232
+ video_url: str,
233
+ *,
234
+ image_mode: str = "RGB",
235
+ num_frames: int = 32,
236
+ ) -> npt.NDArray:
237
+ """
238
+ Asynchronously load video from a HTTP or base64 data URL.
239
+
240
+ By default, the image is converted into RGB format.
241
+ """
242
+ image_io = ImageMediaIO(image_mode=image_mode)
243
+ video_io = VideoMediaIO(image_io, num_frames=num_frames)
244
+
245
+ return await self.load_from_url_async(
246
+ video_url,
247
+ video_io,
248
+ fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
249
+ )
250
+
251
+
252
+ global_media_connector = MediaConnector()
253
+ """The global :class:`MediaConnector` instance used by vLLM."""
254
+
255
+ fetch_audio = global_media_connector.fetch_audio
256
+ fetch_image = global_media_connector.fetch_image
257
+ fetch_video = global_media_connector.fetch_video
258
+
259
+
260
+ def encode_audio_base64(
261
+ audio: np.ndarray,
262
+ sampling_rate: int,
263
+ ) -> str:
264
+ """Encode audio as base64."""
265
+ audio_io = AudioMediaIO()
266
+ return audio_io.encode_base64((audio, sampling_rate))
267
+
268
+
269
+ def encode_image_base64(
270
+ image: Image.Image,
271
+ *,
272
+ image_mode: str = "RGB",
273
+ format: str = "JPEG",
274
+ ) -> str:
275
+ """
276
+ Encode a pillow image to base64 format.
277
+
278
+ By default, the image is converted into RGB format before being encoded.
279
+ """
280
+ image_io = ImageMediaIO(image_mode=image_mode)
281
+ return image_io.encode_base64(image, image_format=format)
282
+
283
+
284
+ def encode_video_base64(frames: npt.NDArray) -> str:
285
+ image_io = ImageMediaIO()
286
+ video_io = VideoMediaIO(image_io)
287
+ return video_io.encode_base64(frames)
288
+
289
+
290
+ # Utilities for input processors
291
+ _T = TypeVar("_T", str, int)
292
+
293
+
294
+ def repeat_and_pad_token(
295
+ token: _T,
296
+ *,
297
+ repeat_count: int = 1,
298
+ pad_token_left: Optional[_T] = None,
299
+ pad_token_right: Optional[_T] = None,
300
+ ) -> list[_T]:
301
+ replacement = [token] * repeat_count
302
+ if pad_token_left is not None:
303
+ replacement = [pad_token_left] + replacement
304
+ if pad_token_right is not None:
305
+ replacement = replacement + [pad_token_right]
306
+
307
+ return replacement
308
+
309
+
310
+ def repeat_and_pad_placeholder_tokens(
311
+ tokenizer: AnyTokenizer,
312
+ prompt: Optional[str],
313
+ prompt_token_ids: list[int],
314
+ *,
315
+ placeholder_token_id: int,
316
+ repeat_count: Union[int, list[int]],
317
+ pad_token_left: Optional[int] = None,
318
+ pad_token_right: Optional[int] = None,
319
+ ) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
320
+ if isinstance(repeat_count, int):
321
+ repeat_count = [repeat_count]
322
+
323
+ if prompt is None:
324
+ new_prompt = None
325
+ else:
326
+ placeholder_token_str = tokenizer.decode(placeholder_token_id)
327
+ pad_token_str_left = (None if pad_token_left is None else
328
+ tokenizer.decode(pad_token_left))
329
+ pad_token_str_right = (None if pad_token_right is None else
330
+ tokenizer.decode(pad_token_right))
331
+
332
+ placeholder_token_count = prompt.count(placeholder_token_str)
333
+ # This is an arbitrary number to distinguish between the two cases
334
+ if placeholder_token_count > 16:
335
+ logger.warning(
336
+ "Please follow the prompt format that is "
337
+ "documented on HuggingFace which does not involve "
338
+ "repeating %s tokens.", placeholder_token_str)
339
+ if placeholder_token_count < len(repeat_count):
340
+ logger.warning(
341
+ "The number of multi-modal placeholder tokens in the prompt "
342
+ "is less than the number of multi-modal inputs. Extra "
343
+ "placeholder tokens will be treated as plain text")
344
+ repeat_count = repeat_count[:placeholder_token_count]
345
+
346
+ prompt_parts = prompt.split(placeholder_token_str,
347
+ maxsplit=len(repeat_count))
348
+ new_prompt = ""
349
+ for i, repeat_count_item in enumerate(repeat_count):
350
+ replacement_str = "".join(
351
+ repeat_and_pad_token(
352
+ placeholder_token_str,
353
+ repeat_count=repeat_count_item,
354
+ pad_token_left=pad_token_str_left,
355
+ pad_token_right=pad_token_str_right,
356
+ ))
357
+ # The image tokens are removed to be consistent with HuggingFace
358
+ new_prompt += prompt_parts[i] + replacement_str
359
+ new_prompt += prompt_parts[-1]
360
+
361
+ new_token_ids = list[int]()
362
+ placeholder_ranges = list[PlaceholderRange]()
363
+ placeholder_token_idx = 0
364
+ for i, token in enumerate(prompt_token_ids):
365
+ if token == placeholder_token_id:
366
+ curr_repeat_count = repeat_count[placeholder_token_idx]
367
+ replacement_ids = repeat_and_pad_token(
368
+ placeholder_token_id,
369
+ repeat_count=curr_repeat_count,
370
+ pad_token_left=pad_token_left,
371
+ pad_token_right=pad_token_right,
372
+ )
373
+ offset = len(new_token_ids)
374
+ if pad_token_left is not None:
375
+ offset += 1
376
+ placeholder_ranges.append({
377
+ "offset": offset,
378
+ "length": curr_repeat_count,
379
+ })
380
+ new_token_ids.extend(replacement_ids)
381
+ placeholder_token_idx += 1
382
+
383
+ # No need to further scan the list since we replaced all tokens
384
+ if placeholder_token_idx >= len(repeat_count):
385
+ new_token_ids.extend(prompt_token_ids[i + 1:])
386
+ break
387
+ else:
388
+ new_token_ids.append(token)
389
+
390
+ return new_prompt, new_token_ids, placeholder_ranges
391
+
392
+
393
+ def consecutive_placeholder_ranges(
394
+ num_items: int,
395
+ item_size: int,
396
+ initial_offset: int = 0) -> list[PlaceholderRange]:
397
+ """Returns a list of consecutive PlaceholderRanges of a fixed size"""
398
+
399
+ return [
400
+ PlaceholderRange(offset=initial_offset + i * item_size,
401
+ length=item_size) for i in range(num_items)
402
+ ]
403
+
404
+
405
+ def merge_and_sort_multimodal_metadata(
406
+ mm_positions: "MultiModalPlaceholderDict",
407
+ mm_hashes: Optional["MultiModalHashDict"],
408
+ ) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
409
+ """Given a MultiModalPlaceholderDict, merge all PlaceholderRange
410
+ objects from all available modalities into a single list of
411
+ PlaceholderRange, sorted by their offset (starting index in the input
412
+ sequence) in the ascending order.
413
+
414
+ Optionally if a MultiModalHashDict is given, same operation will be
415
+ applied to the object and the sorted list of hashes will be returned.
416
+
417
+ Raises:
418
+ ValueError: If the input prompt has interleaved placeholders from
419
+ different modalities (e.g, "<image><audio><image> Describe the
420
+ content.")
421
+
422
+ Returns:
423
+ list[str]: Sorted list of involved modalities.
424
+ list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
425
+ mm_positions.
426
+ Optional[list[str]]: Sorted list of all hashes from mm_hashes if
427
+ given, None otherwise.
428
+ """
429
+
430
+ modalities = list(mm_positions.keys())
431
+
432
+ assert len(modalities) > 0, "No modalities found in the mm_positions."
433
+
434
+ # For single modality, placeholder ranges and hashes are already sorted
435
+ # so we can return the list directly.
436
+ if len(modalities) == 1:
437
+ if mm_hashes is None:
438
+ return modalities, list(mm_positions[modalities[0]]), None
439
+ else:
440
+ return modalities, list(mm_positions[modalities[0]]), list(
441
+ mm_hashes[modalities[0]])
442
+
443
+ placeholder_lists_with_modality = [(modality, mm_positions[modality])
444
+ for modality in modalities]
445
+
446
+ if mm_hashes is None:
447
+ sorted_placeholder_lists = sorted(placeholder_lists_with_modality,
448
+ key=lambda x: x[1][0]['offset'])
449
+ sorted_hash_lists = None
450
+ else:
451
+ hashes_lists = [
452
+ mm_hashes[modality] for modality in modalities
453
+ if modality in mm_hashes
454
+ ]
455
+ sorted_pairs = sorted(zip(placeholder_lists_with_modality,
456
+ hashes_lists),
457
+ key=lambda x: x[0][1][0]['offset'])
458
+ sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
459
+ sorted_placeholder_lists = list(sorted_placeholder_tuple)
460
+ sorted_hash_lists = list(sorted_hash_tuple)
461
+
462
+ sorted_modalities = [modality for modality, _ in sorted_placeholder_lists]
463
+
464
+ # Flatten sorted list of lists to a single list and verify there is no
465
+ # interleaving of placeholders from different modalities.
466
+ merged_placeholders: list[PlaceholderRange] = []
467
+ for modality, placeholder_list in sorted_placeholder_lists:
468
+ if merged_placeholders and placeholder_list[0][
469
+ 'offset'] < merged_placeholders[-1]['offset']:
470
+ raise ValueError(
471
+ "Interleaved mixed-modality inference is currently not "
472
+ "supported.")
473
+ merged_placeholders.extend(placeholder_list)
474
+
475
+ if sorted_hash_lists is not None:
476
+ merged_hashes = []
477
+ for hash_list in sorted_hash_lists:
478
+ merged_hashes.extend(hash_list)
479
+ else:
480
+ merged_hashes = None
481
+
482
+ return sorted_modalities, merged_placeholders, merged_hashes
483
+
484
+
485
+ def group_mm_inputs_by_modality(
486
+ mm_inputs: list["MultiModalKwargs"]) -> list[list["MultiModalKwargs"]]:
487
+ """Group consecutive MultiModalKwargs from mm_inputs with the same modality
488
+ together into the same list for batching purpose. For MultiModalKwargs with
489
+ multiple modalities, put them into their own list.
490
+
491
+ Args:
492
+ mm_inputs: List of MultiModalKwargs.
493
+
494
+ Returns:
495
+ list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
496
+ inner list contains consecutive MultiModalKwargs with same modality, or
497
+ one with multimodal modalities.
498
+ """
499
+ if not mm_inputs:
500
+ return []
501
+
502
+ def modality_group_func(mm_input: "MultiModalKwargs") -> Union[str, int]:
503
+ # If the input has multiple modalities, return a id as the unique key
504
+ # for the mm_input input.
505
+ if len(mm_input.modalities) > 1:
506
+ return id(mm_input)
507
+
508
+ elif len(mm_input.modalities) == 1:
509
+ return list(mm_input.modalities)[0]
510
+
511
+ # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
512
+ # this is used to make InternVL with legacy pipeline still work with v1.
513
+ else:
514
+ return ""
515
+
516
+ return [
517
+ list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
518
+ ]
.venv/lib/python3.11/site-packages/vllm/multimodal/video.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import base64
4
+ from functools import lru_cache, partial
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Dict, Optional
8
+
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ from PIL import Image
12
+
13
+ from vllm.inputs.registry import InputContext
14
+ from vllm.logger import init_logger
15
+ from vllm.transformers_utils.processor import get_video_processor
16
+ from vllm.transformers_utils.tokenizer import get_tokenizer
17
+ from vllm.utils import PlaceholderModule, is_list_of
18
+
19
+ from .base import MediaIO, ModalityData
20
+ from .image import ImageMediaIO, ImagePlugin
21
+ from .inputs import MultiModalKwargs, VideoItem
22
+
23
+ if TYPE_CHECKING:
24
+ from vllm.config import ModelConfig
25
+
26
+ try:
27
+ import decord
28
+ except ImportError:
29
+ decord = PlaceholderModule("decord") # type: ignore[assignment]
30
+
31
+ logger = init_logger(__name__)
32
+
33
+ cached_get_video_processor = lru_cache(get_video_processor)
34
+ cached_get_tokenizer = lru_cache(get_tokenizer)
35
+
36
+
37
+ class VideoPlugin(ImagePlugin):
38
+ """Plugin for video data."""
39
+
40
+ def get_data_key(self) -> str:
41
+ return "video"
42
+
43
+ def _get_hf_video_processor(
44
+ self,
45
+ model_config: "ModelConfig",
46
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None,
47
+ ):
48
+ if mm_processor_kwargs is None:
49
+ mm_processor_kwargs = {}
50
+ return cached_get_video_processor(
51
+ model_config.model,
52
+ trust_remote_code=model_config.trust_remote_code,
53
+ **mm_processor_kwargs)
54
+
55
+ def _default_input_mapper(
56
+ self,
57
+ ctx: InputContext,
58
+ data: ModalityData[VideoItem],
59
+ **mm_processor_kwargs,
60
+ ) -> MultiModalKwargs:
61
+ model_config = ctx.model_config
62
+
63
+ if isinstance(data, list) and len(data) == 1:
64
+ data = data[0] # type: ignore
65
+
66
+ if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
67
+ video_processor = self._get_hf_video_processor(
68
+ model_config,
69
+ mm_processor_kwargs,
70
+ )
71
+ if video_processor is None:
72
+ raise RuntimeError("No HuggingFace processor is available "
73
+ "to process the video object")
74
+ try:
75
+ # NOTE: Similar to image; it may be a good idea to filter and
76
+ # pass mm_processor_kwargs here too, but for now we don't to
77
+ # avoid extra complexity if the initializer and preprocess
78
+ # signatures of the processor don't align
79
+ batch_data = video_processor(data, return_tensors="pt").data
80
+ except Exception:
81
+ logger.error("Failed to process video (%s)", data)
82
+ raise
83
+
84
+ return MultiModalKwargs(batch_data)
85
+
86
+ raise TypeError(f"Invalid video type: {type(data)}")
87
+
88
+ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
89
+ return 4096
90
+
91
+
92
+ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
93
+ num_frames, _, _, channels = frames.shape
94
+ new_height, new_width = size
95
+ resized_frames = np.empty((num_frames, new_height, new_width, channels),
96
+ dtype=frames.dtype)
97
+ # lazy import cv2 to avoid bothering users who only use text models
98
+ import cv2
99
+ for i, frame in enumerate(frames):
100
+ resized_frame = cv2.resize(frame, (new_width, new_height))
101
+ resized_frames[i] = resized_frame
102
+ return resized_frames
103
+
104
+
105
+ def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
106
+ _, height, width, _ = frames.shape
107
+ new_height = int(height * size_factor)
108
+ new_width = int(width * size_factor)
109
+
110
+ return resize_video(frames, (new_height, new_width))
111
+
112
+
113
+ def sample_frames_from_video(frames: npt.NDArray,
114
+ num_frames: int) -> npt.NDArray:
115
+ total_frames = frames.shape[0]
116
+ if num_frames == -1:
117
+ return frames
118
+
119
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
120
+ sampled_frames = frames[frame_indices, ...]
121
+ return sampled_frames
122
+
123
+
124
+ class VideoMediaIO(MediaIO[npt.NDArray]):
125
+
126
+ def __init__(
127
+ self,
128
+ image_io: ImageMediaIO,
129
+ *,
130
+ num_frames: int = 32,
131
+ ) -> None:
132
+ super().__init__()
133
+
134
+ self.image_io = image_io
135
+ self.num_frames = num_frames
136
+
137
+ def load_bytes(self, data: bytes) -> npt.NDArray:
138
+ vr = decord.VideoReader(BytesIO(data), num_threads=1)
139
+ total_frame_num = len(vr)
140
+
141
+ num_frames = self.num_frames
142
+ if total_frame_num > num_frames:
143
+ uniform_sampled_frames = np.linspace(0,
144
+ total_frame_num - 1,
145
+ num_frames,
146
+ dtype=int)
147
+ frame_idx = uniform_sampled_frames.tolist()
148
+ else:
149
+ frame_idx = list(range(0, total_frame_num))
150
+
151
+ return vr.get_batch(frame_idx).asnumpy()
152
+
153
+ def load_base64(self, media_type: str, data: str) -> npt.NDArray:
154
+ if media_type.lower() == "video/jpeg":
155
+ load_frame = partial(
156
+ self.image_io.load_base64,
157
+ "image/jpeg",
158
+ )
159
+
160
+ return np.stack([
161
+ np.array(load_frame(frame_data))
162
+ for frame_data in data.split(",")
163
+ ])
164
+
165
+ return self.load_bytes(base64.b64decode(data))
166
+
167
+ def load_file(self, filepath: Path) -> npt.NDArray:
168
+ with filepath.open("rb") as f:
169
+ data = f.read()
170
+
171
+ return self.load_bytes(data)
172
+
173
+ def encode_base64(
174
+ self,
175
+ media: npt.NDArray,
176
+ *,
177
+ video_format: str = "JPEG",
178
+ ) -> str:
179
+ video = media
180
+
181
+ if video_format == "JPEG":
182
+ encode_frame = partial(
183
+ self.image_io.encode_base64,
184
+ image_format=video_format,
185
+ )
186
+
187
+ return ",".join(
188
+ encode_frame(Image.fromarray(frame)) for frame in video)
189
+
190
+ msg = "Only JPEG format is supported for now."
191
+ raise NotImplementedError(msg)
.venv/lib/python3.11/site-packages/vllm/triton_utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.triton_utils.importing import HAS_TRITON
4
+
5
+ __all__ = ["HAS_TRITON"]
6
+
7
+ if HAS_TRITON:
8
+
9
+ from vllm.triton_utils.custom_cache_manager import (
10
+ maybe_set_triton_cache_manager)
11
+
12
+ __all__ += ["maybe_set_triton_cache_manager"]
.venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (473 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/custom_cache_manager.cpython-311.pyc ADDED
Binary file (3.53 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/triton_utils/__pycache__/importing.cpython-311.pyc ADDED
Binary file (779 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (184 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cache_engine.cpython-311.pyc ADDED
Binary file (7.42 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_enc_dec_model_runner.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_model_runner.cpython-311.pyc ADDED
Binary file (33.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_pooling_model_runner.cpython-311.pyc ADDED
Binary file (6.72 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/cpu_worker.cpython-311.pyc ADDED
Binary file (20.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/enc_dec_model_runner.cpython-311.pyc ADDED
Binary file (22.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/hpu_worker.cpython-311.pyc ADDED
Binary file (27.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner.cpython-311.pyc ADDED
Binary file (87.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/model_runner_base.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_model_runner.cpython-311.pyc ADDED
Binary file (36.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_tpu_worker.cpython-311.pyc ADDED
Binary file (4.61 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/multi_step_worker.cpython-311.pyc ADDED
Binary file (7.57 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_model_runner.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/neuron_worker.cpython-311.pyc ADDED
Binary file (6.76 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_model_runner.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/openvino_worker.cpython-311.pyc ADDED
Binary file (27 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/pooling_model_runner.cpython-311.pyc ADDED
Binary file (9.95 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_model_runner.cpython-311.pyc ADDED
Binary file (40.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/tpu_worker.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.26 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker.cpython-311.pyc ADDED
Binary file (29.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/worker_base.cpython-311.pyc ADDED
Binary file (30.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_model_runner.cpython-311.pyc ADDED
Binary file (26.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/worker/__pycache__/xpu_worker.cpython-311.pyc ADDED
Binary file (9.11 kB). View file