Cyril666 commited on
Commit
6691920
·
verified ·
1 Parent(s): 3ce90ef

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ trainer_state.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<image>": 151669,
6
+ "<think>": 151667,
7
+ "<tool_call>": 151657,
8
+ "<tool_response>": 151665,
9
+ "<|audio_end|>": 151674,
10
+ "<|audio_start|>": 151673,
11
+ "<|audio|>": 151672,
12
+ "<|box_end|>": 151649,
13
+ "<|box_start|>": 151648,
14
+ "<|endoftext|>": 151643,
15
+ "<|file_sep|>": 151664,
16
+ "<|fim_middle|>": 151660,
17
+ "<|fim_pad|>": 151662,
18
+ "<|fim_prefix|>": 151659,
19
+ "<|fim_suffix|>": 151661,
20
+ "<|im_end|>": 151645,
21
+ "<|im_start|>": 151644,
22
+ "<|image_pad|>": 151655,
23
+ "<|object_ref_end|>": 151647,
24
+ "<|object_ref_start|>": 151646,
25
+ "<|quad_end|>": 151651,
26
+ "<|quad_start|>": 151650,
27
+ "<|repo_name|>": 151663,
28
+ "<|stream_end|>": 151671,
29
+ "<|stream_start|>": 151670,
30
+ "<|video_pad|>": 151656,
31
+ "<|vision_end|>": 151653,
32
+ "<|vision_pad|>": 151654,
33
+ "<|vision_start|>": 151652
34
+ }
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "\n{%- set identifier = 'im' %}\n{% for message in messages %}\n {% if message['role'] == 'stream' %}\n {% set identifier = 'stream' %}\n {% else %}\n {% set identifier = 'im' %}\n {% endif %}\n {% if message['role'] is not none %}\n {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}\n {% endif %}\n {% if message['content'] is string %}\n {{- message['content'] + '<|' + identifier + '_end|>\n' -}}\n {% else %}\n {% for content in message['content'] %}\n {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n {% if 'time' in content %}\n {{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}\n {% endif %}\n {{- image_token + '\n' -}}\n {% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}\n {% for i in range(content['num_frames']) %}\n {% if 'timestamps' in content %}\n {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}\n {% endif %}\n {% if i < content['num_frames'] - 1 %}\n {{- image_token + ',' -}}\n {% if 'audio_split' in content and content['audio_split'][i] > 0 %}\n {{- '<|audio_start|>' + audio_token * content['audio_split'][i] + '<|audio_end|>,' -}}\n {% endif %}\n {% else %}\n {{- image_token -}}\n {% if 'audio_split' in content and content['audio_split'][i] > 0 %}\n {{- ',<|audio_start|>' + audio_token * content['audio_split'][i] + '<|audio_end|>\n' -}}\n {% else %}\n {{- '\n' -}}\n {% endif %}\n {% endif %}\n {% endfor %}\n {% elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content %}\n {% for i in range(content['num_frames']) %}\n {% if 'timestamps' in content %}\n {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}\n {% endif %}\n {% if i < content['num_frames'] - 1 %}\n {{- '<|audio_start|>' + audio_token + '<|audio_end|>,' -}}\n {% else %}\n {{- '<|audio_start|>' + audio_token + '<|audio_end|>\n' -}}\n {% endif %}\n {% endfor %}\n {% elif content['type'] == 'text' or 'text' in content %}\n {{- content['text'] -}}\n {% endif %}\n {% endfor %}\n {% if message['role'] is not none %}\n {{- '<|' + identifier + '_end|>\n' -}}\n {% endif %}\n {% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' -}}\n {% if not add_think_prompt %}\n {{- '<think>\n\n</think>\n\n' -}}\n {% endif %}\n{% endif %}\n"
3
+ }
config.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Videollama3Qwen3ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_videollama3.Videollama3Qwen3Config",
7
+ "AutoModelForCausalLM": "modeling_videollama3_qwen3.Videollama3Qwen3ForCausalLM"
8
+ },
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.0,
11
+ "audio_encoder": "Cyril666/whisper-large-v3-encoder",
12
+ "audio_encoder_lr": 5e-06,
13
+ "audio_hidden_size": 1280,
14
+ "audio_projector_lr": 0.001,
15
+ "audio_projector_type": "dmlp2x_gelu",
16
+ "audio_token_index": 151672,
17
+ "bos_token_id": 151643,
18
+ "embedding_lr": null,
19
+ "eos_token_id": 151645,
20
+ "head_dim": 128,
21
+ "hidden_act": "silu",
22
+ "hidden_size": 2048,
23
+ "image_aspect_ratio": "square",
24
+ "image_size": -1,
25
+ "image_token_index": 151669,
26
+ "image_token_length": 1,
27
+ "initializer_range": 0.02,
28
+ "intermediate_size": 6144,
29
+ "is_alignment": false,
30
+ "llm_lr": null,
31
+ "loss_reduction_scope": "batch",
32
+ "max_frames": 180,
33
+ "max_position_embeddings": 40960,
34
+ "max_window_layers": 28,
35
+ "mm_hidden_size": 1024,
36
+ "mm_projector_lr": 1e-05,
37
+ "mm_projector_type": "mlp2x_gelu",
38
+ "mm_vision_select_feature": "patch",
39
+ "mm_vision_select_layer": -1,
40
+ "model_type": "videollama3_qwen3",
41
+ "num_attention_heads": 16,
42
+ "num_hidden_layers": 28,
43
+ "num_key_value_heads": 8,
44
+ "rms_norm_eps": 1e-06,
45
+ "rope_scaling": null,
46
+ "rope_theta": 1000000,
47
+ "sliding_window": null,
48
+ "tie_word_embeddings": true,
49
+ "tokenizer_model_max_length": 16384,
50
+ "tokenizer_padding_side": "right",
51
+ "torch_dtype": "bfloat16",
52
+ "transformers_version": "4.51.3",
53
+ "use_cache": true,
54
+ "use_mm_proj": true,
55
+ "use_reconstruct": false,
56
+ "use_sliding_window": false,
57
+ "use_token_compression": false,
58
+ "use_vision_teacher": false,
59
+ "use_visual_expert": false,
60
+ "vision_encoder": null,
61
+ "vision_encoder_lr": null,
62
+ "vision_encoder_teacher": null,
63
+ "vision_hidden_size": 1024,
64
+ "vision_projector_lr": null,
65
+ "vision_projector_type": "mlp2x_gelu",
66
+ "visual_expert_lr": null,
67
+ "vocab_size": 151936,
68
+ "vision_encoder_config": {
69
+ "head_dim": 128,
70
+ "hidden_act": "silu",
71
+ "hidden_size": 1024,
72
+ "initializer_range": 0.02,
73
+ "intermediate_size": 3072,
74
+ "layer_norm_eps": 1e-06,
75
+ "max_window_layers": 28,
76
+ "num_attention_heads": 16,
77
+ "num_channels": 3,
78
+ "num_hidden_layers": 28,
79
+ "num_key_value_heads": 8,
80
+ "patch_size": 14,
81
+ "rms_norm_eps": 1e-06,
82
+ "rope_scaling": null,
83
+ "rope_theta": 1000000,
84
+ "sliding_window": null,
85
+ "torch_dtype": "bfloat16"
86
+ }
87
+ }
configuration_sfl_encoder.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """VideoLLaMA3 vision encoder model configuration."""
18
+
19
+ from transformers import Qwen2Config, Qwen3Config
20
+
21
+
22
+ class SFLVisionEncoderConfigFromQwen2(Qwen2Config):
23
+
24
+ model_type = "sfl_vision_encoder_qwen2"
25
+
26
+ def __init__(
27
+ self,
28
+ hidden_size=1536,
29
+ intermediate_size=8960,
30
+ num_hidden_layers=12,
31
+ num_attention_heads=12,
32
+ num_channels=3,
33
+ patch_size=14,
34
+ layer_norm_eps=1e-6,
35
+ attention_dropout=0.0,
36
+ num_key_value_heads=2,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.hidden_size = hidden_size
42
+ self.intermediate_size = intermediate_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.num_channels = num_channels
46
+ self.patch_size = patch_size
47
+ self.attention_dropout = attention_dropout
48
+ self.num_key_value_heads = num_key_value_heads
49
+ self.layer_norm_eps = layer_norm_eps
50
+
51
+
52
+ class SFLVisionEncoderConfigFromQwen3(Qwen3Config):
53
+
54
+ model_type = "sfl_vision_encoder_qwen3"
55
+
56
+ def __init__(
57
+ self,
58
+ hidden_size=1536,
59
+ intermediate_size=8960,
60
+ num_hidden_layers=12,
61
+ num_attention_heads=12,
62
+ num_channels=3,
63
+ patch_size=14,
64
+ layer_norm_eps=1e-6,
65
+ attention_dropout=0.0,
66
+ num_key_value_heads=2,
67
+ **kwargs,
68
+ ):
69
+ super().__init__(**kwargs)
70
+
71
+ self.hidden_size = hidden_size
72
+ self.intermediate_size = intermediate_size
73
+ self.num_hidden_layers = num_hidden_layers
74
+ self.num_attention_heads = num_attention_heads
75
+ self.num_channels = num_channels
76
+ self.patch_size = patch_size
77
+ self.attention_dropout = attention_dropout
78
+ self.num_key_value_heads = num_key_value_heads
79
+ self.layer_norm_eps = layer_norm_eps
configuration_videollama3.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VideoLLaMA3 model configuration."""
2
+
3
+ import importlib.util
4
+ import os.path as osp
5
+ from typing import Optional, Dict, Any
6
+
7
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, Qwen2Config, Qwen3Config, Qwen2AudioEncoderConfig
8
+
9
+ try:
10
+ from .configuration_sfl_encoder import SFLVisionEncoderConfigFromQwen3
11
+ except ModuleNotFoundError:
12
+ spec = importlib.util.spec_from_file_location(
13
+ "configuration_sfl_encoder",
14
+ osp.join(osp.dirname(__file__), "configuration_sfl_encoder.py"),
15
+ )
16
+ configuration_sfl_encoder = importlib.util.module_from_spec(spec)
17
+ spec.loader.exec_module(configuration_sfl_encoder)
18
+ SFLVisionEncoderConfigFromQwen3 = getattr(
19
+ configuration_sfl_encoder,
20
+ "SFLVisionEncoderConfigFromQwen3",
21
+ )
22
+
23
+ try:
24
+ from .modeling_sfl_encoder_qwen3 import SFLVisionEncoderModelFromQwen3
25
+ except ModuleNotFoundError:
26
+ spec = importlib.util.spec_from_file_location(
27
+ "modeling_sfl_encoder_qwen3",
28
+ osp.join(osp.dirname(__file__), "modeling_sfl_encoder_qwen3.py"),
29
+ )
30
+ modeling_sfl_encoder_qwen3 = importlib.util.module_from_spec(spec)
31
+ spec.loader.exec_module(modeling_sfl_encoder_qwen3)
32
+ SFLVisionEncoderModelFromQwen3 = getattr(
33
+ modeling_sfl_encoder_qwen3,
34
+ "SFLVisionEncoderModelFromQwen3",
35
+ )
36
+
37
+ AutoConfig.register("sfl_vision_encoder_qwen3", SFLVisionEncoderConfigFromQwen3)
38
+ AutoModel.register(SFLVisionEncoderConfigFromQwen3, SFLVisionEncoderModelFromQwen3)
39
+
40
+
41
+ class Videollama3Qwen3Config(Qwen3Config):
42
+
43
+ model_type = "videollama3_qwen3"
44
+ sub_configs = {"vision_encoder_config": SFLVisionEncoderConfigFromQwen3, "audio_encoder_config": Qwen2AudioEncoderConfig}
45
+
46
+ def __init__(
47
+ self,
48
+ vision_encoder: Optional[str] = None,
49
+ audio_encoder: Optional[str] = None,
50
+ vision_encoder_config: Dict[str, Any] = {},
51
+ audio_encoder_config: Dict[str, Any] = {},
52
+ vision_projector_type: str = "mlp2x_gelu",
53
+ audio_projector_type: str = "mlp2x_gelu",
54
+ use_token_compression: bool = True,
55
+ image_token_index: int = -1,
56
+ audio_token_index: int = -1,
57
+ **kwargs,
58
+ ):
59
+ super().__init__(**kwargs)
60
+ self.model_type = "videollama3_qwen3"
61
+
62
+ self.vision_encoder = vision_encoder
63
+ if vision_encoder_config is not None and not isinstance(vision_encoder_config, PretrainedConfig):
64
+ vision_encoder_config = SFLVisionEncoderConfigFromQwen3(**vision_encoder_config)
65
+ self.vision_encoder_config = vision_encoder_config
66
+
67
+ self.audio_encoder = audio_encoder
68
+ if audio_encoder_config is not None and not isinstance(audio_encoder_config, PretrainedConfig):
69
+ audio_encoder_config = Qwen2AudioEncoderConfig(**audio_encoder_config)
70
+ self.audio_encoder_config = audio_encoder_config
71
+
72
+ self.vision_projector_type = vision_projector_type
73
+ self.audio_projector_type = audio_projector_type
74
+ self.use_token_compression = use_token_compression
75
+ self.image_token_index = image_token_index
76
+ self.audio_token_index = audio_token_index
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.3"
13
+ }
image_processing_sfl.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """Image processor class for VideoLLaMA3."""
22
+
23
+ import math
24
+ from typing import Dict, List, Optional, Union
25
+
26
+ import numpy as np
27
+
28
+ import torch
29
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
30
+ from transformers.image_utils import ImageInput
31
+ from transformers.image_transforms import (
32
+ convert_to_rgb,
33
+ resize,
34
+ to_channel_dimension_format,
35
+ )
36
+ from transformers.image_utils import (
37
+ OPENAI_CLIP_MEAN,
38
+ OPENAI_CLIP_STD,
39
+ ChannelDimension,
40
+ ImageInput,
41
+ PILImageResampling,
42
+ get_image_size,
43
+ infer_channel_dimension_format,
44
+ is_scaled_image,
45
+ is_valid_image,
46
+ make_list_of_images,
47
+ to_numpy_array,
48
+ )
49
+ try:
50
+ from transformers.image_utils import VideoInput
51
+ except:
52
+ from transformers.video_utils import VideoInput
53
+ from transformers.utils import TensorType, is_vision_available, logging
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+
59
+ if is_vision_available():
60
+ from PIL import Image
61
+
62
+
63
+ def is_valid_video(video) -> bool:
64
+ if isinstance(video, (list, tuple)):
65
+ return all(is_valid_image(frame) for frame in video)
66
+ elif isinstance(video, np.ndarray):
67
+ return video.ndim == 4
68
+ elif isinstance(video, torch.Tensor):
69
+ return video.ndim == 4
70
+ return False
71
+
72
+
73
+ def make_batched_images(images) -> List[List[ImageInput]]:
74
+ """
75
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
76
+
77
+ Args:
78
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
79
+ The input image.
80
+
81
+ Returns:
82
+ list: A list of images.
83
+ """
84
+ if isinstance(images, (list, tuple)):
85
+ # list of images/videos
86
+ if not all(is_valid_video(image) or is_valid_image(image) for image in images):
87
+ raise ValueError(f"Could not make batched images from {images}")
88
+ return images
89
+ elif is_valid_video(images) or is_valid_image(images):
90
+ # single image/video
91
+ return [images]
92
+
93
+ raise ValueError(f"Could not make batched images from {images}")
94
+
95
+
96
+ def simple_batched_resize(
97
+ images, factor: int = 28, min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
98
+ ):
99
+ min_pixels = min_tokens * factor * factor
100
+ max_pixels = max_tokens * factor * factor
101
+
102
+ num_images = 0
103
+ for image in images:
104
+ if is_valid_video(image):
105
+ num_images += len(image)
106
+ else:
107
+ num_images += 1
108
+
109
+ image_sizes = []
110
+ for image in images:
111
+ if is_valid_video(image):
112
+ image = image[0]
113
+ if isinstance(image, Image.Image):
114
+ width, height = image.size
115
+ else:
116
+ height, width = get_image_size(image, channel_dim=input_data_format)
117
+ image_sizes.append([height, width])
118
+
119
+ tmp_image_sizes = []
120
+ for height, width in image_sizes:
121
+ h_bar = round(height / factor) * factor
122
+ w_bar = round(width / factor) * factor
123
+ if h_bar * w_bar > (max_pixels // num_images):
124
+ beta = math.sqrt((height * width) / (max_pixels // num_images))
125
+ h_bar = math.floor(height / beta / factor) * factor
126
+ w_bar = math.floor(width / beta / factor) * factor
127
+ # per image min_pixels
128
+ if h_bar * w_bar < min_pixels:
129
+ beta = math.sqrt(min_pixels / (height * width))
130
+ h_bar = math.ceil(height * beta / factor) * factor
131
+ w_bar = math.ceil(width * beta / factor) * factor
132
+ tmp_image_sizes.append((h_bar, w_bar))
133
+ image_sizes = tmp_image_sizes
134
+ return image_sizes
135
+
136
+
137
+ def batched_resize(
138
+ images, factors: List[int], min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
139
+ ):
140
+ image_sizes = []
141
+ for image in images:
142
+ if is_valid_video(image):
143
+ num_frame = len(image)
144
+ image = image[0]
145
+ else:
146
+ num_frame = 1
147
+ if isinstance(image, Image.Image):
148
+ width, height = image.size
149
+ else:
150
+ height, width = get_image_size(image, channel_dim=input_data_format)
151
+ image_sizes.append([num_frame, height, width])
152
+
153
+ # global max_pixels
154
+ smart_scale_factors = 1.0
155
+ total_tokens = 0
156
+ for (num_frame, height, width), factor in zip(image_sizes, factors):
157
+ total_tokens += num_frame * math.ceil(height / factor) * math.ceil(width / factor)
158
+
159
+ # TODO: add min_pixels
160
+ if total_tokens > max_tokens:
161
+ beta = math.sqrt(total_tokens / max_tokens)
162
+ tmp_image_sizes = []
163
+ for (_, height, width), factor in zip(image_sizes, factors):
164
+ h_bar = math.floor(height / beta / factor) * factor
165
+ w_bar = math.floor(width / beta / factor) * factor
166
+ tmp_image_sizes.append((h_bar, w_bar))
167
+ image_sizes = tmp_image_sizes
168
+ else:
169
+ tmp_image_sizes = []
170
+ for (_, height, width), factor in zip(image_sizes, factors):
171
+ height = round(height / factor) * factor
172
+ width = round(width / factor) * factor
173
+ tmp_image_sizes.append((height, width))
174
+ image_sizes = tmp_image_sizes
175
+
176
+ return image_sizes
177
+
178
+
179
+ class SFLImageProcessor(BaseImageProcessor):
180
+ r"""
181
+ Constructs a DAMOVL image processor that dynamically resizes images based on the original images.
182
+
183
+ Args:
184
+ do_resize (`bool`, *optional*, defaults to `True`):
185
+ Whether to resize the image's (height, width) dimensions.
186
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
187
+ Resampling filter to use when resizing the image.
188
+ do_rescale (`bool`, *optional*, defaults to `True`):
189
+ Whether to rescale the image by the specified scale `rescale_factor`.
190
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
191
+ Scale factor to use if rescaling the image.
192
+ do_normalize (`bool`, *optional*, defaults to `True`):
193
+ Whether to normalize the image.
194
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
195
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
196
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
197
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
198
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
199
+ Whether to convert the image to RGB.
200
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
201
+ The min pixels of the image to resize the image.
202
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
203
+ The max pixels of the image to resize the image.
204
+ patch_size (`int`, *optional*, defaults to 14):
205
+ The spacial patch size of the vision encoder.
206
+ """
207
+
208
+ model_input_names = ["pixel_values", "grid_sizes", "merge_sizes"]
209
+
210
+ def __init__(
211
+ self,
212
+ do_resize: bool = True,
213
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
214
+ do_rescale: bool = True,
215
+ rescale_factor: Union[int, float] = 1 / 255,
216
+ do_normalize: bool = True,
217
+ image_mean: Optional[Union[float, List[float]]] = None,
218
+ image_std: Optional[Union[float, List[float]]] = None,
219
+ do_convert_rgb: bool = True,
220
+ min_tokens: int = 4 * 4,
221
+ max_tokens: int = 16384,
222
+ patch_size: int = 14,
223
+ **kwargs,
224
+ ) -> None:
225
+ super().__init__(**kwargs)
226
+ self.do_resize = do_resize
227
+ self.resample = resample
228
+ self.do_rescale = do_rescale
229
+ self.rescale_factor = rescale_factor
230
+ self.do_normalize = do_normalize
231
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
232
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
233
+ self.min_tokens = min_tokens
234
+ self.max_tokens = max_tokens
235
+ self.patch_size = patch_size
236
+ self.do_convert_rgb = do_convert_rgb
237
+
238
+ def _preprocess(
239
+ self,
240
+ images: Union[ImageInput, VideoInput],
241
+ target_size: List[int],
242
+ merge_size: int = 1,
243
+ do_resize: bool = None,
244
+ resample: PILImageResampling = None,
245
+ do_rescale: bool = None,
246
+ rescale_factor: float = None,
247
+ do_normalize: bool = None,
248
+ image_mean: Optional[Union[float, List[float]]] = None,
249
+ image_std: Optional[Union[float, List[float]]] = None,
250
+ do_convert_rgb: bool = None,
251
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
252
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
253
+ ):
254
+ """
255
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
256
+
257
+ Args:
258
+ images (`ImageInput`):
259
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
260
+ target_size (`List[int]`):
261
+ The target size to resize the image to. Should be a list of two integers: [target_height, target_width].
262
+ merge_size (`int`, *optional*, defaults to `1`):
263
+ The merge size after the vision encoder.
264
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
265
+ Whether to resize the image.
266
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
267
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
268
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
269
+ Whether to rescale the image.
270
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
271
+ Scale factor to use if rescaling the image.
272
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
273
+ Whether to normalize the image.
274
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
275
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
276
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
277
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
278
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
279
+ Whether to convert the image to RGB.
280
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
281
+ The channel dimension format for the output image. Can be one of:
282
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
283
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
284
+ - Unset: Use the channel dimension format of the input image.
285
+ input_data_format (`ChannelDimension` or `str`, *optional*):
286
+ The channel dimension format for the input image. Can be one of:
287
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
288
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
289
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
290
+ """
291
+ images = make_list_of_images(images)
292
+
293
+ if do_convert_rgb:
294
+ images = [convert_to_rgb(image) for image in images]
295
+
296
+ # All transformations expect numpy arrays.
297
+ images = [to_numpy_array(image) for image in images]
298
+
299
+ if is_scaled_image(images[0]) and do_rescale:
300
+ logger.warning_once(
301
+ "It looks like you are trying to rescale already rescaled images. If the input"
302
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
303
+ )
304
+ if input_data_format is None:
305
+ # We assume that all images have the same channel dimension format.
306
+ input_data_format = infer_channel_dimension_format(images[0])
307
+
308
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
309
+ resized_height, resized_width = height, width
310
+ processed_images = []
311
+ for image in images:
312
+ if do_resize:
313
+ resized_height, resized_width = target_size
314
+ image = resize(
315
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
316
+ )
317
+
318
+ if do_rescale:
319
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
320
+
321
+ if do_normalize:
322
+ image = self.normalize(
323
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
324
+ )
325
+
326
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
327
+ processed_images.append(image)
328
+
329
+ patches = np.array(processed_images)
330
+ if data_format == ChannelDimension.LAST:
331
+ patches = patches.transpose(0, 3, 1, 2)
332
+ t = patches.shape[0]
333
+ channel = patches.shape[1]
334
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
335
+ patches = patches.reshape(
336
+ t,
337
+ channel,
338
+ grid_h // merge_size,
339
+ merge_size,
340
+ self.patch_size,
341
+ grid_w // merge_size,
342
+ merge_size,
343
+ self.patch_size,
344
+ )
345
+ patches = patches.transpose(0, 2, 5, 3, 6, 1, 4, 7)
346
+ flatten_patches = patches.reshape(
347
+ t * grid_h * grid_w, channel * self.patch_size * self.patch_size
348
+ )
349
+
350
+ return flatten_patches, (t, grid_h, grid_w)
351
+
352
+ def preprocess(
353
+ self,
354
+ images: ImageInput,
355
+ do_resize: bool = None,
356
+ resample: PILImageResampling = None,
357
+ do_rescale: bool = None,
358
+ rescale_factor: float = None,
359
+ do_normalize: bool = None,
360
+ image_mean: Optional[Union[float, List[float]]] = None,
361
+ image_std: Optional[Union[float, List[float]]] = None,
362
+ do_convert_rgb: bool = None,
363
+ merge_size: Optional[Union[int, List[int]]] = None,
364
+ return_tensors: Optional[Union[str, TensorType]] = None,
365
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
366
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
367
+ ):
368
+ """
369
+ Args:
370
+ images (`ImageInput`):
371
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
372
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
373
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
374
+ Whether to resize the image.
375
+ resample (`int`, *optional*, defaults to `self.resample`):
376
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
377
+ has an effect if `do_resize` is set to `True`.
378
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
379
+ Whether to rescale the image.
380
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
381
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
382
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
383
+ Whether to normalize the image.
384
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
385
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
386
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
387
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
388
+ `True`.
389
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
390
+ Whether to convert the image to RGB.
391
+ return_tensors (`str` or `TensorType`, *optional*):
392
+ The type of tensors to return. Can be one of:
393
+ - Unset: Return a list of `np.ndarray`.
394
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
395
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
396
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
397
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
398
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
399
+ The channel dimension format for the output image. Can be one of:
400
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
401
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
402
+ - Unset: Use the channel dimension format of the input image.
403
+ input_data_format (`ChannelDimension` or `str`, *optional*):
404
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
405
+ from the input image. Can be one of:
406
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
407
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
408
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
409
+
410
+ """
411
+ do_resize = do_resize if do_resize is not None else self.do_resize
412
+ resample = resample if resample is not None else self.resample
413
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
414
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
415
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
416
+ image_mean = image_mean if image_mean is not None else self.image_mean
417
+ image_std = image_std if image_std is not None else self.image_std
418
+ merge_size = merge_size if merge_size is not None else self.merge_size
419
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
420
+
421
+ images = make_batched_images(images)
422
+
423
+ if isinstance(merge_size, (list, tuple)):
424
+ assert len(merge_size) == len(images), "Merge size must be the same length as images."
425
+ merge_sizes = merge_size
426
+ else:
427
+ merge_sizes = [merge_size for _ in images]
428
+
429
+ if all(merge_size == merge_sizes[0] for merge_size in merge_sizes):
430
+ target_sizes = simple_batched_resize(
431
+ images,
432
+ factor=self.patch_size * merge_sizes[0],
433
+ min_tokens=self.min_tokens,
434
+ max_tokens=self.max_tokens,
435
+ input_data_format=input_data_format,
436
+ )
437
+ else:
438
+ target_sizes = batched_resize(
439
+ images,
440
+ factors=[self.patch_size * merge_size for merge_size in merge_sizes],
441
+ min_tokens=self.min_tokens,
442
+ max_tokens=self.max_tokens,
443
+ input_data_format=input_data_format,
444
+ )
445
+
446
+ pixel_values, grid_sizes = [], []
447
+ for image, merge_size, target_size in zip(images, merge_sizes, target_sizes):
448
+ patches, grid_size = self._preprocess(
449
+ image,
450
+ target_size=target_size,
451
+ merge_size=merge_size,
452
+ do_resize=do_resize,
453
+ resample=resample,
454
+ do_rescale=do_rescale,
455
+ rescale_factor=rescale_factor,
456
+ do_normalize=do_normalize,
457
+ image_mean=image_mean,
458
+ image_std=image_std,
459
+ data_format=data_format,
460
+ do_convert_rgb=do_convert_rgb,
461
+ input_data_format=input_data_format,
462
+ )
463
+ pixel_values.append(patches)
464
+ grid_sizes.append(grid_size)
465
+
466
+ pixel_values = np.concatenate(pixel_values, axis=0)
467
+ grid_sizes = np.array(grid_sizes)
468
+ merge_sizes = np.array(merge_sizes)
469
+
470
+ data = {
471
+ "pixel_values": pixel_values,
472
+ "grid_sizes": grid_sizes,
473
+ "merge_sizes": merge_sizes,
474
+ }
475
+
476
+ return BatchFeature(data=data, tensor_type=return_tensors)
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb961d8369bc4582a7e4ee1ed830e157ade0d181696a7343208d82dd3e95c832
3
+ size 4993416376
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b665473c46a773ede1c33e86fbe0146dbb87a7034ac422a3420a8ee030e08a5
3
+ size 1278741680
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_qwen2_audio_encoder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2AudioEncoder
2
+ import torch
3
+ from torch import nn
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+ import torch.nn.functional as F
6
+
7
+ class Qwen2AudioEncoderModel(Qwen2AudioEncoder):
8
+ def forward(
9
+ self,
10
+ input_features,
11
+ attention_mask=None,
12
+ head_mask=None,
13
+ output_attentions=None,
14
+ output_hidden_states=None,
15
+ return_dict=None,
16
+ ):
17
+ r"""
18
+ Args:
19
+ attention_mask (`torch.Tensor`)`, *optional*):
20
+ Qwen2Audio does not support masking of the `input_features`, this argument is preserved for compatibility,
21
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
22
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
23
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
24
+
25
+ - 1 indicates the head is **not masked**,
26
+ - 0 indicates the head is **masked**.
27
+ output_attentions (`bool`, *optional*):
28
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
29
+ returned tensors for more detail.
30
+ output_hidden_states (`bool`, *optional*):
31
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
32
+ for more detail.
33
+ return_dict (`bool`, *optional*):
34
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
35
+ """
36
+
37
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
38
+
39
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
40
+ output_hidden_states = (
41
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
42
+ )
43
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
44
+
45
+ # Ignore copy
46
+ input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
47
+
48
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
49
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
50
+
51
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
52
+ embed_pos = self.embed_positions.weight
53
+
54
+ hidden_states = inputs_embeds + embed_pos[: inputs_embeds.shape[1], :]
55
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
56
+
57
+ encoder_states = () if output_hidden_states else None
58
+ all_attentions = () if output_attentions else None
59
+
60
+ # check if head_mask has a correct number of layers specified if desired
61
+ if head_mask is not None:
62
+ assert head_mask.size()[0] == (len(self.layers)), (
63
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
64
+ )
65
+
66
+ for idx, encoder_layer in enumerate(self.layers):
67
+ if output_hidden_states:
68
+ encoder_states = encoder_states + (hidden_states,)
69
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
70
+ to_drop = False
71
+ if self.training:
72
+ dropout_probability = torch.rand([])
73
+ if dropout_probability < self.layerdrop: # skip the layer
74
+ to_drop = True
75
+
76
+ # Ignore copy
77
+ if to_drop:
78
+ layer_outputs = (None, None)
79
+ else:
80
+ layer_outputs = encoder_layer(
81
+ hidden_states,
82
+ attention_mask,
83
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
84
+ output_attentions=output_attentions,
85
+ )
86
+
87
+ hidden_states = layer_outputs[0]
88
+
89
+ if output_attentions:
90
+ all_attentions = all_attentions + (layer_outputs[1],)
91
+
92
+ # Ignore copy
93
+ # hidden_states = hidden_states.permute(0, 2, 1)
94
+ # hidden_states = self.avg_pooler(hidden_states)
95
+ # hidden_states = F.max_pool1d(hidden_states, kernel_size=2)
96
+ # hidden_states = hidden_states.permute(0, 2, 1)
97
+
98
+ hidden_states = self.layer_norm(hidden_states)
99
+ if output_hidden_states:
100
+ encoder_states = encoder_states + (hidden_states,)
101
+
102
+ if not return_dict:
103
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
104
+ return BaseModelOutput(
105
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
106
+ )
modeling_sfl_encoder_qwen3.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import math
4
+ import warnings
5
+ from functools import partial
6
+ from .configuration_sfl_encoder import SFLVisionEncoderConfigFromQwen3
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3Attention, rotate_half, Qwen3DecoderLayer
9
+ from typing import List, Optional, Tuple, Union
10
+ from transformers.modeling_outputs import BaseModelOutputWithPast
11
+ from transformers.processing_utils import Unpack
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ from transformers.cache_utils import Cache, DynamicCache
14
+ from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
15
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+ import torch.nn.functional as F
18
+ if is_flash_attn_2_available():
19
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
20
+ from flash_attn import flash_attn_varlen_func
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ def _trunc_normal_(tensor, mean, std, a, b):
25
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
26
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
27
+ def norm_cdf(x):
28
+ # Computes standard normal cumulative distribution function
29
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
30
+
31
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
32
+ warnings.warn(
33
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
34
+ "The distribution of values may be incorrect.",
35
+ stacklevel=2,
36
+ )
37
+
38
+ # Values are generated by using a truncated uniform distribution and
39
+ # then using the inverse CDF for the normal distribution.
40
+ # Get upper and lower cdf values
41
+ l = norm_cdf((a - mean) / std)
42
+ u = norm_cdf((b - mean) / std)
43
+
44
+ # Uniformly fill tensor with values from [l, u], then translate to
45
+ # [2l-1, 2u-1].
46
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
47
+
48
+ # Use inverse cdf transform for normal distribution to get truncated
49
+ # standard normal
50
+ tensor.erfinv_()
51
+
52
+ # Transform to proper mean, std
53
+ tensor.mul_(std * math.sqrt(2.0))
54
+ tensor.add_(mean)
55
+
56
+ # Clamp to ensure it's in the proper range
57
+ tensor.clamp_(min=a, max=b)
58
+
59
+
60
+ def trunc_normal_tf_(
61
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
62
+ ) -> torch.Tensor:
63
+ """Fills the input Tensor with values drawn from a truncated
64
+ normal distribution. The values are effectively drawn from the
65
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
66
+ with values outside :math:`[a, b]` redrawn until they are within
67
+ the bounds. The method used for generating the random values works
68
+ best when :math:`a \\leq \text{mean} \\leq b`.
69
+
70
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
71
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
72
+ and the result is subsequently scaled and shifted by the mean and std args.
73
+
74
+ Args:
75
+ tensor: an n-dimensional `torch.Tensor`
76
+ mean: the mean of the normal distribution
77
+ std: the standard deviation of the normal distribution
78
+ a: the minimum cutoff value
79
+ b: the maximum cutoff value
80
+ """
81
+ with torch.no_grad():
82
+ _trunc_normal_(tensor, 0, 1.0, a, b)
83
+ tensor.mul_(std).add_(mean)
84
+
85
+
86
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
87
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
88
+ if mode == "fan_in":
89
+ denom = fan_in
90
+ elif mode == "fan_out":
91
+ denom = fan_out
92
+ elif mode == "fan_avg":
93
+ denom = (fan_in + fan_out) / 2
94
+
95
+ variance = scale / denom
96
+
97
+ if distribution == "truncated_normal":
98
+ # constant is stddev of standard normal truncated to (-2, 2)
99
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
100
+ elif distribution == "normal":
101
+ with torch.no_grad():
102
+ tensor.normal_(std=math.sqrt(variance))
103
+ elif distribution == "uniform":
104
+ bound = math.sqrt(3 * variance)
105
+ with torch.no_grad():
106
+ tensor.uniform_(-bound, bound)
107
+ else:
108
+ raise ValueError(f"invalid distribution {distribution}")
109
+
110
+
111
+ def lecun_normal_(tensor):
112
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
113
+
114
+
115
+ class SFLVisionEncoderEmbeddings(nn.Module):
116
+
117
+ def __init__(self, config: SFLVisionEncoderConfigFromQwen3):
118
+ super().__init__()
119
+ self.config = config
120
+ self.embed_dim = config.hidden_size
121
+ self.patch_size = config.patch_size
122
+
123
+ self.patch_embedding = nn.Conv2d(
124
+ in_channels=config.num_channels,
125
+ out_channels=self.embed_dim,
126
+ kernel_size=self.patch_size,
127
+ stride=self.patch_size,
128
+ padding="valid",
129
+ )
130
+
131
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
132
+ hidden_states = hidden_states.view(
133
+ -1, self.config.num_channels, self.patch_size, self.patch_size
134
+ )
135
+ patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
136
+ # embeddings = patch_embeds.flatten(2).transpose(1, 2)
137
+ embeddings = patch_embeds.view(-1, self.embed_dim)
138
+
139
+ return embeddings
140
+
141
+
142
+ class VisualRotaryEmbedding(nn.Module):
143
+ def __init__(
144
+ self,
145
+ dim=None,
146
+ max_position_embeddings=2048,
147
+ base=10000,
148
+ device=None,
149
+ scaling_factor=1.0,
150
+ rope_type="default",
151
+ config = None,
152
+ ):
153
+ super().__init__()
154
+ # TODO (joao): remove the `if` below, only used for BC
155
+ self.rope_kwargs = {}
156
+ if config is None:
157
+ logger.warning_once(
158
+ "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
159
+ "`config` argument. All other arguments will be removed in v4.46"
160
+ )
161
+ self.rope_kwargs = {
162
+ "rope_type": rope_type,
163
+ "factor": scaling_factor,
164
+ "dim": dim,
165
+ "base": base,
166
+ "max_position_embeddings": max_position_embeddings,
167
+ }
168
+ self.rope_type = rope_type
169
+ self.max_seq_len_cached = max_position_embeddings
170
+ self.original_max_seq_len = max_position_embeddings
171
+ else:
172
+ # BC: "rope_type" was originally "type"
173
+ if config.rope_scaling is not None:
174
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
175
+ else:
176
+ self.rope_type = "default"
177
+ self.max_seq_len_cached = config.max_position_embeddings
178
+ self.original_max_seq_len = config.max_position_embeddings
179
+
180
+ self.config = config
181
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
182
+
183
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
184
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
185
+ self.original_inv_freq = self.inv_freq
186
+
187
+ def _dynamic_frequency_update(self, position_ids, device):
188
+ """
189
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
190
+ 1 - growing beyond the cached sequence length (allow scaling)
191
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
192
+ """
193
+ seq_len = torch.max(position_ids) + 1
194
+ if seq_len > self.max_seq_len_cached: # growth
195
+ inv_freq, self.attention_scaling = self.rope_init_fn(
196
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
197
+ )
198
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
199
+ self.max_seq_len_cached = seq_len
200
+
201
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
202
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
203
+ self.max_seq_len_cached = self.original_max_seq_len
204
+
205
+ @torch.no_grad()
206
+ def forward(self, x, position_ids):
207
+ if "dynamic" in self.rope_type:
208
+ self._dynamic_frequency_update(position_ids, device=x.device)
209
+
210
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(2, position_ids.shape[1], -1, 1)
211
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (2, bs, 1, positions)
212
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
213
+ device_type = x.device.type
214
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
215
+ with torch.autocast(device_type=device_type, enabled=False):
216
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
217
+ emb = torch.cat((freqs, freqs), dim=-1)
218
+ cos = emb.cos()
219
+ sin = emb.sin()
220
+
221
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
222
+ cos = cos * self.attention_scaling
223
+ sin = sin * self.attention_scaling
224
+
225
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
226
+
227
+
228
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
229
+ rope_section = [cos.shape[-1] // 2, cos.shape[-1] // 2]
230
+ cos = torch.cat([m[i % 2] for i, m in enumerate(cos.split(rope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
231
+ sin = torch.cat([m[i % 2] for i, m in enumerate(sin.split(rope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
232
+
233
+ q_embed = (q * cos) + (rotate_half(q) * sin)
234
+ k_embed = (k * cos) + (rotate_half(k) * sin)
235
+ return q_embed, k_embed
236
+
237
+
238
+ class SFLQwen3Attention(Qwen3Attention):
239
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
240
+
241
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
242
+ def __init__(self, *args, **kwargs):
243
+ super().__init__(*args, **kwargs)
244
+ self.is_causal = False
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
250
+ attention_mask: Optional[torch.Tensor],
251
+ past_key_value: Optional[Cache] = None,
252
+ cache_position: Optional[torch.LongTensor] = None,
253
+ cu_seqlens: Optional[torch.Tensor] = None,
254
+ **kwargs: Unpack[FlashAttentionKwargs],
255
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
256
+ input_shape = hidden_states.shape[:-1]
257
+ hidden_shape = (*input_shape, -1, self.head_dim)
258
+
259
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
260
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
261
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
262
+
263
+ cos, sin = position_embeddings
264
+ query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin)
265
+
266
+ if past_key_value is not None:
267
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
268
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
269
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
270
+
271
+ # This is before the transpose
272
+ seq_len = query_states.shape[2]
273
+
274
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
275
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
276
+ # cast them back in the correct dtype just to be sure everything works as expected.
277
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
278
+ # in fp32. (usually our RMSNorm modules handle it correctly)
279
+ target_dtype = None
280
+ if query_states.dtype == torch.float32:
281
+ if torch.is_autocast_enabled():
282
+ target_dtype = torch.get_autocast_gpu_dtype()
283
+ # Handle the case where the model is quantized
284
+ elif hasattr(self.config, "_pre_quantization_dtype"):
285
+ target_dtype = self.config._pre_quantization_dtype
286
+ else:
287
+ target_dtype = next(layer for layer in self.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
288
+
289
+ # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
290
+ kwargs.pop("is_causal", None)
291
+
292
+ # Reashape to the expected shape for Flash Attention
293
+ query_states = query_states.transpose(1, 2).squeeze(0)
294
+ key_states = key_states.transpose(1, 2).squeeze(0)
295
+ value_states = value_states.transpose(1, 2).squeeze(0)
296
+
297
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
298
+ attn_output = flash_attn_varlen_func(
299
+ query_states,
300
+ key_states,
301
+ value_states,
302
+ cu_seqlens_q=cu_seqlens,
303
+ cu_seqlens_k=cu_seqlens,
304
+ max_seqlen_q=max_seqlen,
305
+ max_seqlen_k=max_seqlen,
306
+ dropout_p=0.0 if not self.training else self.attention_dropout,
307
+ causal=self.is_causal
308
+ )
309
+
310
+ # attn_output = _flash_attention_forward(
311
+ # query_states,
312
+ # key_states,
313
+ # value_states,
314
+ # attention_mask,
315
+ # q_len,
316
+ # position_ids=position_ids,
317
+ # dropout=dropout_rate,
318
+ # sliding_window=sliding_window,
319
+ # is_causal=self.is_causal,
320
+ # use_top_left_mask=self._flash_attn_uses_top_left_mask,
321
+ # )
322
+
323
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
324
+ attn_output = self.o_proj(attn_output)
325
+ return attn_output, None
326
+
327
+
328
+ class SFLQwen3DecoderLayer(Qwen3DecoderLayer):
329
+ def __init__(self, config: SFLVisionEncoderConfigFromQwen3, layer_idx: int):
330
+ super(SFLQwen3DecoderLayer, self).__init__(config, layer_idx)
331
+ self.self_attn = SFLQwen3Attention(config, layer_idx)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states: torch.Tensor,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ position_ids: Optional[torch.LongTensor] = None,
338
+ past_key_value: Optional[Cache] = None,
339
+ output_attentions: Optional[bool] = False,
340
+ use_cache: Optional[bool] = False,
341
+ cache_position: Optional[torch.LongTensor] = None,
342
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
343
+ cu_seqlens: Optional[torch.Tensor] = None,
344
+ **kwargs: Unpack[FlashAttentionKwargs],
345
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
346
+ residual = hidden_states
347
+
348
+ hidden_states = self.input_layernorm(hidden_states)
349
+
350
+ # Self Attention
351
+ hidden_states, self_attn_weights = self.self_attn(
352
+ hidden_states=hidden_states,
353
+ attention_mask=attention_mask,
354
+ position_ids=position_ids,
355
+ past_key_value=past_key_value,
356
+ output_attentions=output_attentions,
357
+ use_cache=use_cache,
358
+ cache_position=cache_position,
359
+ position_embeddings=position_embeddings,
360
+ cu_seqlens=cu_seqlens,
361
+ **kwargs,
362
+ )
363
+ hidden_states = residual + hidden_states
364
+
365
+ # Fully Connected
366
+ residual = hidden_states
367
+ hidden_states = self.post_attention_layernorm(hidden_states)
368
+ hidden_states = self.mlp(hidden_states)
369
+ hidden_states = residual + hidden_states
370
+
371
+ outputs = (hidden_states,)
372
+ if output_attentions:
373
+ outputs += (self_attn_weights,)
374
+
375
+ return outputs
376
+
377
+
378
+ class SFLVisionEncoderFromQwen3Model(Qwen3Model):
379
+ config_class = SFLVisionEncoderConfigFromQwen3
380
+ def __init__(self, config: SFLVisionEncoderConfigFromQwen3):
381
+ super().__init__(config)
382
+ self.layers = nn.ModuleList(
383
+ [SFLQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
384
+ )
385
+ self.rotary_emb = VisualRotaryEmbedding(config=config)
386
+ del self.embed_tokens
387
+
388
+ @staticmethod
389
+ def _prepare_4d_causal_attention_mask_with_cache_position(
390
+ attention_mask: torch.Tensor,
391
+ sequence_length: int,
392
+ target_length: int,
393
+ dtype: torch.dtype,
394
+ device: torch.device,
395
+ cache_position: torch.Tensor,
396
+ batch_size: int,
397
+ config: SFLVisionEncoderConfigFromQwen3,
398
+ past_key_values: Cache,
399
+ ):
400
+ """
401
+ Override the original causal mask method to create full attention mask instead.
402
+ Creates a full attention 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
403
+ from a 2D mask of shape `(batch_size, key_value_length)`.
404
+
405
+ For vision encoding, we want full attention between all patches, not causal attention.
406
+ """
407
+ if attention_mask is not None and attention_mask.dim() == 4:
408
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
409
+ full_attention_mask = attention_mask
410
+ else:
411
+ # Create full attention mask (all zeros, meaning attend to all positions)
412
+ # We only mask based on the provided attention_mask for padding
413
+ if attention_mask is not None:
414
+ # Use the provided attention_mask to handle padding
415
+ min_dtype = torch.finfo(dtype).min
416
+ full_attention_mask = torch.zeros(
417
+ (sequence_length, target_length), dtype=dtype, device=device
418
+ )
419
+ # Expand to 4D
420
+ full_attention_mask = full_attention_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
421
+
422
+ # Apply padding mask if provided
423
+ full_attention_mask = full_attention_mask.clone() # copy to contiguous memory for in-place edit
424
+ if attention_mask.shape[-1] > target_length:
425
+ attention_mask = attention_mask[:, :target_length]
426
+ mask_length = attention_mask.shape[-1]
427
+ padding_mask = attention_mask[:, None, None, :] == 0
428
+ full_attention_mask[:, :, :, :mask_length] = full_attention_mask[:, :, :, :mask_length].masked_fill(
429
+ padding_mask, min_dtype
430
+ )
431
+ else:
432
+ # No attention mask provided, create all-zeros mask (full attention)
433
+ full_attention_mask = torch.zeros(
434
+ (batch_size, 1, sequence_length, target_length), dtype=dtype, device=device
435
+ )
436
+ return full_attention_mask
437
+
438
+ def get_rope_index(self, grid_sizes, merge_sizes, position_ids):
439
+ position_ids = position_ids.contiguous()
440
+ """
441
+ Generate position indices for RoPE:
442
+ - Vision tokens (vision_mask=True): use 2D position encoding like (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
443
+ - Text tokens (vision_mask=False): use 1D position encoding like (3,3), (4,4), (5,5)
444
+ """
445
+ batch_size = grid_sizes.shape[0]
446
+
447
+ # Vision Part: Generate 2D position indices for vision tokens
448
+ vision_pos_ids = []
449
+ for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
450
+ # Generate height position indices
451
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w).to(position_ids.device)
452
+ hpos_ids = hpos_ids.reshape(
453
+ h // merge_size,
454
+ merge_size,
455
+ w // merge_size,
456
+ merge_size,
457
+ )
458
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
459
+ hpos_ids = hpos_ids.flatten()
460
+
461
+ # Generate width position indices
462
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1).to(position_ids.device)
463
+ wpos_ids = wpos_ids.reshape(
464
+ h // merge_size,
465
+ merge_size,
466
+ w // merge_size,
467
+ merge_size,
468
+ )
469
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
470
+ wpos_ids = wpos_ids.flatten()
471
+
472
+ # Stack height and width to create 2D positions
473
+ vision_pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
474
+
475
+ num_start_idx = 0
476
+ for batch_idx in range(batch_size):
477
+ pos_len = vision_pos_ids[batch_idx].shape[0]
478
+ position_ids[:, 0, num_start_idx: num_start_idx+pos_len] = vision_pos_ids[batch_idx].permute(1, 0)
479
+ num_start_idx += pos_len
480
+
481
+ return position_ids # shape: (2, batch_size, seq_len)
482
+
483
+ # def get_rope_index(self, grid_sizes, merge_sizes, position_ids):
484
+ # position_ids = position_ids.contiguous()
485
+ # """
486
+ # Generate polar (r, φ) position indices for RoPE:
487
+ # - Vision tokens (vision_mask=True): use 2D polar coordinates
488
+ # r = sqrt((h - c_h)^2 + (w - c_w)^2)
489
+ # φ = atan2(h - c_h, w - c_w)
490
+ # - Text tokens (vision_mask=False): keep 1D indices unchanged
491
+ # """
492
+ # batch_size = grid_sizes.shape[0]
493
+
494
+ # vision_pos_ids = []
495
+ # for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
496
+ # device = position_ids.device
497
+
498
+ # h_idx = torch.arange(h, device=device)
499
+ # w_idx = torch.arange(w, device=device)
500
+ # hh, ww = torch.meshgrid(h_idx, w_idx, indexing='ij')
501
+
502
+ # hh = hh.reshape(h // merge_size, merge_size, w // merge_size, merge_size)
503
+ # ww = ww.reshape(h // merge_size, merge_size, w // merge_size, merge_size)
504
+ # hh = hh.permute(0, 2, 1, 3).flatten()
505
+ # ww = ww.permute(0, 2, 1, 3).flatten()
506
+
507
+ # center_h = (h - 1) / 2
508
+ # center_w = (w - 1) / 2
509
+
510
+ # rh = hh.float() - center_h
511
+ # rw = ww.float() - center_w
512
+ # r = torch.sqrt(rh ** 2 + rw ** 2)
513
+ # phi = torch.atan2(rh, rw) # [-pi, pi]
514
+
515
+ # # r_norm = r / r.max()
516
+ # # phi_norm = (phi + math.pi) / (2 * math.pi)
517
+
518
+ # vision_pos_ids.append(torch.stack([r, phi], dim=-1).repeat(t, 1))
519
+
520
+ # num_start_idx = 0
521
+ # for batch_idx in range(batch_size):
522
+ # pos_len = vision_pos_ids[batch_idx].shape[0]
523
+ # position_ids[:, 0, num_start_idx:num_start_idx + pos_len] = vision_pos_ids[batch_idx].permute(1, 0)
524
+ # num_start_idx += pos_len
525
+
526
+ # return position_ids # shape: (2, batch_size, seq_len)
527
+
528
+
529
+ def forward(
530
+ self,
531
+ input_ids: Optional[torch.LongTensor] = None,
532
+ attention_mask: Optional[torch.Tensor] = None,
533
+ position_ids: Optional[torch.LongTensor] = None,
534
+ past_key_values: Optional[Cache] = None,
535
+ inputs_embeds: Optional[torch.FloatTensor] = None,
536
+ use_cache: Optional[bool] = None,
537
+ output_attentions: Optional[bool] = None,
538
+ output_hidden_states: Optional[bool] = None,
539
+ cache_position: Optional[torch.LongTensor] = None,
540
+ grid_sizes: Optional[torch.Tensor] = None,
541
+ merge_sizes: Optional[torch.Tensor] = None,
542
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
543
+ ) -> BaseModelOutputWithPast:
544
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
545
+ output_hidden_states = (
546
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
547
+ )
548
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
549
+
550
+ if (input_ids is None) ^ (inputs_embeds is not None):
551
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
552
+
553
+ if self.gradient_checkpointing and self.training and use_cache:
554
+ logger.warning_once(
555
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
556
+ )
557
+ use_cache = False
558
+
559
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
560
+ if not isinstance(past_key_values, (type(None), Cache)):
561
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
562
+
563
+ if inputs_embeds is None:
564
+ inputs_embeds = self.embed_tokens(input_ids)
565
+
566
+ if use_cache and past_key_values is None:
567
+ past_key_values = DynamicCache()
568
+
569
+ if cache_position is None:
570
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
571
+ cache_position = torch.arange(
572
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
573
+ )
574
+
575
+ # the hard coded `2` is for temporal, height and width.
576
+ if position_ids is None:
577
+ position_ids = cache_position.view(1, 1, -1).expand(2, inputs_embeds.shape[0], -1)
578
+ elif position_ids.dim() == 2:
579
+ position_ids = position_ids[None, ...].expand(2, position_ids.shape[0], -1)
580
+ position_ids = self.get_rope_index(grid_sizes, merge_sizes, position_ids)
581
+
582
+ causal_mask = None
583
+
584
+ hidden_states = inputs_embeds
585
+
586
+ # create position embeddings to be shared across the decoder layers
587
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
588
+
589
+ # decoder layers
590
+ all_hidden_states = () if output_hidden_states else None
591
+ all_self_attns = () if output_attentions else None
592
+
593
+ # Calculate cumulative sequence lengths for the grid sizes
594
+ cu_seqlens = torch.repeat_interleave(grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]).cumsum(dim=0, dtype=torch.int32)
595
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
596
+
597
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
598
+ if output_hidden_states:
599
+ all_hidden_states += (hidden_states,)
600
+
601
+ if self.gradient_checkpointing and self.training:
602
+ layer_outputs = self._gradient_checkpointing_func(
603
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
604
+ hidden_states,
605
+ causal_mask,
606
+ position_ids,
607
+ past_key_values,
608
+ output_attentions,
609
+ use_cache,
610
+ cache_position,
611
+ position_embeddings,
612
+ cu_seqlens,
613
+ )
614
+ else:
615
+ layer_outputs = decoder_layer(
616
+ hidden_states,
617
+ attention_mask=causal_mask,
618
+ position_ids=position_ids,
619
+ past_key_value=past_key_values,
620
+ output_attentions=output_attentions,
621
+ use_cache=use_cache,
622
+ cache_position=cache_position,
623
+ position_embeddings=position_embeddings,
624
+ cu_seqlens=cu_seqlens,
625
+ **flash_attn_kwargs,
626
+ )
627
+
628
+ hidden_states = layer_outputs[0]
629
+
630
+ if output_attentions:
631
+ all_self_attns += (layer_outputs[1],)
632
+
633
+ hidden_states = self.norm(hidden_states)
634
+
635
+ # add hidden states from the last decoder layer
636
+ if output_hidden_states:
637
+ all_hidden_states += (hidden_states,)
638
+
639
+ return BaseModelOutputWithPast(
640
+ last_hidden_state=hidden_states,
641
+ past_key_values=past_key_values if use_cache else None,
642
+ hidden_states=all_hidden_states,
643
+ attentions=all_self_attns,
644
+ )
645
+
646
+
647
+ class SFLVisionEncoderModelFromQwen3(PreTrainedModel):
648
+
649
+ config_class = SFLVisionEncoderConfigFromQwen3
650
+ base_model_prefix = "sfl_vision_encoder_qwen3"
651
+ main_input_name = "pixel_values"
652
+ supports_gradient_checkpointing = True
653
+ _no_split_modules = [
654
+ "SFLVisionEncoderEmbeddings",
655
+ ]
656
+ _supports_flash_attn_2 = True
657
+ _supports_sdpa = True
658
+
659
+ def __init__(self, config: SFLVisionEncoderConfigFromQwen3):
660
+ super().__init__(config=config)
661
+ self.embeddings = SFLVisionEncoderEmbeddings(config)
662
+ self.encoder = SFLVisionEncoderFromQwen3Model(config)
663
+
664
+ self.post_init()
665
+
666
+
667
+ def forward(self, pixel_values, grid_sizes, merge_sizes=None) -> torch.Tensor:
668
+ hidden_states = self.embeddings(pixel_values)
669
+ encoder_output = self.encoder(
670
+ inputs_embeds=hidden_states[None, ...],
671
+ grid_sizes=grid_sizes,
672
+ merge_sizes=merge_sizes,
673
+ output_hidden_states=True,
674
+ )
675
+ hidden_states = encoder_output.hidden_states
676
+ # hidden_states = torch.cat([
677
+ # hidden_states[7],
678
+ # hidden_states[14],
679
+ # hidden_states[21],
680
+ # hidden_states[28],
681
+ # ], dim=-1).squeeze(0)
682
+ hidden_states = hidden_states[-1].squeeze(0)
683
+
684
+ hidden_states_chunks = hidden_states.split(grid_sizes.prod(dim=1).tolist(), dim=0)
685
+ outputs = []
686
+
687
+ for hidden_states, grid_size, merge_size in zip(hidden_states_chunks, grid_sizes, merge_sizes):
688
+ # NOTE: previous implementation, which supports downsampling with any factor
689
+ c = hidden_states.shape[-1]
690
+ hidden_states = hidden_states.view(
691
+ grid_size[0], grid_size[1] // merge_size, grid_size[2] // merge_size, merge_size, merge_size, c
692
+ ).permute(0, 1, 3, 2, 4, 5)
693
+ hidden_states = hidden_states.reshape(
694
+ grid_size[0], grid_size[1], grid_size[2], c
695
+ ).permute(0, 3, 1, 2)
696
+ hidden_states = torch.nn.functional.interpolate(
697
+ hidden_states,
698
+ size=(grid_size[1] // merge_size, grid_size[2] // merge_size),
699
+ mode='bilinear'
700
+ )
701
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(-1, c)
702
+
703
+ # NOTE: simplified implementation, which only supports downsampling with integer factor
704
+ # NOTE: this implementation is mathematically equivalent to the previous one when merge_size is 1 or 2 but may cause slightly different results
705
+ # hidden_states = hidden_states.view(-1, merge_size * merge_size, hidden_states.size(-1))
706
+ # hidden_states = hidden_states.mean(dim=1)
707
+
708
+ outputs.append(hidden_states)
709
+ return torch.cat(outputs, dim=0)
710
+
711
+
712
+ def _init_weights(self, module):
713
+ """Initialize the weights"""
714
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
715
+ lecun_normal_(module.weight)
716
+ if module.bias is not None:
717
+ nn.init.zeros_(module.bias)
718
+ elif isinstance(module, nn.LayerNorm):
719
+ module.bias.data.zero_()
720
+ module.weight.data.fill_(1.0)
modeling_videollama3_qwen3.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA.
2
+ # Below is the original copyright:
3
+ # Copyright 2023 Haotian Liu
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch VideoLLaMA3 model."""
17
+
18
+ import importlib.util
19
+ import os.path as osp
20
+ import re
21
+ from abc import ABC, abstractmethod
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.utils.checkpoint
27
+ import math
28
+
29
+ from transformers import AutoModel, Qwen3ForCausalLM, Qwen3Model
30
+ from transformers.generation.utils import GenerateOutput
31
+ from transformers.modeling_outputs import CausalLMOutputWithPast
32
+
33
+ from .modeling_qwen2_audio_encoder import Qwen2AudioEncoderModel
34
+
35
+ try:
36
+ from .configuration_videollama3 import Videollama3Qwen3Config
37
+ except ModuleNotFoundError:
38
+ spec = importlib.util.spec_from_file_location(
39
+ "configuration_videollama3",
40
+ osp.join(osp.dirname(__file__), "configuration_videollama3.py"),
41
+ )
42
+ configuration_videollama3 = importlib.util.module_from_spec(spec)
43
+ spec.loader.exec_module(configuration_videollama3)
44
+ Videollama3Qwen3Config = getattr(
45
+ configuration_videollama3,
46
+ "Videollama3Qwen3Config",
47
+ )
48
+
49
+
50
+ def build_mlp(depth, hidden_size, output_hidden_size):
51
+ modules = [nn.Linear(hidden_size, output_hidden_size)]
52
+ for _ in range(1, depth):
53
+ modules.append(nn.GELU())
54
+ modules.append(nn.Linear(output_hidden_size, output_hidden_size))
55
+ return nn.Sequential(*modules)
56
+
57
+
58
+ def build_vision_projector(config, delay_load=False, **kwargs):
59
+ # videollama3 projector only support image-wise operation now, i.e., prohibit the temporal aggregation
60
+ projector_type = getattr(config, 'vision_projector_type', 'linear')
61
+ if projector_type == "linear":
62
+ # NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
63
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
64
+ elif projector_type.startswith("mlp"):
65
+ return MlpGeluProjector(config.vision_encoder_config.hidden_size, config.hidden_size, projector_type)
66
+ else:
67
+ raise ValueError(f'Unknown projector type: {projector_type}')
68
+
69
+ def build_audio_projector(config, delay_load=False, **kwargs):
70
+ projector_type = getattr(config, 'audio_projector_type', 'linear')
71
+ if projector_type == "linear":
72
+ # NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
73
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
74
+ elif projector_type.startswith("mlp"):
75
+ return MlpGeluProjector(config.audio_encoder_config.d_model, config.hidden_size, projector_type)
76
+ elif projector_type.startswith("dmlp"):
77
+ return MlpGeluDownsampleProjector(config.audio_encoder_config.d_model, config.hidden_size, projector_type)
78
+ else:
79
+ raise ValueError(f'Unknown projector type: {projector_type}')
80
+
81
+
82
+ class MlpGeluProjector(nn.Module):
83
+
84
+ def __init__(self, mm_hidden_size, hidden_size, projector_type):
85
+ super().__init__()
86
+
87
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
88
+ mlp_depth = int(mlp_gelu_match.group(1))
89
+
90
+ self.readout = build_mlp(mlp_depth, mm_hidden_size, hidden_size)
91
+
92
+ def forward(self, x):
93
+ x = self.readout(x)
94
+ return x
95
+
96
+
97
+ class MlpGeluDownsampleProjector(nn.Module):
98
+ def __init__(self, mm_hidden_size, hidden_size, projector_type):
99
+ super().__init__()
100
+ self.downsample = nn.Linear(mm_hidden_size*8, mm_hidden_size)
101
+
102
+ mlp_gelu_match = re.match(r"^dmlp(\d+)x_gelu$", projector_type)
103
+ mlp_depth = int(mlp_gelu_match.group(1))
104
+
105
+ self.readout = build_mlp(mlp_depth, mm_hidden_size, hidden_size)
106
+
107
+ def forward(self, x):
108
+ B, S, D = x.shape
109
+
110
+ group = 8
111
+ S8 = (S // group) * group
112
+ x = x[:, :S8, :]
113
+ x = x.reshape(B, S8 // group, group * D)
114
+ x = self.downsample(x)
115
+ x = self.readout(x)
116
+ return x
117
+
118
+
119
+ class Videollama3MetaModel:
120
+
121
+ def __init__(self, config):
122
+ super(Videollama3MetaModel, self).__init__(config)
123
+ if config.vision_encoder is not None:
124
+ self.vision_encoder = AutoModel.from_pretrained(
125
+ config.vision_encoder,
126
+ attn_implementation=self.config._attn_implementation,
127
+ torch_dtype=self.dtype,
128
+ )
129
+ self.config.vision_encoder_config = self.vision_encoder.config
130
+ self.config.vision_encoder = None
131
+ elif config.vision_encoder_config is not None:
132
+ self.vision_encoder = AutoModel.from_config(
133
+ self.config.vision_encoder_config,
134
+ attn_implementation=self.config._attn_implementation,
135
+ torch_dtype=self.dtype,
136
+ )
137
+ else:
138
+ raise ValueError("Vision encoder is not provided in config")
139
+
140
+ if config.audio_encoder is not None:
141
+ self.audio_encoder = Qwen2AudioEncoderModel.from_pretrained(
142
+ config.audio_encoder,
143
+ attn_implementation=self.config._attn_implementation,
144
+ torch_dtype=self.dtype,
145
+ )
146
+ self.config.audio_encoder_config = self.audio_encoder.config
147
+ self.config.audio_encoder = None
148
+ elif config.audio_encoder_config is not None:
149
+ self.audio_encoder = Qwen2AudioEncoderModel.from_config(
150
+ self.config.audio_encoder_config,
151
+ attn_implementation=self.config._attn_implementation,
152
+ torch_dtype=self.dtype,
153
+ )
154
+ else:
155
+ raise ValueError("Audio encoder is not provided in config")
156
+
157
+ self.vision_projector = build_vision_projector(config)
158
+ self.audio_projector = build_audio_projector(config)
159
+
160
+ def get_vision_encoder(self):
161
+ return self.vision_encoder
162
+
163
+ def get_audio_encoder(self):
164
+ return self.audio_encoder
165
+
166
+ def get_vision_projector(self):
167
+ return self.vision_projector
168
+
169
+ def get_audio_projector(self):
170
+ return self.audio_projector
171
+
172
+
173
+ class Videollama3Qwen3Model(Videollama3MetaModel, Qwen3Model):
174
+
175
+ config_class = Videollama3Qwen3Config
176
+
177
+ def __init__(self, config: Videollama3Qwen3Config):
178
+ super(Videollama3Qwen3Model, self).__init__(config)
179
+
180
+
181
+ class Videollama3MetaForCausalLM(ABC):
182
+
183
+ @abstractmethod
184
+ def get_model(self):
185
+ pass
186
+
187
+ def get_vision_encoder(self):
188
+ return self.get_model().get_vision_encoder()
189
+
190
+ def get_audio_encoder(self):
191
+ return self.get_model().get_audio_encoder()
192
+
193
+ def get_vision_projector(self):
194
+ return self.get_model().get_vision_projector()
195
+
196
+ def get_audio_projector(self):
197
+ return self.get_model().get_audio_projector()
198
+
199
+ def encode_images(
200
+ self,
201
+ pixel_values: torch.FloatTensor,
202
+ grid_sizes: torch.LongTensor,
203
+ merge_sizes: torch.LongTensor,
204
+ ) -> torch.FloatTensor:
205
+ mm_features = self.get_model().get_vision_encoder()(
206
+ pixel_values=pixel_values,
207
+ grid_sizes=grid_sizes,
208
+ merge_sizes=merge_sizes,
209
+ )
210
+ mm_features = self.get_model().vision_projector(mm_features)
211
+ return mm_features
212
+
213
+ def encode_audios(
214
+ self,
215
+ input_features: torch.FloatTensor,
216
+ audio_attention_mask: Optional[torch.Tensor] = None,
217
+ ) -> torch.FloatTensor:
218
+ mm_features = self.get_model().get_audio_encoder()(input_features).last_hidden_state
219
+ mm_features_projector = self.get_model().audio_projector(mm_features)
220
+ features = []
221
+ for f, m in zip(mm_features_projector, audio_attention_mask):
222
+ valid_length = math.ceil(m.sum() / 200) * 12
223
+ features.append(f[:valid_length])
224
+ mm_features_projector = torch.cat(features, dim=0)
225
+ return mm_features_projector
226
+
227
+ def _get_valid_visual_tokens(
228
+ self,
229
+ mm_features: torch.FloatTensor,
230
+ batched_num_patches: torch.LongTensor,
231
+ modals: List[str],
232
+ ):
233
+ valid_masks = []
234
+ for num_patches, modal in zip(batched_num_patches, modals):
235
+ valid_mask = torch.full((num_patches, ), modal != "text", dtype=torch.bool, device=mm_features.device)
236
+ valid_masks.append(valid_mask)
237
+ mm_features = mm_features[torch.cat(valid_masks)]
238
+ return mm_features
239
+
240
+ def _maybe_truncate_visual_tokens(
241
+ self,
242
+ mm_features: torch.FloatTensor,
243
+ compression_mask: torch.BoolTensor,
244
+ batched_num_patches: torch.LongTensor,
245
+ modals: List[str],
246
+ input_ids: torch.LongTensor,
247
+ position_ids: Optional[torch.LongTensor] = None,
248
+ ):
249
+ if position_ids is None or mm_features.shape[0] == input_ids.eq(self.config.image_token_index).sum():
250
+ return mm_features, compression_mask
251
+
252
+ truncation_mask = []
253
+ for num_patches, modal in zip(batched_num_patches, modals):
254
+ if modal == "text":
255
+ truncation_mask.append(torch.ones((0,), dtype=torch.bool, device=input_ids.device))
256
+ else:
257
+ truncation_mask.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device))
258
+
259
+ seq_end_indices = torch.nonzero(position_ids == 0)[:, 0]
260
+ seq_end_indices = seq_end_indices[seq_end_indices > 0].tolist()+ [len(input_ids)]
261
+ seq_start_indices = [0] + seq_end_indices[:-1]
262
+ num_visual_tokens = [
263
+ input_ids[start:end].eq(self.config.image_token_index).sum()
264
+ for start, end in zip(seq_start_indices, seq_end_indices)
265
+ ]
266
+
267
+ for n, mask in zip(num_visual_tokens, truncation_mask):
268
+ if len(mask) > 0:
269
+ mask[n:] = False
270
+ truncation_mask = torch.cat(truncation_mask)
271
+
272
+ return mm_features[truncation_mask], compression_mask[truncation_mask]
273
+
274
+ def _get_compression_mask(
275
+ self,
276
+ pixel_values: torch.FloatTensor,
277
+ batched_num_patches: torch.LongTensor,
278
+ grid_sizes: torch.LongTensor,
279
+ merge_sizes: torch.LongTensor,
280
+ modals: List[str],
281
+ threshold: float = 0.1,
282
+ min_tokens: int = 1,
283
+ ) -> torch.BoolTensor:
284
+ batched_images = pixel_values.split(grid_sizes.prod(dim=1).tolist(), dim=0)
285
+ compression_masks = []
286
+
287
+ for images, num_patches, grid_size, merge_size, modal in zip(
288
+ batched_images, batched_num_patches, grid_sizes, merge_sizes, modals
289
+ ):
290
+ t, h, w = grid_size
291
+ if modal == "image" or (modal == "video" and t == 1):
292
+ compression_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=images.device))
293
+
294
+ elif modal == "video":
295
+ # NOTE: video token compressor
296
+ images = images.view(t, (h // merge_size) * (w // merge_size), -1)
297
+
298
+ pixel_diff = images[1:] - images[:-1]
299
+ pixel_diff = torch.abs(pixel_diff).mean(dim=-1) * 255
300
+ pixel_diff = torch.cat([torch.full_like(pixel_diff[0:1], threshold + 1), pixel_diff], dim=0)
301
+ mask = pixel_diff > threshold
302
+ padding_ids = torch.nonzero(mask.sum(dim=1) < min_tokens)[:, 0]
303
+ # mask[padding_ids, torch.randperm(min_tokens)] = 1
304
+ mask[padding_ids, :min_tokens] = 1
305
+ compression_masks.append(mask.flatten())
306
+
307
+ else:
308
+ # in case of psuedo image
309
+ compression_masks.append(torch.ones((0,), dtype=torch.bool, device=images.device))
310
+
311
+ return torch.cat(compression_masks)
312
+
313
+ def _compress_visual_tokens(
314
+ self,
315
+ compression_mask: torch.BoolTensor,
316
+ mm_features: torch.FloatTensor,
317
+ input_ids: torch.LongTensor,
318
+ attention_mask: Optional[torch.Tensor] = None,
319
+ position_ids: Optional[torch.LongTensor] = None,
320
+ labels: Optional[torch.LongTensor] = None,
321
+ ):
322
+ mm_features = mm_features[compression_mask]
323
+ image_selected = (input_ids == self.config.image_token_index)
324
+
325
+ text_masks = torch.logical_not(image_selected)
326
+ text_masks[image_selected] = compression_mask
327
+ input_ids = input_ids[text_masks]
328
+
329
+ if attention_mask is not None:
330
+ attention_mask = attention_mask[text_masks]
331
+ if labels is not None:
332
+ labels = labels[text_masks]
333
+ if position_ids is not None:
334
+ # FIXME: assume the first position_id is always 0
335
+ position_ids = position_ids[text_masks]
336
+ pos_start = [0] + torch.nonzero(position_ids == 0)[:, 0].tolist()
337
+ pos_end = pos_start[1:] + [len(input_ids)]
338
+ position_ids = torch.cat([torch.arange(end - start, device=input_ids.device) for start, end in zip(pos_start, pos_end)])
339
+
340
+ return mm_features, input_ids, attention_mask, position_ids, labels
341
+
342
+ def prepare_inputs_labels_for_multimodal(
343
+ self,
344
+ input_ids: torch.LongTensor = None,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ position_ids: Optional[torch.LongTensor] = None,
347
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
348
+ labels: Optional[torch.LongTensor] = None,
349
+ pixel_values: Optional[torch.FloatTensor] = None,
350
+ grid_sizes: Optional[torch.LongTensor] = None,
351
+ merge_sizes: Optional[torch.LongTensor] = None,
352
+ audio_input_features: Optional[torch.FloatTensor] = None,
353
+ audio_attention_mask: Optional[torch.Tensor] = None,
354
+ modals: Optional[List[str]] = None,
355
+ ):
356
+ vision_encoder = self.get_vision_encoder()
357
+ # NOTE: text-only situation
358
+ if vision_encoder is None or (pixel_values is None and audio_input_features is None) or input_ids.shape[1] == 1:
359
+ return input_ids, attention_mask, position_ids, past_key_values, None, labels
360
+
361
+ # 1. flatten text inputs
362
+ B, N = input_ids.shape
363
+ input_ids = input_ids.view(B * N)
364
+ if attention_mask is not None:
365
+ attention_mask = attention_mask.view(B * N)
366
+ if position_ids is not None:
367
+ position_ids = position_ids.view(B * N)
368
+ if labels is not None:
369
+ labels = labels.view(B * N)
370
+
371
+ # 2. embed visual tokens
372
+ image_selected, audio_selected, mm_features_teacher = None, None, None
373
+ if pixel_values is not None:
374
+ batched_num_patches = grid_sizes.prod(dim=1).div(merge_sizes ** 2).long()
375
+ mm_features = self.encode_images(pixel_values, grid_sizes, merge_sizes)
376
+ mm_features = self._get_valid_visual_tokens(mm_features, batched_num_patches, modals)
377
+
378
+ compression_mask = self._get_compression_mask(
379
+ pixel_values, batched_num_patches, grid_sizes, merge_sizes, modals
380
+ )
381
+ mm_features, compression_mask = self._maybe_truncate_visual_tokens(
382
+ mm_features, compression_mask, batched_num_patches, modals, input_ids, position_ids
383
+ )
384
+
385
+ # 2.1 compress visual tokens
386
+ if self.config.use_token_compression:
387
+ assert B == 1, "Token compression is only supported for batch_size=1"
388
+ mm_features, input_ids, attention_mask, labels, position_ids = self._compress_visual_tokens(
389
+ compression_mask, mm_features, input_ids, attention_mask, labels, position_ids
390
+ )
391
+ # 2.2 replace multimodal tokens with features
392
+ image_selected = (input_ids == self.config.image_token_index)
393
+ input_ids[image_selected] = 0
394
+
395
+ num_vision_tokens = image_selected.sum()
396
+ if mm_features.size(0) > num_vision_tokens:
397
+ print(f"Number of vision_features ({mm_features.size(0)}) exceeds the number of image tokens ({num_vision_tokens}). Automative truncated.")
398
+ mm_features = mm_features[:num_vision_tokens]
399
+ # 3. embed audio tokens
400
+ if audio_input_features is not None:
401
+ audio_features = self.encode_audios(audio_input_features, audio_attention_mask)
402
+ # audio_features = audio_features.to(input_ids.device).flatten(0, 1)
403
+ audio_selected = (input_ids == self.config.audio_token_index)
404
+ input_ids[audio_selected] = 0
405
+
406
+ num_audio_tokens = audio_selected.sum()
407
+ if audio_features.size(0) > num_audio_tokens:
408
+ print(f"Number of audio_features ({audio_features.size(0)}) exceeds the number of audio tokens ({num_audio_tokens}). Automative truncated.")
409
+ audio_features = audio_features[:num_audio_tokens]
410
+
411
+ # 4. embed text tokens
412
+ inputs_embeds = self.get_model().embed_tokens(input_ids).clone()
413
+ if image_selected is not None:
414
+ inputs_embeds[image_selected] = inputs_embeds[image_selected] * 0.0 + mm_features
415
+ if audio_selected is not None:
416
+ inputs_embeds[audio_selected] = inputs_embeds[audio_selected] * 0.0 + audio_features
417
+
418
+ # 5. reshape back to batched format
419
+ C = inputs_embeds.shape[-1]
420
+ inputs_embeds = inputs_embeds.reshape(B, -1, C)
421
+ if attention_mask is not None:
422
+ attention_mask = attention_mask.view(B, -1)
423
+ if labels is not None:
424
+ labels = labels.view(B, -1)
425
+ if position_ids is not None:
426
+ position_ids = position_ids.view(B, -1)
427
+
428
+ return None, attention_mask, position_ids, past_key_values, inputs_embeds, labels
429
+
430
+
431
+ class Videollama3Qwen3ForCausalLM(Qwen3ForCausalLM, Videollama3MetaForCausalLM):
432
+
433
+ config_class = Videollama3Qwen3Config
434
+
435
+ def __init__(self, config, **kwargs):
436
+ super(Qwen3ForCausalLM, self).__init__(config)
437
+ self.model = Videollama3Qwen3Model(config)
438
+ self.vocab_size = config.vocab_size
439
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
440
+
441
+ # Initialize weights and apply final processing
442
+ self.post_init()
443
+
444
+ def get_model(self):
445
+ return self.model
446
+
447
+ @classmethod
448
+ def _load_pretrained_model(
449
+ cls,
450
+ model,
451
+ state_dict,
452
+ checkpoint_files,
453
+ pretrained_model_name_or_path,
454
+ ignore_mismatched_sizes=False,
455
+ sharded_metadata=None,
456
+ device_map=None,
457
+ disk_offload_folder=None,
458
+ offload_state_dict=None,
459
+ dtype=None,
460
+ hf_quantizer=None,
461
+ keep_in_fp32_regex=None,
462
+ device_mesh=None,
463
+ key_mapping=None,
464
+ weights_only=True,
465
+ ):
466
+ """
467
+ Override to handle nested vision_encoder and audio_encoder keys before calling parent's load method.
468
+ Remaps keys from 'model.vision_encoder.vision_encoder.*' to 'model.vision_encoder.*'
469
+ and 'model.audio_encoder.audio_encoder.*' to 'model.audio_encoder.*'
470
+ """
471
+ # If state_dict is provided and needs remapping, do it here
472
+ if state_dict is not None:
473
+ needs_remapping = any(k.startswith('model.vision_encoder.vision_encoder.') or k.startswith("model.audio_encoder.audio_encoder.") for k in state_dict.keys())
474
+ if needs_remapping:
475
+ print("Detected nested encoder keys, remapping 'model.vision_encoder.vision_encoder.*' -> 'model.vision_encoder.*' and 'model.audio_encoder.audio_encoder.*' -> 'model.audio_encoder.*'")
476
+ new_state_dict = {}
477
+ for k, v in state_dict.items():
478
+ if k.startswith('model.vision_encoder.vision_encoder.'):
479
+ # Remap: model.vision_encoder.vision_encoder.xxx -> model.vision_encoder.xxx
480
+ new_key = k.replace('model.vision_encoder.vision_encoder.', 'model.vision_encoder.')
481
+ new_state_dict[new_key] = v
482
+ elif k.startswith('model.audio_encoder.audio_encoder.'):
483
+ # Remap: model.audio_encoder.audio_encoder.xxx -> model.audio_encoder.xxx
484
+ new_key = k.replace('model.audio_encoder.audio_encoder.', 'model.audio_encoder.')
485
+ new_state_dict[new_key] = v
486
+ else:
487
+ new_state_dict[k] = v
488
+ state_dict = new_state_dict
489
+
490
+ # For checkpoint files, we need to add key_mapping to remap the keys during loading
491
+ if checkpoint_files is not None and key_mapping is None:
492
+ # Check if we need remapping by loading the first checkpoint
493
+ from transformers.modeling_utils import load_state_dict
494
+ checkpoint = {}
495
+ checkpoint_files_list = checkpoint_files if isinstance(checkpoint_files, list) else [checkpoint_files]
496
+ for ckpt_file in checkpoint_files_list:
497
+ ckpt = load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only)
498
+ checkpoint.update(ckpt)
499
+ needs_remapping = any(k.startswith('model.vision_encoder.vision_encoder.') or k.startswith("model.audio_encoder.audio_encoder.") for k in checkpoint.keys())
500
+
501
+ if needs_remapping:
502
+ print("Detected nested encoder keys in checkpoint, adding key mapping for vision_encoder and audio_encoder")
503
+ key_mapping = {}
504
+ for k in checkpoint.keys():
505
+ if k.startswith('model.vision_encoder.vision_encoder.'):
506
+ new_key = k.replace('model.vision_encoder.vision_encoder.', 'model.vision_encoder.')
507
+ key_mapping[k] = new_key
508
+ elif k.startswith('model.audio_encoder.audio_encoder.'):
509
+ new_key = k.replace('model.audio_encoder.audio_encoder.', 'model.audio_encoder.')
510
+ key_mapping[k] = new_key
511
+ del checkpoint
512
+
513
+ return super()._load_pretrained_model(
514
+ model=model,
515
+ state_dict=state_dict,
516
+ checkpoint_files=checkpoint_files,
517
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
518
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
519
+ sharded_metadata=sharded_metadata,
520
+ device_map=device_map,
521
+ disk_offload_folder=disk_offload_folder,
522
+ offload_state_dict=offload_state_dict,
523
+ dtype=dtype,
524
+ hf_quantizer=hf_quantizer,
525
+ keep_in_fp32_regex=keep_in_fp32_regex,
526
+ device_mesh=device_mesh,
527
+ key_mapping=key_mapping,
528
+ weights_only=weights_only,
529
+ )
530
+
531
+ # NOTE: arguments are copied from transformers==4.46.3
532
+ def forward(
533
+ self,
534
+ input_ids: torch.LongTensor = None,
535
+ attention_mask: Optional[torch.Tensor] = None,
536
+ position_ids: Optional[torch.LongTensor] = None,
537
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
538
+ inputs_embeds: Optional[torch.FloatTensor] = None,
539
+ labels: Optional[torch.LongTensor] = None,
540
+ use_cache: Optional[bool] = None,
541
+ output_attentions: Optional[bool] = None,
542
+ output_hidden_states: Optional[bool] = None,
543
+ return_dict: Optional[bool] = None,
544
+ cache_position: Optional[torch.LongTensor] = None,
545
+ num_logits_to_keep: int = 0,
546
+ # multimodal inputs
547
+ pixel_values: Optional[torch.FloatTensor] = None,
548
+ grid_sizes: Optional[torch.LongTensor] = None,
549
+ merge_sizes: Optional[torch.LongTensor] = None,
550
+ modals: Optional[List[str]] = None,
551
+ **loss_kwargs,
552
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
553
+ if inputs_embeds is None:
554
+ (
555
+ input_ids,
556
+ attention_mask,
557
+ position_ids,
558
+ past_key_values,
559
+ inputs_embeds,
560
+ labels,
561
+ ) = self.prepare_inputs_labels_for_multimodal(
562
+ input_ids=input_ids,
563
+ attention_mask=attention_mask,
564
+ position_ids=position_ids,
565
+ past_key_values=past_key_values,
566
+ labels=labels,
567
+ pixel_values=pixel_values,
568
+ grid_sizes=grid_sizes,
569
+ merge_sizes=merge_sizes,
570
+ modals=modals,
571
+ )
572
+
573
+ return super().forward(
574
+ input_ids=input_ids,
575
+ attention_mask=attention_mask,
576
+ position_ids=position_ids,
577
+ past_key_values=past_key_values,
578
+ inputs_embeds=inputs_embeds,
579
+ labels=labels,
580
+ use_cache=use_cache,
581
+ output_attentions=output_attentions,
582
+ output_hidden_states=output_hidden_states,
583
+ return_dict=return_dict,
584
+ cache_position=cache_position,
585
+ num_logits_to_keep=num_logits_to_keep,
586
+ **loss_kwargs,
587
+ )
588
+
589
+ @torch.no_grad()
590
+ def generate(
591
+ self,
592
+ # multimodal inputs
593
+ pixel_values: Optional[torch.FloatTensor] = None,
594
+ grid_sizes: Optional[torch.LongTensor] = None,
595
+ merge_sizes: Optional[torch.LongTensor] = None,
596
+ audio_input_features: Optional[torch.FloatTensor] = None,
597
+ audio_attention_mask: Optional[torch.Tensor] = None,
598
+ modals: Optional[List[str]] = None,
599
+ **kwargs,
600
+ ) -> Union[GenerateOutput, torch.LongTensor]:
601
+ input_ids = kwargs.pop("input_ids", None)
602
+ attention_mask = kwargs.pop("attention_mask", None)
603
+ position_ids = kwargs.pop("position_ids", None)
604
+ past_key_values = kwargs.pop("past_key_values", None)
605
+
606
+ if "inputs_embeds" in kwargs:
607
+ raise NotImplementedError("`inputs_embeds` is not supported")
608
+
609
+ if pixel_values is not None or audio_input_features is not None:
610
+ (
611
+ input_ids,
612
+ attention_mask,
613
+ position_ids,
614
+ past_key_values,
615
+ inputs_embeds,
616
+ labels,
617
+ ) = self.prepare_inputs_labels_for_multimodal(
618
+ input_ids=input_ids,
619
+ attention_mask=attention_mask,
620
+ position_ids=position_ids,
621
+ past_key_values=past_key_values,
622
+ labels=None,
623
+ pixel_values=pixel_values,
624
+ grid_sizes=grid_sizes,
625
+ merge_sizes=merge_sizes,
626
+ audio_input_features=audio_input_features,
627
+ audio_attention_mask=audio_attention_mask,
628
+ modals=modals,
629
+ )
630
+ else:
631
+ inputs_embeds = self.get_model().embed_tokens(input_ids)
632
+
633
+ return super().generate(
634
+ position_ids=position_ids,
635
+ attention_mask=attention_mask,
636
+ inputs_embeds=inputs_embeds,
637
+ **kwargs
638
+ )
639
+
640
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
641
+ images = kwargs.pop("images", None)
642
+ _inputs = super().prepare_inputs_for_generation(
643
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
644
+ )
645
+ if images is not None:
646
+ _inputs['images'] = images
647
+ return _inputs
preprocessor_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_sfl.SFLImageProcessor",
4
+ "AutoProcessor": "processing_videollama3_qwen3.Videollama3Qwen3Processor"
5
+ },
6
+ "chunk_length": 30,
7
+ "dither": 0.0,
8
+ "feature_extractor_type": "Qwen2AudioEncoderProcessor",
9
+ "feature_size": 128,
10
+ "hop_length": 160,
11
+ "n_fft": 400,
12
+ "n_samples": 480000,
13
+ "nb_max_frames": 3000,
14
+ "padding_side": "right",
15
+ "padding_value": 0.0,
16
+ "processor_class": "Videollama3Qwen3Processor",
17
+ "return_attention_mask": true,
18
+ "sampling_rate": 16000,
19
+ "do_convert_rgb": true,
20
+ "do_normalize": true,
21
+ "do_rescale": true,
22
+ "do_resize": true,
23
+ "image_mean": [
24
+ 0.5,
25
+ 0.5,
26
+ 0.5
27
+ ],
28
+ "image_processor_type": "SFLImageProcessor",
29
+ "image_std": [
30
+ 0.5,
31
+ 0.5,
32
+ 0.5
33
+ ],
34
+ "max_tokens": 10240,
35
+ "min_tokens": 16,
36
+ "patch_size": 14,
37
+ "resample": 3,
38
+ "rescale_factor": 0.00392156862745098
39
+ }
processing_videollama3_qwen3.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Processor class for VideoLLaMA3."""
2
+
3
+ import copy
4
+ import importlib.util
5
+ import os
6
+ import os.path as osp
7
+ import warnings
8
+ from collections import defaultdict
9
+ from typing import Any, List, Union, Dict, Optional, Tuple, TypedDict
10
+
11
+ import cv2
12
+ import ffmpeg
13
+ import imageio
14
+ import json
15
+ import math
16
+ import numpy as np
17
+ import torch
18
+ import transformers
19
+ from decord import VideoReader, cpu
20
+ from PIL import Image
21
+ from transformers.feature_extraction_utils import BatchFeature
22
+ from transformers.image_utils import ImageInput
23
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
24
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
25
+
26
+ try:
27
+ from . import image_processing_sfl
28
+ from .image_processing_sfl import (
29
+ is_valid_image, is_valid_video,
30
+ )
31
+ # from . import audio_processing_qwen2_audio
32
+ from transformers import WhisperFeatureExtractor
33
+ except ModuleNotFoundError:
34
+ spec = importlib.util.spec_from_file_location(
35
+ "image_processing_sfl",
36
+ osp.join(osp.dirname(__file__), "image_processing_sfl.py"),
37
+ )
38
+ image_processing_sfl = importlib.util.module_from_spec(spec)
39
+ spec.loader.exec_module(image_processing_sfl)
40
+ is_valid_image = getattr(image_processing_sfl, "is_valid_image")
41
+ is_valid_video = getattr(image_processing_sfl, "is_valid_video")
42
+
43
+ # constants
44
+ DEFAULT_IMAGE_TOKEN = "<image>"
45
+ DEFAULT_AUDIO_TOKEN = "<|audio|>"
46
+ IGNORE_INDEX = -100
47
+
48
+ # Type aliases
49
+ Conversation = List[Dict[str, Any]]
50
+ SingleImage = Union[Image.Image, np.ndarray, torch.Tensor]
51
+ SingleVideo = Union[List[SingleImage], np.ndarray, torch.Tensor]
52
+ BatchedImage = List[Union[SingleImage, SingleVideo]]
53
+ BatchedNamedImage = List[Tuple[str, Union[SingleImage, SingleVideo]]]
54
+
55
+
56
+ def _custom_import(class_name: str):
57
+ try:
58
+ attribute_class = getattr(transformers, class_name)
59
+ except AttributeError:
60
+ if "image" in class_name.lower():
61
+ attribute_class = getattr(image_processing_sfl, class_name)
62
+ return attribute_class
63
+
64
+
65
+ def is_named_image(image) -> bool:
66
+ return isinstance(image, (list, tuple)) and \
67
+ len(image) == 2 and \
68
+ isinstance(image[0], str) and \
69
+ image[0] in ["image", "video"] and \
70
+ (is_valid_image(image[1]) or is_valid_video(image[1]))
71
+
72
+
73
+ def make_batched_images(images) -> List[List[ImageInput]]:
74
+ if isinstance(images, (list, tuple)) and all(is_named_image(image) for image in images):
75
+ # list of named images
76
+ return [image[0] for image in images], [image[1] for image in images]
77
+ elif isinstance(images, (list, tuple)) and all(is_valid_image(image) or is_valid_video(image) for image in images):
78
+ # list of images/videos
79
+ batch = []
80
+ for image in images:
81
+ if is_valid_video(image):
82
+ batch.append(("video", image))
83
+ elif is_valid_image(image):
84
+ batch.append(("image", image))
85
+ else:
86
+ raise ValueError(f"Could not make batched images from {images}")
87
+ return [x[0] for x in batch], [x[1] for x in batch]
88
+ elif is_named_image(images):
89
+ # named images
90
+ return [images[0]], [image[1]]
91
+ elif is_valid_video(images):
92
+ # single video
93
+ return ["video"], [images]
94
+ elif is_valid_image(images):
95
+ # single image
96
+ return ["image"], [images]
97
+
98
+ raise ValueError(f"Could not make batched images from {images}")
99
+
100
+
101
+ def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None):
102
+ if mode == 'uniform':
103
+ assert num_frames is not None, "Number of frames must be provided for uniform sampling."
104
+ if duration <= num_frames:
105
+ return np.arange(duration).astype(int)
106
+ # NOTE: v1 version
107
+ # Calculate the size of each segment from which a frame will be extracted
108
+ # if duration <= num_frames:
109
+ # return np.arange(duration).astype(int)
110
+ # seg_size = float(duration - 1) / num_frames
111
+
112
+ # frame_ids = []
113
+ # for i in range(num_frames):
114
+ # # Calculate the start and end indices of each segment
115
+ # start = seg_size * i
116
+ # end = seg_size * (i + 1)
117
+ # # Append the middle index of the segment to the list
118
+ # frame_ids.append((start + end) / 2)
119
+
120
+ # return np.round(np.array(frame_ids) + 1e-6).astype(int)
121
+ # NOTE: v0 version
122
+ return np.linspace(0, duration-1, num_frames, dtype=int)
123
+ elif mode == 'fps':
124
+ assert vid_fps is not None, "FPS must be provided for FPS sampling."
125
+ assert fps is not None, "FPS must be provided for FPS sampling."
126
+ segment_len = min(vid_fps // fps, duration)
127
+ return np.arange(segment_len // 2, duration, segment_len, dtype=int)
128
+ else:
129
+ raise ImportError(f'Unsupported frame sampling mode: {mode}')
130
+
131
+
132
+ def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=128, temporal_factor=1):
133
+ if s is not None and e is not None:
134
+ s = s if s >= 0. else 0.
135
+ e = e if e >= 0. else 0.
136
+ if s > e:
137
+ s, e = e, s
138
+ elif s == e:
139
+ e = s + 1
140
+
141
+ # 1. Loading Video
142
+ if os.path.isdir(video_path):
143
+ frame_files = sorted(os.listdir(video_path))
144
+
145
+ vid_fps = 3
146
+ num_frames_of_video = len(frame_files)
147
+ elif video_path.endswith('.gif'):
148
+ gif_reader = imageio.get_reader(video_path)
149
+
150
+ vid_fps = 25
151
+ num_frames_of_video = len(gif_reader)
152
+ else:
153
+ vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
154
+ # vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
155
+
156
+ vid_fps = vreader.get_avg_fps()
157
+ num_frames_of_video = len(vreader)
158
+
159
+ # 2. Determine frame range & Calculate frame indices
160
+ f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
161
+ f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
162
+ frame_indices = list(range(f_start, f_end + 1))
163
+
164
+ duration = len(frame_indices)
165
+ # 3. Sampling frame indices
166
+ if fps is not None and duration / vid_fps < max_frames:
167
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps)]
168
+ else:
169
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
170
+
171
+ # 4. Acquire frame data
172
+ if os.path.isdir(video_path):
173
+ frames = np.array([cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices])
174
+ elif video_path.endswith('.gif'):
175
+ frames = np.array([cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices])
176
+ else:
177
+ frames = vreader.get_batch(sampled_frame_indices).asnumpy()
178
+
179
+ frames = frames.transpose(0, 3, 1, 2)
180
+ timestamps = [x / vid_fps for x in sampled_frame_indices]
181
+
182
+ if temporal_factor > 1:
183
+ pad_length = temporal_factor - len(frames) % temporal_factor
184
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
185
+ [timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
186
+
187
+ frames = [frame for frame in frames]
188
+
189
+ return frames, timestamps
190
+
191
+
192
+ class ChatTemplateKwargs(TypedDict, total=False):
193
+
194
+ chat_template: Optional[str]
195
+ add_system_prompt: Optional[bool]
196
+ add_generation_prompt: Optional[bool]
197
+
198
+
199
+ class Videollama3Qwen3ProcessorKwargs(ProcessingKwargs, ChatTemplateKwargs, total=False):
200
+
201
+ chat_template_kwargs: ChatTemplateKwargs = {
202
+ **ChatTemplateKwargs.__annotations__,
203
+ }
204
+
205
+ _defaults = {
206
+ "text_kwargs": {
207
+ "padding": False,
208
+ },
209
+ "image_kwargs": {
210
+ "merge_size": None,
211
+ },
212
+ "chat_template_kwargs": {
213
+ "chat_template": None,
214
+ "add_system_prompt": False,
215
+ "add_generation_prompt": False,
216
+ },
217
+ }
218
+
219
+
220
+ class Videollama3Qwen3Processor(ProcessorMixin):
221
+
222
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
223
+ image_processor_class = "SFLImageProcessor"
224
+ audio_processor_class = "WhisperFeatureExtractor"
225
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
226
+ valid_kwargs = ["chat_template", "image_merge_size", "video_merge_size", "fps", "max_frames"]
227
+
228
+ def __init__(
229
+ self,
230
+ image_processor=None,
231
+ audio_processor=None,
232
+ tokenizer=None,
233
+ chat_template: str = None,
234
+ image_merge_size: int = 1,
235
+ video_merge_size: int = 2,
236
+ fps: Optional[int] = 1,
237
+ max_frames: Optional[int] = 128,
238
+ ):
239
+ self.image_processor = image_processor
240
+ self.audio_processor = audio_processor
241
+ self.tokenizer = tokenizer
242
+ if chat_template is None:
243
+ chat_template = self.tokenizer.chat_template
244
+ self.chat_template = chat_template
245
+
246
+ self.image_merge_size = image_merge_size
247
+ self.video_merge_size = video_merge_size
248
+ self.fps = fps
249
+ self.max_frames = max_frames
250
+
251
+ self.generation_prompt = self._infer_generation_prompt()
252
+ self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt")
253
+ self.generation_prompt_length = len(self.generation_prompt_ids[0])
254
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
255
+ self.eos_token_id = self.tokenizer.eos_token_id
256
+
257
+ @classmethod
258
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
259
+ args = []
260
+ for attribute_name in cls.attributes:
261
+ class_name = getattr(cls, f"{attribute_name}_class")
262
+ if isinstance(class_name, tuple):
263
+ classes = tuple(_custom_import(n) if n is not None else None for n in class_name)
264
+ use_fast = kwargs.get("use_fast", True)
265
+ if use_fast and classes[1] is not None:
266
+ attribute_class = classes[1]
267
+ else:
268
+ attribute_class = classes[0]
269
+ else:
270
+ attribute_class = _custom_import(class_name)
271
+
272
+ args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
273
+ return args
274
+
275
+ def get_generation_prompt(self):
276
+ return self.generation_prompt
277
+
278
+ def get_generation_prompt_ids(self):
279
+ return self.generation_prompt_ids
280
+
281
+ def _infer_generation_prompt(self):
282
+ pseudo_message = [{"role": "user", "content": ""}]
283
+ instruction = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True)
284
+ conversation = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False)
285
+ return instruction.replace(conversation, "")
286
+
287
+ def _get_downsampled_grid_sizes(self, image_inputs: Dict[str, Any]):
288
+ grid_sizes = []
289
+ for grid_size, merge_size in zip(image_inputs.get("grid_sizes", []), image_inputs.get("merge_sizes", [])):
290
+ if not torch.all(grid_size[1:] % merge_size == 0):
291
+ warnings.warn(f"Grid size {grid_size} is not divisible by merge size. Some undesired errors may occur.")
292
+ if grid_size[0] == 1:
293
+ grid_sizes.append(grid_size[1:] / merge_size)
294
+ elif grid_size[0] > 1:
295
+ grid_sizes.extend([grid_size[1:] / merge_size] * grid_size[0])
296
+ return grid_sizes
297
+
298
+ def _get_visual_seq_len(self, grid_size: torch.Tensor):
299
+ num_tokens = int(grid_size.prod().item())
300
+ return num_tokens
301
+
302
+ def load_images(self, image_path: Union[str, List[str], Image.Image, List[Image.Image]]):
303
+ if isinstance(image_path, str) and os.path.isfile(image_path):
304
+ # images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)]
305
+ images = [Image.open(image_path).convert('RGB')]
306
+ elif isinstance(image_path, str) and os.path.isdir(image_path):
307
+ # images = [cv2.cvtColor(cv2.imread(os.path.join(image_path, f)), cv2.COLOR_BGR2RGB) for f in sorted(os.listdir(image_path))]
308
+ images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
309
+ elif isinstance(image_path, list) and isinstance(image_path[0], str):
310
+ # images = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in image_path]
311
+ images = [Image.open(f).convert('RGB') for f in image_path]
312
+ elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
313
+ images = [np.array(x) for x in image_path]
314
+ elif isinstance(image_path, Image.Image):
315
+ images = [np.array(image_path)]
316
+ else:
317
+ raise ValueError(f"Unsupported image path type: {type(image_path)}")
318
+ return images
319
+
320
+ def load_video(
321
+ self,
322
+ video_path: str,
323
+ start_time: Optional[float] = None,
324
+ end_time: Optional[float] = None,
325
+ fps: Optional[float] = None,
326
+ max_frames: Optional[float] = None,
327
+ size: Optional[int] = None,
328
+ size_divisible: int = 1,
329
+ precise_time: bool = False,
330
+ verbose: bool = False,
331
+ temporal_factor: int = 1
332
+ ):
333
+ """
334
+ Load and process a video file and return the frames and the timestamps of each frame.
335
+
336
+ Args:
337
+ video_path (str): Path to the video file.
338
+ start_time (float, optional): Start time in seconds. Defaults to None.
339
+ end_time (float, optional): End time in seconds. Defaults to None.
340
+ fps (float, optional): Frames per second. Defaults to None.
341
+ num_frames (float, optional): Number of frames to sample. Defaults to None.
342
+ size (int, optional): Size of the shortest side. Defaults to None.
343
+ size_divisible (int, optional): Size divisible by this number. Defaults to 1.
344
+ precise_time (bool, optional): Whether to use precise time. Defaults to False.
345
+ verbose (bool, optional): Print ffmpeg output. Defaults to False.
346
+
347
+ Returns:
348
+ frames (List[PIL.Image]): List of frames.
349
+ timestamps (List[float]): List of timestamps.
350
+ """
351
+ fps = self.fps if fps is None else fps
352
+ max_frames = self.max_frames if max_frames is None else max_frames
353
+
354
+ if start_time is not None and end_time is not None and end_time - start_time < 1:
355
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
356
+ if os.path.isdir(video_path):
357
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
358
+ if video_path.endswith('.gif'):
359
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
360
+ probe = ffmpeg.probe(video_path)
361
+ duration = float(probe['format']['duration'])
362
+ video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
363
+ w, h = int(video_stream['width']), int(video_stream['height'])
364
+
365
+ kwargs, input_kwargs, output_kwargs = {}, {}, {}
366
+ do_trim = start_time is not None or end_time is not None
367
+ if start_time is not None:
368
+ new_start_time = max(float(video_stream['start_time']), start_time)
369
+ duration -= new_start_time - start_time
370
+ start_time = new_start_time
371
+ else:
372
+ start_time = float(video_stream['start_time'])
373
+ if end_time is not None:
374
+ duration = min(duration, end_time - start_time)
375
+ else:
376
+ duration = duration
377
+ if do_trim:
378
+ kwargs = {'ss': start_time, 't': duration}
379
+ if precise_time:
380
+ output_kwargs.update(kwargs)
381
+ else:
382
+ input_kwargs.update(kwargs)
383
+
384
+ if size is not None:
385
+ scale_factor = size / min(w, h)
386
+ new_w, new_h = round(w * scale_factor), round(h * scale_factor)
387
+ else:
388
+ new_w, new_h = w, h
389
+ new_w = new_w // size_divisible * size_divisible
390
+ new_h = new_h // size_divisible * size_divisible
391
+
392
+ # NOTE: It may result in unexpected number of frames in ffmpeg
393
+ # if calculate the fps directly according to max_frames
394
+ # if max_frames is not None and (fps is None or duration * fps > 2 * max_frames):
395
+ # fps = round(max_frames / duration * 2)
396
+
397
+ stream = ffmpeg.input(video_path, **input_kwargs)
398
+ if fps is not None:
399
+ stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
400
+ if new_w != w or new_h != h:
401
+ stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
402
+ stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
403
+ out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
404
+
405
+ frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
406
+
407
+ if fps is not None:
408
+ timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
409
+ else:
410
+ timestamps = np.linspace(start_time, start_time + duration, len(frames))
411
+
412
+ if max_frames is not None and len(frames) > max_frames:
413
+ indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
414
+ frames = frames[indices]
415
+ timestamps = timestamps[indices]
416
+
417
+ if temporal_factor > 1:
418
+ pad_length = temporal_factor - len(frames) % temporal_factor
419
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
420
+ timestamps = np.concatenate([timestamps, timestamps[-1:].repeat(pad_length) + np.arange(1, pad_length + 1) / fps])
421
+
422
+ frames = [frame for frame in frames]
423
+ timestamps = [timestamp for timestamp in timestamps]
424
+
425
+ return frames, timestamps
426
+
427
+ def load_audio(
428
+ self,
429
+ audio_path: str,
430
+ start_time: Optional[float] = None,
431
+ end_time: Optional[float] = None,
432
+ verbose: bool = False,
433
+ sample_rate: int = 16000,
434
+ ):
435
+ """
436
+ Load and process an audio file and return the wave and the timestamps of each frame.
437
+
438
+ Args:
439
+ audio_path (str): Path to the audio file.
440
+ start_time (float, optional): Start time in seconds. Defaults to None.
441
+ end_time (float, optional): End time in seconds. Defaults to None.
442
+ verbose (bool, optional): Print ffmpeg output. Defaults to False.
443
+
444
+ Returns:
445
+ wave (List[PIL.Image]): List of wave.
446
+ timestamps (List[float]): List of timestamps.
447
+ """
448
+
449
+ audio_stream_ff = (
450
+ ffmpeg
451
+ .input(audio_path)
452
+ .output(
453
+ "pipe:",
454
+ format="s16le",
455
+ acodec="pcm_s16le",
456
+ ac=1,
457
+ ar=sample_rate,
458
+ )
459
+ )
460
+ audio_out, audio_err = ffmpeg.run(audio_stream_ff, capture_stdout=True, quiet=not verbose)
461
+ audio = np.frombuffer(audio_out, dtype=np.int16).astype(np.float32) / 32768.0
462
+ duration = len(audio) / sample_rate
463
+ if duration > 30:
464
+ audio = [audio[i*30*sample_rate: (i+1)*30*sample_rate] for i in range(int(duration // 30) + 1)]
465
+ else:
466
+ audio = [audio]
467
+ timestamps = [t for n, chunk in enumerate(audio) for t in range(n*30, n*30 + math.ceil(len(chunk) / sample_rate), 2)]
468
+ return audio, timestamps
469
+
470
+ def _load_multimodal_data(self, conversation: Conversation):
471
+ multimodal_info = defaultdict(list)
472
+ new_conversation = []
473
+ for message in conversation:
474
+ new_message = {"role": message["role"]}
475
+ if not isinstance(message["content"], (list, tuple)):
476
+ new_message["content"] = message["content"]
477
+ new_conversation.append(new_message)
478
+ continue
479
+
480
+ new_contents = []
481
+ for content in message["content"]:
482
+ if not isinstance(content, dict):
483
+ new_contents.append(content)
484
+ continue
485
+ assert "type" in content, "Content must have 'type' field."
486
+ if content["type"] in ["image", "video", "audio"] and content["type"] in content and isinstance(content[content["type"]], dict):
487
+ # TODO: support other types which are not compatible with json
488
+ load_args = content[content["type"]]
489
+ data_id = json.dumps({k: v for k, v in load_args.items() if not k in ["start_time", "end_time"]})
490
+ new_content = copy.deepcopy(content)
491
+ multimodal_info[data_id].append(new_content)
492
+ new_contents.append(new_content)
493
+ else:
494
+ new_contents.append(content)
495
+
496
+ new_message["content"] = new_contents
497
+ new_conversation.append(new_message)
498
+
499
+ for data_id, contents in multimodal_info.items():
500
+ data_type = contents[0]["type"]
501
+ if data_type == "image":
502
+ image = self.load_images(contents[0][data_type]["image_path"])[0]
503
+ for content in contents:
504
+ content["image"] = [image.copy()]
505
+
506
+ elif data_type == "video":
507
+ # TODO: start_time is None?
508
+ start_times = [content["video"].get("start_time", 0.) for content in contents]
509
+ end_times = [content["video"].get("end_time", float("inf")) for content in contents]
510
+
511
+ load_args = contents[0][data_type]
512
+ start_time, end_time = min(start_times), max(end_times)
513
+ if start_time > 0:
514
+ load_args["start_time"] = start_time
515
+ if end_time < float("inf"):
516
+ load_args["end_time"] = end_time
517
+ images, timestamps = self.load_video(**load_args)
518
+
519
+ for content, start_time, end_time in zip(contents, start_times, end_times):
520
+ cur_images, cur_timestamps = [], []
521
+ for image, timestamp in zip(images, timestamps):
522
+ if start_time <= timestamp <= end_time:
523
+ cur_images.append(image.copy())
524
+ cur_timestamps.append(timestamp)
525
+
526
+ content[data_type] = cur_images
527
+ content["num_frames"] = len(cur_images)
528
+ content["timestamps"] = cur_timestamps
529
+ if contents[0].get("with_audio", False):
530
+ _ = content.pop("with_audio")
531
+ waves, audio_timestamps = self.load_audio(load_args["video_path"])
532
+ content["audio"] = [wave.copy() for wave in waves]
533
+ audio_split = [0] * len(timestamps)
534
+ temp_count = 0
535
+ for t in audio_timestamps:
536
+ while temp_count < len(timestamps) - 1 and t >= timestamps[temp_count+1]:
537
+ temp_count += 1
538
+ audio_split[temp_count] += 1
539
+ content["audio_split"] = audio_split
540
+ elif data_type == "audio":
541
+ waves, timestamps = self.load_audio(contents[0][data_type]["audio_path"])
542
+ for content in contents:
543
+ content["audio"] = [wave.copy() for wave in waves]
544
+ content["num_frames"] = len(timestamps)
545
+ content["timestamps"] = timestamps
546
+
547
+ return new_conversation
548
+
549
+ def _gather_multimodal_data(self, conversation: Conversation):
550
+ images = []
551
+ audios = []
552
+ for message in conversation:
553
+ if not isinstance(message["content"], (list, tuple)):
554
+ continue
555
+ for content in message["content"]:
556
+ if not isinstance(content, dict):
557
+ continue
558
+ if content["type"] == "video":
559
+ video = content["video"]
560
+ assert is_valid_video(video), f"Invalid video data: {video}."
561
+ images.append(("video", video))
562
+ if "audio" in content:
563
+ audio = content["audio"]
564
+ audios.append(("audio", audio))
565
+ if content["type"] == "image":
566
+ image = content["image"]
567
+ images.append(("image", image))
568
+ if content["type"] == "audio":
569
+ audio = content["audio"]
570
+ audios.append(("audio", audio))
571
+ images = images if len(images) > 0 else None
572
+ audios = audios if len(audios) > 0 else None
573
+ return images, audios
574
+
575
+ def _process_conversation_with_label(
576
+ self,
577
+ conversation: Conversation,
578
+ image_inputs: Dict[str, Any],
579
+ audio_inputs: Dict[str, Any],
580
+ **kwargs,
581
+ ):
582
+ assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True."
583
+ assert not "add_generation_prompt" in kwargs, "'add_generation_prompt' argument is not supported when return_labels=True."
584
+
585
+ output_kwargs = self._merge_kwargs(
586
+ Videollama3Qwen3ProcessorKwargs,
587
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
588
+ **kwargs,
589
+ )
590
+ output_kwargs["chat_template_kwargs"].pop("add_generation_prompt")
591
+
592
+ grid_sizes = self._get_downsampled_grid_sizes(image_inputs)
593
+ text_inputs = {"input_ids": [], "labels": []}
594
+ sample_types_list = []
595
+ image_idx = 0
596
+
597
+ for message_idx, message in enumerate(conversation):
598
+ prompt = self.apply_chat_template(
599
+ [message],
600
+ tokenize=False,
601
+ add_generation_prompt=False,
602
+ **output_kwargs["chat_template_kwargs"],
603
+ )
604
+ prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN)
605
+ prompt = []
606
+ for chunk_idx in range(len(prompt_chunks) - 1):
607
+ prompt.append(prompt_chunks[chunk_idx])
608
+ num_tokens = self._get_visual_seq_len(grid_sizes[image_idx])
609
+ prompt.append(DEFAULT_IMAGE_TOKEN * num_tokens)
610
+ image_idx += 1
611
+ prompt.append(prompt_chunks[-1])
612
+ prompt = "".join(prompt)
613
+
614
+ # TODO: support attention_mask, position_ids, etc.
615
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt", **output_kwargs["text_kwargs"])[0]
616
+ text_inputs["input_ids"].append(input_ids)
617
+
618
+ targets = torch.full_like(input_ids, IGNORE_INDEX)
619
+ sample_types = torch.full_like(input_ids, IGNORE_INDEX)
620
+ if message["role"] == "assistant":
621
+ targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone()
622
+ # elif message["role"] == "stream":
623
+ # diff = torch.diff((input_ids == self.image_token_id).float())
624
+ # image_end_indices = torch.nonzero(diff < 0)[:, 0]
625
+ # targets[image_end_indices + 1] = input_ids[image_end_indices + 1]
626
+ # sample_types = targets.clone()
627
+ # sample_types[torch.logical_and(sample_types > 0, sample_types != self.eos_token_id)] = 0
628
+ # targets[-2] = input_ids[-2] # <|im_end|>
629
+
630
+ if message_idx > 0 and conversation[message_idx - 1]["role"] == "stream":
631
+ targets[0] = input_ids[0]
632
+ # TODO: consider non-special tokens
633
+ sample_types[0] = input_ids[0]
634
+
635
+ text_inputs["labels"].append(targets)
636
+ sample_types_list.append(sample_types)
637
+
638
+ # Negative sampling for streaming data
639
+ text_inputs = {k: torch.cat(v) for k, v in text_inputs.items()}
640
+ sample_types = torch.cat(sample_types_list)
641
+ types, counts = torch.unique(sample_types[sample_types > -1], return_counts=True)
642
+
643
+ if len(types) > 0:
644
+ target_num_samples = counts.amin()
645
+ for type_id, type_count in zip(types, counts):
646
+ if type_count > target_num_samples:
647
+ indices = torch.nonzero(sample_types == type_id)[:, 0]
648
+ random_selector = torch.randperm(indices.size(0))[:-target_num_samples]
649
+ text_inputs["labels"][indices[random_selector]] = IGNORE_INDEX
650
+ # sample_types[indices[random_selector]] = -1
651
+
652
+ assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
653
+
654
+ return text_inputs
655
+
656
+ def _process_conversation_without_label(
657
+ self,
658
+ conversation: Conversation,
659
+ image_inputs: Dict[str, Any],
660
+ audio_inputs: Dict[str, Any],
661
+ add_think_prompt: bool = False,
662
+ **kwargs,
663
+ ):
664
+ output_kwargs = self._merge_kwargs(
665
+ Videollama3Qwen3ProcessorKwargs,
666
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
667
+ **kwargs,
668
+ )
669
+ prompt = self.apply_chat_template(
670
+ conversation,
671
+ tokenize=False,
672
+ add_think_prompt=add_think_prompt,
673
+ **output_kwargs["chat_template_kwargs"],
674
+ )
675
+ return self.process_text(prompt, image_inputs, audio_inputs, **output_kwargs["text_kwargs"])
676
+
677
+ def _process_conversation(
678
+ self,
679
+ conversation: Conversation,
680
+ images: Optional[Union[BatchedImage, BatchedNamedImage]] = None,
681
+ audios = None,
682
+ return_labels: bool = False,
683
+ add_think_prompt: bool = False,
684
+ **kwargs: Unpack[Videollama3Qwen3ProcessorKwargs],
685
+ ) -> BatchFeature:
686
+ assert isinstance(conversation, list), "Conversation must be a list of messages."
687
+
688
+ if images is None or audios is None:
689
+ conversation = self._load_multimodal_data(conversation)
690
+ images, audios = self._gather_multimodal_data(conversation)
691
+
692
+ output_kwargs = self._merge_kwargs(
693
+ Videollama3Qwen3ProcessorKwargs,
694
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
695
+ **kwargs,
696
+ )
697
+
698
+ if images is not None:
699
+ image_inputs = self.process_images(images, **output_kwargs["images_kwargs"])
700
+ else:
701
+ image_inputs = {}
702
+
703
+ if audios is not None:
704
+ audio_inputs = self.process_audios(audios)
705
+ audio_inputs["audio_input_features"] = audio_inputs.pop("input_features")
706
+ audio_inputs["audio_attention_mask"] = audio_inputs.pop("attention_mask")
707
+ else:
708
+ audio_inputs = {}
709
+ # print(image_inputs, audio_inputs)
710
+
711
+ if return_labels:
712
+ text_inputs = self._process_conversation_with_label(conversation, image_inputs, audio_inputs, **kwargs)
713
+ else:
714
+ text_inputs = self._process_conversation_without_label(conversation, image_inputs, audio_inputs, add_think_prompt, **kwargs)
715
+
716
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs})
717
+
718
+ def _process_plain(
719
+ self,
720
+ text: Union[TextInput, PreTokenizedInput] = None,
721
+ images: Optional[Union[BatchedImage, BatchedNamedImage]] = None,
722
+ audios = None,
723
+ return_labels: bool = False,
724
+ **kwargs: Unpack[Videollama3Qwen3ProcessorKwargs],
725
+ ) -> BatchFeature:
726
+ if text is None:
727
+ raise ValueError("You must provide 'text' or 'message'.")
728
+ if return_labels:
729
+ raise ValueError("return_labels is not supported for plain text processing.")
730
+
731
+ output_kwargs = self._merge_kwargs(
732
+ Videollama3Qwen3ProcessorKwargs,
733
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
734
+ **kwargs,
735
+ )
736
+
737
+ if images is not None:
738
+ image_inputs = self.process_images(images, **output_kwargs["images_kwargs"])
739
+ else:
740
+ image_inputs = {}
741
+
742
+ text_inputs = self.process_text(text, image_inputs, **output_kwargs["text_kwargs"])
743
+
744
+ return BatchFeature(data={**text_inputs, **image_inputs})
745
+
746
+ def process_images(self, images: Union[BatchedImage, BatchedNamedImage], **kwargs):
747
+ modals, images = make_batched_images(images)
748
+ if not "merge_size" in kwargs:
749
+ kwargs["merge_size"] = [
750
+ self.image_merge_size if modal == "image" else self.video_merge_size
751
+ for modal in modals
752
+ ]
753
+ image_inputs = self.image_processor(images=images, **kwargs)
754
+ image_inputs["modals"] = modals
755
+ return image_inputs
756
+
757
+ def process_audios(
758
+ self,
759
+ audios = None,
760
+ **kwargs,
761
+ ):
762
+ if audios is None:
763
+ return {}
764
+ audios = [a[1] for a in audios]
765
+ audios = sum(audios, [])
766
+ audio_inputs = self.audio_processor(raw_speech=audios, sampling_rate=16000, truncation=False, return_attention_mask=True, return_tensors='pt', **kwargs)
767
+ return audio_inputs
768
+
769
+ def process_text(
770
+ self,
771
+ text: TextInput,
772
+ image_inputs: Dict[str, Any],
773
+ audio_inputs: Dict[str, Any],
774
+ **kwargs,
775
+ ):
776
+ grid_sizes = self._get_downsampled_grid_sizes(image_inputs)
777
+
778
+ kwargs.pop("padding")
779
+ kwargs.pop("padding_side")
780
+
781
+ image_idx = 0
782
+ while DEFAULT_IMAGE_TOKEN in text:
783
+ num_tokens = self._get_visual_seq_len(grid_sizes[image_idx])
784
+ text = text.replace(DEFAULT_IMAGE_TOKEN, "<placeholder>" * num_tokens, 1)
785
+ image_idx += 1
786
+ text = text.replace("<placeholder>", DEFAULT_IMAGE_TOKEN)
787
+
788
+ assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
789
+
790
+ text = text.replace(DEFAULT_AUDIO_TOKEN, DEFAULT_AUDIO_TOKEN*12)
791
+ # print(text)
792
+
793
+ text_inputs = self.tokenizer(text, **kwargs)
794
+ return text_inputs
795
+
796
+ def __call__(
797
+ self,
798
+ text: Optional[TextInput] = None,
799
+ conversation: Optional[Conversation] = None,
800
+ images: Optional[Union[BatchedImage, BatchedNamedImage]] = None,
801
+ audios = None,
802
+ return_labels: bool = False,
803
+ add_think_prompt: bool = False,
804
+ **kwargs: Unpack[Videollama3Qwen3ProcessorKwargs],
805
+ ) -> BatchFeature:
806
+ if conversation is not None:
807
+ if text is not None:
808
+ raise ValueError("You cannot provide 'message' with 'text'.")
809
+ return self._process_conversation(conversation, images, audios, return_labels, add_think_prompt, **kwargs)
810
+ return self._process_plain(text, images, audios, return_labels, **kwargs)
811
+
812
+ def batch_decode(self, *args, **kwargs):
813
+ return self.tokenizer.batch_decode(*args, **kwargs)
814
+
815
+ def decode(self, *args, **kwargs):
816
+ return self.tokenizer.decode(*args, **kwargs)
817
+
818
+ def apply_chat_template(
819
+ self,
820
+ conversation: Conversation,
821
+ chat_template: Optional[str] = None,
822
+ tokenize: bool = False,
823
+ add_system_prompt: bool = False,
824
+ add_generation_prompt: bool = True,
825
+ add_think_prompt: bool = False,
826
+ image_token: Optional[str] = DEFAULT_IMAGE_TOKEN,
827
+ audio_token: Optional[str] = DEFAULT_AUDIO_TOKEN,
828
+ **kwargs,
829
+ ) -> str:
830
+ """
831
+ Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
832
+ conversations to turn them into a single tokenizable string.
833
+
834
+ Args:
835
+ conversation (`List[Dict, str, str]`):
836
+ The conversation to format.
837
+ chat_template (`Optional[str]`, *optional*):
838
+ The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
839
+ chat template is used.
840
+ tokenize (`bool`, *optional*, defaults to `False`):
841
+ Whether to tokenize the output or not.
842
+ add_system_prompt (`bool`, *optional*, defaults to `False`):
843
+ Whether to add the system prompt to the output or not.
844
+ add_generation_prompt (`bool`, *optional*, defaults to `False`):
845
+ Whether to add the generation prompt to the output or not.
846
+ image_token (`Optional[str]`, *optional*, defaults to `<image>`):
847
+ The token to use for indicating images in the conversation.
848
+ **kwargs:
849
+ Additional keyword arguments
850
+ """
851
+
852
+ if chat_template is None:
853
+ if self.chat_template is not None:
854
+ chat_template = self.chat_template
855
+ else:
856
+ raise ValueError(
857
+ "No chat template is set for this processor. Please either set the `chat_template` attribute, "
858
+ "or provide a chat template as an argument. See "
859
+ "https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
860
+ )
861
+ return self.tokenizer.apply_chat_template(
862
+ conversation,
863
+ chat_template=chat_template,
864
+ tokenize=tokenize,
865
+ add_system_prompt=add_system_prompt,
866
+ add_generation_prompt=add_generation_prompt,
867
+ add_think_prompt=add_think_prompt,
868
+ image_token=image_token,
869
+ audio_token=audio_token,
870
+ **kwargs
871
+ )
872
+
873
+ @property
874
+ def model_input_names(self):
875
+ tokenizer_input_names = self.tokenizer.model_input_names
876
+ image_processor_input_names = self.image_processor.model_input_names
877
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + ["modals"]
878
+
879
+ # modified from transformers.ProcessorMixin
880
+ def _merge_kwargs(
881
+ self,
882
+ ModelProcessorKwargs: ProcessingKwargs,
883
+ tokenizer_init_kwargs: Optional[Dict] = None,
884
+ **kwargs,
885
+ ) -> Dict[str, Dict]:
886
+ """
887
+ Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance.
888
+ The order of operations is as follows:
889
+ 1) kwargs passed as before have highest priority to preserve BC.
890
+ ```python
891
+ high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"}
892
+ processor(..., **high_priority_kwargs)
893
+ ```
894
+ 2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API.
895
+ ```python
896
+ processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}})
897
+ ```
898
+ 3) kwargs passed during instantiation of a modality processor have fourth priority.
899
+ ```python
900
+ tokenizer = tokenizer_class(..., {"padding": "max_length"})
901
+ image_processor = image_processor_class(...)
902
+ processor(tokenizer, image_processor) # will pass max_length unless overriden by kwargs at call
903
+ ```
904
+ 4) defaults kwargs specified at processor level have lowest priority.
905
+ ```python
906
+ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwargs, total=False):
907
+ _defaults = {
908
+ "text_kwargs": {
909
+ "padding": "max_length",
910
+ "max_length": 64,
911
+ },
912
+ }
913
+ ```
914
+ Args:
915
+ ModelProcessorKwargs (`ProcessingKwargs`):
916
+ Typed dictionary of kwargs specifically required by the model passed.
917
+ tokenizer_init_kwargs (`Dict`, *optional*):
918
+ Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over defaults.
919
+
920
+ Returns:
921
+ output_kwargs (`Dict`):
922
+ Dictionary of per-modality kwargs to be passed to each modality-specific processor.
923
+
924
+ """
925
+ # Initialize dictionaries
926
+ output_kwargs = {
927
+ "text_kwargs": {},
928
+ "images_kwargs": {},
929
+ "audio_kwargs": {},
930
+ "videos_kwargs": {},
931
+ "chat_template_kwargs": {},
932
+ "common_kwargs": {},
933
+ }
934
+
935
+ default_kwargs = {
936
+ "text_kwargs": {},
937
+ "images_kwargs": {},
938
+ "audio_kwargs": {},
939
+ "videos_kwargs": {},
940
+ "chat_template_kwargs": {},
941
+ "common_kwargs": {},
942
+ }
943
+
944
+ used_keys = set()
945
+
946
+ # get defaults from set model processor kwargs if they exist
947
+ for modality in default_kwargs:
948
+ default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
949
+ # update defaults with arguments from tokenizer init
950
+ for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
951
+ # init with tokenizer init kwargs if necessary
952
+ if modality_key in tokenizer_init_kwargs:
953
+ value = (
954
+ getattr(self.tokenizer, modality_key)
955
+ if hasattr(self.tokenizer, modality_key)
956
+ else tokenizer_init_kwargs[modality_key]
957
+ )
958
+ default_kwargs[modality][modality_key] = value
959
+ # now defaults kwargs are updated with the tokenizers defaults.
960
+ # pass defaults to output dictionary
961
+ output_kwargs.update(default_kwargs)
962
+
963
+ # update modality kwargs with passed kwargs
964
+ non_modality_kwargs = set(kwargs) - set(output_kwargs)
965
+ for modality in output_kwargs:
966
+ for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
967
+ # check if we received a structured kwarg dict or not to handle it correctly
968
+ if modality in kwargs:
969
+ kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
970
+ # check if this key was passed as a flat kwarg.
971
+ if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
972
+ raise ValueError(
973
+ f"Keyword argument {modality_key} was passed two times:\n"
974
+ f"in a dictionary for {modality} and as a **kwarg."
975
+ )
976
+ elif modality_key in kwargs:
977
+ # we get a modality_key instead of popping it because modality-specific processors
978
+ # can have overlapping kwargs
979
+ kwarg_value = kwargs.get(modality_key, "__empty__")
980
+ else:
981
+ kwarg_value = "__empty__"
982
+ if kwarg_value != "__empty__":
983
+ output_kwargs[modality][modality_key] = kwarg_value
984
+ used_keys.add(modality_key)
985
+
986
+ # Determine if kwargs is a flat dictionary or contains nested dictionaries
987
+ if any(key in default_kwargs for key in kwargs):
988
+ # kwargs is dictionary-based, and some keys match modality names
989
+ for modality, subdict in kwargs.items():
990
+ if modality in default_kwargs:
991
+ for subkey, subvalue in subdict.items():
992
+ if subkey not in used_keys:
993
+ output_kwargs[modality][subkey] = subvalue
994
+ used_keys.add(subkey)
995
+ else:
996
+ # kwargs is a flat dictionary
997
+ for key in kwargs:
998
+ if key not in used_keys:
999
+ output_kwargs["common_kwargs"][key] = kwargs[key]
1000
+
1001
+ # all modality-specific kwargs are updated with common kwargs
1002
+ for modality in output_kwargs:
1003
+ output_kwargs[modality].update(output_kwargs["common_kwargs"])
1004
+ return output_kwargs
processor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_videollama3_qwen3.Videollama3Qwen3Processor"
4
+ },
5
+ "fps": 1,
6
+ "image_merge_size": 1,
7
+ "max_frames": 180,
8
+ "processor_class": "Videollama3Qwen3Processor",
9
+ "video_merge_size": 2
10
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3de4265d6c1499ee2f7f7c2ec71004f59d8676ce0373cd32cbad37d40b945cbd
3
+ size 11423788
tokenizer_config.json ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": true,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": true,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151669": {
214
+ "content": "<image>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<|stream_start|>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<|stream_end|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ },
237
+ "151672": {
238
+ "content": "<|audio|>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": false,
242
+ "single_word": false,
243
+ "special": true
244
+ },
245
+ "151673": {
246
+ "content": "<|audio_start|>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": false,
250
+ "single_word": false,
251
+ "special": true
252
+ },
253
+ "151674": {
254
+ "content": "<|audio_end|>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": false,
258
+ "single_word": false,
259
+ "special": true
260
+ }
261
+ },
262
+ "additional_special_tokens": [
263
+ "<|im_start|>",
264
+ "<|im_end|>",
265
+ "<|object_ref_start|>",
266
+ "<|object_ref_end|>",
267
+ "<|box_start|>",
268
+ "<|box_end|>",
269
+ "<|quad_start|>",
270
+ "<|quad_end|>",
271
+ "<|vision_start|>",
272
+ "<|vision_end|>",
273
+ "<|vision_pad|>",
274
+ "<|image_pad|>",
275
+ "<|video_pad|>"
276
+ ],
277
+ "bos_token": null,
278
+ "chat_template": "\n{%- set identifier = 'im' %}\n{% for message in messages %}\n {% if message['role'] == 'stream' %}\n {% set identifier = 'stream' %}\n {% else %}\n {% set identifier = 'im' %}\n {% endif %}\n {% if message['role'] is not none %}\n {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}\n {% endif %}\n {% if message['content'] is string %}\n {{- message['content'] + '<|' + identifier + '_end|>\n' -}}\n {% else %}\n {% for content in message['content'] %}\n {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n {% if 'time' in content %}\n {{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}\n {% endif %}\n {{- image_token + '\n' -}}\n {% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}\n {% for i in range(content['num_frames']) %}\n {% if 'timestamps' in content %}\n {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}\n {% endif %}\n {% if i < content['num_frames'] - 1 %}\n {{- image_token + ',' -}}\n {% if 'audio_split' in content and content['audio_split'][i] > 0 %}\n {{- '<|audio_start|>' + audio_token * content['audio_split'][i] + '<|audio_end|>,' -}}\n {% endif %}\n {% else %}\n {{- image_token -}}\n {% if 'audio_split' in content and content['audio_split'][i] > 0 %}\n {{- ',<|audio_start|>' + audio_token * content['audio_split'][i] + '<|audio_end|>\n' -}}\n {% else %}\n {{- '\n' -}}\n {% endif %}\n {% endif %}\n {% endfor %}\n {% elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content %}\n {% for i in range(content['num_frames']) %}\n {% if 'timestamps' in content %}\n {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}\n {% endif %}\n {% if i < content['num_frames'] - 1 %}\n {{- '<|audio_start|>' + audio_token + '<|audio_end|>,' -}}\n {% else %}\n {{- '<|audio_start|>' + audio_token + '<|audio_end|>\n' -}}\n {% endif %}\n {% endfor %}\n {% elif content['type'] == 'text' or 'text' in content %}\n {{- content['text'] -}}\n {% endif %}\n {% endfor %}\n {% if message['role'] is not none %}\n {{- '<|' + identifier + '_end|>\n' -}}\n {% endif %}\n {% endif %}\n{% endfor %}\n{% if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' -}}\n {% if not add_think_prompt %}\n {{- '<think>\n\n</think>\n\n' -}}\n {% endif %}\n{% endif %}\n",
279
+ "clean_up_tokenization_spaces": false,
280
+ "eos_token": "<|im_end|>",
281
+ "errors": "replace",
282
+ "extra_special_tokens": {},
283
+ "model_max_length": 16384,
284
+ "pad_token": "<|endoftext|>",
285
+ "padding_side": "right",
286
+ "processor_class": "Videollama3Qwen3Processor",
287
+ "split_special_tokens": false,
288
+ "tokenizer_class": "Qwen2Tokenizer",
289
+ "unk_token": null
290
+ }
trainer_state.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeae0aa430b05cbcf2d1820d2fd7421aead1e6cbb402df12048617552171fa33
3
+ size 16188911
vocab.json ADDED
The diff for this file is too large to render. See raw diff