fix: type annotation bugs in media_utils, processor and config

#35
Files changed (3) hide show
  1. configuration_kimi_k25.py +39 -36
  2. kimi_k25_processor.py +168 -165
  3. media_utils.py +362 -368
configuration_kimi_k25.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from transformers.configuration_utils import PretrainedConfig
2
 
3
  try:
@@ -7,36 +9,35 @@ except ImportError:
7
 
8
 
9
  class KimiK25VisionConfig(PretrainedConfig):
10
-
11
  def __init__(
12
- self,
13
- patch_size: int = 14,
14
- init_pos_emb_height: int = 64,
15
- init_pos_emb_width: int = 64,
16
- init_pos_emb_time: int = 4,
17
- pos_emb_type: str = 'divided_fixed',
18
- vt_num_attention_heads: int = 16,
19
- vt_num_hidden_layers: int = 27,
20
- vt_hidden_size: int = 1152,
21
- vt_intermediate_size: int = 4304,
22
- merge_kernel_size: tuple = (2, 2),
23
- video_attn_type: str = 'spatial_temporal',
24
- merge_type: str = 'sd2_tpool',
25
- _attn_implementation: str = 'flash_attention_2',
26
- # MM Projector parameters
27
- mm_projector_type: str = 'patchmerger',
28
- mm_hidden_size: int | None = None,
29
- projector_hidden_act: str = "gelu",
30
- projector_ln_eps: float = 1e-5,
31
- # Other parameters
32
- ignore_index: int = -100,
33
- media_placeholder_token_id: int = 163605,
34
- pad_token_id: int = 0,
35
- use_unified_vision_chunk: bool = True,
36
- video_placeholder="<|kimi_k25_video_placeholder|>",
37
- text_hidden_size=7168,
38
- **vision_config_kwargs):
39
-
40
  self.patch_size = patch_size
41
  self.init_pos_emb_height = init_pos_emb_height
42
  self.init_pos_emb_width = init_pos_emb_width
@@ -53,7 +54,9 @@ class KimiK25VisionConfig(PretrainedConfig):
53
 
54
  # MM Projector config
55
  self.mm_projector_type = mm_projector_type
56
- self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else vt_hidden_size
 
 
57
  self.projector_hidden_act = projector_hidden_act
58
  self.projector_ln_eps = projector_ln_eps
59
  self.text_hidden_size = text_hidden_size
@@ -64,7 +67,7 @@ class KimiK25Config(PretrainedConfig):
64
 
65
  Args:
66
  text_config (dict | DeepseekV3Config): Configuration for the text model.
67
-
68
  Vision Tower Parameters (from MoonViT3dConfig):
69
  patch_size (int): Patch size for vision tower.
70
  init_pos_emb_height (int): Initial position embedding height.
@@ -79,13 +82,13 @@ class KimiK25Config(PretrainedConfig):
79
  video_attn_type (str): Type of video attention.
80
  merge_type (str): Type of merge operation.
81
  _attn_implementation (str): Attention implementation type.
82
-
83
  MM Projector Parameters (from MultiModalProjectorConfig):
84
  mm_projector_type (str): Type of multimodal projector.
85
  mm_hidden_size (int): Hidden size from vision tower (should match vt_hidden_size).
86
  projector_hidden_act (str): Activation function for projector.
87
  projector_ln_eps (float): Layer norm epsilon for projector.
88
-
89
  Other Parameters:
90
  ignore_index (int): The ignore index for the loss function.
91
  media_placeholder_token_id (int): The token ID to use for media placeholders.
@@ -96,14 +99,14 @@ class KimiK25Config(PretrainedConfig):
96
 
97
  def __init__(
98
  self,
99
- text_config: dict | DeepseekV3Config = None,
100
- vision_config: dict | KimiK25VisionConfig = None,
101
  # Other parameters
102
  ignore_index: int = -100,
103
  media_placeholder_token_id: int = 163605,
104
  pad_token_id: int = 0,
105
  use_unified_vision_chunk: bool = True,
106
- video_placeholder="<|kimi_k25_video_placeholder|>",
107
  **kwargs,
108
  ):
109
  if isinstance(text_config, dict):
 
1
+ from typing import Optional, Union
2
+
3
  from transformers.configuration_utils import PretrainedConfig
4
 
5
  try:
 
9
 
10
 
11
  class KimiK25VisionConfig(PretrainedConfig):
 
12
  def __init__(
13
+ self,
14
+ patch_size: int = 14,
15
+ init_pos_emb_height: int = 64,
16
+ init_pos_emb_width: int = 64,
17
+ init_pos_emb_time: int = 4,
18
+ pos_emb_type: str = "divided_fixed",
19
+ vt_num_attention_heads: int = 16,
20
+ vt_num_hidden_layers: int = 27,
21
+ vt_hidden_size: int = 1152,
22
+ vt_intermediate_size: int = 4304,
23
+ merge_kernel_size: tuple = (2, 2),
24
+ video_attn_type: str = "spatial_temporal",
25
+ merge_type: str = "sd2_tpool",
26
+ _attn_implementation: str = "flash_attention_2",
27
+ # MM Projector parameters
28
+ mm_projector_type: str = "patchmerger",
29
+ mm_hidden_size: int | None = None,
30
+ projector_hidden_act: str = "gelu",
31
+ projector_ln_eps: float = 1e-5,
32
+ # Other parameters
33
+ ignore_index: int = -100,
34
+ media_placeholder_token_id: int = 163605,
35
+ pad_token_id: int = 0,
36
+ use_unified_vision_chunk: bool = True,
37
+ video_placeholder="<|kimi_k25_video_placeholder|>",
38
+ text_hidden_size=7168,
39
+ **vision_config_kwargs,
40
+ ):
41
  self.patch_size = patch_size
42
  self.init_pos_emb_height = init_pos_emb_height
43
  self.init_pos_emb_width = init_pos_emb_width
 
54
 
55
  # MM Projector config
56
  self.mm_projector_type = mm_projector_type
57
+ self.mm_hidden_size = (
58
+ mm_hidden_size if mm_hidden_size is not None else vt_hidden_size
59
+ )
60
  self.projector_hidden_act = projector_hidden_act
