Flourish commited on
Commit
44cb7bc
·
verified ·
1 Parent(s): f2c6f36

Upload 14 files

Browse files
ovis_image/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+ from ovis_image.model.args import OvisImageModelArgs
6
+ from ovis_image.model.autoencoder import AutoEncoderParams
7
+ from ovis_image.model.model import OvisImageModel
8
+
9
+ __all__ = [
10
+ "OvisImageModelArgs",
11
+ "OvisImageModel",
12
+ "ovis_image_configs",
13
+ ]
14
+
15
+
16
+ ovis_image_configs = {
17
+ "ovis-image-7b": OvisImageModelArgs(
18
+ in_channels=64,
19
+ out_channels=64,
20
+ context_in_dim=2048,
21
+ hidden_size=3072,
22
+ mlp_ratio=4.0,
23
+ num_heads=24,
24
+ depth=6,
25
+ double_block_type="DoubleStreamBlock",
26
+ depth_single_blocks=27,
27
+ axes_dim=(16, 56, 56),
28
+ theta=10_000,
29
+ qkv_bias=True,
30
+ activation = "swiglu",
31
+ autoencoder_params=AutoEncoderParams(
32
+ resolution=256,
33
+ in_channels=3,
34
+ ch=128,
35
+ out_ch=3,
36
+ ch_mult=(1, 2, 4, 4),
37
+ num_res_blocks=2,
38
+ z_channels=16,
39
+ scale_factor=0.3611,
40
+ shift_factor=0.1159,
41
+ ),
42
+ ),
43
+ }
44
+
ovis_image/dataset/image_util.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+ import math
6
+ import torch
7
+ import torchvision
8
+ from torchvision import transforms
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ def ceil_to(x, factor=16):
13
+ return math.ceil(float(x) / factor) * factor
14
+
15
+
16
+ def build_img_ids(
17
+ latent_height,
18
+ latent_width,
19
+ latent_crop_height = None,
20
+ latent_crop_width = None,
21
+ time = 0,
22
+ ):
23
+ if latent_crop_height is None:
24
+ latent_crop_height = latent_height
25
+ if latent_crop_width is None:
26
+ latent_crop_width = latent_width
27
+ img_ids = torch.zeros(latent_height, latent_width, 3)
28
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(latent_height)[:, None]
29
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(latent_width)[None, :]
30
+ # crop
31
+ crop_h = (latent_height - latent_crop_height) // 2
32
+ crop_w = (latent_width - latent_crop_width) // 2
33
+ img_ids = img_ids[crop_h:crop_h+latent_crop_height, crop_w:crop_w+latent_crop_width]
34
+ img_ids[..., 0] = time
35
+ h, w, c = img_ids.shape
36
+ img_ids = img_ids.reshape(h * w, c)
37
+ return img_ids
38
+
39
+
40
+ def process_pil_img_to_tensor(
41
+ pil_img,
42
+ output_size: int | None = 256,
43
+ output_width: int | None = None,
44
+ output_height: int | None = None,
45
+ with_position_ids: bool = False,
46
+ position_ids_time: int = 0,
47
+ ):
48
+ width, height = pil_img.size
49
+ if output_width is None or output_height is None:
50
+ output_width = output_size
51
+ output_height = output_size
52
+ assert output_height % 16 == 0
53
+ assert output_width % 16 == 0
54
+ resize_ratio = max(
55
+ float(output_width)/width,
56
+ float(output_height)/height
57
+ )
58
+ resize_size = (
59
+ ceil_to(resize_ratio * height, 16),
60
+ ceil_to(resize_ratio * width, 16)
61
+ )
62
+ pil_resize_img = torchvision.transforms.functional.resize(
63
+ pil_img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC
64
+ )
65
+ pil_crop_img = torchvision.transforms.functional.center_crop(
66
+ pil_resize_img, (output_height, output_width)
67
+ )
68
+ image_tensor = torchvision.transforms.functional.to_tensor(pil_crop_img)
69
+ image_tensor = torchvision.transforms.functional.normalize(
70
+ image_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
71
+ )
72
+ if with_position_ids:
73
+ img_ids = build_img_ids(
74
+ latent_height = resize_size[0] // 16,
75
+ latent_width = resize_size[1] // 16,
76
+ latent_crop_height = output_height // 16,
77
+ latent_crop_width = output_width // 16,
78
+ time = position_ids_time,
79
+ )
80
+ else:
81
+ img_ids = None
82
+ return pil_crop_img, image_tensor, img_ids
83
+
84
+
85
+ def pack_latent_to_token(
86
+ latent,
87
+ ):
88
+ token = rearrange(
89
+ latent,
90
+ "b c (h ph) (w pw) -> b (h w) (c ph pw)",
91
+ ph=2,
92
+ pw=2
93
+ )
94
+ return token
95
+
96
+
97
+ def unpack_token_to_latent(
98
+ token,
99
+ image_height: int | None = None,
100
+ latent_height: int | None = None,
101
+ image_width: int | None = None,
102
+ latent_width: int | None = None,
103
+ ):
104
+ if image_height is not None:
105
+ h = math.ceil(image_height / 16)
106
+ elif latent_height is not None:
107
+ h = latent_height // 2
108
+ else:
109
+ raise ValueError(f"both {image_height} and {latent_height} are None")
110
+ if image_width is not None:
111
+ w = math.ceil(image_width / 16)
112
+ elif latent_width is not None:
113
+ w = latent_width // 2
114
+ else:
115
+ raise ValueError(f"both {image_width} and {latent_width} are None")
116
+ return rearrange(
117
+ token,
118
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
119
+ h=h,
120
+ w=w,
121
+ ph=2,
122
+ pw=2,
123
+ )
124
+
125
+
ovis_image/model/args.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ from dataclasses import dataclass, field
7
+
8
+ from ovis_image.model.autoencoder import AutoEncoderParams
9
+
10
+
11
+ @dataclass
12
+ class OvisImageModelArgs:
13
+ in_channels: int = 64
14
+ out_channels: int = 64
15
+ context_in_dim: int = 512
16
+ hidden_size: int = 3072
17
+ mlp_ratio: float = 4.0
18
+ num_heads: int = 24
19
+ depth: int = 19
20
+ double_block_type: str = "DoubleStreamBlock"
21
+ depth_single_blocks: int = 38
22
+ axes_dim: tuple = (16, 56, 56)
23
+ theta: int = 10_000
24
+ qkv_bias: bool = True
25
+ activation: str = "gelu_tanh"
26
+ """activation: gelu_tanh or swiglu"""
27
+ norm: str = "layernorm"
28
+ """norm: layernorm or rmsnorm"""
29
+ autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
ovis_image/model/autoencoder.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ import os
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from safetensors.torch import load_file as load_sft
12
+ from torch import nn, Tensor
13
+
14
+
15
+ @dataclass
16
+ class AutoEncoderParams:
17
+ resolution: int = 256
18
+ in_channels: int = 3
19
+ ch: int = 128
20
+ out_ch: int = 3
21
+ ch_mult: tuple[int] = (1, 2, 4, 4)
22
+ num_res_blocks: int = 2
23
+ z_channels: int = 16
24
+ scale_factor: float = 0.3611
25
+ shift_factor: float = 0.1159
26
+ use_quant_conv: bool = False
27
+ use_post_quant_conv: bool = False
28
+
29
+
30
+ def swish(x: Tensor) -> Tensor:
31
+ return x * torch.sigmoid(x)
32
+
33
+
34
+ class AttnBlock(nn.Module):
35
+ def __init__(self, in_channels: int):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+
39
+ self.norm = nn.GroupNorm(
40
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
41
+ )
42
+
43
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
44
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
45
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
46
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
47
+
48
+ def attention(self, h_: Tensor) -> Tensor:
49
+ h_ = self.norm(h_)
50
+ q = self.q(h_)
51
+ k = self.k(h_)
52
+ v = self.v(h_)
53
+
54
+ b, c, h, w = q.shape
55
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
56
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
57
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
58
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
59
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ return x + self.proj_out(self.attention(x))
63
+
64
+
65
+ class ResnetBlock(nn.Module):
66
+ def __init__(self, in_channels: int, out_channels: int):
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+
72
+ self.norm1 = nn.GroupNorm(
73
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
74
+ )
75
+ self.conv1 = nn.Conv2d(
76
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
77
+ )
78
+ self.norm2 = nn.GroupNorm(
79
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
80
+ )
81
+ self.conv2 = nn.Conv2d(
82
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
83
+ )
84
+ if self.in_channels != self.out_channels:
85
+ self.nin_shortcut = nn.Conv2d(
86
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
87
+ )
88
+
89
+ def forward(self, x):
90
+ h = x
91
+ h = self.norm1(h)
92
+ h = swish(h)
93
+ h = self.conv1(h)
94
+
95
+ h = self.norm2(h)
96
+ h = swish(h)
97
+ h = self.conv2(h)
98
+
99
+ if self.in_channels != self.out_channels:
100
+ x = self.nin_shortcut(x)
101
+
102
+ return x + h
103
+
104
+
105
+ class Downsample(nn.Module):
106
+ def __init__(self, in_channels: int):
107
+ super().__init__()
108
+ # no asymmetric padding in torch conv, must do it ourselves
109
+ self.conv = nn.Conv2d(
110
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
111
+ )
112
+
113
+ def forward(self, x: Tensor):
114
+ pad = (0, 1, 0, 1)
115
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
116
+ x = self.conv(x)
117
+ return x
118
+
119
+
120
+ class Upsample(nn.Module):
121
+ def __init__(self, in_channels: int):
122
+ super().__init__()
123
+ self.conv = nn.Conv2d(
124
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
125
+ )
126
+
127
+ def forward(self, x: Tensor):
128
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
129
+ x = self.conv(x)
130
+ return x
131
+
132
+
133
+ class Encoder(nn.Module):
134
+ def __init__(
135
+ self,
136
+ resolution: int,
137
+ in_channels: int,
138
+ ch: int,
139
+ ch_mult: list[int],
140
+ num_res_blocks: int,
141
+ z_channels: int,
142
+ ):
143
+ super().__init__()
144
+ self.ch = ch
145
+ self.num_resolutions = len(ch_mult)
146
+ self.num_res_blocks = num_res_blocks
147
+ self.resolution = resolution
148
+ self.in_channels = in_channels
149
+ # downsampling
150
+ self.conv_in = nn.Conv2d(
151
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
152
+ )
153
+
154
+ curr_res = resolution
155
+ in_ch_mult = (1,) + tuple(ch_mult)
156
+ self.in_ch_mult = in_ch_mult
157
+ self.down = nn.ModuleList()
158
+ block_in = self.ch
159
+ for i_level in range(self.num_resolutions):
160
+ block = nn.ModuleList()
161
+ attn = nn.ModuleList()
162
+ block_in = ch * in_ch_mult[i_level]
163
+ block_out = ch * ch_mult[i_level]
164
+ for _ in range(self.num_res_blocks):
165
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
166
+ block_in = block_out
167
+ down = nn.Module()
168
+ down.block = block
169
+ down.attn = attn
170
+ if i_level != self.num_resolutions - 1:
171
+ down.downsample = Downsample(block_in)
172
+ curr_res = curr_res // 2
173
+ self.down.append(down)
174
+
175
+ # middle
176
+ self.mid = nn.Module()
177
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
178
+ self.mid.attn_1 = AttnBlock(block_in)
179
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
180
+
181
+ # end
182
+ self.norm_out = nn.GroupNorm(
183
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
184
+ )
185
+ self.conv_out = nn.Conv2d(
186
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
187
+ )
188
+
189
+ def forward(self, x: Tensor) -> Tensor:
190
+ # downsampling
191
+ hs = [self.conv_in(x)]
192
+ for i_level in range(self.num_resolutions):
193
+ for i_block in range(self.num_res_blocks):
194
+ h = self.down[i_level].block[i_block](hs[-1])
195
+ if len(self.down[i_level].attn) > 0:
196
+ h = self.down[i_level].attn[i_block](h)
197
+ hs.append(h)
198
+ if i_level != self.num_resolutions - 1:
199
+ hs.append(self.down[i_level].downsample(hs[-1]))
200
+
201
+ # middle
202
+ h = hs[-1]
203
+ h = self.mid.block_1(h)
204
+ h = self.mid.attn_1(h)
205
+ h = self.mid.block_2(h)
206
+ # end
207
+ h = self.norm_out(h)
208
+ h = swish(h)
209
+ h = self.conv_out(h)
210
+ return h
211
+
212
+
213
+ class Decoder(nn.Module):
214
+ def __init__(
215
+ self,
216
+ ch: int,
217
+ out_ch: int,
218
+ ch_mult: list[int],
219
+ num_res_blocks: int,
220
+ in_channels: int,
221
+ resolution: int,
222
+ z_channels: int,
223
+ ):
224
+ super().__init__()
225
+ self.ch = ch
226
+ self.num_resolutions = len(ch_mult)
227
+ self.num_res_blocks = num_res_blocks
228
+ self.resolution = resolution
229
+ self.in_channels = in_channels
230
+ self.ffactor = 2 ** (self.num_resolutions - 1)
231
+
232
+ # compute in_ch_mult, block_in and curr_res at lowest res
233
+ block_in = ch * ch_mult[self.num_resolutions - 1]
234
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
235
+ self.z_shape = (1, z_channels, curr_res, curr_res)
236
+
237
+ # z to block_in
238
+ self.conv_in = nn.Conv2d(
239
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
240
+ )
241
+
242
+ # middle
243
+ self.mid = nn.Module()
244
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
245
+ self.mid.attn_1 = AttnBlock(block_in)
246
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
247
+
248
+ # upsampling
249
+ self.up = nn.ModuleList()
250
+ for i_level in reversed(range(self.num_resolutions)):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_out = ch * ch_mult[i_level]
254
+ for _ in range(self.num_res_blocks + 1):
255
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
256
+ block_in = block_out
257
+ up = nn.Module()
258
+ up.block = block
259
+ up.attn = attn
260
+ if i_level != 0:
261
+ up.upsample = Upsample(block_in)
262
+ curr_res = curr_res * 2
263
+ self.up.insert(0, up) # prepend to get consistent order
264
+
265
+ # end
266
+ self.norm_out = nn.GroupNorm(
267
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
268
+ )
269
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
270
+
271
+ def forward(self, z: Tensor) -> Tensor:
272
+ # get dtype for proper tracing
273
+ upscale_dtype = next(self.up.parameters()).dtype
274
+
275
+ # z to block_in
276
+ h = self.conv_in(z)
277
+
278
+ # middle
279
+ h = self.mid.block_1(h)
280
+ h = self.mid.attn_1(h)
281
+ h = self.mid.block_2(h)
282
+
283
+ # cast to proper dtype
284
+ h = h.to(upscale_dtype)
285
+ # upsampling
286
+ for i_level in reversed(range(self.num_resolutions)):
287
+ for i_block in range(self.num_res_blocks + 1):
288
+ h = self.up[i_level].block[i_block](h)
289
+ if len(self.up[i_level].attn) > 0:
290
+ h = self.up[i_level].attn[i_block](h)
291
+ if i_level != 0:
292
+ h = self.up[i_level].upsample(h)
293
+
294
+ # end
295
+ h = self.norm_out(h)
296
+ h = swish(h)
297
+ h = self.conv_out(h)
298
+ return h
299
+
300
+
301
+ class DiagonalGaussian(nn.Module):
302
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
303
+ super().__init__()
304
+ self.sample = sample
305
+ self.chunk_dim = chunk_dim
306
+
307
+ def forward(self, z: Tensor) -> Tensor:
308
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
309
+ if self.sample:
310
+ std = torch.exp(0.5 * logvar)
311
+ return mean + std * torch.randn_like(mean)
312
+ else:
313
+ return mean
314
+
315
+
316
+ class AutoEncoder(nn.Module):
317
+ def __init__(self, params: AutoEncoderParams):
318
+ super().__init__()
319
+ self.params = params
320
+ self.encoder = Encoder(
321
+ resolution=params.resolution,
322
+ in_channels=params.in_channels,
323
+ ch=params.ch,
324
+ ch_mult=params.ch_mult,
325
+ num_res_blocks=params.num_res_blocks,
326
+ z_channels=params.z_channels,
327
+ )
328
+ self.decoder = Decoder(
329
+ resolution=params.resolution,
330
+ in_channels=params.in_channels,
331
+ ch=params.ch,
332
+ out_ch=params.out_ch,
333
+ ch_mult=params.ch_mult,
334
+ num_res_blocks=params.num_res_blocks,
335
+ z_channels=params.z_channels,
336
+ )
337
+ self.reg = DiagonalGaussian()
338
+
339
+ self.scale_factor = params.scale_factor
340
+ self.shift_factor = params.shift_factor
341
+
342
+ self.quant_conv = nn.Conv2d(2 * params.z_channels, 2 * params.z_channels, 1) if params.use_quant_conv else None
343
+ self.post_quant_conv = nn.Conv2d(params.z_channels, params.z_channels, 1) if params.use_post_quant_conv else None
344
+
345
+ def encode(self, x: Tensor) -> Tensor:
346
+ x = self.encoder(x)
347
+ if self.quant_conv is not None:
348
+ x = self.quant_conv(x)
349
+ z = self.reg(x)
350
+ z = self.scale_factor * (z - self.shift_factor)
351
+ return z
352
+
353
+ def decode(self, z: Tensor) -> Tensor:
354
+ z = z / self.scale_factor + self.shift_factor
355
+ if self.post_quant_conv is not None:
356
+ z = self.post_quant_conv(z)
357
+ return self.decoder(z)
358
+
359
+ def forward(self, x: Tensor) -> Tensor:
360
+ return self.decode(self.encode(x))
361
+
362
+
363
+ def load_ae(
364
+ ckpt_path: str,
365
+ autoencoder_params: AutoEncoderParams,
366
+ device: str | torch.device = "cuda",
367
+ dtype=torch.bfloat16,
368
+ random_init=False,
369
+ ) -> AutoEncoder:
370
+ """
371
+ Load the autoencoder from the given model name.
372
+ Args:
373
+ name (str): The name of the autoencoder.
374
+ device (str or torch.device): The device to load the autoencoder to.
375
+ Returns:
376
+ AutoEncoder: The loaded autoencoder.
377
+ """
378
+ # Loading the autoencoder
379
+ with torch.device(device):
380
+ ae = AutoEncoder(autoencoder_params)
381
+
382
+ if random_init:
383
+ print(f"Random Init VAE")
384
+ return ae.to(dtype=dtype)
385
+
386
+ if not os.path.exists(ckpt_path):
387
+ raise ValueError(
388
+ f"Autoencoder path {ckpt_path} does not exist. Please download it first."
389
+ )
390
+
391
+ if ckpt_path is not None:
392
+ print(f"Loading {ckpt_path}")
393
+ sd = load_sft(ckpt_path, device=str(device))
394
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
395
+ if len(missing) > 0:
396
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
397
+ if len(unexpected) > 0:
398
+ print(
399
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
400
+ )
401
+ return ae.to(dtype=dtype)
402
+
ovis_image/model/hf_embedder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+
9
+ from ovis_image.model.ovis.modeling_ovis2_5 import Ovis2_5, Ovis2_5_Config
10
+
11
+
12
+ class OvisEmbedder(nn.Module):
13
+ def __init__(
14
+ self,
15
+ model_path: str,
16
+ random_init=False,
17
+ **hf_kwargs
18
+ ):
19
+ super().__init__()
20
+ if random_init:
21
+ # Initialize Ovis model with random weights for test purpose only
22
+ config = Ovis2_5_Config.from_pretrained(model_path)
23
+ config.name_or_path = model_path
24
+ self.hf_module = Ovis2_5._from_config(config, **hf_kwargs)
25
+ else:
26
+ self.hf_module = Ovis2_5.from_pretrained(
27
+ model_path, **hf_kwargs
28
+ )
29
+ self.pad_token_id = self.hf_module.text_tokenizer.pad_token_id
30
+ self.user_prompt_begin_id = 28
31
+ # get Qwen3
32
+ self.hf_module = self.hf_module.llm.model
33
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
34
+
35
+
36
+ def forward(self, batch_tokens: Tensor, attention_mask = None) -> Tensor:
37
+ if attention_mask is None:
38
+ attention_mask = torch.ne(
39
+ batch_tokens, self.pad_token_id
40
+ ).to(device=batch_tokens.device)
41
+ outputs = self.hf_module(
42
+ input_ids=batch_tokens,
43
+ attention_mask=attention_mask,
44
+ )
45
+ txt_semantic_embed = outputs.last_hidden_state
46
+ txt_semantic_embed = txt_semantic_embed * attention_mask[..., None]
47
+ txt_semantic_embed = txt_semantic_embed[:, self.user_prompt_begin_id:, :]
48
+ return txt_semantic_embed
49
+
50
+
ovis_image/model/layers.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from torch import nn, Tensor
12
+
13
+ from ovis_image.model.ops import attention, rope
14
+
15
+
16
+ class EmbedND(nn.Module):
17
+
18
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
19
+ super().__init__()
20
+ self.dim = dim
21
+ self.theta = theta
22
+ self.axes_dim = axes_dim
23
+
24
+ @torch.no_grad()
25
+ def forward(self, ids: Tensor) -> Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+ # bs x 1 x 512 x 64 x 2 x 2
32
+ return emb.unsqueeze(1)
33
+
34
+
35
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ t = time_factor * t
45
+ half = dim // 2
46
+ with torch.device(t.device):
47
+ freqs = torch.exp(
48
+ -math.log(max_period)
49
+ * torch.arange(start=0, end=half, dtype=torch.float32)
50
+ / half
51
+ )
52
+
53
+ args = t[:, None].float() * freqs[None]
54
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
55
+ if dim % 2:
56
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
57
+ if torch.is_floating_point(t):
58
+ embedding = embedding.to(t)
59
+ return embedding
60
+
61
+
62
+ class MLPEmbedder(nn.Module):
63
+ def __init__(self, in_dim: int, hidden_dim: int):
64
+ super().__init__()
65
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
66
+ self.silu = nn.SiLU()
67
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
68
+
69
+ def init_weights(self, init_std: float = 0.02):
70
+ nn.init.normal_(self.in_layer.weight, std=init_std)
71
+ nn.init.constant_(self.in_layer.bias, 0)
72
+ nn.init.normal_(self.out_layer.weight, std=init_std)
73
+ nn.init.constant_(self.out_layer.bias, 0)
74
+
75
+ def forward(self, x: Tensor) -> Tensor:
76
+ return self.out_layer(self.silu(self.in_layer(x)))
77
+
78
+
79
+ class QKNorm(torch.nn.Module):
80
+ def __init__(self, dim: int):
81
+ super().__init__()
82
+ self.query_norm = nn.RMSNorm(dim)
83
+ self.key_norm = nn.RMSNorm(dim)
84
+
85
+ def init_weights(self):
86
+ self.query_norm.reset_parameters()
87
+ self.key_norm.reset_parameters()
88
+
89
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
90
+ q = self.query_norm(q)
91
+ k = self.key_norm(k)
92
+ return q.to(v), k.to(v)
93
+
94
+
95
+ class SelfAttention(nn.Module):
96
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
97
+ super().__init__()
98
+ self.num_heads = num_heads
99
+ head_dim = dim // num_heads
100
+
101
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
102
+ self.norm = QKNorm(head_dim)
103
+ self.proj = nn.Linear(dim, dim)
104
+
105
+ def init_weights(self):
106
+ for layer in (self.qkv, self.proj):
107
+ nn.init.xavier_uniform_(layer.weight)
108
+ if layer.bias is not None:
109
+ nn.init.constant_(layer.bias, 0)
110
+ self.norm.init_weights()
111
+
112
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
113
+ qkv = self.qkv(x)
114
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
115
+ q, k = self.norm(q, k, v)
116
+ x = attention(q, k, v, pe=pe)
117
+ x = self.proj(x)
118
+ return x
119
+
120
+
121
+ class YakMLP(nn.Module):
122
+ # Use SwiGLU
123
+ def __init__(self, hidden_size: int, intermediate_size: int):
124
+ super().__init__()
125
+ self.hidden_size = hidden_size
126
+ self.intermediate_size = intermediate_size
127
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
128
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
129
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
130
+ self.act_fn = nn.SiLU()
131
+
132
+ def init_weights(self):
133
+ for layer in (self.gate_proj, self.up_proj, self.down_proj):
134
+ nn.init.xavier_uniform_(layer.weight)
135
+ nn.init.constant_(layer.bias, 0)
136
+
137
+ def forward(self, x: Tensor) -> Tensor:
138
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
139
+ return down_proj
140
+
141
+
142
+ def build_mlp(hidden_size, intermediate_size, activation = "gelu_tanh"):
143
+ if activation == "gelu_tanh":
144
+ mlp = nn.Sequential(
145
+ nn.Linear(hidden_size, intermediate_size, bias=True),
146
+ nn.GELU(approximate="tanh"),
147
+ nn.Linear(intermediate_size, hidden_size, bias=True),
148
+ )
149
+ else:
150
+ mlp = YakMLP(hidden_size, intermediate_size)
151
+ return mlp
152
+
153
+
154
+ def init_mlp(mlp, activation = "gelu_tanh"):
155
+ if activation == "gelu_tanh":
156
+ for layer in (mlp[0], mlp[2]):
157
+ nn.init.xavier_uniform_(layer.weight)
158
+ nn.init.constant_(layer.bias, 0)
159
+ else:
160
+ mlp.init_weights()
161
+
162
+
163
+ @dataclass
164
+ class ModulationOut:
165
+ shift: Tensor
166
+ scale: Tensor
167
+ gate: Tensor
168
+
169
+
170
+ class Modulation(nn.Module):
171
+ def __init__(self, dim: int, multiples: int = 1):
172
+ super().__init__()
173
+ assert multiples in [1, 2, 3]
174
+ self.multiples = multiples
175
+ self.multiplier = 3 * multiples
176
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
177
+ self.act = nn.SiLU()
178
+
179
+ def init_weights(self):
180
+ nn.init.constant_(self.lin.weight, 0)
181
+ nn.init.constant_(self.lin.bias, 0)
182
+
183
+ def forward(self, vec: Tensor):
184
+ out = self.lin(self.act(vec))[:, None, :].chunk(
185
+ self.multiplier, dim=-1
186
+ )
187
+ if self.multiples == 1:
188
+ return ModulationOut(*out[:3])
189
+ elif self.multiples == 2:
190
+ return (
191
+ ModulationOut(*out[:3]),
192
+ ModulationOut(*out[3:]),
193
+ )
194
+ elif self.multiples == 3:
195
+ return (
196
+ ModulationOut(*out[:3]),
197
+ ModulationOut(*out[3:6]),
198
+ ModulationOut(*out[6:]),
199
+ )
200
+
201
+
202
+ class DoubleStreamBlock(nn.Module):
203
+ def __init__(
204
+ self,
205
+ hidden_size: int,
206
+ num_heads: int,
207
+ mlp_ratio: float,
208
+ qkv_bias: bool = False,
209
+ activation: str = "gelu_tanh",
210
+ norm_layer: nn.Module = nn.LayerNorm,
211
+ ):
212
+ super().__init__()
213
+
214
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
215
+ self.num_heads = num_heads
216
+ self.hidden_size = hidden_size
217
+ self.activation = activation
218
+ self.img_mod = Modulation(hidden_size, multiples=2)
219
+ self.img_norm1 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
220
+ self.img_attn = SelfAttention(
221
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
222
+ )
223
+
224
+ self.img_norm2 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
225
+ self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, activation)
226
+
227
+ self.txt_mod = Modulation(hidden_size, multiples=2)
228
+ self.txt_norm1 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
229
+ self.txt_attn = SelfAttention(
230
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
231
+ )
232
+
233
+ self.txt_norm2 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
234
+ self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, activation)
235
+
236
+ def init_weights(self):
237
+ # initialize all the nn.Linear submodules
238
+ init_mlp(self.img_mlp, self.activation)
239
+ init_mlp(self.txt_mlp, self.activation)
240
+
241
+ # initialize Modulation layers, SelfAttention layers
242
+ for layer in (self.img_attn, self.img_mod, self.txt_attn, self.txt_mod):
243
+ layer.init_weights()
244
+
245
+ # Reset parameters for Normalization layers
246
+ for norm in (self.txt_norm1, self.txt_norm2, self.img_norm1, self.img_norm2):
247
+ norm.reset_parameters()
248
+
249
+ def forward(
250
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
251
+ ) -> tuple[Tensor, Tensor]:
252
+ img_mod1, img_mod2 = self.img_mod(vec)
253
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
254
+
255
+ # prepare image for attention
256
+ img_modulated = self.img_norm1(img)
257
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
258
+ img_qkv = self.img_attn.qkv(img_modulated)
259
+ img_q, img_k, img_v = rearrange(
260
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
261
+ )
262
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
263
+
264
+ # prepare txt for attention
265
+ txt_modulated = self.txt_norm1(txt)
266
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
267
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
268
+ txt_q, txt_k, txt_v = rearrange(
269
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
270
+ )
271
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
272
+
273
+ # run actual attention
274
+ q = torch.cat((txt_q, img_q), dim=2)
275
+ k = torch.cat((txt_k, img_k), dim=2)
276
+ v = torch.cat((txt_v, img_v), dim=2)
277
+
278
+ attn = attention(q, k, v, pe=pe)
279
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
280
+
281
+ # calculate the img bloks
282
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
283
+ img = img + img_mod2.gate * self.img_mlp(
284
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
285
+ )
286
+
287
+ # calculate the txt bloks
288
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
289
+ txt = txt + txt_mod2.gate * self.txt_mlp(
290
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
291
+ )
292
+ return img, txt
293
+
294
+
295
+ class SingleStreamBlock(nn.Module):
296
+ """
297
+ A DiT block with parallel linear layers as described in
298
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ hidden_size: int,
304
+ num_heads: int,
305
+ mlp_ratio: float = 4.0,
306
+ qkv_bias: bool = False,
307
+ qk_scale: float | None = None,
308
+ activation: str = "gelu_tanh",
309
+ norm_layer: nn.Module = nn.LayerNorm,
310
+ ):
311
+ super().__init__()
312
+ self.hidden_dim = hidden_size
313
+ self.num_heads = num_heads
314
+ head_dim = hidden_size // num_heads
315
+ self.scale = qk_scale or head_dim**-0.5
316
+ self.activation = activation
317
+
318
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
319
+ if activation == "gelu_tanh":
320
+ # qkv and mlp_in
321
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, bias=qkv_bias)
322
+ else:
323
+ # qkv and mlp_in and mlp_gate
324
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim * 2, bias=qkv_bias)
325
+ # proj and mlp_out
326
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
327
+
328
+ self.norm = QKNorm(head_dim)
329
+
330
+ self.hidden_size = hidden_size
331
+ self.pre_norm = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
332
+
333
+ if activation == "gelu_tanh":
334
+ self.mlp_act = nn.GELU(approximate="tanh")
335
+ else:
336
+ self.mlp_act = nn.SiLU()
337
+ self.modulation = Modulation(hidden_size, multiples=1)
338
+
339
+ def init_weights(self):
340
+ for layer in (self.linear1, self.linear2):
341
+ nn.init.xavier_uniform_(layer.weight)
342
+ if layer.bias is not None:
343
+ nn.init.constant_(layer.bias, 0)
344
+ self.norm.init_weights()
345
+ self.pre_norm.reset_parameters()
346
+ self.modulation.init_weights()
347
+
348
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
349
+ mod = self.modulation(vec)
350
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
351
+ if self.activation == "gelu_tanh":
352
+ qkv, mlp = torch.split(
353
+ self.linear1(x_mod),
354
+ [3 * self.hidden_size, self.mlp_hidden_dim],
355
+ dim=-1
356
+ )
357
+ else:
358
+ qkv, mlp, mlp_gate = torch.split(
359
+ self.linear1(x_mod),
360
+ [3 * self.hidden_size, self.mlp_hidden_dim, self.mlp_hidden_dim],
361
+ dim=-1
362
+ )
363
+
364
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
365
+ q, k = self.norm(q, k, v)
366
+ # compute attention
367
+ attn = attention(q, k, v, pe=pe)
368
+
369
+ if self.activation == "gelu_tanh":
370
+ # compute activation in mlp stream, cat again and run second linear layer
371
+ x = x + mod.gate * self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
372
+ else:
373
+ x = x + mod.gate * self.linear2(
374
+ torch.cat((attn, self.mlp_act(mlp_gate) * mlp), 2)
375
+ )
376
+ return x
377
+
378
+
379
+ class LastLayer(nn.Module):
380
+ def __init__(
381
+ self,
382
+ hidden_size: int,
383
+ patch_size: int,
384
+ out_channels: int,
385
+ norm_layer: nn.Module = nn.LayerNorm,
386
+ ):
387
+ super().__init__()
388
+ self.norm_final = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
389
+ self.linear = nn.Linear(
390
+ hidden_size, patch_size * patch_size * out_channels, bias=True
391
+ )
392
+ self.adaLN_modulation = nn.Sequential(
393
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
394
+ )
395
+
396
+ def init_weights(self):
397
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
398
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
399
+ nn.init.constant_(self.linear.weight, 0)
400
+ nn.init.constant_(self.linear.bias, 0)
401
+ self.norm_final.reset_parameters()
402
+
403
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
404
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
405
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
406
+ x = self.linear(x)
407
+ return x
ovis_image/model/model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+
9
+ from ovis_image.model.layers import (
10
+ DoubleStreamBlock,
11
+ EmbedND,
12
+ LastLayer,
13
+ MLPEmbedder,
14
+ SingleStreamBlock,
15
+ timestep_embedding,
16
+ )
17
+
18
+ from ovis_image.model.args import OvisImageModelArgs
19
+
20
+
21
+ class OvisImageModel(nn.Module):
22
+
23
+ def __init__(self, model_args: OvisImageModelArgs):
24
+ super().__init__()
25
+
26
+ self.model_args = model_args
27
+
28
+ self.in_channels = model_args.in_channels
29
+ self.out_channels = model_args.out_channels
30
+ if model_args.hidden_size % model_args.num_heads != 0:
31
+ raise ValueError(
32
+ f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
33
+ )
34
+ pe_dim = model_args.hidden_size // model_args.num_heads
35
+ if sum(model_args.axes_dim) != pe_dim:
36
+ raise ValueError(
37
+ f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
38
+ )
39
+ self.hidden_size = model_args.hidden_size
40
+ self.num_heads = model_args.num_heads
41
+ self.pe_embedder = EmbedND(
42
+ dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
43
+ )
44
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
45
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
46
+ self.semantic_txt_norm = nn.RMSNorm(model_args.context_in_dim, eps=1e-6)
47
+ self.semantic_txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size, bias=True)
48
+
49
+ if model_args.norm == "layernorm":
50
+ norm_layer = nn.LayerNorm
51
+ else:
52
+ norm_layer = nn.RMSNorm
53
+
54
+ DoubleBlock = DoubleStreamBlock
55
+
56
+ self.double_blocks = nn.ModuleList(
57
+ [
58
+ DoubleBlock(
59
+ self.hidden_size,
60
+ self.num_heads,
61
+ mlp_ratio=model_args.mlp_ratio,
62
+ qkv_bias=model_args.qkv_bias,
63
+ activation=model_args.activation,
64
+ norm_layer=norm_layer,
65
+ )
66
+ for _ in range(model_args.depth)
67
+ ]
68
+ )
69
+
70
+ self.single_blocks = nn.ModuleList(
71
+ [
72
+ SingleStreamBlock(
73
+ self.hidden_size,
74
+ self.num_heads,
75
+ mlp_ratio=model_args.mlp_ratio,
76
+ qkv_bias=model_args.qkv_bias,
77
+ activation=model_args.activation,
78
+ norm_layer=norm_layer,
79
+ )
80
+ for _ in range(model_args.depth_single_blocks)
81
+ ]
82
+ )
83
+
84
+ self.final_layer = LastLayer(
85
+ self.hidden_size,
86
+ 1,
87
+ self.out_channels,
88
+ norm_layer=norm_layer,
89
+ )
90
+
91
+ def forward(
92
+ self,
93
+ img: Tensor,
94
+ img_ids: Tensor,
95
+ txt: Tensor,
96
+ txt_ids: Tensor,
97
+ timesteps: Tensor,
98
+ ) -> Tensor:
99
+ if img.ndim != 3 or txt.ndim != 3:
100
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
101
+
102
+ # running on sequences img
103
+ img = self.img_in(img)
104
+ vec = self.time_in(timestep_embedding(timesteps, 256))
105
+ txt = self.semantic_txt_norm(txt)
106
+ txt = self.semantic_txt_in(txt)
107
+ ids = torch.cat((txt_ids, img_ids), dim=1)
108
+ pe = self.pe_embedder(ids)
109
+
110
+ for block in self.double_blocks:
111
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
112
+
113
+ img = torch.cat((txt, img), 1)
114
+ for block in self.single_blocks:
115
+ img = block(img, vec=vec, pe=pe)
116
+ img = img[:, txt.shape[1] :, ...]
117
+
118
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
119
+ return img
120
+
121
+
ovis_image/model/ops.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ import torch
7
+ from einops import rearrange
8
+ from torch import Tensor
9
+ from torch.nn.attention import sdpa_kernel, SDPBackend
10
+ flash_attn_func = None
11
+ try:
12
+ from flash_attn_interface import flash_attn_func
13
+ print("find flash attn 3")
14
+ except:
15
+ flash_attn_func = None
16
+
17
+
18
+ def check_attention_type(attn_implementation):
19
+ if torch.__version__ >= "2.7.0":
20
+ if attn_implementation != "sdpa":
21
+ print("please set attn_implementation as sdpa for torch271")
22
+ elif flash_attn_func is not None:
23
+ if attn_implementation != "flash_attention_3":
24
+ print("please set attn_implementation as flash_attention_3 for H100")
25
+
26
+ def get_attention_type_by_system():
27
+ if torch.__version__ >= "2.7.0":
28
+ return "sdpa"
29
+ elif flash_attn_func is not None:
30
+ return "flash_attention_3"
31
+ else:
32
+ return "eager"
33
+
34
+
35
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
36
+ if torch.__version__ >= "2.7.0":
37
+ return attention_sdpa(q, k, v, pe)
38
+ elif flash_attn_func is not None:
39
+ return attention_fa3(q, k, v, pe)
40
+ else:
41
+ return attention_eager(q, k, v, pe)
42
+
43
+
44
+ def attention_eager(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
45
+ q, k = apply_rope(q, k, pe)
46
+ # https://docs.pytorch.org/docs/2.6/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
47
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
48
+ x = rearrange(x, "B H L D -> B L (H D)")
49
+ return x
50
+
51
+
52
+ def attention_sdpa(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
53
+ q, k = apply_rope(q, k, pe)
54
+ # B200用torch271镜像,用SDPA加速
55
+ with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
56
+ x = torch.nn.functional.scaled_dot_product_attention(
57
+ q, k, v,
58
+ )
59
+ x = rearrange(x, "B H L D -> B L (H D)")
60
+ return x
61
+
62
+
63
+ def attention_fa3(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
64
+ q, k = apply_rope(q, k, pe)
65
+ # H100上用flash_attn_3加速
66
+ q = rearrange(q, "B H L D -> B L H D")
67
+ k = rearrange(k, "B H L D -> B L H D")
68
+ v = rearrange(v, "B H L D -> B L H D")
69
+ x = flash_attn_func(q, k, v)[0]
70
+ x = rearrange(x, "B L H D -> B L (H D)")
71
+ return x
72
+
73
+ def get_attention_func(attn_implementation):
74
+ if attn_implementation == "eager":
75
+ return attention_eager
76
+ elif attn_implementation == "sdpa":
77
+ return attention_sdpa
78
+ elif attn_implementation == "flash_attention_3":
79
+ return attention_fa3
80
+ else:
81
+ return attention_eager
82
+
83
+
84
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
85
+ assert dim % 2 == 0
86
+ scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
87
+ omega = 1.0 / (theta**scale)
88
+ out = torch.einsum("...n,d->...nd", pos, omega)
89
+ out = torch.stack(
90
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
91
+ )
92
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
93
+ return out.float()
94
+
95
+
96
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
97
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
98
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
99
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
100
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
101
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
102
+
ovis_image/model/ovis/configuration_ovis2_5.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, List, Union
2
+
3
+ from transformers import Qwen3Config
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+ __all__ = ["Siglip2NavitConfig", "Ovis2_5_Config"]
7
+
8
+
9
+ class Siglip2NavitConfig(PretrainedConfig):
10
+ """This is the configuration class to store the configuration of an [`AIMv2Model`].
11
+
12
+ Instantiating a configuration with the defaults will yield a similar configuration
13
+ to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
14
+
15
+ Args:
16
+ hidden_size: Dimension of the hidden representations.
17
+ intermediate_size: Dimension of the SwiGLU representations.
18
+ num_hidden_layers: Number of hidden layers in the Transformer.
19
+ num_attention_heads: Number of attention heads for each attention layer
20
+ in the Transformer.
21
+ num_channels: Number of input channels.
22
+ image_size: Image size.
23
+ patch_size: Patch size.
24
+ rms_norm_eps: Epsilon value used for the RMS normalization layer.
25
+ attention_dropout: Dropout ratio for attention probabilities.
26
+ projection_dropout: Dropout ratio for the projection layer after the attention.
27
+ qkv_bias: Whether to add a bias to the queries, keys and values.
28
+ use_bias: Whether to add a bias in the feed-forward and projection layers.
29
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
30
+ """
31
+
32
+ model_type: str = "siglip2_navit"
33
+
34
+ def __init__(
35
+ self,
36
+ hidden_size: int = 1024,
37
+ intermediate_size: int = 4096,
38
+ num_hidden_layers: int = 24,
39
+ num_attention_heads: int = 16,
40
+ num_channels: int = 3,
41
+ num_patches: int = -1,
42
+ image_size: int = 512,
43
+ patch_size: int = 16,
44
+ hidden_act: str="gelu_pytorch_tanh",
45
+ layer_norm_eps: float = 1e-6,
46
+ attention_dropout: float = 0.0,
47
+ hidden_stride: int = 2,
48
+ window_size: int = 112,
49
+ fullatt_block_indexes: Optional[list] = None,
50
+ temporal_patch_size: int = 1,
51
+ preserve_original_pe: bool = True,
52
+ use_rope: bool = True,
53
+ **kwargs: Any,
54
+ ):
55
+ super().__init__(**kwargs)
56
+ self.hidden_size = hidden_size
57
+ self.intermediate_size = intermediate_size
58
+ self.num_hidden_layers = num_hidden_layers
59
+ self.num_attention_heads = num_attention_heads
60
+ self.num_channels = num_channels
61
+ self.num_patches = num_patches
62
+ self.patch_size = patch_size
63
+ self.image_size = image_size
64
+ self.hidden_act = hidden_act
65
+ self.attention_dropout = attention_dropout
66
+ self.layer_norm_eps = layer_norm_eps
67
+ self.hidden_stride = hidden_stride
68
+ self.window_size = window_size
69
+ self.fullatt_block_indexes = fullatt_block_indexes
70
+ self.temporal_patch_size = temporal_patch_size
71
+ self.preserve_original_pe = preserve_original_pe
72
+ self.use_rope = use_rope
73
+
74
+ class Ovis2_5_Config(PretrainedConfig):
75
+ model_type = "ovis2_5"
76
+ sub_configs = dict(llm_config=Qwen3Config, vit_config=Siglip2NavitConfig)
77
+
78
+ def __init__(self,
79
+ llm_config: Optional[Union[Qwen3Config, dict]] = None,
80
+ vit_config: Optional[Union[Siglip2NavitConfig, dict]] = None,
81
+ visual_vocab_size=65536,
82
+ hidden_size=None,
83
+ **kwargs
84
+ ):
85
+ super().__init__(**kwargs)
86
+ if isinstance(llm_config, dict):
87
+ llm_config = Qwen3Config(**llm_config)
88
+ self.llm_config = llm_config
89
+ if isinstance(vit_config, dict):
90
+ vit_config = Siglip2NavitConfig(**vit_config)
91
+ self.vit_config = vit_config
92
+ self.visual_vocab_size = visual_vocab_size
93
+ self.hidden_size = hidden_size
94
+ if kwargs.get('attn_implementation'):
95
+ self.llm_config._attn_implementation = kwargs['attn_implementation']
96
+ self.vit_config._attn_implementation = kwargs['attn_implementation']
ovis_image/model/ovis/modeling_ovis2_5.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+ import math
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
8
+ import PIL.Image
9
+ import numpy as np
10
+ import torch
11
+ # from flash_attn import flash_attn_varlen_func
12
+ # from flash_attn.layers.rotary import apply_rotary_emb
13
+ from torch import Tensor, nn
14
+ from torch.nn import functional as F
15
+ from transformers import (
16
+ AutoConfig,
17
+ AutoImageProcessor,
18
+ AutoModel,
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ )
22
+ from transformers.activations import ACT2FN
23
+ from transformers.generation.utils import GenerateOutput
24
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
25
+ from transformers.modeling_utils import PreTrainedModel
26
+
27
+ from ovis_image.model.ovis.configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
28
+ from ovis_image.model.ops import get_attention_type_by_system
29
+
30
+ IMAGE_PLACEHOLDER = "<image>"
31
+ IMAGE_PLACEHOLDER_ID = -200
32
+ VIDEO_PLACEHOLDER = "<video>"
33
+ VIDEO_PLACEHOLDER_ID = -201
34
+
35
+ VISUAL_ATOM_ID = -300
36
+ INDICATOR_IDS = [-301, -302, -303, -304]
37
+
38
+ # copied from qwen2.5-vl
39
+ class VisionRotaryEmbedding(nn.Module):
40
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
41
+ super().__init__()
42
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
43
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
44
+
45
+ def forward(self, seqlen: int) -> torch.Tensor:
46
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
47
+ freqs = torch.outer(seq, self.inv_freq)
48
+ return freqs
49
+
50
+
51
+ class Siglip2VisionEmbeddings(nn.Module):
52
+ def __init__(self, config: Siglip2NavitConfig):
53
+ super().__init__()
54
+ self.config = config
55
+ self.embed_dim = config.hidden_size
56
+ self.patch_size = config.patch_size
57
+ self.image_size = config.image_size
58
+ self.num_patches = config.num_patches
59
+ self.preserve_original_pe = config.preserve_original_pe
60
+ self.hidden_stride = config.hidden_stride
61
+
62
+
63
+ # siglip2 naflex
64
+ if self.num_patches > 0:
65
+ self.patch_embedding = nn.Linear(
66
+ in_features=config.num_channels * self.patch_size * self.patch_size,
67
+ out_features=self.embed_dim,
68
+ )
69
+ if self.preserve_original_pe:
70
+ self.position_embedding_size = int(self.num_patches**0.5)
71
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
72
+
73
+ else:
74
+ self.patch_embedding = nn.Conv2d(
75
+ in_channels=config.num_channels,
76
+ out_channels=self.embed_dim,
77
+ kernel_size=self.patch_size,
78
+ stride=self.patch_size,
79
+ padding="valid",
80
+ )
81
+ if self.preserve_original_pe:
82
+ self.num_patches = (self.image_size // self.patch_size) ** 2
83
+ self.position_embedding_size = self.image_size // self.patch_size
84
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
85
+
86
+ @staticmethod
87
+ def resize_positional_embeddings(
88
+ positional_embeddings: torch.Tensor,
89
+ spatial_shapes: torch.LongTensor,
90
+ max_length: int,
91
+ ) -> torch.Tensor:
92
+ """
93
+ Resize positional embeddings to image-specific size and pad to a fixed size.
94
+
95
+ Args:
96
+ positional_embeddings (`torch.Tensor`):
97
+ Position embeddings of shape (height, width, embed_dim)
98
+ spatial_shapes (`torch.LongTensor`):
99
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
100
+ max_length (`int`):
101
+ Maximum length of the positional embeddings to pad resized positional embeddings to
102
+
103
+ Returns:
104
+ `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
105
+ """
106
+ batch_size = spatial_shapes.shape[0]
107
+ embed_dim = positional_embeddings.shape[-1]
108
+ source_dtype = positional_embeddings.dtype
109
+
110
+ resulted_positional_embeddings = torch.empty(
111
+ (batch_size, max_length, embed_dim),
112
+ device=positional_embeddings.device,
113
+ dtype=source_dtype,
114
+ )
115
+
116
+ # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
117
+ positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
118
+
119
+ # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
120
+ if positional_embeddings.device.type == "cpu":
121
+ positional_embeddings = positional_embeddings.to(torch.float32)
122
+
123
+ for i in range(batch_size):
124
+ # (1, dim, height, width) -> (1, dim, target_height, target_width)
125
+ height, width = spatial_shapes[i]
126
+ resized_embeddings = F.interpolate(
127
+ positional_embeddings,
128
+ size=(height, width),
129
+ mode="bilinear",
130
+ align_corners=False,
131
+ antialias=True,
132
+ )
133
+
134
+ # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
135
+ resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
136
+
137
+ # Cast to original dtype
138
+ resized_embeddings = resized_embeddings.to(source_dtype)
139
+
140
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
141
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
142
+
143
+ return resulted_positional_embeddings
144
+
145
+ def forward(self, pixel_values: torch.FloatTensor,
146
+ grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
147
+ """
148
+ Args:
149
+ pixel_values (`torch.FloatTensor`):
150
+ Pixel values of shape (num_patches, num_channels * temporal_patch_size * patch_size * patch_size)
151
+ grid_thws: (`torch.LongTensor`):
152
+ grid shape (num_patches, 3)
153
+ """
154
+
155
+ # Apply patch embeddings to already patchified pixel values
156
+ target_dtype = self.patch_embedding.weight.dtype
157
+ if isinstance(self.patch_embedding, nn.Linear):
158
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
159
+ elif isinstance(self.patch_embedding, nn.Conv2d):
160
+ pixel_values = pixel_values.view(-1, self.config.num_channels * self.config.temporal_patch_size, self.patch_size,
161
+ self.patch_size)
162
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
163
+ patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
164
+
165
+
166
+ if self.preserve_original_pe:
167
+ assert grid_thws is not None
168
+ pos_embed_new = torch.zeros_like(patch_embeds)
169
+ ori_h = ori_w = self.position_embedding_size
170
+ positional_embeddings = self.position_embedding.weight.reshape(
171
+ self.position_embedding_size, self.position_embedding_size, -1
172
+ ).unsqueeze(0).permute(0,3,1,2)
173
+ # pos_embed = self.pos_embed.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2)
174
+ cnt = 0
175
+ for t, h, w in grid_thws:
176
+ thw = t * h * w
177
+ pe = F.interpolate(positional_embeddings, size=(h, w), mode='bicubic', align_corners=False)
178
+ pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
179
+ pe = pe[0].repeat(t, 1)
180
+ pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, w // self.hidden_stride,
181
+ self.hidden_stride, -1)
182
+ pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(thw, -1)
183
+ pos_embed_new[cnt:cnt + thw] = pe
184
+ cnt += thw
185
+ patch_embeds = patch_embeds + pos_embed_new
186
+
187
+ return patch_embeds
188
+
189
+
190
+
191
+ def rotate_half(x):
192
+ x1 = x[..., : x.shape[-1] // 2]
193
+ x2 = x[..., x.shape[-1] // 2 :]
194
+ return torch.cat((-x2, x1), dim=-1)
195
+
196
+ def apply_rotary_pos_emb_flashatt(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
197
+ cos = cos.unsqueeze(unsqueeze_dim)
198
+ sin = sin.unsqueeze(unsqueeze_dim)
199
+ q_embed = (q * cos) + (rotate_half(q) * sin)
200
+ k_embed = (k * cos) + (rotate_half(k) * sin)
201
+ return q_embed, k_embed
202
+
203
+
204
+ class Siglip2Attention(nn.Module):
205
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
206
+
207
+ def __init__(self, config):
208
+ super().__init__()
209
+ self.config = config
210
+ self.embed_dim = config.hidden_size
211
+ self.num_heads = config.num_attention_heads
212
+ self.head_dim = self.embed_dim // self.num_heads
213
+ if self.head_dim * self.num_heads != self.embed_dim:
214
+ raise ValueError(
215
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
216
+ f" {self.num_heads})."
217
+ )
218
+ self.scale = self.head_dim**-0.5
219
+ self.dropout = config.attention_dropout
220
+ self.is_causal = False
221
+
222
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
223
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
224
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
225
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
226
+
227
+ self.use_rope = config.use_rope
228
+
229
+ def forward(
230
+ self,
231
+ hidden_states: torch.Tensor,
232
+ cu_seqlens: torch.Tensor,
233
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
234
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
235
+ """Input shape: Batch x Time x Channel"""
236
+
237
+ seq_length, embed_dim = hidden_states.shape
238
+
239
+ queries = self.q_proj(hidden_states)
240
+ keys = self.k_proj(hidden_states)
241
+ values = self.v_proj(hidden_states)
242
+
243
+ queries = queries.view(seq_length, self.num_heads, self.head_dim)
244
+ keys = keys.view(seq_length, self.num_heads, self.head_dim)
245
+ values = values.view(seq_length, self.num_heads, self.head_dim)
246
+
247
+ if self.use_rope:
248
+ cos, sin = position_embeddings
249
+ queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
250
+ queries = queries.squeeze(0)
251
+ keys = keys.squeeze(0)
252
+
253
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
254
+ # attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
255
+ # seq_length, -1
256
+ # )
257
+ batch_size = cu_seqlens.shape[0] - 1
258
+ q_list, k_list, v_list = [], [], []
259
+ for i in range(batch_size):
260
+ start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
261
+ q_list.append(queries[start:end])
262
+ k_list.append(keys[start:end])
263
+ v_list.append(values[start:end])
264
+
265
+ def pad_to_max(t, max_len):
266
+ pad = (0, 0, 0, 0, 0, max_len - t.shape[0]) # [seqlen, num_heads, head_dim]
267
+ return torch.nn.functional.pad(t, pad)
268
+
269
+ q_batched = torch.stack([pad_to_max(x, max_seqlen) for x in q_list]) # (batch, seqlen, nhead, dim)
270
+ k_batched = torch.stack([pad_to_max(x, max_seqlen) for x in k_list])
271
+ v_batched = torch.stack([pad_to_max(x, max_seqlen) for x in v_list])
272
+
273
+ mask = torch.zeros((batch_size, max_seqlen), dtype=torch.bool, device=queries.device)
274
+ for i in range(batch_size):
275
+ mask[i, :q_list[i].shape[0]] = True # [batch, seqlen]
276
+
277
+ # (batch, nhead, seqlen, head_dim)
278
+ q_batched = q_batched.transpose(1, 2) # (batch, nhead, seqlen, head_dim)
279
+ k_batched = k_batched.transpose(1, 2)
280
+ v_batched = v_batched.transpose(1, 2)
281
+ if torch.__version__ >= "2.7.0":
282
+ from torch.nn.attention import sdpa_kernel, SDPBackend
283
+ with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
284
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
285
+ q_batched, k_batched, v_batched,
286
+ attn_mask=mask.unsqueeze(1).unsqueeze(2)
287
+ ).permute(0, 2, 1, 3).reshape(seq_length, -1)
288
+ else:
289
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
290
+ q_batched, k_batched, v_batched,
291
+ attn_mask = mask.unsqueeze(1).unsqueeze(2) # broadcast到 (batch, 1, 1, seqlen)
292
+ ).permute(0, 2, 1, 3).reshape(seq_length, -1)
293
+
294
+ attn_output = self.out_proj(attn_output)
295
+ return attn_output
296
+
297
+ class Siglip2MLP(nn.Module):
298
+ def __init__(self, config):
299
+ super().__init__()
300
+ self.config = config
301
+ self.activation_fn = ACT2FN[config.hidden_act]
302
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
303
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
304
+
305
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
306
+ hidden_states = self.fc1(hidden_states)
307
+ hidden_states = self.activation_fn(hidden_states)
308
+ hidden_states = self.fc2(hidden_states)
309
+ return hidden_states
310
+
311
+
312
+ class Siglip2EncoderLayer(nn.Module):
313
+ def __init__(self, config: Siglip2NavitConfig):
314
+ super().__init__()
315
+ self.embed_dim = config.hidden_size
316
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
317
+ self.self_attn = Siglip2Attention(config)
318
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
319
+ self.mlp = Siglip2MLP(config)
320
+
321
+ def forward(
322
+ self,
323
+ hidden_states: torch.Tensor,
324
+ cu_seqlens: torch.Tensor,
325
+ position_embeddings: torch.Tensor
326
+ ) -> tuple[torch.FloatTensor]:
327
+ """
328
+ Args:
329
+ hidden_states (`torch.FloatTensor`):
330
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
331
+ attention_mask (`torch.FloatTensor`):
332
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
333
+ output_attentions (`bool`, *optional*, defaults to `False`):
334
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
335
+ returned tensors for more detail.
336
+ """
337
+ residual = hidden_states
338
+
339
+ hidden_states = self.layer_norm1(hidden_states)
340
+ hidden_states = self.self_attn(
341
+ hidden_states=hidden_states,
342
+ cu_seqlens=cu_seqlens,
343
+ position_embeddings=position_embeddings
344
+ )
345
+ hidden_states = residual + hidden_states
346
+
347
+ residual = hidden_states
348
+ hidden_states = self.layer_norm2(hidden_states)
349
+ hidden_states = self.mlp(hidden_states)
350
+ hidden_states = residual + hidden_states
351
+
352
+ return hidden_states
353
+
354
+ class Siglip2Encoder(nn.Module):
355
+ """
356
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
357
+ [`Siglip2EncoderLayer`].
358
+
359
+ Args:
360
+ config: Siglip2NavitConfig
361
+ """
362
+
363
+ def __init__(self, config: Siglip2NavitConfig):
364
+ super().__init__()
365
+ self.config = config
366
+ self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
367
+ self.gradient_checkpointing = False
368
+
369
+ self.rotary_pos_emb = VisionRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
370
+ self.patch_size = config.patch_size
371
+ self.hidden_stride = config.hidden_stride
372
+ self.window_size = config.window_size
373
+ self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
374
+ self.fullatt_block_indexes = None if config.fullatt_block_indexes is None else [int(i) for i in config.fullatt_block_indexes.split('|')]
375
+
376
+
377
+ # copied from qwen2.5_vl
378
+ def rot_pos_emb(self, grid_thw):
379
+ pos_ids = []
380
+ for t, h, w in grid_thw:
381
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
382
+ hpos_ids = hpos_ids.reshape(
383
+ h // self.hidden_stride,
384
+ self.hidden_stride,
385
+ w // self.hidden_stride,
386
+ self.hidden_stride,
387
+ )
388
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
389
+ hpos_ids = hpos_ids.flatten()
390
+
391
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
392
+ wpos_ids = wpos_ids.reshape(
393
+ h // self.hidden_stride,
394
+ self.hidden_stride,
395
+ w // self.hidden_stride,
396
+ self.hidden_stride,
397
+ )
398
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
399
+ wpos_ids = wpos_ids.flatten()
400
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
401
+ pos_ids = torch.cat(pos_ids, dim=0)
402
+ max_grid_size = grid_thw[:, 1:].max()
403
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
404
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
405
+ return rotary_pos_emb
406
+
407
+ def get_window_index(self, grid_thw):
408
+ window_index: list = []
409
+ cu_window_seqlens: list = [0]
410
+ window_index_id = 0
411
+ vit_merger_window_size = self.window_size // self.hidden_stride // self.patch_size # patch (after merge) number in each window
412
+
413
+ for grid_t, grid_h, grid_w in grid_thw:
414
+ llm_grid_h, llm_grid_w = (
415
+ grid_h // self.hidden_stride, # number of patch after merge
416
+ grid_w // self.hidden_stride,
417
+ )
418
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
419
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
420
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
421
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
422
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
423
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
424
+ index_padded = index_padded.reshape(
425
+ grid_t,
426
+ num_windows_h,
427
+ vit_merger_window_size,
428
+ num_windows_w,
429
+ vit_merger_window_size,
430
+ )
431
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
432
+ grid_t,
433
+ num_windows_h * num_windows_w,
434
+ vit_merger_window_size,
435
+ vit_merger_window_size,
436
+ )
437
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
438
+ index_padded = index_padded.reshape(-1)
439
+ index_new = index_padded[index_padded != -100]
440
+ window_index.append(index_new + window_index_id)
441
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
442
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
443
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
444
+ window_index = torch.cat(window_index, dim=0)
445
+
446
+ return window_index, cu_window_seqlens
447
+
448
+ # Ignore copy
449
+ def forward(
450
+ self,
451
+ inputs_embeds,
452
+ grid_thws: torch.Tensor,
453
+ output_hidden_states: bool = False,
454
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
455
+ r"""
456
+ Args:
457
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
458
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
459
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
460
+ than the model's internal embedding lookup matrix.
461
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
462
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
463
+
464
+ - 1 for tokens that are **not masked**,
465
+ - 0 for tokens that are **masked**.
466
+
467
+ [What are attention masks?](../glossary#attention-mask)
468
+ output_attentions (`bool`, *optional*):
469
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
470
+ returned tensors for more detail.
471
+ output_hidden_states (`bool`, *optional*):
472
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
473
+ for more detail.
474
+ return_dict (`bool`, *optional*):
475
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
476
+ """
477
+
478
+ rotary_pos_emb = self.rot_pos_emb(grid_thws)
479
+ window_index, cu_window_seqlens = self.get_window_index(grid_thws)
480
+ cu_window_seqlens = torch.tensor(
481
+ cu_window_seqlens,
482
+ device=inputs_embeds.device,
483
+ dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
484
+ )
485
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
486
+
487
+ seq_len, _ = inputs_embeds.size()
488
+ inputs_embeds = inputs_embeds.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
489
+ inputs_embeds = inputs_embeds[window_index, :, :]
490
+ inputs_embeds = inputs_embeds.reshape(seq_len, -1)
491
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
492
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
493
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
494
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
495
+ position_embeddings = (emb.cos(), emb.sin())
496
+
497
+ cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(
498
+ dim=0,
499
+ # Select dtype based on the following factors:
500
+ # - FA2 requires that cu_seqlens_q must have dtype int32
501
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
502
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
503
+ dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
504
+ )
505
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
506
+
507
+ reverse_indices = torch.argsort(window_index)
508
+ encoder_states = () if output_hidden_states else None
509
+
510
+ hidden_states = inputs_embeds
511
+ for index, block in enumerate(self.layers):
512
+ if self.fullatt_block_indexes is None or index in self.fullatt_block_indexes:
513
+ cu_seqlens_tmp = cu_seqlens
514
+ else:
515
+ cu_seqlens_tmp = cu_window_seqlens
516
+ if self.gradient_checkpointing and self.training:
517
+ hidden_states = self._gradient_checkpointing_func(block.__call__, hidden_states, cu_seqlens_tmp, position_embeddings)
518
+ else:
519
+ hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
520
+ if output_hidden_states:
521
+ hidden_states_ = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
522
+ encoder_states += (hidden_states_[reverse_indices, :].reshape(seq_len, -1),)
523
+ # tokens = self.post_trunk_norm(tokens)
524
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
525
+ hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
526
+
527
+ return hidden_states, encoder_states
528
+
529
+ class Siglip2VisionTransformer(nn.Module):
530
+ def __init__(self, config: Siglip2NavitConfig):
531
+ super().__init__()
532
+ self.config = config
533
+ embed_dim = config.hidden_size
534
+
535
+ self.embeddings = Siglip2VisionEmbeddings(config)
536
+ self.encoder = Siglip2Encoder(config)
537
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
538
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
539
+
540
+ def forward(
541
+ self,
542
+ pixel_values: torch.FloatTensor,
543
+ grid_thws: torch.LongTensor,
544
+ output_hidden_states: Optional[bool] = True,
545
+ return_dict: Optional[bool] = True,
546
+ ) -> Union[
547
+ Tuple[torch.Tensor],
548
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
549
+ BaseModelOutputWithNoAttention,
550
+ ]:
551
+ r"""
552
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
553
+ Tensor containing the spatial dimensions (height, width) of the input images.
554
+ """
555
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
556
+ # output_hidden_states = (
557
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
558
+ # )
559
+
560
+ hidden_states = self.embeddings(pixel_values, grid_thws)
561
+
562
+ last_hidden_state, hidden_states = self.encoder(hidden_states, grid_thws, output_hidden_states)
563
+ last_hidden_state = self.post_layernorm(last_hidden_state)
564
+
565
+ if not return_dict:
566
+ output = (last_hidden_state,)
567
+ output += (hidden_states,) if output_hidden_states else ()
568
+ return output
569
+
570
+ return BaseModelOutputWithNoAttention(
571
+ last_hidden_state=last_hidden_state,
572
+ hidden_states=hidden_states
573
+ )
574
+
575
+ class Siglip2PreTrainedModel(PreTrainedModel):
576
+ config_class = Siglip2NavitConfig
577
+ base_model_prefix = "siglip2_navit"
578
+ supports_gradient_checkpointing = True
579
+
580
+ _no_split_modules = [
581
+ "Siglip2VisionEmbeddings",
582
+ "Siglip2EncoderLayer",
583
+ ]
584
+ _supports_flash_attn_2 = True
585
+ _supports_sdpa = False
586
+ _supports_flex_attn = False
587
+ _supports_attention_backend = True
588
+
589
+
590
+ class Siglip2NavitModel(Siglip2PreTrainedModel):
591
+ config_class = Siglip2NavitConfig
592
+ main_input_name = "pixel_values"
593
+
594
+ def __init__(self, config: Siglip2NavitConfig):
595
+ super().__init__(config)
596
+
597
+ self.vision_model = Siglip2VisionTransformer(config)
598
+
599
+ def get_input_embeddings(self) -> nn.Module:
600
+ return self.vision_model.embeddings.patch_embedding
601
+
602
+ def forward(
603
+ self,
604
+ pixel_values: torch.FloatTensor,
605
+ grid_thws: torch.LongTensor,
606
+ output_hidden_states: Optional[bool] = None,
607
+ return_dict: Optional[bool] = None,
608
+ ) -> Union[
609
+ Tuple[torch.Tensor],
610
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
611
+ BaseModelOutputWithNoAttention,
612
+ ]:
613
+
614
+ if output_hidden_states is None:
615
+ output_hidden_states = self.config.output_hidden_states
616
+ if return_dict is None:
617
+ return_dict = self.config.use_return_dict
618
+
619
+ return self.vision_model(
620
+ pixel_values=pixel_values,
621
+ grid_thws=grid_thws,
622
+ output_hidden_states=output_hidden_states,
623
+ return_dict=return_dict,
624
+ )
625
+
626
+ class VisualEmbedding(torch.nn.Embedding):
627
+ """
628
+ A visual embedding layer that can handle both discrete token IDs (long) and continuous
629
+ soft-token probabilities (float).
630
+ """
631
+
632
+ def forward(self, visual_tokens: Tensor) -> Tensor:
633
+ if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
634
+ return super().forward(visual_tokens)
635
+ # Handle soft tokens (probabilities) by matrix multiplication with the embedding weight
636
+ return torch.matmul(visual_tokens, self.weight)
637
+
638
+
639
+ class VisualTokenizer(torch.nn.Module):
640
+ """
641
+ Tokenizes images or videos into a sequence of continuous visual tokens.
642
+ """
643
+
644
+ def __init__(self, vit, visual_vocab_size, image_processor_name_or_path, *args, **kwargs):
645
+ super().__init__(*args, **kwargs)
646
+ self.vit = vit
647
+ self.image_processor = AutoImageProcessor.from_pretrained(image_processor_name_or_path, do_center_crop=False)
648
+ head_dim = visual_vocab_size - len(INDICATOR_IDS)
649
+ self.head = torch.nn.Sequential(
650
+ torch.nn.Linear(self.vit.config.hidden_size * self.vit.config.hidden_stride ** 2, head_dim, bias=False),
651
+ torch.nn.LayerNorm(head_dim)
652
+ )
653
+
654
+ def _encode(self, pixel_values, grid_thws):
655
+ output = self.vit(pixel_values, grid_thws, output_hidden_states=True, return_dict=True)
656
+ features = output.hidden_states[-1]
657
+ seq_len, _ = features.shape
658
+ features = features.reshape(seq_len // (self.vit.config.hidden_stride ** 2), -1)
659
+ return features
660
+
661
+ # Adapted from qwen2_vl
662
+ @staticmethod
663
+ def smart_resize(
664
+ height: int, width: int, factor: int = 28, min_pixels: int = 448 * 448, max_pixels: int = 1344 * 1792
665
+ ):
666
+ """Rescales the image so that the following conditions are met:
667
+ 1. Both dimensions are divisible by 'factor'.
668
+ 2. The total number of pixels is within ['min_pixels', 'max_pixels'].
669
+ 3. The aspect ratio is maintained as closely as possible.
670
+ """
671
+ if height < factor or width < factor:
672
+ if height < width:
673
+ width = round(factor / height * width)
674
+ height = factor
675
+ else:
676
+ height = round(factor / width * height)
677
+ width = factor
678
+
679
+ elif max(height, width) / min(height, width) > 200:
680
+ if height > width:
681
+ height = 200 * width
682
+ else:
683
+ width = 200 * height
684
+
685
+ h_bar = round(height / factor) * factor
686
+ w_bar = round(width / factor) * factor
687
+ if h_bar * w_bar > max_pixels:
688
+ beta = math.sqrt((height * width) / max_pixels)
689
+ h_bar = math.floor(height / beta / factor) * factor
690
+ w_bar = math.floor(width / beta / factor) * factor
691
+ elif h_bar * w_bar < min_pixels:
692
+ beta = math.sqrt(min_pixels / (height * width))
693
+ h_bar = math.ceil(height * beta / factor) * factor
694
+ w_bar = math.ceil(width * beta / factor) * factor
695
+ return h_bar, w_bar
696
+
697
+ def preprocess(
698
+ self,
699
+ image: Optional[PIL.Image.Image] = None,
700
+ video: Optional[List[PIL.Image.Image]] = None,
701
+ min_pixels: Optional[int] = None,
702
+ max_pixels: Optional[int] = None
703
+ ):
704
+ patch_size = self.vit.config.patch_size
705
+ temporal_patch_size = self.vit.config.temporal_patch_size
706
+ hidden_stride = self.vit.config.hidden_stride
707
+ assert (image is None) ^ (video is None), "Invalid input: expect either image or video"
708
+ if image is not None:
709
+ images = [image]
710
+ else:
711
+ images = video
712
+ images = [image.convert("RGB") if image.mode != 'RGB' else image for image in images]
713
+ width, height = images[0].size
714
+ processed_images = []
715
+ for image in images:
716
+ resized_height, resized_width = self.smart_resize(
717
+ height,
718
+ width,
719
+ factor=patch_size * hidden_stride,
720
+ min_pixels=min_pixels,
721
+ max_pixels=max_pixels,
722
+ )
723
+ new_size = dict(height=resized_height, width=resized_width)
724
+ new_image = self.image_processor.preprocess(image, size=new_size, return_tensors="np")['pixel_values'][0]
725
+ processed_images.append(new_image)
726
+
727
+ patches = np.array(processed_images)
728
+ if patches.shape[0] % temporal_patch_size != 0:
729
+ repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size - 1, axis=0)
730
+ patches = np.concatenate([patches, repeats], axis=0)
731
+ channel = patches.shape[1]
732
+ grid_t = patches.shape[0] // temporal_patch_size
733
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
734
+ grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
735
+
736
+ patches = patches.reshape(
737
+ grid_t, temporal_patch_size, channel,
738
+ grid_h // hidden_stride, hidden_stride, patch_size,
739
+ grid_w // hidden_stride, hidden_stride, patch_size,
740
+ )
741
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
742
+ flatten_patches = patches.reshape(
743
+ grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
744
+ )
745
+ flatten_patches = torch.tensor(flatten_patches)
746
+
747
+ return flatten_patches, grid_thw
748
+
749
+ def forward(
750
+ self, pixel_values, grid_thws
751
+ ) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
752
+ features = self._encode(pixel_values, grid_thws)
753
+ logits = self.head(features)
754
+ tokens = torch.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype)
755
+
756
+ token_len, _ = tokens.shape
757
+ padding_tensor = torch.zeros(size=(token_len, len(INDICATOR_IDS)),
758
+ dtype=tokens.dtype,
759
+ device=tokens.device,
760
+ layout=tokens.layout,
761
+ requires_grad=False)
762
+ tokens = torch.cat((tokens, padding_tensor), dim=1)
763
+ return tokens
764
+
765
+
766
+ class OvisPreTrainedModel(PreTrainedModel):
767
+ config_class = Ovis2_5_Config
768
+ base_model_prefix = "ovis2_5"
769
+
770
+
771
+ class Ovis2_5(OvisPreTrainedModel):
772
+ _supports_flash_attn_2 = True
773
+ _supports_flash_attn_3 = True
774
+
775
+ def __init__(self, config: Ovis2_5_Config, *inputs, **kwargs):
776
+ super().__init__(config, *inputs, **kwargs)
777
+ attn_implementation = get_attention_type_by_system()
778
+ print(f"Use {attn_implementation} for LLM!")
779
+ self.llm = AutoModelForCausalLM.from_config(
780
+ self.config.llm_config,
781
+ attn_implementation=attn_implementation,
782
+ )
783
+ assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
784
+ self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
785
+ self.visual_tokenizer = VisualTokenizer(vit=AutoModel.from_config(self.config.vit_config),
786
+ visual_vocab_size=self.config.visual_vocab_size,
787
+ image_processor_name_or_path=self.config.name_or_path)
788
+
789
+ self.vte = VisualEmbedding(self.config.visual_vocab_size, self.config.hidden_size,
790
+ device=self.visual_tokenizer.vit.device, dtype=self.visual_tokenizer.vit.dtype)
791
+ indicator_token_indices = torch.arange(
792
+ self.config.visual_vocab_size - len(INDICATOR_IDS),
793
+ self.config.visual_vocab_size,
794
+ dtype=torch.long
795
+ )
796
+ self.register_buffer("indicator_token_indices", indicator_token_indices, persistent=False)
797
+
798
+ def _merge_modules(modules_list: tuple):
799
+ merged_modules = []
800
+ for modules in modules_list:
801
+ merged_modules.extend(modules if modules else [])
802
+ return merged_modules
803
+
804
+ # Standard model configurations for parallelism and device placement
805
+ self._no_split_modules = _merge_modules(
806
+ (self.llm._no_split_modules, self.visual_tokenizer.vit._no_split_modules))
807
+ self._skip_keys_device_placement = self.llm._skip_keys_device_placement
808
+ self._keep_in_fp32_modules = _merge_modules(
809
+ (self.llm._keep_in_fp32_modules, self.visual_tokenizer.vit._keep_in_fp32_modules))
810
+ self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.vit.is_parallelizable))
811
+ # self.supports_gradient_checkpointing = True
812
+ self.supports_gradient_checkpointing = False
813
+
814
+ def tie_weights(self):
815
+ self.llm.tie_weights()
816
+
817
+ def get_wte(self):
818
+ return self.llm.get_input_embeddings()
819
+
820
+ def forward(
821
+ self,
822
+ input_ids: torch.Tensor,
823
+ attention_mask: torch.Tensor,
824
+ pixel_values: Optional[torch.Tensor],
825
+ grid_thws: Optional[torch.Tensor],
826
+ labels: Optional[torch.Tensor] = None,
827
+ **kwargs
828
+ ):
829
+ inputs_embeds = self.merge_multimodal(
830
+ input_ids=input_ids,
831
+ pixel_values=pixel_values,
832
+ grid_thws=grid_thws,
833
+ )
834
+ return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)
835
+
836
+ def merge_multimodal(
837
+ self,
838
+ input_ids: torch.Tensor,
839
+ pixel_values: Optional[torch.Tensor],
840
+ grid_thws: Optional[torch.Tensor],
841
+ ):
842
+ placeholder_token_mask = torch.lt(input_ids, 0)
843
+ multimodal_embeds = self.get_wte()(torch.masked_fill(input_ids, placeholder_token_mask, 0))
844
+
845
+ if pixel_values is not None:
846
+ visual_indicator_embeds = self.vte(self.indicator_token_indices).to(
847
+ dtype=multimodal_embeds.dtype, device=multimodal_embeds.device
848
+ )
849
+ visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
850
+ visual_embeds = self.vte(visual_tokens).to(dtype=multimodal_embeds.dtype, device=multimodal_embeds.device)
851
+
852
+ for i, indicator_id in enumerate(INDICATOR_IDS):
853
+ multimodal_embeds[input_ids == indicator_id] = visual_indicator_embeds[i]
854
+ multimodal_embeds[input_ids == VISUAL_ATOM_ID] = visual_embeds
855
+
856
+ return multimodal_embeds
857
+
858
+ def _merge_inputs(
859
+ self, raw_input_ids, placeholder_id, grid_thws, indicator_begin_id, indicator_end_id
860
+ ):
861
+ input_ids = []
862
+ prev_index = 0
863
+ placeholder_indexes = [i for i, v in enumerate(raw_input_ids) if v == placeholder_id]
864
+ for placeholder_index, grid_thw in zip(placeholder_indexes, grid_thws):
865
+ input_ids.extend(raw_input_ids[prev_index:placeholder_index])
866
+ num_image_atoms = grid_thw.prod().item()
867
+ num_image_atoms //= self.visual_tokenizer.vit.config.hidden_stride ** 2
868
+ num_image_atoms //= self.visual_tokenizer.vit.config.temporal_patch_size
869
+ input_ids.extend([indicator_begin_id] + [VISUAL_ATOM_ID] * num_image_atoms + [indicator_end_id])
870
+ prev_index = placeholder_index + 1
871
+ input_ids.extend(raw_input_ids[prev_index:])
872
+ return input_ids
873
+
874
+ def _tokenize_with_visual_placeholder(self, text):
875
+ placeholder = VIDEO_PLACEHOLDER if VIDEO_PLACEHOLDER in text else IMAGE_PLACEHOLDER
876
+ placeholder_id = VIDEO_PLACEHOLDER_ID if VIDEO_PLACEHOLDER in text else IMAGE_PLACEHOLDER_ID
877
+ chunks = [self.text_tokenizer(chunk, add_special_tokens=False).input_ids for chunk in text.split(placeholder)]
878
+ input_ids = chunks[0]
879
+ for chunk in chunks[1:]:
880
+ input_ids.append(placeholder_id)
881
+ input_ids.extend(chunk)
882
+ return input_ids
883
+
884
+ def preprocess_inputs(
885
+ self,
886
+ messages: List[Union[str, Dict]],
887
+ min_pixels=448 * 448,
888
+ max_pixels=1344 * 1792,
889
+ add_generation_prompt=True,
890
+ enable_thinking=False
891
+ ):
892
+ text = self.text_tokenizer.apply_chat_template(
893
+ messages,
894
+ tokenize=False,
895
+ add_generation_prompt=add_generation_prompt,
896
+ enable_thinking=enable_thinking
897
+ )
898
+ input_ids = self._tokenize_with_visual_placeholder(text)
899
+ images = []
900
+ videos = []
901
+ for message in messages:
902
+ content = message["content"]
903
+ if isinstance(content, list):
904
+ images.extend([item["image"] for item in content if item.get("image") is not None])
905
+ videos.extend([item["video"] for item in content if item.get("video") is not None])
906
+ if images and videos:
907
+ raise ValueError(
908
+ "Multiple visual input data types detected (both image and video provided). "
909
+ "This model supports only one type of visual input data at a time. "
910
+ "Please provide either image or video, but not both."
911
+ )
912
+
913
+ pixel_values, grid_thws = None, None
914
+ if images:
915
+ pixel_values, grid_thws = zip(
916
+ *(self.visual_tokenizer.preprocess(image=image, min_pixels=min_pixels, max_pixels=max_pixels)
917
+ for image in images)
918
+ )
919
+ input_ids = self._merge_inputs(
920
+ input_ids, IMAGE_PLACEHOLDER_ID, grid_thws, INDICATOR_IDS[0], INDICATOR_IDS[1]
921
+ )
922
+ pixel_values = torch.cat(pixel_values, dim=0)
923
+ grid_thws = torch.cat(grid_thws, dim=0)
924
+ elif videos:
925
+ assert len(videos) == 1, "only support single video"
926
+ pixel_values, grid_thws = self.visual_tokenizer.preprocess(
927
+ video=videos[0], min_pixels=min_pixels, max_pixels=max_pixels
928
+ )
929
+ input_ids = self._merge_inputs(
930
+ input_ids, VIDEO_PLACEHOLDER_ID, grid_thws, INDICATOR_IDS[2], INDICATOR_IDS[3]
931
+ )
932
+
933
+ input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
934
+
935
+ return input_ids, pixel_values, grid_thws
936
+
937
+ def generate(
938
+ self,
939
+ inputs: Optional[torch.Tensor] = None,
940
+ **kwargs,
941
+ ) -> Union[GenerateOutput, torch.LongTensor]:
942
+ attention_mask = torch.ne(inputs, self.text_tokenizer.pad_token_id).to(device=inputs.device)
943
+ inputs_embeds = self.merge_multimodal(
944
+ input_ids=inputs,
945
+ pixel_values=kwargs.pop('pixel_values', None),
946
+ grid_thws=kwargs.pop('grid_thws', None)
947
+ )
948
+ enable_thinking = kwargs.pop('enable_thinking', False)
949
+ enable_thinking_budget = kwargs.pop('enable_thinking_budget', False)
950
+ thinking_budget = kwargs.pop('thinking_budget', 1024)
951
+
952
+ if enable_thinking and enable_thinking_budget:
953
+ actual_max_new_tokens = kwargs['max_new_tokens']
954
+ kwargs['max_new_tokens'] = thinking_budget
955
+ generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
956
+ output_ids = generated_ids
957
+ output_ids_list = generated_ids[0]
958
+
959
+ # check if the generation has already finished (151645 is <|im_end|>)
960
+ if 151645 not in output_ids_list:
961
+ # check if the thinking process has finished (151668 is </think>)
962
+ # and prepare the second model input
963
+ if 151668 not in output_ids_list:
964
+ early_stopping_text = "\n\nConsidering the limited time by the user, I have to give the solution based on the thinking directly now.\n</think>\n\n"
965
+ early_stopping_ids = self.text_tokenizer(early_stopping_text, return_tensors="pt", return_attention_mask=False).input_ids.to(inputs.device)
966
+ input_ids_appendent = torch.cat([output_ids, early_stopping_ids], dim=-1)
967
+ kwargs['streamer'].put(early_stopping_ids) if 'streamer' in kwargs else None
968
+ else:
969
+ input_ids_appendent = output_ids
970
+
971
+
972
+ # second generation
973
+ new_inputs = torch.cat([inputs, input_ids_appendent], dim=-1)
974
+ attention_mask = torch.ne(new_inputs, self.text_tokenizer.pad_token_id).to(device=inputs.device)
975
+ inputs_embeds_appendent = self.merge_multimodal(
976
+ input_ids=input_ids_appendent,
977
+ pixel_values=None,
978
+ grid_thws=None
979
+ )
980
+ new_inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_appendent], dim=-2)
981
+
982
+ kwargs['max_new_tokens'] = inputs_embeds.size(-2) + actual_max_new_tokens - new_inputs_embeds.size(-2)
983
+ generated_ids2 = self.llm.generate(inputs=None, inputs_embeds=new_inputs_embeds, attention_mask=attention_mask, **kwargs)
984
+ kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
985
+ return torch.cat([input_ids_appendent, generated_ids2], dim=-1)
986
+
987
+ else:
988
+ kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
989
+ return generated_ids
990
+
991
+ else:
992
+ generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
993
+ kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
994
+ return generated_ids
995
+
996
+
997
+ AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
998
+ AutoModel.register(Siglip2NavitConfig, Siglip2NavitModel)
999
+ AutoConfig.register("ovis2_5", Ovis2_5_Config)
1000
+ AutoModelForCausalLM.register(Ovis2_5_Config, Ovis2_5)
ovis_image/model/tokenizer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ from typing import List
7
+
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+
11
+
12
+ class OvisTokenizer:
13
+ """
14
+ Tokenizing and encoding/decoding text using the Ovis tokenizer.
15
+
16
+ Args:
17
+ model_path (str): Path to the tokenzier from hugging face.
18
+
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model_path: str = "Ovis2.5-2B",
24
+ max_length: int = 256,
25
+ **hf_kwargs
26
+ ):
27
+ super().__init__()
28
+ self._tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+ self.system_prompt = "Describe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: "
30
+ self.user_prompt_begin_id = 28
31
+ self._max_length = max_length + self.user_prompt_begin_id
32
+
33
+ def encode(
34
+ self,
35
+ s: str,
36
+ system_prompt = ""
37
+ ) -> torch.Tensor:
38
+ """
39
+ Encode the prompt text into tokens.
40
+ """
41
+ if len(system_prompt) == 0:
42
+ system_prompt = self.system_prompt
43
+ messages = [{
44
+ "role": "user",
45
+ "content": system_prompt + s,
46
+ }]
47
+ text = self._tokenizer.apply_chat_template(
48
+ messages,
49
+ tokenize=False,
50
+ add_generation_prompt=True,
51
+ enable_thinking=False
52
+ )
53
+ tokens = self._tokenizer(
54
+ text,
55
+ padding="max_length",
56
+ truncation=True,
57
+ max_length=self._max_length,
58
+ return_tensors="pt",
59
+ add_special_tokens=False,
60
+ )
61
+ return tokens.input_ids, tokens.attention_mask
62
+
63
+ def decode(self, t: List[int]) -> str:
64
+ return self._tokenizer.decode(t, skip_special_tokens=False)
65
+
66
+
67
+ def build_ovis_tokenizer(tokenizer_path):
68
+ max_ovis_encoding_len = 256
69
+ ovis_tokenizer = OvisTokenizer(
70
+ tokenizer_path,
71
+ max_length=max_ovis_encoding_len,
72
+ )
73
+ return ovis_tokenizer
74
+
75
+
76
+ if __name__ == "__main__":
77
+ ovis_path = "/mnt/workspace/cv_multimodal/aigc/huggingface/Ovis2.5-2B"
78
+ text = "a cute cat"
79
+ ovis_tokenizer = OvisTokenizer(ovis_path, max_length=256)
80
+ ovis_token = ovis_tokenizer.encode(text)
81
+ import pdb
82
+ pdb.set_trace()
ovis_image/sampling.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+ import math
6
+ import os
7
+ from typing import Callable, Optional
8
+ from PIL import ExifTags, Image
9
+ import torch
10
+ from torch import Tensor
11
+ from einops import rearrange, repeat
12
+
13
+ from ovis_image.dataset.image_util import build_img_ids
14
+ from ovis_image.model.autoencoder import AutoEncoder
15
+ from ovis_image.model.hf_embedder import OvisEmbedder
16
+ from ovis_image.model.model import OvisImageModel
17
+ from ovis_image.utils import (
18
+ generate_noise_latent,
19
+ pack_latents,
20
+ unpack_latents,
21
+ generate_txt_ids,
22
+ )
23
+
24
+
25
+
26
+ def time_shift(mu: float, sigma: float, t: Tensor):
27
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
28
+
29
+
30
+ def get_lin_function(
31
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
32
+ ) -> Callable[[float], float]:
33
+ m = (y2 - y1) / (x2 - x1)
34
+ b = y1 - m * x1
35
+ return lambda x: m * x + b
36
+
37
+
38
+ def sample_timesteps(batch_size, image_seq_len=None, base_shift=None, max_shift=None):
39
+ if image_seq_len is None or base_shift is None or max_shift is None:
40
+ logit_mean = 0
41
+ else:
42
+ logit_mean = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
43
+ logit_std = 1.0
44
+ timesteps = torch.normal(
45
+ mean=logit_mean, std=logit_std, size=(batch_size,)
46
+ )
47
+ timesteps = torch.nn.functional.sigmoid(timesteps)
48
+ return timesteps
49
+
50
+
51
+ def get_schedule(
52
+ num_steps: int,
53
+ image_seq_len: int,
54
+ base_shift: float = 0.5,
55
+ max_shift: float = 1.15,
56
+ shift: bool = True,
57
+ ) -> list[float]:
58
+ # extra step for zero
59
+ timesteps = torch.linspace(1, 0, num_steps + 1)
60
+
61
+ # shifting the schedule to favor high timesteps for higher signal images
62
+ if shift:
63
+ # estimate mu based on linear estimation between two points
64
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
65
+ timesteps = time_shift(mu, 1.0, timesteps)
66
+
67
+ return timesteps.tolist()
68
+
69
+
70
+ def generate_image(
71
+ device: torch.device,
72
+ dtype: torch.dtype,
73
+ model: OvisImageModel,
74
+ prompt: str,
75
+ autoencoder: AutoEncoder,
76
+ ovis_tokenizer,
77
+ ovis_encoder: OvisEmbedder,
78
+ img_height: int = 256,
79
+ img_width: int = 256,
80
+ denoising_steps: int = 50,
81
+ cfg_scale: float = 5.0,
82
+ seed: int = 42,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Sampling and save a single images from noise using a given prompt.
86
+ For randomized noise generation, the random seend should already be set at the begining of training.
87
+ Since we will always use the local random seed on this rank, we don't need to pass in the seed again.
88
+ """
89
+
90
+ # allow for packing and conversion to latent space. Use the same resolution as training time.
91
+ img_height = 16 * (img_height // 16)
92
+ img_width = 16 * (img_width // 16)
93
+
94
+ enable_classifier_free_guidance = True
95
+
96
+ # Tokenize the prompt. Unsqueeze to add a batch dimension.
97
+ ovis_token_ids, ovis_token_mask = ovis_tokenizer.encode(prompt)
98
+ ovis_encodings = ovis_encoder(
99
+ ovis_token_ids.to(device=device), ovis_token_mask.to(device=device)
100
+ )
101
+
102
+ if enable_classifier_free_guidance:
103
+ empty_ovis_token_ids, empty_ovis_token_mask = ovis_tokenizer.encode("")
104
+ empty_ovis_encodings = ovis_encoder(
105
+ empty_ovis_token_ids.to(device=device), empty_ovis_token_mask.to(device=device)
106
+ )
107
+
108
+ latents = generate_noise_latent(
109
+ ovis_token_ids.shape[0],
110
+ img_height, img_width, device, dtype, seed=seed,
111
+ latent_channel=autoencoder.params.z_channels)
112
+
113
+ img = denoise(
114
+ device=device,
115
+ dtype=dtype,
116
+ model=model,
117
+ latents=latents,
118
+ denoising_steps=denoising_steps,
119
+ ovis_encodings=ovis_encodings,
120
+ enable_classifier_free_guidance=enable_classifier_free_guidance,
121
+ empty_ovis_encodings=(
122
+ empty_ovis_encodings if enable_classifier_free_guidance else None
123
+ ),
124
+ classifier_free_guidance_scale=cfg_scale,
125
+ )
126
+
127
+ img = autoencoder.decode(img)
128
+ return img
129
+
130
+
131
+ def denoise(
132
+ device: torch.device,
133
+ dtype: torch.dtype,
134
+ model: OvisImageModel,
135
+ latents: torch.Tensor,
136
+ denoising_steps: int,
137
+ ovis_encodings: torch.Tensor,
138
+ enable_classifier_free_guidance: bool = False,
139
+ empty_ovis_encodings: torch.Tensor | None = None,
140
+ classifier_free_guidance_scale: float | None = None,
141
+ ) -> torch.Tensor:
142
+ """
143
+ Sampling images from noise using a given prompt, by running inference with trained model.
144
+ Save the generated images to the given output path.
145
+ """
146
+ bsz = ovis_encodings.shape[0]
147
+ _, latent_channels, latent_height, latent_width = latents.shape
148
+
149
+ # create denoising schedule
150
+ timesteps = get_schedule(denoising_steps, latent_height * latent_width, shift=True)
151
+
152
+ # create positional encodings
153
+
154
+ latent_pos_enc = build_img_ids(
155
+ latent_height // 2, latent_width // 2,
156
+ ).to(latents)
157
+ latent_pos_enc = repeat(latent_pos_enc, 'l c -> bsz l c', bsz=bsz)
158
+ ovis_txt_ids = generate_txt_ids(ovis_encodings, time_id=0).to(latents)
159
+
160
+ if enable_classifier_free_guidance:
161
+ ovis_encodings = torch.cat([empty_ovis_encodings, ovis_encodings], dim=0)
162
+ latent_pos_enc = torch.cat([latent_pos_enc, latent_pos_enc], dim=0)
163
+ ovis_txt_ids = torch.cat([ovis_txt_ids, ovis_txt_ids], dim=0)
164
+
165
+ # convert img-like latents into sequences of patches
166
+ latents = pack_latents(latents)
167
+
168
+ # this is ignored for schnell
169
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
170
+ if enable_classifier_free_guidance:
171
+ img = torch.cat([latents, latents], dim=0)
172
+ t_vec = torch.full((bsz * 2,), t_curr, dtype=dtype, device=device)
173
+ else:
174
+ img = latents
175
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
176
+ model_pred = model(
177
+ img=img,
178
+ img_ids=latent_pos_enc,
179
+ txt=ovis_encodings,
180
+ txt_ids=ovis_txt_ids,
181
+ timesteps=t_vec,
182
+ )
183
+ if enable_classifier_free_guidance:
184
+ pred_u, pred_c = model_pred.chunk(2)
185
+ pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u)
186
+ else:
187
+ pred = model_pred
188
+
189
+ latents = latents + (t_prev - t_curr) * pred
190
+
191
+ # convert sequences of patches into img-like latents
192
+ latents = unpack_latents(latents, latent_height, latent_width)
193
+
194
+ return latents
195
+
196
+
197
+
198
+ def save_image(
199
+ name: str,
200
+ output_dir: str,
201
+ x: torch.Tensor,
202
+ add_sampling_metadata: bool,
203
+ prompt: str,
204
+ verbose = True,
205
+ ):
206
+ if verbose:
207
+ print(f"Saving image to {output_dir}/{name}")
208
+ os.makedirs(output_dir, exist_ok=True)
209
+ output_name = os.path.join(output_dir, name)
210
+
211
+ # bring into PIL format and save
212
+ x = x.clamp(-1, 1)
213
+ x = rearrange(x[0], "c h w -> h w c")
214
+
215
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
216
+
217
+ exif_data = Image.Exif()
218
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img"
219
+ exif_data[ExifTags.Base.Make] = "Ovis"
220
+ exif_data[ExifTags.Base.Model] = name
221
+ if add_sampling_metadata:
222
+ exif_data[ExifTags.Base.ImageDescription] = prompt
223
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
224
+
ovis_image/test.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from safetensors.torch import load_file
10
+
11
+ from ovis_image.model.tokenizer import build_ovis_tokenizer
12
+ from ovis_image.model.autoencoder import load_ae
13
+ from ovis_image.model.hf_embedder import OvisEmbedder
14
+ from ovis_image.model.model import OvisImageModel
15
+ from ovis_image.sampling import generate_image, save_image
16
+ from ovis_image import ovis_image_configs
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--model_path', type=str, required=True)
21
+ parser.add_argument('--ovis_path', type=str, default="")
22
+ parser.add_argument('--vae_path', type=str, default="")
23
+ parser.add_argument('--prompt', type=str, default="")
24
+ parser.add_argument('--image_size', type=int, default=1024)
25
+ parser.add_argument('--denoising_steps', type=int, default=50)
26
+ parser.add_argument('--cfg_scale', type=float, default=5.0)
27
+ args = parser.parse_args()
28
+ return args
29
+
30
+ def load_model_weight(model, model_path):
31
+ model_state_dict = load_file(model_path)
32
+ missing_keys, unexpected_keys = model.load_state_dict(model_state_dict)
33
+ print(f"Load Missing Keys {missing_keys}")
34
+ print(f"Load Unexpected Keys {unexpected_keys}")
35
+ return model
36
+
37
+
38
+ def main():
39
+ args = parse_args()
40
+ model_config = ovis_image_configs["ovis-image-7b"]
41
+ device = "cuda"
42
+ _dtype = torch.bfloat16
43
+ print(f"dtype: {_dtype}")
44
+ ovis_image = OvisImageModel(model_config)
45
+ ovis_image = load_model_weight(ovis_image, args.model_path)
46
+ ovis_image = ovis_image.to(device=device, dtype=_dtype)
47
+ ovis_image.eval()
48
+
49
+ ovis_tokenizer = build_ovis_tokenizer(args.ovis_path)
50
+ autoencoder = load_ae(
51
+ args.vae_path,
52
+ model_config.autoencoder_params,
53
+ device=device,
54
+ dtype=_dtype,
55
+ random_init=False,
56
+ )
57
+ autoencoder.eval()
58
+ ovis_encoder = OvisEmbedder(
59
+ model_path=args.ovis_path,
60
+ random_init=False,
61
+ low_cpu_mem_usage=True,
62
+ torch_dtype=torch.bfloat16,
63
+ ).to(device=device, dtype=_dtype)
64
+
65
+ with torch.no_grad():
66
+ image = generate_image(
67
+ device=device,
68
+ dtype=_dtype,
69
+ model=ovis_image,
70
+ prompt=args.prompt,
71
+ autoencoder=autoencoder,
72
+ ovis_tokenizer=ovis_tokenizer,
73
+ ovis_encoder=ovis_encoder,
74
+ img_height=args.image_size,
75
+ img_width=args.image_size,
76
+ denoising_steps=args.denoising_steps,
77
+ cfg_scale=args.cfg_scale,
78
+ seed=42,
79
+ )
80
+ image_name = f"ovis_image.png"
81
+ save_image(
82
+ name=image_name,
83
+ output_dir="outputs",
84
+ x=image,
85
+ add_sampling_metadata=True,
86
+ prompt=args.prompt,
87
+ verbose=False,
88
+ )
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()
ovis_image/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 AIDC-AI
2
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ def generate_txt_ids(encodings, time_id=0):
10
+ txt_ids = torch.zeros(encodings.shape[0], encodings.shape[1], 3)
11
+ txt_ids[..., 1] = txt_ids[..., 1] + torch.arange(encodings.shape[1])[None, :]
12
+ txt_ids[..., 2] = txt_ids[..., 2] + torch.arange(encodings.shape[1])[None, :]
13
+ txt_ids[..., 0] = time_id
14
+ return txt_ids
15
+
16
+
17
+ def generate_noise_latent(
18
+ bsz: int,
19
+ height: int,
20
+ width: int,
21
+ device: str | torch.device,
22
+ dtype: torch.dtype,
23
+ seed: int | None = None,
24
+ latent_channel = None,
25
+ ) -> Tensor:
26
+ """Generate noise latents for the flow model. The random seed will be set at the begining of training.
27
+
28
+ Args:
29
+ bsz (int): batch_size.
30
+ height (int): The height of the image.
31
+ width (int): The width of the image.
32
+ device (str | torch.device): The device to use.
33
+ dtype (torch.dtype): The dtype to use.
34
+
35
+ Returns:
36
+ Tensor: The noise latents.
37
+ Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO]
38
+
39
+ """
40
+ LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8
41
+ if latent_channel is not None:
42
+ LATENT_CHANNELS = latent_channel
43
+ return torch.randn(
44
+ bsz,
45
+ LATENT_CHANNELS,
46
+ height // IMAGE_LATENT_SIZE_RATIO,
47
+ width // IMAGE_LATENT_SIZE_RATIO,
48
+ dtype=dtype,
49
+ generator=torch.Generator().manual_seed(seed),
50
+ ).to(device)
51
+
52
+
53
+ def pack_latents(x: Tensor) -> Tensor:
54
+ """
55
+ Rearrange latents from an image-like format into a sequence of patches.
56
+ Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`.
57
+
58
+ Args:
59
+ x (Tensor): The unpacked latents.
60
+ Shape: [bsz, ch, latent height, latent width]
61
+
62
+ Returns:
63
+ Tensor: The packed latents.
64
+ Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
65
+ """
66
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
67
+
68
+ b, c, latent_height, latent_width = x.shape
69
+ h = latent_height // PATCH_HEIGHT
70
+ w = latent_width // PATCH_WIDTH
71
+
72
+ # [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw]
73
+ x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH)
74
+
75
+ # [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW]
76
+ x = x.permute(0, 2, 3, 1, 4, 5)
77
+
78
+ # [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW]
79
+ return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH)
80
+
81
+
82
+ def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor:
83
+ """
84
+ Rearrange latents from a sequence of patches into an image-like format.
85
+ Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`.
86
+
87
+ Args:
88
+ x (Tensor): The packed latents.
89
+ Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
90
+ latent_height (int): The height of the unpacked latents.
91
+ latent_width (int): The width of the unpacked latents.
92
+
93
+ Returns:
94
+ Tensor: The unpacked latents.
95
+ Shape: [bsz, ch, latent height, latent width]
96
+ """
97
+ PATCH_HEIGHT, PATCH_WIDTH = 2, 2
98
+
99
+ b, _, c_ph_pw = x.shape
100
+ h = latent_height // PATCH_HEIGHT
101
+ w = latent_width // PATCH_WIDTH
102
+ c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH)
103
+
104
+ # [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw]
105
+ x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH)
106
+
107
+ # [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw]
108
+ x = x.permute(0, 3, 1, 4, 2, 5)
109
+
110
+ # [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw]
111
+ return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH)
112
+