61
  self.projector_ln_eps = projector_ln_eps
62
  self.text_hidden_size = text_hidden_size
 
67
 
68
  Args:
69
  text_config (dict | DeepseekV3Config): Configuration for the text model.
70
+
71
  Vision Tower Parameters (from MoonViT3dConfig):
72
  patch_size (int): Patch size for vision tower.
73
  init_pos_emb_height (int): Initial position embedding height.
 
82
  video_attn_type (str): Type of video attention.
83
  merge_type (str): Type of merge operation.
84
  _attn_implementation (str): Attention implementation type.
85
+
86
  MM Projector Parameters (from MultiModalProjectorConfig):
87
  mm_projector_type (str): Type of multimodal projector.
88
  mm_hidden_size (int): Hidden size from vision tower (should match vt_hidden_size).
89
  projector_hidden_act (str): Activation function for projector.
90
  projector_ln_eps (float): Layer norm epsilon for projector.
91
+
92
  Other Parameters:
93
  ignore_index (int): The ignore index for the loss function.
94
  media_placeholder_token_id (int): The token ID to use for media placeholders.
 
99
 
100
  def __init__(
101
  self,
102
+ text_config: Optional[Union[dict, DeepseekV3Config]] = None,
103
+ vision_config: Optional[Union[dict, KimiK25VisionConfig]] = None,
104
  # Other parameters
105
  ignore_index: int = -100,
106
  media_placeholder_token_id: int = 163605,
107
  pad_token_id: int = 0,
108
  use_unified_vision_chunk: bool = True,
109
+ video_placeholder: str = "<|kimi_k25_video_placeholder|>",
110
  **kwargs,
111
  ):
112
  if isinstance(text_config, dict):
kimi_k25_processor.py CHANGED
@@ -1,165 +1,168 @@
1
- from transformers.feature_extraction_utils import BatchFeature
2
- from transformers.processing_utils import ProcessorMixin
3
- from transformers.utils import logging
4
-
5
- logger = logging.get_logger(__name__)
6
-
7
-
8
- class KimiK25Processor(ProcessorMixin):
9
- r"""
10
- Constructs a KimiK25 processor which wraps a KimiK25 image processor and a tokenizer into a single processor.
11
-
12
- [`KimiK25Processor`] offers all the functionalities of [`KimiK25ImageProcessor`] and [`TikTokenTokenizer`]. See the
13
- [`~KimiK25Processor.__call__`] and [`~KimiK25Processor.decode`] for more information.
14
-
15
- Args:
16
- image_processor ([`KimiK25ImageProcessor`], *optional*):
17
- The image processor is a required input.
18
- tokenizer ([`TikTokenTokenizer`], *optional*):
19
- The tokenizer is a required input.
20
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
21
- in a chat into a tokenizable string.
22
- """
23
-
24
- attributes = ["image_processor", "tokenizer"]
25
- valid_kwargs = ["chat_template"]
26
- image_processor_class = "AutoImageProcessor"
27
- tokenizer_class = "AutoTokenizer"
28
-
29
- def __init__(
30
- self,
31
- image_processor=None,
32
- tokenizer=None,
33
- chat_template=None,
34
- **kwargs,
35
- ):
36
- super().__init__(image_processor,
37
- tokenizer,
38
- chat_template=chat_template)
39
- self.media_processor = image_processor
40
- # A special temporal placeholder to be replaced by actual video placeholders
41
- self.video_placeholder = "<|kimi_k25_video_placeholder|>"
42
-
43
- def update_raw_text(self, text: str, video_prompts: list[str]) -> str:
44
- # replace video prompt in text with video chunk prompts
45
- video_count = text.count(self.video_placeholder)
46
- if video_count == 0:
47
- return text
48
- assert video_count == len(video_prompts)
49
- text_parts = text.split(self.video_placeholder)
50
- assert len(text_parts) == len(video_prompts) + 1
51
- text = "".join([
52
- text_parts[i] + video_prompts[i] for i in range(len(video_prompts))
53
- ])
54
- text += text_parts[-1]
55
- return text
56
-
57
- def preprocess_medias(self, medias: list[dict]) -> list[dict]:
58
- updated_medias = []
59
- video_prompts = []
60
- for media in medias:
61
- if media['type'] == 'image':
62
- updated_medias.append(media)
63
- elif media['type'] == 'video':
64
- video_chunks = self.media_processor.split_video_chunks(
65
- media['video'])
66
- updated_medias.extend(video_chunks)
67
- video_prompts.append("".join(
68
- [vc['prompt'] for vc in video_chunks]))
69
- else:
70
- raise ValueError(f"unsupported media type: {media['type']}")
71
- return updated_medias, video_prompts
72
-
73
- def __call__(self,
74
- messages: list[dict] = None,
75
- medias: list[dict] = None,
76
- text: str = None,
77
- return_tensors: str = "pt",
78
- **kwargs) -> BatchFeature:
79
- """
80
- Process multimodal inputs for Kimi-K2.5 model.
81
-
82
- This processor accepts ordered messages and extracts both media and text in a single pass.
83
- text will be automatically updated if video input detected in messages
84
-
85
- Args:
86
- messages: List of message dicts with 'role' and 'content' fields.
87
- If provided, medias and text will be extracted automatically.
88
- medias: Pre-extracted list of media dicts. If None, extracted from messages.
89
- text: Pre-formatted text string. If None, generated via apply_chat_template.
90
- return_tensors: Format of returned tensors ('pt', 'np', 'tf'). Default: 'pt'.
91
- **kwargs: Additional arguments passed to tokenizer.apply_chat_template.
92
-
93
- Returns:
94
- BatchFeature with fields: input_ids, attention_mask, pixel_values, grid_thws.
95
- """
96
- if messages is None and (medias is None or text is None):
97
- raise ValueError(
98
- "Provide either 'messages' or both 'medias' and 'text'")
99
-
100
- if medias is not None and text is not None:
101
- updated_medias, video_prompts = self.preprocess_medias(medias)
102
- preprocessed = self.media_processor.preprocess(
103
- updated_medias, return_tensors=return_tensors)
104
- text = self.update_raw_text(text, video_prompts)
105
- text_inputs = self.tokenizer(text, return_tensors=return_tensors)
106
- return BatchFeature(data={**text_inputs, **preprocessed.data})
107
-
108
- if medias is None:
109
- medias = self._extract_medias_from_messages(messages)
110
- updated_medias, video_prompts = self.preprocess_medias(medias)
111
- preprocessed = self.media_processor.preprocess(
112
- updated_medias, return_tensors=return_tensors)
113
-
114
- # Generate text if not provided
115
- if text is None:
116
- text = self.tokenizer.apply_chat_template(messages, **kwargs)
117
-
118
- text = self.update_raw_text(text, video_prompts)
119
-
120
- text_inputs = self.tokenizer(text, return_tensors=return_tensors)
121
- return BatchFeature(data={**text_inputs, **preprocessed.data})
122
-
123
- @staticmethod
124
- def _extract_medias_from_messages(messages: list[dict]) -> list[dict]:
125
- """
126
- Extract media items from messages in a single pass.
127
-
128
- This is an optimized version that processes messages only once.
129
- Kept as internal method since external callers should use __call__.
130
- """
131
- medias = []
132
- for msg in messages:
133
- if msg['role'] != 'user' or not msg.get('content'):
134
- continue
135
-
136
- for content_part in msg['content']:
137
- if not isinstance(content_part, dict):
138
- continue
139
-
140
- content_type = content_part.get('type')
141
- if content_type in ['video_url', 'video']:
142
- medias.append({
143
- 'type': 'video',
144
- 'video': content_part['video_url']['url'],
145
- 'first_frame_timestamp': 0.0
146
- })
147
- elif content_type in ['image_url', 'image']:
148
- medias.append({
149
- 'type': 'image',
150
- 'image': content_part['image_url'],
151
- })
152
- return medias
153
-
154
- def apply_chat_template(self, messages, **kwargs):
155
- return self.tokenizer.apply_chat_template(messages, **kwargs)
156
-
157
- def batch_decode(self, *args, **kwargs):
158
- return self.tokenizer.batch_decode(*args, **kwargs)
159
-
160
- def decode(self, *args, **kwargs):
161
- return self.tokenizer.decode(*args, **kwargs)
162
-
163
- @property
164
- def model_input_names(self):
165
- return ['input_ids', 'attention_mask', 'pixel_values', 'grid_thws']
 
 
 
 
1
+ from transformers.feature_extraction_utils import BatchFeature
2
+ from transformers.processing_utils import ProcessorMixin
3
+ from transformers.utils import logging
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class KimiK25Processor(ProcessorMixin):
9
+ r"""
10
+ Constructs a KimiK25 processor which wraps a KimiK25 image processor and a tokenizer into a single processor.
11
+
12
+ [`KimiK25Processor`] offers all the functionalities of [`KimiK25ImageProcessor`] and [`TikTokenTokenizer`]. See the
13
+ [`~KimiK25Processor.__call__`] and [`~KimiK25Processor.decode`] for more information.
14
+
15
+ Args:
16
+ image_processor ([`KimiK25ImageProcessor`], *optional*):
17
+ The image processor is a required input.
18
+ tokenizer ([`TikTokenTokenizer`], *optional*):
19
+ The tokenizer is a required input.
20
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
21
+ in a chat into a tokenizable string.
22
+ """
23
+
24
+ attributes = ["image_processor", "tokenizer"]
25
+ valid_kwargs = ["chat_template"]
26
+ image_processor_class = "AutoImageProcessor"
27
+ tokenizer_class = "AutoTokenizer"
28
+
29
+ def __init__(
30
+ self,
31
+ image_processor=None,
32
+ tokenizer=None,
33
+ chat_template=None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
37
+ self.media_processor = image_processor
38
+ # A special temporal placeholder to be replaced by actual video placeholders
39
+ self.video_placeholder = "<|kimi_k25_video_placeholder|>"
40
+
41
+ def update_raw_text(self, text: str, video_prompts: list[str]) -> str:
42
+ # replace video prompt in text with video chunk prompts
43
+ video_count = text.count(self.video_placeholder)
44
+ if video_count == 0:
45
+ return text
46
+ assert video_count == len(video_prompts)
47
+ text_parts = text.split(self.video_placeholder)
48
+ assert len(text_parts) == len(video_prompts) + 1
49
+ text = "".join(
50
+ [text_parts[i] + video_prompts[i] for i in range(len(video_prompts))]
51
+ )
52
+ text += text_parts[-1]
53
+ return text
54
+
55
+ def preprocess_medias(self, medias: list[dict]) -> tuple[list[dict], list[str]]:
56
+ updated_medias = []
57
+ video_prompts = []
58
+ for media in medias:
59
+ if media["type"] == "image":
60
+ updated_medias.append(media)
61
+ elif media["type"] == "video":
62
+ video_chunks = self.media_processor.split_video_chunks(media["video"])
63
+ updated_medias.extend(video_chunks)
64
+ video_prompts.append("".join([vc["prompt"] for vc in video_chunks]))
65
+ else:
66
+ raise ValueError(f"unsupported media type: {media['type']}")
67
+ return updated_medias, video_prompts
68
+
69
+ def __call__(
70
+ self,
71
+ messages: list[dict] = None,
72
+ medias: list[dict] = None,
73
+ text: str = None,
74
+ return_tensors: str = "pt",
75
+ **kwargs,
76
+ ) -> BatchFeature:
77
+ """
78
+ Process multimodal inputs for Kimi-K2.5 model.
79
+
80
+ This processor accepts ordered messages and extracts both media and text in a single pass.
81
+ text will be automatically updated if video input detected in messages
82
+
83
+ Args:
84
+ messages: List of message dicts with 'role' and 'content' fields.
85
+ If provided, medias and text will be extracted automatically.
86
+ medias: Pre-extracted list of media dicts. If None, extracted from messages.
87
+ text: Pre-formatted text string. If None, generated via apply_chat_template.
88
+ return_tensors: Format of returned tensors ('pt', 'np', 'tf'). Default: 'pt'.
89
+ **kwargs: Additional arguments passed to tokenizer.apply_chat_template.
90
+
91
+ Returns:
92
+ BatchFeature with fields: input_ids, attention_mask, pixel_values, grid_thws.
93
+ """
94
+ if messages is None and (medias is None or text is None):
95
+ raise ValueError("Provide either 'messages' or both 'medias' and 'text'")
96
+
97
+ if medias is not None and text is not None:
98
+ updated_medias, video_prompts = self.preprocess_medias(medias)
99
+ preprocessed = self.media_processor.preprocess(
100
+ updated_medias, return_tensors=return_tensors
101
+ )
102
+ text = self.update_raw_text(text, video_prompts)
103
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
104
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
105
+
106
+ if medias is None:
107
+ medias = self._extract_medias_from_messages(messages)
108
+ updated_medias, video_prompts = self.preprocess_medias(medias)
109
+ preprocessed = self.media_processor.preprocess(
110
+ updated_medias, return_tensors=return_tensors
111
+ )
112
+
113
+ # Generate text if not provided
114
+ if text is None:
115
+ text = self.tokenizer.apply_chat_template(messages, **kwargs)
116
+
117
+ text = self.update_raw_text(text, video_prompts)
118
+
119
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
120
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
121
+
122
+ @staticmethod
123
+ def _extract_medias_from_messages(messages: list[dict]) -> list[dict]:
124
+ """
125
+ Extract media items from messages in a single pass.
126
+
127
+ This is an optimized version that processes messages only once.
128
+ Kept as internal method since external callers should use __call__.
129
+ """
130
+ medias = []
131
+ for msg in messages:
132
+ if msg["role"] != "user" or not msg.get("content"):
133
+ continue
134
+
135
+ for content_part in msg["content"]:
136
+ if not isinstance(content_part, dict):
137
+ continue
138
+
139
+ content_type = content_part.get("type")
140
+ if content_type in ["video_url", "video"]:
141
+ medias.append(
142
+ {
143
+ "type": "video",
144
+ "video": content_part["video_url"]["url"],
145
+ "first_frame_timestamp": 0.0,
146
+ }
147
+ )
148
+ elif content_type in ["image_url", "image"]:
149
+ medias.append(
150
+ {
151
+ "type": "image",
152
+ "image": content_part["image_url"],
153
+ }
154
+ )
155
+ return medias
156
+
157
+ def apply_chat_template(self, messages, **kwargs):
158
+ return self.tokenizer.apply_chat_template(messages, **kwargs)
159
+
160
+ def batch_decode(self, *args, **kwargs):
161
+ return self.tokenizer.batch_decode(*args, **kwargs)
162
+
163
+ def decode(self, *args, **kwargs):
164
+ return self.tokenizer.decode(*args, **kwargs)
165
+
166
+ @property
167
+ def model_input_names(self):
168
+ return ["input_ids", "attention_mask", "pixel_values", "grid_thws"]
media_utils.py CHANGED
@@ -1,368 +1,362 @@
1
- import base64
2
- import io
3
- import math
4
- import os
5
- from datetime import datetime, timezone
6
- from typing import List, Literal, Optional, TypedDict
7
-
8
- import numpy as np
9
- from PIL import Image
10
- from pydantic import BaseModel, Field
11
-
12
- try:
13
- from mecord import VideoReader
14
- except ImportError:
15
- VideoReader = None
16
-
17
-
18
- class VideoSpec(BaseModel):
19
- media_type: str = Literal['video']
20
- height: int = Field(..., gt=0, description="video frame height")
21
- width: int = Field(..., gt=0, description="video frame width")
22
- num_frames: int = Field(..., gt=0, description="num frames")
23
- fps: float = Field(..., gt=0, description="average fps")
24
-
25
- # optional, help to accelerate video reading
26
- key_indices: list[int] = Field(None, description="key indices")
27
- frame_time_info: dict = Field(None, description="frame time info")
28
-
29
-
30
- class ImageInput(TypedDict):
31
- type: Literal['image']
32
- image: Image.Image
33
-
34
-
35
- class VideoChunkInput(TypedDict):
36
- type: Literal['video_chunk']
37
- video_chunk: List[Image.Image]
38
- prompt: Optional[str] = None
39
-
40
-
41
- MediaInput = ImageInput | VideoChunkInput
42
-
43
-
44
- def get_video_meta(video_src: bytes | str | os.PathLike,
45
- accurate: bool = True) -> dict:
46
- """Get the dimensions of a video."""
47
- if isinstance(video_src, os.PathLike):
48
- video_src = str(video_src)
49
- # if b64 string, decode to bytes
50
- if isinstance(video_src,
51
- str) and video_src.startswith('data:video/mp4;base64,'):
52
- video_src = base64.b64decode(video_src.split(',')[1])
53
- video = VideoReader(video_src, auto_init=accurate, num_threads=1)
54
- assert video.num_frames > 0, "Invalid video format."
55
- assert video.original_width > 0 and video.original_height > 0, (
56
- "Invalid video format.")
57
- assert video.avg_fps > 0, "Invalid video format."
58
- return VideoSpec(media_type='video',
59
- height=video.original_height,
60
- width=video.original_width,
61
- num_frames=video.num_frames,
62
- fps=video.avg_fps,
63
- key_indices=video.key_indices,
64
- frame_time_info=video.frame_time_info)
65
-
66
-
67
- def timestamp_as_str(timestamp: float,
68
- timestamp_mode: str = "hh:mm:ss.fff") -> str:
69
- """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
70
- if timestamp_mode == "hh:mm:ss.fff":
71
- return (datetime.fromtimestamp(timestamp,
72
- tz=timezone.utc).strftime("%H:%M:%S") +
73
- f".{int((timestamp % 1) * 1000):03d}")
74
- elif timestamp_mode == "mm:ss.fff":
75
- return (datetime.fromtimestamp(timestamp,
76
- tz=timezone.utc).strftime("%M:%S") +
77
- f".{int((timestamp % 1) * 1000):03d}")
78
- elif timestamp_mode == "mm:ss":
79
- return datetime.fromtimestamp(timestamp,
80
- tz=timezone.utc).strftime("%M:%S")
81
- else:
82
- raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
83
-
84
-
85
- def navit_resize_image(
86
- width: int,
87
- height: int,
88
- patch_size: int,
89
- merge_kernel_size: int,
90
- in_patch_limit: int,
91
- patch_limit_on_one_side: int,
92
- fixed_output_tokens: int | None,
93
- ):
94
- # Apply the patch limits.
95
- s1 = math.sqrt(
96
- in_patch_limit /
97
- (max(1.0, width // patch_size) * max(1.0, height // patch_size)))
98
- s2 = patch_limit_on_one_side * patch_size / width
99
- s3 = patch_limit_on_one_side * patch_size / height
100
- scale = min(1.0, s1, s2, s3)
101
- new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
102
- new_w = min(new_w, patch_limit_on_one_side * patch_size)
103
- new_h = min(new_h, patch_limit_on_one_side * patch_size)
104
-
105
- # Calculate the padding to make the height and width divisible by the merge kernel size and patch size.
106
- factor = merge_kernel_size * patch_size
107
-
108
- pad_height = (factor - new_h % factor) % factor
109
- pad_width = (factor - new_w % factor) % factor
110
-
111
- if fixed_output_tokens is not None:
112
- num_tokens = fixed_output_tokens
113
- else:
114
- # Calculate new dimensions after padding and patching
115
- token_height = (new_h + pad_height) // factor
116
- token_width = (new_w + pad_width) // factor
117
-
118
- assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
119
- f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
120
- )
121
- assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
122
- f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
123
- )
124
-
125
- num_tokens = token_height * token_width
126
- return {
127
- "num_tokens": num_tokens,
128
- "new_width": new_w,
129
- "new_height": new_h,
130
- "pad_width": pad_width,
131
- "pad_height": pad_height,
132
- "sampled_nframes": 1,
133
- }
134
-
135
-
136
- def navit_resize_video(
137
- width: int,
138
- height: int,
139
- nframes: int,
140
- avg_fps: float,
141
- sample_fps: float,
142
- patch_size: int,
143
- merge_kernel_size: int,
144
- in_patch_limit_each_frame: int,
145
- patch_limit_on_one_side: int,
146
- in_patch_limit_total: int | None,
147
- max_num_frames_each_video: int | None,
148
- fixed_output_tokens_each_frame: int | None,
149
- ):
150
- sample_fps = min(sample_fps, avg_fps)
151
- # Calculate the number of frames to sample based on target FPS
152
- sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
153
- if max_num_frames_each_video is not None:
154
- sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
155
-
156
- if in_patch_limit_total is not None:
157
- in_patch_limit_each_frame = min(
158
- round(in_patch_limit_total / sampled_nframes),
159
- in_patch_limit_each_frame)
160
-
161
- ret = navit_resize_image(
162
- width,
163
- height,
164
- patch_size,
165
- merge_kernel_size,
166
- in_patch_limit_each_frame,
167
- patch_limit_on_one_side,
168
- fixed_output_tokens_each_frame,
169
- )
170
- ret["sampled_nframes"] = sampled_nframes
171
- return ret
172
-
173
-
174
- def real_sample_fps_and_max_num_frames(
175
- type_name: Literal["video", "video_chunk"],
176
- sample_fps: float,
177
- max_num_frames_each_video: int | None,
178
- ) -> tuple[int, int | None]:
179
- if type_name == "video":
180
- return sample_fps, max_num_frames_each_video
181
- elif type_name == "video_chunk":
182
- max_num_frames_each_video = None
183
- sample_fps = math.inf
184
- return sample_fps, max_num_frames_each_video
185
- else:
186
- return math.inf, None
187
-
188
-
189
- def _to_pil(data: str | bytes):
190
- if isinstance(data, Image.Image):
191
-
192
- return data.convert("RGB")
193
- elif isinstance(data, str):
194
- if data.startswith("data:"):
195
- raw_base64 = data.split(",")[1]
196
- return Image.open(io.BytesIO(
197
- base64.b64decode(raw_base64))).convert("RGB")
198
- else:
199
- return Image.open(data).convert("RGB")
200
- elif isinstance(data, bytes):
201
- return Image.open(io.BytesIO(data)).convert("RGB")
202
- else:
203
- raise ValueError(f"Unsupported data type: {type(data)}")
204
-
205
-
206
- def ensure_media_type(media: MediaInput) -> MediaInput:
207
- if media['type'] == 'image':
208
- media['image'] = _to_pil(media['image'])
209
- return media
210
- elif media['type'] == 'video_chunk':
211
- media['video_chunk'] = [
212
- _to_pil(frame) for frame in media['video_chunk']
213
- ]
214
- return media
215
- else:
216
- raise ValueError(f"Unsupported media type: {media['type']}")
217
-
218
-
219
- def image_to_np(
220
- image: Image.Image,
221
- resize_to: tuple[int, int] | None = None,
222
- mode: str = "resize",
223
- raise_error_for_ill_resize: bool = True,
224
- ) -> np.ndarray:
225
- """Convert an image to a numpy array.
226
-
227
- Args:
228
- content: The image to convert.
229
- resize_to: The size to resize the image to.
230
- mode: The mode to resize the image to.
231
- raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
232
-
233
- Returns:
234
- A numpy array.
235
- """
236
- assert isinstance(image, Image.Image), "image must be a PIL Image"
237
- if resize_to is not None:
238
- if mode == "resize":
239
- image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
240
-
241
- elif mode == "rescale_and_pad_to_center":
242
- scale = min(resize_to[0] / image.width,
243
- resize_to[1] / image.height, 1.0)
244
- new_width = round(image.width * scale)
245
- new_height = round(image.height * scale)
246
- if new_width == 0 or new_height == 0:
247
- if raise_error_for_ill_resize:
248
- raise ValueError(
249
- f"Invalid resize to: {resize_to}, from image size: {image.size}"
250
- )
251
- else:
252
- return np.zeros((resize_to[1], resize_to[0], 3),
253
- dtype=np.uint8)
254
-
255
- image = image.resize((new_width, new_height),
256
- resample=Image.Resampling.BICUBIC)
257
- padding_left = (resize_to[0] - new_width) // 2
258
- padding_right = resize_to[0] - new_width - padding_left
259
- padding_top = (resize_to[1] - new_height) // 2
260
- padding_bottom = resize_to[1] - new_height - padding_top
261
- image = np.asarray(image)
262
- image = np.pad(
263
- image,
264
- ((padding_top, padding_bottom), (padding_left, padding_right),
265
- (0, 0)),
266
- mode="constant",
267
- constant_values=0,
268
- )
269
- assert image.shape == (resize_to[1], resize_to[0], 3)
270
-
271
- elif mode == "rescale_and_pad_to_rightbottom":
272
- scale = min(resize_to[0] / image.width,
273
- resize_to[1] / image.height, 1.0)
274
- new_width = round(image.width * scale)
275
- new_height = round(image.height * scale)
276
- if new_width == 0 or new_height == 0:
277
- if raise_error_for_ill_resize:
278
- raise ValueError(
279
- f"Invalid resize to: {resize_to}, from image size: {image.size}"
280
- )
281
- else:
282
- return np.zeros((resize_to[1], resize_to[0], 3),
283
- dtype=np.uint8)
284
-
285
- image = image.resize((new_width, new_height),
286
- resample=Image.Resampling.BICUBIC)
287
- padding_right = resize_to[0] - new_width
288
- padding_bottom = resize_to[1] - new_height
289
- image = np.asarray(image)
290
- image = np.pad(
291
- image,
292
- ((0, padding_bottom), (0, padding_right), (0, 0)),
293
- mode="constant",
294
- constant_values=0,
295
- )
296
- assert image.shape == (resize_to[1], resize_to[0], 3)
297
-
298
- else:
299
- raise ValueError(f"Invalid mode: {mode}")
300
-
301
- if isinstance(image, Image.Image):
302
- return np.asarray(image)
303
- else:
304
- return image
305
-
306
-
307
- def navit_patchify(pixel_values: np.ndarray,
308
- patch_size: int) -> dict[str, np.ndarray]:
309
- """Reshape the pixel values to a navit shape.
310
-
311
- Args:
312
- pixel_values: np.ndarray, shape (t, h, w, c)
313
- patch_size: int
314
-
315
- Returns:
316
- dict[str, np.ndarray]
317
- - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
318
- - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
319
- """
320
- T, H, W, C = pixel_values.shape
321
- assert C == 3, "pixel_values must have 3 channels"
322
-
323
- patches = pixel_values.reshape(T, H // patch_size, patch_size,
324
- W // patch_size, patch_size, C)
325
- # (T, H//patch_size, W//patch_size, C, patch_size, patch_size)
326
- patches = patches.transpose(0, 1, 3, 5, 2, 4)
327
- patches = patches.reshape(-1, C, patch_size, patch_size)
328
- grid_thw = np.array([T, H // patch_size, W // patch_size])
329
- return {"pixel_values": patches, "grid_thw": grid_thw}
330
-
331
-
332
- def normalize(x: np.ndarray,
333
- mean,
334
- std_inv,
335
- pixels_dtype: np.dtype = np.float32) -> np.ndarray:
336
- """Normalize the image.
337
-
338
- Args:
339
- x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
340
- mean: The mean of the image.
341
- std_inv: The inverse of the std of the image.
342
- pixels_dtype: The dtype of the image.
343
- Returns:
344
- The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
345
- """
346
- x = (x / 255.0).astype(pixels_dtype)
347
- x -= mean
348
- x *= std_inv
349
- return x
350
-
351
-
352
- def _to_tensor(data, **kwargs):
353
- import torch
354
-
355
- if isinstance(data, np.ndarray):
356
- return torch.from_numpy(data).to(**kwargs)
357
- elif isinstance(data, torch.Tensor):
358
- return data.to(**kwargs)
359
- elif isinstance(data, list):
360
- return [_to_tensor(item, **kwargs) for item in data]
361
- elif isinstance(data, tuple):
362
- return tuple(_to_tensor(item, **kwargs) for item in data)
363
- elif isinstance(data, dict):
364
- return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
365
- elif data is None:
366
- return None
367
- else:
368
- raise ValueError(f"Unsupported data type: {type(data)}")
 
1
+ import base64
2
+ import io
3
+ import math
4
+ import os
5
+ from datetime import datetime, timezone
6
+ from typing import List, Literal, NotRequired, Optional, TypedDict
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pydantic import BaseModel, Field
11
+
12
+ try:
13
+ from mecord import VideoReader
14
+ except ImportError:
15
+ VideoReader = None
16
+
17
+
18
+ class VideoSpec(BaseModel):
19
+ media_type: Literal["video"] = "video"
20
+ height: int = Field(..., gt=0, description="video frame height")
21
+ width: int = Field(..., gt=0, description="video frame width")
22
+ num_frames: int = Field(..., gt=0, description="num frames")
23
+ fps: float = Field(..., gt=0, description="average fps")
24
+
25
+ # optional, help to accelerate video reading
26
+ key_indices: list[int] = Field(None, description="key indices")
27
+ frame_time_info: dict = Field(None, description="frame time info")
28
+
29
+
30
+ class ImageInput(TypedDict):
31
+ type: Literal["image"]
32
+ image: Image.Image
33
+
34
+
35
+ class VideoChunkInput(TypedDict):
36
+ type: Literal["video_chunk"]
37
+ video_chunk: List[Image.Image]
38
+ prompt: NotRequired[str]
39
+
40
+
41
+ MediaInput = ImageInput | VideoChunkInput
42
+
43
+
44
+ def get_video_meta(video_src: bytes | str | os.PathLike, accurate: bool = True) -> dict:
45
+ """Get the dimensions of a video."""
46
+ if isinstance(video_src, os.PathLike):
47
+ video_src = str(video_src)
48
+ # if b64 string, decode to bytes
49
+ if isinstance(video_src, str) and video_src.startswith("data:video/mp4;base64,"):
50
+ video_src = base64.b64decode(video_src.split(",")[1])
51
+ video = VideoReader(video_src, auto_init=accurate, num_threads=1)
52
+ assert video.num_frames > 0, "Invalid video format."
53
+ assert video.original_width > 0 and video.original_height > 0, (
54
+ "Invalid video format."
55
+ )
56
+ assert video.avg_fps > 0, "Invalid video format."
57
+ return VideoSpec(
58
+ media_type="video",
59
+ height=video.original_height,
60
+ width=video.original_width,
61
+ num_frames=video.num_frames,
62
+ fps=video.avg_fps,
63
+ key_indices=video.key_indices,
64
+ frame_time_info=video.frame_time_info,
65
+ )
66
+
67
+
68
+ def timestamp_as_str(timestamp: float, timestamp_mode: str = "hh:mm:ss.fff") -> str:
69
+ """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
70
+ if timestamp_mode == "hh:mm:ss.fff":
71
+ return (
72
+ datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime("%H:%M:%S")
73
+ + f".{int((timestamp % 1) * 1000):03d}"
74
+ )
75
+ elif timestamp_mode == "mm:ss.fff":
76
+ return (
77
+ datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime("%M:%S")
78
+ + f".{int((timestamp % 1) * 1000):03d}"
79
+ )
80
+ elif timestamp_mode == "mm:ss":
81
+ return datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime("%M:%S")
82
+ else:
83
+ raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
84
+
85
+
86
+ def navit_resize_image(
87
+ width: int,
88
+ height: int,
89
+ patch_size: int,
90
+ merge_kernel_size: int,
91
+ in_patch_limit: int,
92
+ patch_limit_on_one_side: int,
93
+ fixed_output_tokens: int | None,
94
+ ):
95
+ # Apply the patch limits.
96
+ s1 = math.sqrt(
97
+ in_patch_limit
98
+ / (max(1.0, width // patch_size) * max(1.0, height // patch_size))
99
+ )
100
+ s2 = patch_limit_on_one_side * patch_size / width
101
+ s3 = patch_limit_on_one_side * patch_size / height
102
+ scale = min(1.0, s1, s2, s3)
103
+ new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
104
+ new_w = min(new_w, patch_limit_on_one_side * patch_size)
105
+ new_h = min(new_h, patch_limit_on_one_side * patch_size)
106
+
107
+ # Calculate the padding to make the height and width divisible by the merge kernel size and patch size.
108
+ factor = merge_kernel_size * patch_size
109
+
110
+ pad_height = (factor - new_h % factor) % factor
111
+ pad_width = (factor - new_w % factor) % factor
112
+
113
+ if fixed_output_tokens is not None:
114
+ num_tokens = fixed_output_tokens
115
+ else:
116
+ # Calculate new dimensions after padding and patching
117
+ token_height = (new_h + pad_height) // factor
118
+ token_width = (new_w + pad_width) // factor
119
+
120
+ assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
121
+ f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
122
+ )
123
+ assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
124
+ f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
125
+ )
126
+
127
+ num_tokens = token_height * token_width
128
+ return {
129
+ "num_tokens": num_tokens,
130
+ "new_width": new_w,
131
+ "new_height": new_h,
132
+ "pad_width": pad_width,
133
+ "pad_height": pad_height,
134
+ "sampled_nframes": 1,
135
+ }
136
+
137
+
138
+ def navit_resize_video(
139
+ width: int,
140
+ height: int,
141
+ nframes: int,
142
+ avg_fps: float,
143
+ sample_fps: float,
144
+ patch_size: int,
145
+ merge_kernel_size: int,
146
+ in_patch_limit_each_frame: int,
147
+ patch_limit_on_one_side: int,
148
+ in_patch_limit_total: int | None,
149
+ max_num_frames_each_video: int | None,
150
+ fixed_output_tokens_each_frame: int | None,
151
+ ):
152
+ sample_fps = min(sample_fps, avg_fps)
153
+ # Calculate the number of frames to sample based on target FPS
154
+ sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
155
+ if max_num_frames_each_video is not None:
156
+ sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
157
+
158
+ if in_patch_limit_total is not None:
159
+ in_patch_limit_each_frame = min(
160
+ round(in_patch_limit_total / sampled_nframes), in_patch_limit_each_frame
161
+ )
162
+
163
+ ret = navit_resize_image(
164
+ width,
165
+ height,
166
+ patch_size,
167
+ merge_kernel_size,
168
+ in_patch_limit_each_frame,
169
+ patch_limit_on_one_side,
170
+ fixed_output_tokens_each_frame,
171
+ )
172
+ ret["sampled_nframes"] = sampled_nframes
173
+ return ret
174
+
175
+
176
+ def real_sample_fps_and_max_num_frames(
177
+ type_name: Literal["video", "video_chunk"],
178
+ sample_fps: float,
179
+ max_num_frames_each_video: int | None,
180
+ ) -> tuple[int, int | None]:
181
+ if type_name == "video":
182
+ return sample_fps, max_num_frames_each_video
183
+ elif type_name == "video_chunk":
184
+ max_num_frames_each_video = None
185
+ sample_fps = math.inf
186
+ return sample_fps, max_num_frames_each_video
187
+ else:
188
+ return math.inf, None
189
+
190
+
191
+ def _to_pil(data: str | bytes):
192
+ if isinstance(data, Image.Image):
193
+ return data.convert("RGB")
194
+ elif isinstance(data, str):
195
+ if data.startswith("data:"):
196
+ raw_base64 = data.split(",")[1]
197
+ return Image.open(io.BytesIO(base64.b64decode(raw_base64))).convert("RGB")
198
+ else:
199
+ return Image.open(data).convert("RGB")
200
+ elif isinstance(data, bytes):
201
+ return Image.open(io.BytesIO(data)).convert("RGB")
202
+ else:
203
+ raise ValueError(f"Unsupported data type: {type(data)}")
204
+
205
+
206
+ def ensure_media_type(media: MediaInput) -> MediaInput:
207
+ if media["type"] == "image":
208
+ media["image"] = _to_pil(media["image"])
209
+ return media
210
+ elif media["type"] == "video_chunk":
211
+ media["video_chunk"] = [_to_pil(frame) for frame in media["video_chunk"]]
212
+ return media
213
+ else:
214
+ raise ValueError(f"Unsupported media type: {media['type']}")
215
+
216
+
217
+ def image_to_np(
218
+ image: Image.Image,
219
+ resize_to: tuple[int, int] | None = None,
220
+ mode: str = "resize",
221
+ raise_error_for_ill_resize: bool = True,
222
+ ) -> np.ndarray:
223
+ """Convert an image to a numpy array.
224
+
225
+ Args:
226
+ content: The image to convert.
227
+ resize_to: The size to resize the image to.
228
+ mode: The mode to resize the image to.
229
+ raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
230
+
231
+ Returns:
232
+ A numpy array.
233
+ """
234
+ assert isinstance(image, Image.Image), "image must be a PIL Image"
235
+ if resize_to is not None:
236
+ if mode == "resize":
237
+ image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
238
+
239
+ elif mode == "rescale_and_pad_to_center":
240
+ scale = min(resize_to[0] / image.width, resize_to[1] / image.height, 1.0)
241
+ new_width = round(image.width * scale)
242
+ new_height = round(image.height * scale)
243
+ if new_width == 0 or new_height == 0:
244
+ if raise_error_for_ill_resize:
245
+ raise ValueError(
246
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
247
+ )
248
+ else:
249
+ return np.zeros((resize_to[1], resize_to[0], 3), dtype=np.uint8)
250
+
251
+ image = image.resize(
252
+ (new_width, new_height), resample=Image.Resampling.BICUBIC
253
+ )
254
+ padding_left = (resize_to[0] - new_width) // 2
255
+ padding_right = resize_to[0] - new_width - padding_left
256
+ padding_top = (resize_to[1] - new_height) // 2
257
+ padding_bottom = resize_to[1] - new_height - padding_top
258
+ image = np.asarray(image)
259
+ image = np.pad(
260
+ image,
261
+ ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)),
262
+ mode="constant",
263
+ constant_values=0,
264
+ )
265
+ assert image.shape == (resize_to[1], resize_to[0], 3)
266
+
267
+ elif mode == "rescale_and_pad_to_rightbottom":
268
+ scale = min(resize_to[0] / image.width, resize_to[1] / image.height, 1.0)
269
+ new_width = round(image.width * scale)
270
+ new_height = round(image.height * scale)
271
+ if new_width == 0 or new_height == 0:
272
+ if raise_error_for_ill_resize:
273
+ raise ValueError(
274
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
275
+ )
276
+ else:
277
+ return np.zeros((resize_to[1], resize_to[0], 3), dtype=np.uint8)
278
+
279
+ image = image.resize(
280
+ (new_width, new_height), resample=Image.Resampling.BICUBIC
281
+ )
282
+ padding_right = resize_to[0] - new_width
283
+ padding_bottom = resize_to[1] - new_height
284
+ image = np.asarray(image)
285
+ image = np.pad(
286
+ image,
287
+ ((0, padding_bottom), (0, padding_right), (0, 0)),
288
+ mode="constant",
289
+ constant_values=0,
290
+ )
291
+ assert image.shape == (resize_to[1], resize_to[0], 3)
292
+
293
+ else:
294
+ raise ValueError(f"Invalid mode: {mode}")
295
+
296
+ if isinstance(image, Image.Image):
297
+ return np.asarray(image)
298
+ else:
299
+ return image
300
+
301
+
302
+ def navit_patchify(pixel_values: np.ndarray, patch_size: int) -> dict[str, np.ndarray]:
303
+ """Reshape the pixel values to a navit shape.
304
+
305
+ Args:
306
+ pixel_values: np.ndarray, shape (t, h, w, c)
307
+ patch_size: int
308
+
309
+ Returns:
310
+ dict[str, np.ndarray]
311
+ - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
312
+ - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
313
+ """
314
+ T, H, W, C = pixel_values.shape
315
+ assert C == 3, "pixel_values must have 3 channels"
316
+
317
+ patches = pixel_values.reshape(
318
+ T, H // patch_size, patch_size, W // patch_size, patch_size, C
319
+ )
320
+ # (T, H//patch_size, W//patch_size, C, patch_size, patch_size)
321
+ patches = patches.transpose(0, 1, 3, 5, 2, 4)
322
+ patches = patches.reshape(-1, C, patch_size, patch_size)
323
+ grid_thw = np.array([T, H // patch_size, W // patch_size])
324
+ return {"pixel_values": patches, "grid_thw": grid_thw}
325
+
326
+
327
+ def normalize(
328
+ x: np.ndarray, mean, std_inv, pixels_dtype: np.dtype = np.float32
329
+ ) -> np.ndarray:
330
+ """Normalize the image.
331
+
332
+ Args:
333
+ x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
334
+ mean: The mean of the image.
335
+ std_inv: The inverse of the std of the image.
336
+ pixels_dtype: The dtype of the image.
337
+ Returns:
338
+ The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
339
+ """
340
+ x = (x / 255.0).astype(pixels_dtype)
341
+ x -= mean
342
+ x *= std_inv
343
+ return x
344
+
345
+
346
+ def _to_tensor(data, **kwargs):
347
+ import torch
348
+
349
+ if isinstance(data, np.ndarray):
350
+ return torch.from_numpy(data).to(**kwargs)
351
+ elif isinstance(data, torch.Tensor):
352
+ return data.to(**kwargs)
353
+ elif isinstance(data, list):
354
+ return [_to_tensor(item, **kwargs) for item in data]
355
+ elif isinstance(data, tuple):
356
+ return tuple(_to_tensor(item, **kwargs) for item in data)
357
+ elif isinstance(data, dict):
358
+ return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
359
+ elif data is None:
360
+ return None
361
+ else:
362
+ raise ValueError(f"Unsupported data type: {type(data)}")