Upload 14 files
Browse files- ovis_image/__init__.py +44 -0
- ovis_image/dataset/image_util.py +125 -0
- ovis_image/model/args.py +29 -0
- ovis_image/model/autoencoder.py +402 -0
- ovis_image/model/hf_embedder.py +50 -0
- ovis_image/model/layers.py +407 -0
- ovis_image/model/model.py +121 -0
- ovis_image/model/ops.py +102 -0
- ovis_image/model/ovis/configuration_ovis2_5.py +96 -0
- ovis_image/model/ovis/modeling_ovis2_5.py +1000 -0
- ovis_image/model/tokenizer.py +82 -0
- ovis_image/sampling.py +224 -0
- ovis_image/test.py +92 -0
- ovis_image/utils.py +112 -0
ovis_image/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
from ovis_image.model.args import OvisImageModelArgs
|
| 6 |
+
from ovis_image.model.autoencoder import AutoEncoderParams
|
| 7 |
+
from ovis_image.model.model import OvisImageModel
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"OvisImageModelArgs",
|
| 11 |
+
"OvisImageModel",
|
| 12 |
+
"ovis_image_configs",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
ovis_image_configs = {
|
| 17 |
+
"ovis-image-7b": OvisImageModelArgs(
|
| 18 |
+
in_channels=64,
|
| 19 |
+
out_channels=64,
|
| 20 |
+
context_in_dim=2048,
|
| 21 |
+
hidden_size=3072,
|
| 22 |
+
mlp_ratio=4.0,
|
| 23 |
+
num_heads=24,
|
| 24 |
+
depth=6,
|
| 25 |
+
double_block_type="DoubleStreamBlock",
|
| 26 |
+
depth_single_blocks=27,
|
| 27 |
+
axes_dim=(16, 56, 56),
|
| 28 |
+
theta=10_000,
|
| 29 |
+
qkv_bias=True,
|
| 30 |
+
activation = "swiglu",
|
| 31 |
+
autoencoder_params=AutoEncoderParams(
|
| 32 |
+
resolution=256,
|
| 33 |
+
in_channels=3,
|
| 34 |
+
ch=128,
|
| 35 |
+
out_ch=3,
|
| 36 |
+
ch_mult=(1, 2, 4, 4),
|
| 37 |
+
num_res_blocks=2,
|
| 38 |
+
z_channels=16,
|
| 39 |
+
scale_factor=0.3611,
|
| 40 |
+
shift_factor=0.1159,
|
| 41 |
+
),
|
| 42 |
+
),
|
| 43 |
+
}
|
| 44 |
+
|
ovis_image/dataset/image_util.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def ceil_to(x, factor=16):
|
| 13 |
+
return math.ceil(float(x) / factor) * factor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_img_ids(
|
| 17 |
+
latent_height,
|
| 18 |
+
latent_width,
|
| 19 |
+
latent_crop_height = None,
|
| 20 |
+
latent_crop_width = None,
|
| 21 |
+
time = 0,
|
| 22 |
+
):
|
| 23 |
+
if latent_crop_height is None:
|
| 24 |
+
latent_crop_height = latent_height
|
| 25 |
+
if latent_crop_width is None:
|
| 26 |
+
latent_crop_width = latent_width
|
| 27 |
+
img_ids = torch.zeros(latent_height, latent_width, 3)
|
| 28 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(latent_height)[:, None]
|
| 29 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(latent_width)[None, :]
|
| 30 |
+
# crop
|
| 31 |
+
crop_h = (latent_height - latent_crop_height) // 2
|
| 32 |
+
crop_w = (latent_width - latent_crop_width) // 2
|
| 33 |
+
img_ids = img_ids[crop_h:crop_h+latent_crop_height, crop_w:crop_w+latent_crop_width]
|
| 34 |
+
img_ids[..., 0] = time
|
| 35 |
+
h, w, c = img_ids.shape
|
| 36 |
+
img_ids = img_ids.reshape(h * w, c)
|
| 37 |
+
return img_ids
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def process_pil_img_to_tensor(
|
| 41 |
+
pil_img,
|
| 42 |
+
output_size: int | None = 256,
|
| 43 |
+
output_width: int | None = None,
|
| 44 |
+
output_height: int | None = None,
|
| 45 |
+
with_position_ids: bool = False,
|
| 46 |
+
position_ids_time: int = 0,
|
| 47 |
+
):
|
| 48 |
+
width, height = pil_img.size
|
| 49 |
+
if output_width is None or output_height is None:
|
| 50 |
+
output_width = output_size
|
| 51 |
+
output_height = output_size
|
| 52 |
+
assert output_height % 16 == 0
|
| 53 |
+
assert output_width % 16 == 0
|
| 54 |
+
resize_ratio = max(
|
| 55 |
+
float(output_width)/width,
|
| 56 |
+
float(output_height)/height
|
| 57 |
+
)
|
| 58 |
+
resize_size = (
|
| 59 |
+
ceil_to(resize_ratio * height, 16),
|
| 60 |
+
ceil_to(resize_ratio * width, 16)
|
| 61 |
+
)
|
| 62 |
+
pil_resize_img = torchvision.transforms.functional.resize(
|
| 63 |
+
pil_img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC
|
| 64 |
+
)
|
| 65 |
+
pil_crop_img = torchvision.transforms.functional.center_crop(
|
| 66 |
+
pil_resize_img, (output_height, output_width)
|
| 67 |
+
)
|
| 68 |
+
image_tensor = torchvision.transforms.functional.to_tensor(pil_crop_img)
|
| 69 |
+
image_tensor = torchvision.transforms.functional.normalize(
|
| 70 |
+
image_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
|
| 71 |
+
)
|
| 72 |
+
if with_position_ids:
|
| 73 |
+
img_ids = build_img_ids(
|
| 74 |
+
latent_height = resize_size[0] // 16,
|
| 75 |
+
latent_width = resize_size[1] // 16,
|
| 76 |
+
latent_crop_height = output_height // 16,
|
| 77 |
+
latent_crop_width = output_width // 16,
|
| 78 |
+
time = position_ids_time,
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
img_ids = None
|
| 82 |
+
return pil_crop_img, image_tensor, img_ids
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def pack_latent_to_token(
|
| 86 |
+
latent,
|
| 87 |
+
):
|
| 88 |
+
token = rearrange(
|
| 89 |
+
latent,
|
| 90 |
+
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
|
| 91 |
+
ph=2,
|
| 92 |
+
pw=2
|
| 93 |
+
)
|
| 94 |
+
return token
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def unpack_token_to_latent(
|
| 98 |
+
token,
|
| 99 |
+
image_height: int | None = None,
|
| 100 |
+
latent_height: int | None = None,
|
| 101 |
+
image_width: int | None = None,
|
| 102 |
+
latent_width: int | None = None,
|
| 103 |
+
):
|
| 104 |
+
if image_height is not None:
|
| 105 |
+
h = math.ceil(image_height / 16)
|
| 106 |
+
elif latent_height is not None:
|
| 107 |
+
h = latent_height // 2
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError(f"both {image_height} and {latent_height} are None")
|
| 110 |
+
if image_width is not None:
|
| 111 |
+
w = math.ceil(image_width / 16)
|
| 112 |
+
elif latent_width is not None:
|
| 113 |
+
w = latent_width // 2
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"both {image_width} and {latent_width} are None")
|
| 116 |
+
return rearrange(
|
| 117 |
+
token,
|
| 118 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
| 119 |
+
h=h,
|
| 120 |
+
w=w,
|
| 121 |
+
ph=2,
|
| 122 |
+
pw=2,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
ovis_image/model/args.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
|
| 8 |
+
from ovis_image.model.autoencoder import AutoEncoderParams
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class OvisImageModelArgs:
|
| 13 |
+
in_channels: int = 64
|
| 14 |
+
out_channels: int = 64
|
| 15 |
+
context_in_dim: int = 512
|
| 16 |
+
hidden_size: int = 3072
|
| 17 |
+
mlp_ratio: float = 4.0
|
| 18 |
+
num_heads: int = 24
|
| 19 |
+
depth: int = 19
|
| 20 |
+
double_block_type: str = "DoubleStreamBlock"
|
| 21 |
+
depth_single_blocks: int = 38
|
| 22 |
+
axes_dim: tuple = (16, 56, 56)
|
| 23 |
+
theta: int = 10_000
|
| 24 |
+
qkv_bias: bool = True
|
| 25 |
+
activation: str = "gelu_tanh"
|
| 26 |
+
"""activation: gelu_tanh or swiglu"""
|
| 27 |
+
norm: str = "layernorm"
|
| 28 |
+
"""norm: layernorm or rmsnorm"""
|
| 29 |
+
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
|
ovis_image/model/autoencoder.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from safetensors.torch import load_file as load_sft
|
| 12 |
+
from torch import nn, Tensor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class AutoEncoderParams:
|
| 17 |
+
resolution: int = 256
|
| 18 |
+
in_channels: int = 3
|
| 19 |
+
ch: int = 128
|
| 20 |
+
out_ch: int = 3
|
| 21 |
+
ch_mult: tuple[int] = (1, 2, 4, 4)
|
| 22 |
+
num_res_blocks: int = 2
|
| 23 |
+
z_channels: int = 16
|
| 24 |
+
scale_factor: float = 0.3611
|
| 25 |
+
shift_factor: float = 0.1159
|
| 26 |
+
use_quant_conv: bool = False
|
| 27 |
+
use_post_quant_conv: bool = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def swish(x: Tensor) -> Tensor:
|
| 31 |
+
return x * torch.sigmoid(x)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AttnBlock(nn.Module):
|
| 35 |
+
def __init__(self, in_channels: int):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.in_channels = in_channels
|
| 38 |
+
|
| 39 |
+
self.norm = nn.GroupNorm(
|
| 40 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 44 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 45 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 46 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 47 |
+
|
| 48 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 49 |
+
h_ = self.norm(h_)
|
| 50 |
+
q = self.q(h_)
|
| 51 |
+
k = self.k(h_)
|
| 52 |
+
v = self.v(h_)
|
| 53 |
+
|
| 54 |
+
b, c, h, w = q.shape
|
| 55 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 56 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 57 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 58 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 59 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 62 |
+
return x + self.proj_out(self.attention(x))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ResnetBlock(nn.Module):
|
| 66 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.in_channels = in_channels
|
| 69 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 70 |
+
self.out_channels = out_channels
|
| 71 |
+
|
| 72 |
+
self.norm1 = nn.GroupNorm(
|
| 73 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 74 |
+
)
|
| 75 |
+
self.conv1 = nn.Conv2d(
|
| 76 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 77 |
+
)
|
| 78 |
+
self.norm2 = nn.GroupNorm(
|
| 79 |
+
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
| 80 |
+
)
|
| 81 |
+
self.conv2 = nn.Conv2d(
|
| 82 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 83 |
+
)
|
| 84 |
+
if self.in_channels != self.out_channels:
|
| 85 |
+
self.nin_shortcut = nn.Conv2d(
|
| 86 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
h = x
|
| 91 |
+
h = self.norm1(h)
|
| 92 |
+
h = swish(h)
|
| 93 |
+
h = self.conv1(h)
|
| 94 |
+
|
| 95 |
+
h = self.norm2(h)
|
| 96 |
+
h = swish(h)
|
| 97 |
+
h = self.conv2(h)
|
| 98 |
+
|
| 99 |
+
if self.in_channels != self.out_channels:
|
| 100 |
+
x = self.nin_shortcut(x)
|
| 101 |
+
|
| 102 |
+
return x + h
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Downsample(nn.Module):
|
| 106 |
+
def __init__(self, in_channels: int):
|
| 107 |
+
super().__init__()
|
| 108 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 109 |
+
self.conv = nn.Conv2d(
|
| 110 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x: Tensor):
|
| 114 |
+
pad = (0, 1, 0, 1)
|
| 115 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 116 |
+
x = self.conv(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Upsample(nn.Module):
|
| 121 |
+
def __init__(self, in_channels: int):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.conv = nn.Conv2d(
|
| 124 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def forward(self, x: Tensor):
|
| 128 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 129 |
+
x = self.conv(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Encoder(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
resolution: int,
|
| 137 |
+
in_channels: int,
|
| 138 |
+
ch: int,
|
| 139 |
+
ch_mult: list[int],
|
| 140 |
+
num_res_blocks: int,
|
| 141 |
+
z_channels: int,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.ch = ch
|
| 145 |
+
self.num_resolutions = len(ch_mult)
|
| 146 |
+
self.num_res_blocks = num_res_blocks
|
| 147 |
+
self.resolution = resolution
|
| 148 |
+
self.in_channels = in_channels
|
| 149 |
+
# downsampling
|
| 150 |
+
self.conv_in = nn.Conv2d(
|
| 151 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
curr_res = resolution
|
| 155 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 156 |
+
self.in_ch_mult = in_ch_mult
|
| 157 |
+
self.down = nn.ModuleList()
|
| 158 |
+
block_in = self.ch
|
| 159 |
+
for i_level in range(self.num_resolutions):
|
| 160 |
+
block = nn.ModuleList()
|
| 161 |
+
attn = nn.ModuleList()
|
| 162 |
+
block_in = ch * in_ch_mult[i_level]
|
| 163 |
+
block_out = ch * ch_mult[i_level]
|
| 164 |
+
for _ in range(self.num_res_blocks):
|
| 165 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 166 |
+
block_in = block_out
|
| 167 |
+
down = nn.Module()
|
| 168 |
+
down.block = block
|
| 169 |
+
down.attn = attn
|
| 170 |
+
if i_level != self.num_resolutions - 1:
|
| 171 |
+
down.downsample = Downsample(block_in)
|
| 172 |
+
curr_res = curr_res // 2
|
| 173 |
+
self.down.append(down)
|
| 174 |
+
|
| 175 |
+
# middle
|
| 176 |
+
self.mid = nn.Module()
|
| 177 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 178 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 179 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 180 |
+
|
| 181 |
+
# end
|
| 182 |
+
self.norm_out = nn.GroupNorm(
|
| 183 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 184 |
+
)
|
| 185 |
+
self.conv_out = nn.Conv2d(
|
| 186 |
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 190 |
+
# downsampling
|
| 191 |
+
hs = [self.conv_in(x)]
|
| 192 |
+
for i_level in range(self.num_resolutions):
|
| 193 |
+
for i_block in range(self.num_res_blocks):
|
| 194 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 195 |
+
if len(self.down[i_level].attn) > 0:
|
| 196 |
+
h = self.down[i_level].attn[i_block](h)
|
| 197 |
+
hs.append(h)
|
| 198 |
+
if i_level != self.num_resolutions - 1:
|
| 199 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 200 |
+
|
| 201 |
+
# middle
|
| 202 |
+
h = hs[-1]
|
| 203 |
+
h = self.mid.block_1(h)
|
| 204 |
+
h = self.mid.attn_1(h)
|
| 205 |
+
h = self.mid.block_2(h)
|
| 206 |
+
# end
|
| 207 |
+
h = self.norm_out(h)
|
| 208 |
+
h = swish(h)
|
| 209 |
+
h = self.conv_out(h)
|
| 210 |
+
return h
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Decoder(nn.Module):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
ch: int,
|
| 217 |
+
out_ch: int,
|
| 218 |
+
ch_mult: list[int],
|
| 219 |
+
num_res_blocks: int,
|
| 220 |
+
in_channels: int,
|
| 221 |
+
resolution: int,
|
| 222 |
+
z_channels: int,
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.ch = ch
|
| 226 |
+
self.num_resolutions = len(ch_mult)
|
| 227 |
+
self.num_res_blocks = num_res_blocks
|
| 228 |
+
self.resolution = resolution
|
| 229 |
+
self.in_channels = in_channels
|
| 230 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 231 |
+
|
| 232 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 233 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 234 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 235 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 236 |
+
|
| 237 |
+
# z to block_in
|
| 238 |
+
self.conv_in = nn.Conv2d(
|
| 239 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# middle
|
| 243 |
+
self.mid = nn.Module()
|
| 244 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 245 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 246 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 247 |
+
|
| 248 |
+
# upsampling
|
| 249 |
+
self.up = nn.ModuleList()
|
| 250 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 251 |
+
block = nn.ModuleList()
|
| 252 |
+
attn = nn.ModuleList()
|
| 253 |
+
block_out = ch * ch_mult[i_level]
|
| 254 |
+
for _ in range(self.num_res_blocks + 1):
|
| 255 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 256 |
+
block_in = block_out
|
| 257 |
+
up = nn.Module()
|
| 258 |
+
up.block = block
|
| 259 |
+
up.attn = attn
|
| 260 |
+
if i_level != 0:
|
| 261 |
+
up.upsample = Upsample(block_in)
|
| 262 |
+
curr_res = curr_res * 2
|
| 263 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 264 |
+
|
| 265 |
+
# end
|
| 266 |
+
self.norm_out = nn.GroupNorm(
|
| 267 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 268 |
+
)
|
| 269 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 270 |
+
|
| 271 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 272 |
+
# get dtype for proper tracing
|
| 273 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
| 274 |
+
|
| 275 |
+
# z to block_in
|
| 276 |
+
h = self.conv_in(z)
|
| 277 |
+
|
| 278 |
+
# middle
|
| 279 |
+
h = self.mid.block_1(h)
|
| 280 |
+
h = self.mid.attn_1(h)
|
| 281 |
+
h = self.mid.block_2(h)
|
| 282 |
+
|
| 283 |
+
# cast to proper dtype
|
| 284 |
+
h = h.to(upscale_dtype)
|
| 285 |
+
# upsampling
|
| 286 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 287 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 288 |
+
h = self.up[i_level].block[i_block](h)
|
| 289 |
+
if len(self.up[i_level].attn) > 0:
|
| 290 |
+
h = self.up[i_level].attn[i_block](h)
|
| 291 |
+
if i_level != 0:
|
| 292 |
+
h = self.up[i_level].upsample(h)
|
| 293 |
+
|
| 294 |
+
# end
|
| 295 |
+
h = self.norm_out(h)
|
| 296 |
+
h = swish(h)
|
| 297 |
+
h = self.conv_out(h)
|
| 298 |
+
return h
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class DiagonalGaussian(nn.Module):
|
| 302 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.sample = sample
|
| 305 |
+
self.chunk_dim = chunk_dim
|
| 306 |
+
|
| 307 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 308 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 309 |
+
if self.sample:
|
| 310 |
+
std = torch.exp(0.5 * logvar)
|
| 311 |
+
return mean + std * torch.randn_like(mean)
|
| 312 |
+
else:
|
| 313 |
+
return mean
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class AutoEncoder(nn.Module):
|
| 317 |
+
def __init__(self, params: AutoEncoderParams):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.params = params
|
| 320 |
+
self.encoder = Encoder(
|
| 321 |
+
resolution=params.resolution,
|
| 322 |
+
in_channels=params.in_channels,
|
| 323 |
+
ch=params.ch,
|
| 324 |
+
ch_mult=params.ch_mult,
|
| 325 |
+
num_res_blocks=params.num_res_blocks,
|
| 326 |
+
z_channels=params.z_channels,
|
| 327 |
+
)
|
| 328 |
+
self.decoder = Decoder(
|
| 329 |
+
resolution=params.resolution,
|
| 330 |
+
in_channels=params.in_channels,
|
| 331 |
+
ch=params.ch,
|
| 332 |
+
out_ch=params.out_ch,
|
| 333 |
+
ch_mult=params.ch_mult,
|
| 334 |
+
num_res_blocks=params.num_res_blocks,
|
| 335 |
+
z_channels=params.z_channels,
|
| 336 |
+
)
|
| 337 |
+
self.reg = DiagonalGaussian()
|
| 338 |
+
|
| 339 |
+
self.scale_factor = params.scale_factor
|
| 340 |
+
self.shift_factor = params.shift_factor
|
| 341 |
+
|
| 342 |
+
self.quant_conv = nn.Conv2d(2 * params.z_channels, 2 * params.z_channels, 1) if params.use_quant_conv else None
|
| 343 |
+
self.post_quant_conv = nn.Conv2d(params.z_channels, params.z_channels, 1) if params.use_post_quant_conv else None
|
| 344 |
+
|
| 345 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 346 |
+
x = self.encoder(x)
|
| 347 |
+
if self.quant_conv is not None:
|
| 348 |
+
x = self.quant_conv(x)
|
| 349 |
+
z = self.reg(x)
|
| 350 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 351 |
+
return z
|
| 352 |
+
|
| 353 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 354 |
+
z = z / self.scale_factor + self.shift_factor
|
| 355 |
+
if self.post_quant_conv is not None:
|
| 356 |
+
z = self.post_quant_conv(z)
|
| 357 |
+
return self.decoder(z)
|
| 358 |
+
|
| 359 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 360 |
+
return self.decode(self.encode(x))
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def load_ae(
|
| 364 |
+
ckpt_path: str,
|
| 365 |
+
autoencoder_params: AutoEncoderParams,
|
| 366 |
+
device: str | torch.device = "cuda",
|
| 367 |
+
dtype=torch.bfloat16,
|
| 368 |
+
random_init=False,
|
| 369 |
+
) -> AutoEncoder:
|
| 370 |
+
"""
|
| 371 |
+
Load the autoencoder from the given model name.
|
| 372 |
+
Args:
|
| 373 |
+
name (str): The name of the autoencoder.
|
| 374 |
+
device (str or torch.device): The device to load the autoencoder to.
|
| 375 |
+
Returns:
|
| 376 |
+
AutoEncoder: The loaded autoencoder.
|
| 377 |
+
"""
|
| 378 |
+
# Loading the autoencoder
|
| 379 |
+
with torch.device(device):
|
| 380 |
+
ae = AutoEncoder(autoencoder_params)
|
| 381 |
+
|
| 382 |
+
if random_init:
|
| 383 |
+
print(f"Random Init VAE")
|
| 384 |
+
return ae.to(dtype=dtype)
|
| 385 |
+
|
| 386 |
+
if not os.path.exists(ckpt_path):
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"Autoencoder path {ckpt_path} does not exist. Please download it first."
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
if ckpt_path is not None:
|
| 392 |
+
print(f"Loading {ckpt_path}")
|
| 393 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 394 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
| 395 |
+
if len(missing) > 0:
|
| 396 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 397 |
+
if len(unexpected) > 0:
|
| 398 |
+
print(
|
| 399 |
+
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
|
| 400 |
+
)
|
| 401 |
+
return ae.to(dtype=dtype)
|
| 402 |
+
|
ovis_image/model/hf_embedder.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
|
| 9 |
+
from ovis_image.model.ovis.modeling_ovis2_5 import Ovis2_5, Ovis2_5_Config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OvisEmbedder(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
model_path: str,
|
| 16 |
+
random_init=False,
|
| 17 |
+
**hf_kwargs
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
if random_init:
|
| 21 |
+
# Initialize Ovis model with random weights for test purpose only
|
| 22 |
+
config = Ovis2_5_Config.from_pretrained(model_path)
|
| 23 |
+
config.name_or_path = model_path
|
| 24 |
+
self.hf_module = Ovis2_5._from_config(config, **hf_kwargs)
|
| 25 |
+
else:
|
| 26 |
+
self.hf_module = Ovis2_5.from_pretrained(
|
| 27 |
+
model_path, **hf_kwargs
|
| 28 |
+
)
|
| 29 |
+
self.pad_token_id = self.hf_module.text_tokenizer.pad_token_id
|
| 30 |
+
self.user_prompt_begin_id = 28
|
| 31 |
+
# get Qwen3
|
| 32 |
+
self.hf_module = self.hf_module.llm.model
|
| 33 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def forward(self, batch_tokens: Tensor, attention_mask = None) -> Tensor:
|
| 37 |
+
if attention_mask is None:
|
| 38 |
+
attention_mask = torch.ne(
|
| 39 |
+
batch_tokens, self.pad_token_id
|
| 40 |
+
).to(device=batch_tokens.device)
|
| 41 |
+
outputs = self.hf_module(
|
| 42 |
+
input_ids=batch_tokens,
|
| 43 |
+
attention_mask=attention_mask,
|
| 44 |
+
)
|
| 45 |
+
txt_semantic_embed = outputs.last_hidden_state
|
| 46 |
+
txt_semantic_embed = txt_semantic_embed * attention_mask[..., None]
|
| 47 |
+
txt_semantic_embed = txt_semantic_embed[:, self.user_prompt_begin_id:, :]
|
| 48 |
+
return txt_semantic_embed
|
| 49 |
+
|
| 50 |
+
|
ovis_image/model/layers.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
|
| 13 |
+
from ovis_image.model.ops import attention, rope
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class EmbedND(nn.Module):
|
| 17 |
+
|
| 18 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.dim = dim
|
| 21 |
+
self.theta = theta
|
| 22 |
+
self.axes_dim = axes_dim
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 26 |
+
n_axes = ids.shape[-1]
|
| 27 |
+
emb = torch.cat(
|
| 28 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 29 |
+
dim=-3,
|
| 30 |
+
)
|
| 31 |
+
# bs x 1 x 512 x 64 x 2 x 2
|
| 32 |
+
return emb.unsqueeze(1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 36 |
+
"""
|
| 37 |
+
Create sinusoidal timestep embeddings.
|
| 38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 39 |
+
These may be fractional.
|
| 40 |
+
:param dim: the dimension of the output.
|
| 41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 43 |
+
"""
|
| 44 |
+
t = time_factor * t
|
| 45 |
+
half = dim // 2
|
| 46 |
+
with torch.device(t.device):
|
| 47 |
+
freqs = torch.exp(
|
| 48 |
+
-math.log(max_period)
|
| 49 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 50 |
+
/ half
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
args = t[:, None].float() * freqs[None]
|
| 54 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 55 |
+
if dim % 2:
|
| 56 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 57 |
+
if torch.is_floating_point(t):
|
| 58 |
+
embedding = embedding.to(t)
|
| 59 |
+
return embedding
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class MLPEmbedder(nn.Module):
|
| 63 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 66 |
+
self.silu = nn.SiLU()
|
| 67 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 68 |
+
|
| 69 |
+
def init_weights(self, init_std: float = 0.02):
|
| 70 |
+
nn.init.normal_(self.in_layer.weight, std=init_std)
|
| 71 |
+
nn.init.constant_(self.in_layer.bias, 0)
|
| 72 |
+
nn.init.normal_(self.out_layer.weight, std=init_std)
|
| 73 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 74 |
+
|
| 75 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 76 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class QKNorm(torch.nn.Module):
|
| 80 |
+
def __init__(self, dim: int):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.query_norm = nn.RMSNorm(dim)
|
| 83 |
+
self.key_norm = nn.RMSNorm(dim)
|
| 84 |
+
|
| 85 |
+
def init_weights(self):
|
| 86 |
+
self.query_norm.reset_parameters()
|
| 87 |
+
self.key_norm.reset_parameters()
|
| 88 |
+
|
| 89 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 90 |
+
q = self.query_norm(q)
|
| 91 |
+
k = self.key_norm(k)
|
| 92 |
+
return q.to(v), k.to(v)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class SelfAttention(nn.Module):
|
| 96 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
head_dim = dim // num_heads
|
| 100 |
+
|
| 101 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 102 |
+
self.norm = QKNorm(head_dim)
|
| 103 |
+
self.proj = nn.Linear(dim, dim)
|
| 104 |
+
|
| 105 |
+
def init_weights(self):
|
| 106 |
+
for layer in (self.qkv, self.proj):
|
| 107 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 108 |
+
if layer.bias is not None:
|
| 109 |
+
nn.init.constant_(layer.bias, 0)
|
| 110 |
+
self.norm.init_weights()
|
| 111 |
+
|
| 112 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 113 |
+
qkv = self.qkv(x)
|
| 114 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 115 |
+
q, k = self.norm(q, k, v)
|
| 116 |
+
x = attention(q, k, v, pe=pe)
|
| 117 |
+
x = self.proj(x)
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class YakMLP(nn.Module):
|
| 122 |
+
# Use SwiGLU
|
| 123 |
+
def __init__(self, hidden_size: int, intermediate_size: int):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.hidden_size = hidden_size
|
| 126 |
+
self.intermediate_size = intermediate_size
|
| 127 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
|
| 128 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
|
| 129 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
|
| 130 |
+
self.act_fn = nn.SiLU()
|
| 131 |
+
|
| 132 |
+
def init_weights(self):
|
| 133 |
+
for layer in (self.gate_proj, self.up_proj, self.down_proj):
|
| 134 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 135 |
+
nn.init.constant_(layer.bias, 0)
|
| 136 |
+
|
| 137 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 138 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 139 |
+
return down_proj
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def build_mlp(hidden_size, intermediate_size, activation = "gelu_tanh"):
|
| 143 |
+
if activation == "gelu_tanh":
|
| 144 |
+
mlp = nn.Sequential(
|
| 145 |
+
nn.Linear(hidden_size, intermediate_size, bias=True),
|
| 146 |
+
nn.GELU(approximate="tanh"),
|
| 147 |
+
nn.Linear(intermediate_size, hidden_size, bias=True),
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
mlp = YakMLP(hidden_size, intermediate_size)
|
| 151 |
+
return mlp
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def init_mlp(mlp, activation = "gelu_tanh"):
|
| 155 |
+
if activation == "gelu_tanh":
|
| 156 |
+
for layer in (mlp[0], mlp[2]):
|
| 157 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 158 |
+
nn.init.constant_(layer.bias, 0)
|
| 159 |
+
else:
|
| 160 |
+
mlp.init_weights()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclass
|
| 164 |
+
class ModulationOut:
|
| 165 |
+
shift: Tensor
|
| 166 |
+
scale: Tensor
|
| 167 |
+
gate: Tensor
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class Modulation(nn.Module):
|
| 171 |
+
def __init__(self, dim: int, multiples: int = 1):
|
| 172 |
+
super().__init__()
|
| 173 |
+
assert multiples in [1, 2, 3]
|
| 174 |
+
self.multiples = multiples
|
| 175 |
+
self.multiplier = 3 * multiples
|
| 176 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 177 |
+
self.act = nn.SiLU()
|
| 178 |
+
|
| 179 |
+
def init_weights(self):
|
| 180 |
+
nn.init.constant_(self.lin.weight, 0)
|
| 181 |
+
nn.init.constant_(self.lin.bias, 0)
|
| 182 |
+
|
| 183 |
+
def forward(self, vec: Tensor):
|
| 184 |
+
out = self.lin(self.act(vec))[:, None, :].chunk(
|
| 185 |
+
self.multiplier, dim=-1
|
| 186 |
+
)
|
| 187 |
+
if self.multiples == 1:
|
| 188 |
+
return ModulationOut(*out[:3])
|
| 189 |
+
elif self.multiples == 2:
|
| 190 |
+
return (
|
| 191 |
+
ModulationOut(*out[:3]),
|
| 192 |
+
ModulationOut(*out[3:]),
|
| 193 |
+
)
|
| 194 |
+
elif self.multiples == 3:
|
| 195 |
+
return (
|
| 196 |
+
ModulationOut(*out[:3]),
|
| 197 |
+
ModulationOut(*out[3:6]),
|
| 198 |
+
ModulationOut(*out[6:]),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class DoubleStreamBlock(nn.Module):
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
hidden_size: int,
|
| 206 |
+
num_heads: int,
|
| 207 |
+
mlp_ratio: float,
|
| 208 |
+
qkv_bias: bool = False,
|
| 209 |
+
activation: str = "gelu_tanh",
|
| 210 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 211 |
+
):
|
| 212 |
+
super().__init__()
|
| 213 |
+
|
| 214 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 215 |
+
self.num_heads = num_heads
|
| 216 |
+
self.hidden_size = hidden_size
|
| 217 |
+
self.activation = activation
|
| 218 |
+
self.img_mod = Modulation(hidden_size, multiples=2)
|
| 219 |
+
self.img_norm1 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 220 |
+
self.img_attn = SelfAttention(
|
| 221 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.img_norm2 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 225 |
+
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, activation)
|
| 226 |
+
|
| 227 |
+
self.txt_mod = Modulation(hidden_size, multiples=2)
|
| 228 |
+
self.txt_norm1 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 229 |
+
self.txt_attn = SelfAttention(
|
| 230 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.txt_norm2 = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 234 |
+
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, activation)
|
| 235 |
+
|
| 236 |
+
def init_weights(self):
|
| 237 |
+
# initialize all the nn.Linear submodules
|
| 238 |
+
init_mlp(self.img_mlp, self.activation)
|
| 239 |
+
init_mlp(self.txt_mlp, self.activation)
|
| 240 |
+
|
| 241 |
+
# initialize Modulation layers, SelfAttention layers
|
| 242 |
+
for layer in (self.img_attn, self.img_mod, self.txt_attn, self.txt_mod):
|
| 243 |
+
layer.init_weights()
|
| 244 |
+
|
| 245 |
+
# Reset parameters for Normalization layers
|
| 246 |
+
for norm in (self.txt_norm1, self.txt_norm2, self.img_norm1, self.img_norm2):
|
| 247 |
+
norm.reset_parameters()
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
| 251 |
+
) -> tuple[Tensor, Tensor]:
|
| 252 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
| 253 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
| 254 |
+
|
| 255 |
+
# prepare image for attention
|
| 256 |
+
img_modulated = self.img_norm1(img)
|
| 257 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 258 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 259 |
+
img_q, img_k, img_v = rearrange(
|
| 260 |
+
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 261 |
+
)
|
| 262 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 263 |
+
|
| 264 |
+
# prepare txt for attention
|
| 265 |
+
txt_modulated = self.txt_norm1(txt)
|
| 266 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 267 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 268 |
+
txt_q, txt_k, txt_v = rearrange(
|
| 269 |
+
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 270 |
+
)
|
| 271 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 272 |
+
|
| 273 |
+
# run actual attention
|
| 274 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 275 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 276 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 277 |
+
|
| 278 |
+
attn = attention(q, k, v, pe=pe)
|
| 279 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
| 280 |
+
|
| 281 |
+
# calculate the img bloks
|
| 282 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
| 283 |
+
img = img + img_mod2.gate * self.img_mlp(
|
| 284 |
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# calculate the txt bloks
|
| 288 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
| 289 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
| 290 |
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
| 291 |
+
)
|
| 292 |
+
return img, txt
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class SingleStreamBlock(nn.Module):
|
| 296 |
+
"""
|
| 297 |
+
A DiT block with parallel linear layers as described in
|
| 298 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
hidden_size: int,
|
| 304 |
+
num_heads: int,
|
| 305 |
+
mlp_ratio: float = 4.0,
|
| 306 |
+
qkv_bias: bool = False,
|
| 307 |
+
qk_scale: float | None = None,
|
| 308 |
+
activation: str = "gelu_tanh",
|
| 309 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 310 |
+
):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.hidden_dim = hidden_size
|
| 313 |
+
self.num_heads = num_heads
|
| 314 |
+
head_dim = hidden_size // num_heads
|
| 315 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 316 |
+
self.activation = activation
|
| 317 |
+
|
| 318 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 319 |
+
if activation == "gelu_tanh":
|
| 320 |
+
# qkv and mlp_in
|
| 321 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, bias=qkv_bias)
|
| 322 |
+
else:
|
| 323 |
+
# qkv and mlp_in and mlp_gate
|
| 324 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim * 2, bias=qkv_bias)
|
| 325 |
+
# proj and mlp_out
|
| 326 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 327 |
+
|
| 328 |
+
self.norm = QKNorm(head_dim)
|
| 329 |
+
|
| 330 |
+
self.hidden_size = hidden_size
|
| 331 |
+
self.pre_norm = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 332 |
+
|
| 333 |
+
if activation == "gelu_tanh":
|
| 334 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 335 |
+
else:
|
| 336 |
+
self.mlp_act = nn.SiLU()
|
| 337 |
+
self.modulation = Modulation(hidden_size, multiples=1)
|
| 338 |
+
|
| 339 |
+
def init_weights(self):
|
| 340 |
+
for layer in (self.linear1, self.linear2):
|
| 341 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 342 |
+
if layer.bias is not None:
|
| 343 |
+
nn.init.constant_(layer.bias, 0)
|
| 344 |
+
self.norm.init_weights()
|
| 345 |
+
self.pre_norm.reset_parameters()
|
| 346 |
+
self.modulation.init_weights()
|
| 347 |
+
|
| 348 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
| 349 |
+
mod = self.modulation(vec)
|
| 350 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 351 |
+
if self.activation == "gelu_tanh":
|
| 352 |
+
qkv, mlp = torch.split(
|
| 353 |
+
self.linear1(x_mod),
|
| 354 |
+
[3 * self.hidden_size, self.mlp_hidden_dim],
|
| 355 |
+
dim=-1
|
| 356 |
+
)
|
| 357 |
+
else:
|
| 358 |
+
qkv, mlp, mlp_gate = torch.split(
|
| 359 |
+
self.linear1(x_mod),
|
| 360 |
+
[3 * self.hidden_size, self.mlp_hidden_dim, self.mlp_hidden_dim],
|
| 361 |
+
dim=-1
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 365 |
+
q, k = self.norm(q, k, v)
|
| 366 |
+
# compute attention
|
| 367 |
+
attn = attention(q, k, v, pe=pe)
|
| 368 |
+
|
| 369 |
+
if self.activation == "gelu_tanh":
|
| 370 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 371 |
+
x = x + mod.gate * self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 372 |
+
else:
|
| 373 |
+
x = x + mod.gate * self.linear2(
|
| 374 |
+
torch.cat((attn, self.mlp_act(mlp_gate) * mlp), 2)
|
| 375 |
+
)
|
| 376 |
+
return x
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class LastLayer(nn.Module):
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
hidden_size: int,
|
| 383 |
+
patch_size: int,
|
| 384 |
+
out_channels: int,
|
| 385 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 386 |
+
):
|
| 387 |
+
super().__init__()
|
| 388 |
+
self.norm_final = norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 389 |
+
self.linear = nn.Linear(
|
| 390 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
| 391 |
+
)
|
| 392 |
+
self.adaLN_modulation = nn.Sequential(
|
| 393 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def init_weights(self):
|
| 397 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
| 398 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
| 399 |
+
nn.init.constant_(self.linear.weight, 0)
|
| 400 |
+
nn.init.constant_(self.linear.bias, 0)
|
| 401 |
+
self.norm_final.reset_parameters()
|
| 402 |
+
|
| 403 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 404 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 405 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 406 |
+
x = self.linear(x)
|
| 407 |
+
return x
|
ovis_image/model/model.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
|
| 9 |
+
from ovis_image.model.layers import (
|
| 10 |
+
DoubleStreamBlock,
|
| 11 |
+
EmbedND,
|
| 12 |
+
LastLayer,
|
| 13 |
+
MLPEmbedder,
|
| 14 |
+
SingleStreamBlock,
|
| 15 |
+
timestep_embedding,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from ovis_image.model.args import OvisImageModelArgs
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class OvisImageModel(nn.Module):
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_args: OvisImageModelArgs):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.model_args = model_args
|
| 27 |
+
|
| 28 |
+
self.in_channels = model_args.in_channels
|
| 29 |
+
self.out_channels = model_args.out_channels
|
| 30 |
+
if model_args.hidden_size % model_args.num_heads != 0:
|
| 31 |
+
raise ValueError(
|
| 32 |
+
f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
|
| 33 |
+
)
|
| 34 |
+
pe_dim = model_args.hidden_size // model_args.num_heads
|
| 35 |
+
if sum(model_args.axes_dim) != pe_dim:
|
| 36 |
+
raise ValueError(
|
| 37 |
+
f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
|
| 38 |
+
)
|
| 39 |
+
self.hidden_size = model_args.hidden_size
|
| 40 |
+
self.num_heads = model_args.num_heads
|
| 41 |
+
self.pe_embedder = EmbedND(
|
| 42 |
+
dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
|
| 43 |
+
)
|
| 44 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 45 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 46 |
+
self.semantic_txt_norm = nn.RMSNorm(model_args.context_in_dim, eps=1e-6)
|
| 47 |
+
self.semantic_txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size, bias=True)
|
| 48 |
+
|
| 49 |
+
if model_args.norm == "layernorm":
|
| 50 |
+
norm_layer = nn.LayerNorm
|
| 51 |
+
else:
|
| 52 |
+
norm_layer = nn.RMSNorm
|
| 53 |
+
|
| 54 |
+
DoubleBlock = DoubleStreamBlock
|
| 55 |
+
|
| 56 |
+
self.double_blocks = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
DoubleBlock(
|
| 59 |
+
self.hidden_size,
|
| 60 |
+
self.num_heads,
|
| 61 |
+
mlp_ratio=model_args.mlp_ratio,
|
| 62 |
+
qkv_bias=model_args.qkv_bias,
|
| 63 |
+
activation=model_args.activation,
|
| 64 |
+
norm_layer=norm_layer,
|
| 65 |
+
)
|
| 66 |
+
for _ in range(model_args.depth)
|
| 67 |
+
]
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.single_blocks = nn.ModuleList(
|
| 71 |
+
[
|
| 72 |
+
SingleStreamBlock(
|
| 73 |
+
self.hidden_size,
|
| 74 |
+
self.num_heads,
|
| 75 |
+
mlp_ratio=model_args.mlp_ratio,
|
| 76 |
+
qkv_bias=model_args.qkv_bias,
|
| 77 |
+
activation=model_args.activation,
|
| 78 |
+
norm_layer=norm_layer,
|
| 79 |
+
)
|
| 80 |
+
for _ in range(model_args.depth_single_blocks)
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.final_layer = LastLayer(
|
| 85 |
+
self.hidden_size,
|
| 86 |
+
1,
|
| 87 |
+
self.out_channels,
|
| 88 |
+
norm_layer=norm_layer,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
img: Tensor,
|
| 94 |
+
img_ids: Tensor,
|
| 95 |
+
txt: Tensor,
|
| 96 |
+
txt_ids: Tensor,
|
| 97 |
+
timesteps: Tensor,
|
| 98 |
+
) -> Tensor:
|
| 99 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 100 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 101 |
+
|
| 102 |
+
# running on sequences img
|
| 103 |
+
img = self.img_in(img)
|
| 104 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 105 |
+
txt = self.semantic_txt_norm(txt)
|
| 106 |
+
txt = self.semantic_txt_in(txt)
|
| 107 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 108 |
+
pe = self.pe_embedder(ids)
|
| 109 |
+
|
| 110 |
+
for block in self.double_blocks:
|
| 111 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
| 112 |
+
|
| 113 |
+
img = torch.cat((txt, img), 1)
|
| 114 |
+
for block in self.single_blocks:
|
| 115 |
+
img = block(img, vec=vec, pe=pe)
|
| 116 |
+
img = img[:, txt.shape[1] :, ...]
|
| 117 |
+
|
| 118 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 119 |
+
return img
|
| 120 |
+
|
| 121 |
+
|
ovis_image/model/ops.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
| 10 |
+
flash_attn_func = None
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn_interface import flash_attn_func
|
| 13 |
+
print("find flash attn 3")
|
| 14 |
+
except:
|
| 15 |
+
flash_attn_func = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def check_attention_type(attn_implementation):
|
| 19 |
+
if torch.__version__ >= "2.7.0":
|
| 20 |
+
if attn_implementation != "sdpa":
|
| 21 |
+
print("please set attn_implementation as sdpa for torch271")
|
| 22 |
+
elif flash_attn_func is not None:
|
| 23 |
+
if attn_implementation != "flash_attention_3":
|
| 24 |
+
print("please set attn_implementation as flash_attention_3 for H100")
|
| 25 |
+
|
| 26 |
+
def get_attention_type_by_system():
|
| 27 |
+
if torch.__version__ >= "2.7.0":
|
| 28 |
+
return "sdpa"
|
| 29 |
+
elif flash_attn_func is not None:
|
| 30 |
+
return "flash_attention_3"
|
| 31 |
+
else:
|
| 32 |
+
return "eager"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 36 |
+
if torch.__version__ >= "2.7.0":
|
| 37 |
+
return attention_sdpa(q, k, v, pe)
|
| 38 |
+
elif flash_attn_func is not None:
|
| 39 |
+
return attention_fa3(q, k, v, pe)
|
| 40 |
+
else:
|
| 41 |
+
return attention_eager(q, k, v, pe)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def attention_eager(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 45 |
+
q, k = apply_rope(q, k, pe)
|
| 46 |
+
# https://docs.pytorch.org/docs/2.6/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
|
| 47 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 48 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def attention_sdpa(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 53 |
+
q, k = apply_rope(q, k, pe)
|
| 54 |
+
# B200用torch271镜像,用SDPA加速
|
| 55 |
+
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
|
| 56 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 57 |
+
q, k, v,
|
| 58 |
+
)
|
| 59 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def attention_fa3(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 64 |
+
q, k = apply_rope(q, k, pe)
|
| 65 |
+
# H100上用flash_attn_3加速
|
| 66 |
+
q = rearrange(q, "B H L D -> B L H D")
|
| 67 |
+
k = rearrange(k, "B H L D -> B L H D")
|
| 68 |
+
v = rearrange(v, "B H L D -> B L H D")
|
| 69 |
+
x = flash_attn_func(q, k, v)[0]
|
| 70 |
+
x = rearrange(x, "B L H D -> B L (H D)")
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
def get_attention_func(attn_implementation):
|
| 74 |
+
if attn_implementation == "eager":
|
| 75 |
+
return attention_eager
|
| 76 |
+
elif attn_implementation == "sdpa":
|
| 77 |
+
return attention_sdpa
|
| 78 |
+
elif attn_implementation == "flash_attention_3":
|
| 79 |
+
return attention_fa3
|
| 80 |
+
else:
|
| 81 |
+
return attention_eager
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 85 |
+
assert dim % 2 == 0
|
| 86 |
+
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
| 87 |
+
omega = 1.0 / (theta**scale)
|
| 88 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 89 |
+
out = torch.stack(
|
| 90 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
| 91 |
+
)
|
| 92 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 93 |
+
return out.float()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 97 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 98 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 99 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 100 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 101 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 102 |
+
|
ovis_image/model/ovis/configuration_ovis2_5.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional, List, Union
|
| 2 |
+
|
| 3 |
+
from transformers import Qwen3Config
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
|
| 6 |
+
__all__ = ["Siglip2NavitConfig", "Ovis2_5_Config"]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Siglip2NavitConfig(PretrainedConfig):
|
| 10 |
+
"""This is the configuration class to store the configuration of an [`AIMv2Model`].
|
| 11 |
+
|
| 12 |
+
Instantiating a configuration with the defaults will yield a similar configuration
|
| 13 |
+
to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
hidden_size: Dimension of the hidden representations.
|
| 17 |
+
intermediate_size: Dimension of the SwiGLU representations.
|
| 18 |
+
num_hidden_layers: Number of hidden layers in the Transformer.
|
| 19 |
+
num_attention_heads: Number of attention heads for each attention layer
|
| 20 |
+
in the Transformer.
|
| 21 |
+
num_channels: Number of input channels.
|
| 22 |
+
image_size: Image size.
|
| 23 |
+
patch_size: Patch size.
|
| 24 |
+
rms_norm_eps: Epsilon value used for the RMS normalization layer.
|
| 25 |
+
attention_dropout: Dropout ratio for attention probabilities.
|
| 26 |
+
projection_dropout: Dropout ratio for the projection layer after the attention.
|
| 27 |
+
qkv_bias: Whether to add a bias to the queries, keys and values.
|
| 28 |
+
use_bias: Whether to add a bias in the feed-forward and projection layers.
|
| 29 |
+
kwargs: Keyword arguments for the [`PretrainedConfig`].
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
model_type: str = "siglip2_navit"
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
hidden_size: int = 1024,
|
| 37 |
+
intermediate_size: int = 4096,
|
| 38 |
+
num_hidden_layers: int = 24,
|
| 39 |
+
num_attention_heads: int = 16,
|
| 40 |
+
num_channels: int = 3,
|
| 41 |
+
num_patches: int = -1,
|
| 42 |
+
image_size: int = 512,
|
| 43 |
+
patch_size: int = 16,
|
| 44 |
+
hidden_act: str="gelu_pytorch_tanh",
|
| 45 |
+
layer_norm_eps: float = 1e-6,
|
| 46 |
+
attention_dropout: float = 0.0,
|
| 47 |
+
hidden_stride: int = 2,
|
| 48 |
+
window_size: int = 112,
|
| 49 |
+
fullatt_block_indexes: Optional[list] = None,
|
| 50 |
+
temporal_patch_size: int = 1,
|
| 51 |
+
preserve_original_pe: bool = True,
|
| 52 |
+
use_rope: bool = True,
|
| 53 |
+
**kwargs: Any,
|
| 54 |
+
):
|
| 55 |
+
super().__init__(**kwargs)
|
| 56 |
+
self.hidden_size = hidden_size
|
| 57 |
+
self.intermediate_size = intermediate_size
|
| 58 |
+
self.num_hidden_layers = num_hidden_layers
|
| 59 |
+
self.num_attention_heads = num_attention_heads
|
| 60 |
+
self.num_channels = num_channels
|
| 61 |
+
self.num_patches = num_patches
|
| 62 |
+
self.patch_size = patch_size
|
| 63 |
+
self.image_size = image_size
|
| 64 |
+
self.hidden_act = hidden_act
|
| 65 |
+
self.attention_dropout = attention_dropout
|
| 66 |
+
self.layer_norm_eps = layer_norm_eps
|
| 67 |
+
self.hidden_stride = hidden_stride
|
| 68 |
+
self.window_size = window_size
|
| 69 |
+
self.fullatt_block_indexes = fullatt_block_indexes
|
| 70 |
+
self.temporal_patch_size = temporal_patch_size
|
| 71 |
+
self.preserve_original_pe = preserve_original_pe
|
| 72 |
+
self.use_rope = use_rope
|
| 73 |
+
|
| 74 |
+
class Ovis2_5_Config(PretrainedConfig):
|
| 75 |
+
model_type = "ovis2_5"
|
| 76 |
+
sub_configs = dict(llm_config=Qwen3Config, vit_config=Siglip2NavitConfig)
|
| 77 |
+
|
| 78 |
+
def __init__(self,
|
| 79 |
+
llm_config: Optional[Union[Qwen3Config, dict]] = None,
|
| 80 |
+
vit_config: Optional[Union[Siglip2NavitConfig, dict]] = None,
|
| 81 |
+
visual_vocab_size=65536,
|
| 82 |
+
hidden_size=None,
|
| 83 |
+
**kwargs
|
| 84 |
+
):
|
| 85 |
+
super().__init__(**kwargs)
|
| 86 |
+
if isinstance(llm_config, dict):
|
| 87 |
+
llm_config = Qwen3Config(**llm_config)
|
| 88 |
+
self.llm_config = llm_config
|
| 89 |
+
if isinstance(vit_config, dict):
|
| 90 |
+
vit_config = Siglip2NavitConfig(**vit_config)
|
| 91 |
+
self.vit_config = vit_config
|
| 92 |
+
self.visual_vocab_size = visual_vocab_size
|
| 93 |
+
self.hidden_size = hidden_size
|
| 94 |
+
if kwargs.get('attn_implementation'):
|
| 95 |
+
self.llm_config._attn_implementation = kwargs['attn_implementation']
|
| 96 |
+
self.vit_config._attn_implementation = kwargs['attn_implementation']
|
ovis_image/model/ovis/modeling_ovis2_5.py
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
import math
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import PIL.Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
# from flash_attn import flash_attn_varlen_func
|
| 12 |
+
# from flash_attn.layers.rotary import apply_rotary_emb
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoConfig,
|
| 17 |
+
AutoImageProcessor,
|
| 18 |
+
AutoModel,
|
| 19 |
+
AutoModelForCausalLM,
|
| 20 |
+
AutoTokenizer,
|
| 21 |
+
)
|
| 22 |
+
from transformers.activations import ACT2FN
|
| 23 |
+
from transformers.generation.utils import GenerateOutput
|
| 24 |
+
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
|
| 25 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 26 |
+
|
| 27 |
+
from ovis_image.model.ovis.configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
|
| 28 |
+
from ovis_image.model.ops import get_attention_type_by_system
|
| 29 |
+
|
| 30 |
+
IMAGE_PLACEHOLDER = "<image>"
|
| 31 |
+
IMAGE_PLACEHOLDER_ID = -200
|
| 32 |
+
VIDEO_PLACEHOLDER = "<video>"
|
| 33 |
+
VIDEO_PLACEHOLDER_ID = -201
|
| 34 |
+
|
| 35 |
+
VISUAL_ATOM_ID = -300
|
| 36 |
+
INDICATOR_IDS = [-301, -302, -303, -304]
|
| 37 |
+
|
| 38 |
+
# copied from qwen2.5-vl
|
| 39 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 40 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 43 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 44 |
+
|
| 45 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
| 46 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 47 |
+
freqs = torch.outer(seq, self.inv_freq)
|
| 48 |
+
return freqs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Siglip2VisionEmbeddings(nn.Module):
|
| 52 |
+
def __init__(self, config: Siglip2NavitConfig):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.config = config
|
| 55 |
+
self.embed_dim = config.hidden_size
|
| 56 |
+
self.patch_size = config.patch_size
|
| 57 |
+
self.image_size = config.image_size
|
| 58 |
+
self.num_patches = config.num_patches
|
| 59 |
+
self.preserve_original_pe = config.preserve_original_pe
|
| 60 |
+
self.hidden_stride = config.hidden_stride
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# siglip2 naflex
|
| 64 |
+
if self.num_patches > 0:
|
| 65 |
+
self.patch_embedding = nn.Linear(
|
| 66 |
+
in_features=config.num_channels * self.patch_size * self.patch_size,
|
| 67 |
+
out_features=self.embed_dim,
|
| 68 |
+
)
|
| 69 |
+
if self.preserve_original_pe:
|
| 70 |
+
self.position_embedding_size = int(self.num_patches**0.5)
|
| 71 |
+
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
|
| 72 |
+
|
| 73 |
+
else:
|
| 74 |
+
self.patch_embedding = nn.Conv2d(
|
| 75 |
+
in_channels=config.num_channels,
|
| 76 |
+
out_channels=self.embed_dim,
|
| 77 |
+
kernel_size=self.patch_size,
|
| 78 |
+
stride=self.patch_size,
|
| 79 |
+
padding="valid",
|
| 80 |
+
)
|
| 81 |
+
if self.preserve_original_pe:
|
| 82 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 83 |
+
self.position_embedding_size = self.image_size // self.patch_size
|
| 84 |
+
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def resize_positional_embeddings(
|
| 88 |
+
positional_embeddings: torch.Tensor,
|
| 89 |
+
spatial_shapes: torch.LongTensor,
|
| 90 |
+
max_length: int,
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
Resize positional embeddings to image-specific size and pad to a fixed size.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
positional_embeddings (`torch.Tensor`):
|
| 97 |
+
Position embeddings of shape (height, width, embed_dim)
|
| 98 |
+
spatial_shapes (`torch.LongTensor`):
|
| 99 |
+
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
|
| 100 |
+
max_length (`int`):
|
| 101 |
+
Maximum length of the positional embeddings to pad resized positional embeddings to
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
|
| 105 |
+
"""
|
| 106 |
+
batch_size = spatial_shapes.shape[0]
|
| 107 |
+
embed_dim = positional_embeddings.shape[-1]
|
| 108 |
+
source_dtype = positional_embeddings.dtype
|
| 109 |
+
|
| 110 |
+
resulted_positional_embeddings = torch.empty(
|
| 111 |
+
(batch_size, max_length, embed_dim),
|
| 112 |
+
device=positional_embeddings.device,
|
| 113 |
+
dtype=source_dtype,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
|
| 117 |
+
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
|
| 118 |
+
|
| 119 |
+
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
|
| 120 |
+
if positional_embeddings.device.type == "cpu":
|
| 121 |
+
positional_embeddings = positional_embeddings.to(torch.float32)
|
| 122 |
+
|
| 123 |
+
for i in range(batch_size):
|
| 124 |
+
# (1, dim, height, width) -> (1, dim, target_height, target_width)
|
| 125 |
+
height, width = spatial_shapes[i]
|
| 126 |
+
resized_embeddings = F.interpolate(
|
| 127 |
+
positional_embeddings,
|
| 128 |
+
size=(height, width),
|
| 129 |
+
mode="bilinear",
|
| 130 |
+
align_corners=False,
|
| 131 |
+
antialias=True,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
|
| 135 |
+
resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
|
| 136 |
+
|
| 137 |
+
# Cast to original dtype
|
| 138 |
+
resized_embeddings = resized_embeddings.to(source_dtype)
|
| 139 |
+
|
| 140 |
+
resulted_positional_embeddings[i, : height * width] = resized_embeddings
|
| 141 |
+
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
|
| 142 |
+
|
| 143 |
+
return resulted_positional_embeddings
|
| 144 |
+
|
| 145 |
+
def forward(self, pixel_values: torch.FloatTensor,
|
| 146 |
+
grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
| 147 |
+
"""
|
| 148 |
+
Args:
|
| 149 |
+
pixel_values (`torch.FloatTensor`):
|
| 150 |
+
Pixel values of shape (num_patches, num_channels * temporal_patch_size * patch_size * patch_size)
|
| 151 |
+
grid_thws: (`torch.LongTensor`):
|
| 152 |
+
grid shape (num_patches, 3)
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
# Apply patch embeddings to already patchified pixel values
|
| 156 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 157 |
+
if isinstance(self.patch_embedding, nn.Linear):
|
| 158 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
| 159 |
+
elif isinstance(self.patch_embedding, nn.Conv2d):
|
| 160 |
+
pixel_values = pixel_values.view(-1, self.config.num_channels * self.config.temporal_patch_size, self.patch_size,
|
| 161 |
+
self.patch_size)
|
| 162 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
| 163 |
+
patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if self.preserve_original_pe:
|
| 167 |
+
assert grid_thws is not None
|
| 168 |
+
pos_embed_new = torch.zeros_like(patch_embeds)
|
| 169 |
+
ori_h = ori_w = self.position_embedding_size
|
| 170 |
+
positional_embeddings = self.position_embedding.weight.reshape(
|
| 171 |
+
self.position_embedding_size, self.position_embedding_size, -1
|
| 172 |
+
).unsqueeze(0).permute(0,3,1,2)
|
| 173 |
+
# pos_embed = self.pos_embed.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2)
|
| 174 |
+
cnt = 0
|
| 175 |
+
for t, h, w in grid_thws:
|
| 176 |
+
thw = t * h * w
|
| 177 |
+
pe = F.interpolate(positional_embeddings, size=(h, w), mode='bicubic', align_corners=False)
|
| 178 |
+
pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
|
| 179 |
+
pe = pe[0].repeat(t, 1)
|
| 180 |
+
pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, w // self.hidden_stride,
|
| 181 |
+
self.hidden_stride, -1)
|
| 182 |
+
pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(thw, -1)
|
| 183 |
+
pos_embed_new[cnt:cnt + thw] = pe
|
| 184 |
+
cnt += thw
|
| 185 |
+
patch_embeds = patch_embeds + pos_embed_new
|
| 186 |
+
|
| 187 |
+
return patch_embeds
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def rotate_half(x):
|
| 192 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 193 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 194 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 195 |
+
|
| 196 |
+
def apply_rotary_pos_emb_flashatt(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 197 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 198 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 199 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 200 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 201 |
+
return q_embed, k_embed
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Siglip2Attention(nn.Module):
|
| 205 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 206 |
+
|
| 207 |
+
def __init__(self, config):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.config = config
|
| 210 |
+
self.embed_dim = config.hidden_size
|
| 211 |
+
self.num_heads = config.num_attention_heads
|
| 212 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 213 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 214 |
+
raise ValueError(
|
| 215 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 216 |
+
f" {self.num_heads})."
|
| 217 |
+
)
|
| 218 |
+
self.scale = self.head_dim**-0.5
|
| 219 |
+
self.dropout = config.attention_dropout
|
| 220 |
+
self.is_causal = False
|
| 221 |
+
|
| 222 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 223 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 224 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 225 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 226 |
+
|
| 227 |
+
self.use_rope = config.use_rope
|
| 228 |
+
|
| 229 |
+
def forward(
|
| 230 |
+
self,
|
| 231 |
+
hidden_states: torch.Tensor,
|
| 232 |
+
cu_seqlens: torch.Tensor,
|
| 233 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 234 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 235 |
+
"""Input shape: Batch x Time x Channel"""
|
| 236 |
+
|
| 237 |
+
seq_length, embed_dim = hidden_states.shape
|
| 238 |
+
|
| 239 |
+
queries = self.q_proj(hidden_states)
|
| 240 |
+
keys = self.k_proj(hidden_states)
|
| 241 |
+
values = self.v_proj(hidden_states)
|
| 242 |
+
|
| 243 |
+
queries = queries.view(seq_length, self.num_heads, self.head_dim)
|
| 244 |
+
keys = keys.view(seq_length, self.num_heads, self.head_dim)
|
| 245 |
+
values = values.view(seq_length, self.num_heads, self.head_dim)
|
| 246 |
+
|
| 247 |
+
if self.use_rope:
|
| 248 |
+
cos, sin = position_embeddings
|
| 249 |
+
queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
|
| 250 |
+
queries = queries.squeeze(0)
|
| 251 |
+
keys = keys.squeeze(0)
|
| 252 |
+
|
| 253 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 254 |
+
# attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 255 |
+
# seq_length, -1
|
| 256 |
+
# )
|
| 257 |
+
batch_size = cu_seqlens.shape[0] - 1
|
| 258 |
+
q_list, k_list, v_list = [], [], []
|
| 259 |
+
for i in range(batch_size):
|
| 260 |
+
start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
|
| 261 |
+
q_list.append(queries[start:end])
|
| 262 |
+
k_list.append(keys[start:end])
|
| 263 |
+
v_list.append(values[start:end])
|
| 264 |
+
|
| 265 |
+
def pad_to_max(t, max_len):
|
| 266 |
+
pad = (0, 0, 0, 0, 0, max_len - t.shape[0]) # [seqlen, num_heads, head_dim]
|
| 267 |
+
return torch.nn.functional.pad(t, pad)
|
| 268 |
+
|
| 269 |
+
q_batched = torch.stack([pad_to_max(x, max_seqlen) for x in q_list]) # (batch, seqlen, nhead, dim)
|
| 270 |
+
k_batched = torch.stack([pad_to_max(x, max_seqlen) for x in k_list])
|
| 271 |
+
v_batched = torch.stack([pad_to_max(x, max_seqlen) for x in v_list])
|
| 272 |
+
|
| 273 |
+
mask = torch.zeros((batch_size, max_seqlen), dtype=torch.bool, device=queries.device)
|
| 274 |
+
for i in range(batch_size):
|
| 275 |
+
mask[i, :q_list[i].shape[0]] = True # [batch, seqlen]
|
| 276 |
+
|
| 277 |
+
# (batch, nhead, seqlen, head_dim)
|
| 278 |
+
q_batched = q_batched.transpose(1, 2) # (batch, nhead, seqlen, head_dim)
|
| 279 |
+
k_batched = k_batched.transpose(1, 2)
|
| 280 |
+
v_batched = v_batched.transpose(1, 2)
|
| 281 |
+
if torch.__version__ >= "2.7.0":
|
| 282 |
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
| 283 |
+
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
|
| 284 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 285 |
+
q_batched, k_batched, v_batched,
|
| 286 |
+
attn_mask=mask.unsqueeze(1).unsqueeze(2)
|
| 287 |
+
).permute(0, 2, 1, 3).reshape(seq_length, -1)
|
| 288 |
+
else:
|
| 289 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 290 |
+
q_batched, k_batched, v_batched,
|
| 291 |
+
attn_mask = mask.unsqueeze(1).unsqueeze(2) # broadcast到 (batch, 1, 1, seqlen)
|
| 292 |
+
).permute(0, 2, 1, 3).reshape(seq_length, -1)
|
| 293 |
+
|
| 294 |
+
attn_output = self.out_proj(attn_output)
|
| 295 |
+
return attn_output
|
| 296 |
+
|
| 297 |
+
class Siglip2MLP(nn.Module):
|
| 298 |
+
def __init__(self, config):
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.config = config
|
| 301 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 302 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 303 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 304 |
+
|
| 305 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 306 |
+
hidden_states = self.fc1(hidden_states)
|
| 307 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 308 |
+
hidden_states = self.fc2(hidden_states)
|
| 309 |
+
return hidden_states
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class Siglip2EncoderLayer(nn.Module):
|
| 313 |
+
def __init__(self, config: Siglip2NavitConfig):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.embed_dim = config.hidden_size
|
| 316 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 317 |
+
self.self_attn = Siglip2Attention(config)
|
| 318 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 319 |
+
self.mlp = Siglip2MLP(config)
|
| 320 |
+
|
| 321 |
+
def forward(
|
| 322 |
+
self,
|
| 323 |
+
hidden_states: torch.Tensor,
|
| 324 |
+
cu_seqlens: torch.Tensor,
|
| 325 |
+
position_embeddings: torch.Tensor
|
| 326 |
+
) -> tuple[torch.FloatTensor]:
|
| 327 |
+
"""
|
| 328 |
+
Args:
|
| 329 |
+
hidden_states (`torch.FloatTensor`):
|
| 330 |
+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
| 331 |
+
attention_mask (`torch.FloatTensor`):
|
| 332 |
+
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
| 333 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 334 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 335 |
+
returned tensors for more detail.
|
| 336 |
+
"""
|
| 337 |
+
residual = hidden_states
|
| 338 |
+
|
| 339 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 340 |
+
hidden_states = self.self_attn(
|
| 341 |
+
hidden_states=hidden_states,
|
| 342 |
+
cu_seqlens=cu_seqlens,
|
| 343 |
+
position_embeddings=position_embeddings
|
| 344 |
+
)
|
| 345 |
+
hidden_states = residual + hidden_states
|
| 346 |
+
|
| 347 |
+
residual = hidden_states
|
| 348 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 349 |
+
hidden_states = self.mlp(hidden_states)
|
| 350 |
+
hidden_states = residual + hidden_states
|
| 351 |
+
|
| 352 |
+
return hidden_states
|
| 353 |
+
|
| 354 |
+
class Siglip2Encoder(nn.Module):
|
| 355 |
+
"""
|
| 356 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 357 |
+
[`Siglip2EncoderLayer`].
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
config: Siglip2NavitConfig
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def __init__(self, config: Siglip2NavitConfig):
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.config = config
|
| 366 |
+
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 367 |
+
self.gradient_checkpointing = False
|
| 368 |
+
|
| 369 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
|
| 370 |
+
self.patch_size = config.patch_size
|
| 371 |
+
self.hidden_stride = config.hidden_stride
|
| 372 |
+
self.window_size = config.window_size
|
| 373 |
+
self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
|
| 374 |
+
self.fullatt_block_indexes = None if config.fullatt_block_indexes is None else [int(i) for i in config.fullatt_block_indexes.split('|')]
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# copied from qwen2.5_vl
|
| 378 |
+
def rot_pos_emb(self, grid_thw):
|
| 379 |
+
pos_ids = []
|
| 380 |
+
for t, h, w in grid_thw:
|
| 381 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 382 |
+
hpos_ids = hpos_ids.reshape(
|
| 383 |
+
h // self.hidden_stride,
|
| 384 |
+
self.hidden_stride,
|
| 385 |
+
w // self.hidden_stride,
|
| 386 |
+
self.hidden_stride,
|
| 387 |
+
)
|
| 388 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
| 389 |
+
hpos_ids = hpos_ids.flatten()
|
| 390 |
+
|
| 391 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 392 |
+
wpos_ids = wpos_ids.reshape(
|
| 393 |
+
h // self.hidden_stride,
|
| 394 |
+
self.hidden_stride,
|
| 395 |
+
w // self.hidden_stride,
|
| 396 |
+
self.hidden_stride,
|
| 397 |
+
)
|
| 398 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
| 399 |
+
wpos_ids = wpos_ids.flatten()
|
| 400 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 401 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 402 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 403 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 404 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 405 |
+
return rotary_pos_emb
|
| 406 |
+
|
| 407 |
+
def get_window_index(self, grid_thw):
|
| 408 |
+
window_index: list = []
|
| 409 |
+
cu_window_seqlens: list = [0]
|
| 410 |
+
window_index_id = 0
|
| 411 |
+
vit_merger_window_size = self.window_size // self.hidden_stride // self.patch_size # patch (after merge) number in each window
|
| 412 |
+
|
| 413 |
+
for grid_t, grid_h, grid_w in grid_thw:
|
| 414 |
+
llm_grid_h, llm_grid_w = (
|
| 415 |
+
grid_h // self.hidden_stride, # number of patch after merge
|
| 416 |
+
grid_w // self.hidden_stride,
|
| 417 |
+
)
|
| 418 |
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
| 419 |
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
| 420 |
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
| 421 |
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
| 422 |
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
| 423 |
+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
| 424 |
+
index_padded = index_padded.reshape(
|
| 425 |
+
grid_t,
|
| 426 |
+
num_windows_h,
|
| 427 |
+
vit_merger_window_size,
|
| 428 |
+
num_windows_w,
|
| 429 |
+
vit_merger_window_size,
|
| 430 |
+
)
|
| 431 |
+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
| 432 |
+
grid_t,
|
| 433 |
+
num_windows_h * num_windows_w,
|
| 434 |
+
vit_merger_window_size,
|
| 435 |
+
vit_merger_window_size,
|
| 436 |
+
)
|
| 437 |
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
| 438 |
+
index_padded = index_padded.reshape(-1)
|
| 439 |
+
index_new = index_padded[index_padded != -100]
|
| 440 |
+
window_index.append(index_new + window_index_id)
|
| 441 |
+
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
| 442 |
+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
| 443 |
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
| 444 |
+
window_index = torch.cat(window_index, dim=0)
|
| 445 |
+
|
| 446 |
+
return window_index, cu_window_seqlens
|
| 447 |
+
|
| 448 |
+
# Ignore copy
|
| 449 |
+
def forward(
|
| 450 |
+
self,
|
| 451 |
+
inputs_embeds,
|
| 452 |
+
grid_thws: torch.Tensor,
|
| 453 |
+
output_hidden_states: bool = False,
|
| 454 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
|
| 455 |
+
r"""
|
| 456 |
+
Args:
|
| 457 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 458 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 459 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 460 |
+
than the model's internal embedding lookup matrix.
|
| 461 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 462 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 463 |
+
|
| 464 |
+
- 1 for tokens that are **not masked**,
|
| 465 |
+
- 0 for tokens that are **masked**.
|
| 466 |
+
|
| 467 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 468 |
+
output_attentions (`bool`, *optional*):
|
| 469 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 470 |
+
returned tensors for more detail.
|
| 471 |
+
output_hidden_states (`bool`, *optional*):
|
| 472 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 473 |
+
for more detail.
|
| 474 |
+
return_dict (`bool`, *optional*):
|
| 475 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thws)
|
| 479 |
+
window_index, cu_window_seqlens = self.get_window_index(grid_thws)
|
| 480 |
+
cu_window_seqlens = torch.tensor(
|
| 481 |
+
cu_window_seqlens,
|
| 482 |
+
device=inputs_embeds.device,
|
| 483 |
+
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
|
| 484 |
+
)
|
| 485 |
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
| 486 |
+
|
| 487 |
+
seq_len, _ = inputs_embeds.size()
|
| 488 |
+
inputs_embeds = inputs_embeds.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 489 |
+
inputs_embeds = inputs_embeds[window_index, :, :]
|
| 490 |
+
inputs_embeds = inputs_embeds.reshape(seq_len, -1)
|
| 491 |
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 492 |
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
| 493 |
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
| 494 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 495 |
+
position_embeddings = (emb.cos(), emb.sin())
|
| 496 |
+
|
| 497 |
+
cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(
|
| 498 |
+
dim=0,
|
| 499 |
+
# Select dtype based on the following factors:
|
| 500 |
+
# - FA2 requires that cu_seqlens_q must have dtype int32
|
| 501 |
+
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
|
| 502 |
+
# See https://github.com/huggingface/transformers/pull/34852 for more information
|
| 503 |
+
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
|
| 504 |
+
)
|
| 505 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 506 |
+
|
| 507 |
+
reverse_indices = torch.argsort(window_index)
|
| 508 |
+
encoder_states = () if output_hidden_states else None
|
| 509 |
+
|
| 510 |
+
hidden_states = inputs_embeds
|
| 511 |
+
for index, block in enumerate(self.layers):
|
| 512 |
+
if self.fullatt_block_indexes is None or index in self.fullatt_block_indexes:
|
| 513 |
+
cu_seqlens_tmp = cu_seqlens
|
| 514 |
+
else:
|
| 515 |
+
cu_seqlens_tmp = cu_window_seqlens
|
| 516 |
+
if self.gradient_checkpointing and self.training:
|
| 517 |
+
hidden_states = self._gradient_checkpointing_func(block.__call__, hidden_states, cu_seqlens_tmp, position_embeddings)
|
| 518 |
+
else:
|
| 519 |
+
hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
|
| 520 |
+
if output_hidden_states:
|
| 521 |
+
hidden_states_ = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 522 |
+
encoder_states += (hidden_states_[reverse_indices, :].reshape(seq_len, -1),)
|
| 523 |
+
# tokens = self.post_trunk_norm(tokens)
|
| 524 |
+
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 525 |
+
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
|
| 526 |
+
|
| 527 |
+
return hidden_states, encoder_states
|
| 528 |
+
|
| 529 |
+
class Siglip2VisionTransformer(nn.Module):
|
| 530 |
+
def __init__(self, config: Siglip2NavitConfig):
|
| 531 |
+
super().__init__()
|
| 532 |
+
self.config = config
|
| 533 |
+
embed_dim = config.hidden_size
|
| 534 |
+
|
| 535 |
+
self.embeddings = Siglip2VisionEmbeddings(config)
|
| 536 |
+
self.encoder = Siglip2Encoder(config)
|
| 537 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 538 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 539 |
+
|
| 540 |
+
def forward(
|
| 541 |
+
self,
|
| 542 |
+
pixel_values: torch.FloatTensor,
|
| 543 |
+
grid_thws: torch.LongTensor,
|
| 544 |
+
output_hidden_states: Optional[bool] = True,
|
| 545 |
+
return_dict: Optional[bool] = True,
|
| 546 |
+
) -> Union[
|
| 547 |
+
Tuple[torch.Tensor],
|
| 548 |
+
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
|
| 549 |
+
BaseModelOutputWithNoAttention,
|
| 550 |
+
]:
|
| 551 |
+
r"""
|
| 552 |
+
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
| 553 |
+
Tensor containing the spatial dimensions (height, width) of the input images.
|
| 554 |
+
"""
|
| 555 |
+
# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 556 |
+
# output_hidden_states = (
|
| 557 |
+
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 558 |
+
# )
|
| 559 |
+
|
| 560 |
+
hidden_states = self.embeddings(pixel_values, grid_thws)
|
| 561 |
+
|
| 562 |
+
last_hidden_state, hidden_states = self.encoder(hidden_states, grid_thws, output_hidden_states)
|
| 563 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 564 |
+
|
| 565 |
+
if not return_dict:
|
| 566 |
+
output = (last_hidden_state,)
|
| 567 |
+
output += (hidden_states,) if output_hidden_states else ()
|
| 568 |
+
return output
|
| 569 |
+
|
| 570 |
+
return BaseModelOutputWithNoAttention(
|
| 571 |
+
last_hidden_state=last_hidden_state,
|
| 572 |
+
hidden_states=hidden_states
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
class Siglip2PreTrainedModel(PreTrainedModel):
|
| 576 |
+
config_class = Siglip2NavitConfig
|
| 577 |
+
base_model_prefix = "siglip2_navit"
|
| 578 |
+
supports_gradient_checkpointing = True
|
| 579 |
+
|
| 580 |
+
_no_split_modules = [
|
| 581 |
+
"Siglip2VisionEmbeddings",
|
| 582 |
+
"Siglip2EncoderLayer",
|
| 583 |
+
]
|
| 584 |
+
_supports_flash_attn_2 = True
|
| 585 |
+
_supports_sdpa = False
|
| 586 |
+
_supports_flex_attn = False
|
| 587 |
+
_supports_attention_backend = True
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class Siglip2NavitModel(Siglip2PreTrainedModel):
|
| 591 |
+
config_class = Siglip2NavitConfig
|
| 592 |
+
main_input_name = "pixel_values"
|
| 593 |
+
|
| 594 |
+
def __init__(self, config: Siglip2NavitConfig):
|
| 595 |
+
super().__init__(config)
|
| 596 |
+
|
| 597 |
+
self.vision_model = Siglip2VisionTransformer(config)
|
| 598 |
+
|
| 599 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 600 |
+
return self.vision_model.embeddings.patch_embedding
|
| 601 |
+
|
| 602 |
+
def forward(
|
| 603 |
+
self,
|
| 604 |
+
pixel_values: torch.FloatTensor,
|
| 605 |
+
grid_thws: torch.LongTensor,
|
| 606 |
+
output_hidden_states: Optional[bool] = None,
|
| 607 |
+
return_dict: Optional[bool] = None,
|
| 608 |
+
) -> Union[
|
| 609 |
+
Tuple[torch.Tensor],
|
| 610 |
+
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
|
| 611 |
+
BaseModelOutputWithNoAttention,
|
| 612 |
+
]:
|
| 613 |
+
|
| 614 |
+
if output_hidden_states is None:
|
| 615 |
+
output_hidden_states = self.config.output_hidden_states
|
| 616 |
+
if return_dict is None:
|
| 617 |
+
return_dict = self.config.use_return_dict
|
| 618 |
+
|
| 619 |
+
return self.vision_model(
|
| 620 |
+
pixel_values=pixel_values,
|
| 621 |
+
grid_thws=grid_thws,
|
| 622 |
+
output_hidden_states=output_hidden_states,
|
| 623 |
+
return_dict=return_dict,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
class VisualEmbedding(torch.nn.Embedding):
|
| 627 |
+
"""
|
| 628 |
+
A visual embedding layer that can handle both discrete token IDs (long) and continuous
|
| 629 |
+
soft-token probabilities (float).
|
| 630 |
+
"""
|
| 631 |
+
|
| 632 |
+
def forward(self, visual_tokens: Tensor) -> Tensor:
|
| 633 |
+
if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
|
| 634 |
+
return super().forward(visual_tokens)
|
| 635 |
+
# Handle soft tokens (probabilities) by matrix multiplication with the embedding weight
|
| 636 |
+
return torch.matmul(visual_tokens, self.weight)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
class VisualTokenizer(torch.nn.Module):
|
| 640 |
+
"""
|
| 641 |
+
Tokenizes images or videos into a sequence of continuous visual tokens.
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
def __init__(self, vit, visual_vocab_size, image_processor_name_or_path, *args, **kwargs):
|
| 645 |
+
super().__init__(*args, **kwargs)
|
| 646 |
+
self.vit = vit
|
| 647 |
+
self.image_processor = AutoImageProcessor.from_pretrained(image_processor_name_or_path, do_center_crop=False)
|
| 648 |
+
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
| 649 |
+
self.head = torch.nn.Sequential(
|
| 650 |
+
torch.nn.Linear(self.vit.config.hidden_size * self.vit.config.hidden_stride ** 2, head_dim, bias=False),
|
| 651 |
+
torch.nn.LayerNorm(head_dim)
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
def _encode(self, pixel_values, grid_thws):
|
| 655 |
+
output = self.vit(pixel_values, grid_thws, output_hidden_states=True, return_dict=True)
|
| 656 |
+
features = output.hidden_states[-1]
|
| 657 |
+
seq_len, _ = features.shape
|
| 658 |
+
features = features.reshape(seq_len // (self.vit.config.hidden_stride ** 2), -1)
|
| 659 |
+
return features
|
| 660 |
+
|
| 661 |
+
# Adapted from qwen2_vl
|
| 662 |
+
@staticmethod
|
| 663 |
+
def smart_resize(
|
| 664 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 448 * 448, max_pixels: int = 1344 * 1792
|
| 665 |
+
):
|
| 666 |
+
"""Rescales the image so that the following conditions are met:
|
| 667 |
+
1. Both dimensions are divisible by 'factor'.
|
| 668 |
+
2. The total number of pixels is within ['min_pixels', 'max_pixels'].
|
| 669 |
+
3. The aspect ratio is maintained as closely as possible.
|
| 670 |
+
"""
|
| 671 |
+
if height < factor or width < factor:
|
| 672 |
+
if height < width:
|
| 673 |
+
width = round(factor / height * width)
|
| 674 |
+
height = factor
|
| 675 |
+
else:
|
| 676 |
+
height = round(factor / width * height)
|
| 677 |
+
width = factor
|
| 678 |
+
|
| 679 |
+
elif max(height, width) / min(height, width) > 200:
|
| 680 |
+
if height > width:
|
| 681 |
+
height = 200 * width
|
| 682 |
+
else:
|
| 683 |
+
width = 200 * height
|
| 684 |
+
|
| 685 |
+
h_bar = round(height / factor) * factor
|
| 686 |
+
w_bar = round(width / factor) * factor
|
| 687 |
+
if h_bar * w_bar > max_pixels:
|
| 688 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 689 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 690 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 691 |
+
elif h_bar * w_bar < min_pixels:
|
| 692 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 693 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 694 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 695 |
+
return h_bar, w_bar
|
| 696 |
+
|
| 697 |
+
def preprocess(
|
| 698 |
+
self,
|
| 699 |
+
image: Optional[PIL.Image.Image] = None,
|
| 700 |
+
video: Optional[List[PIL.Image.Image]] = None,
|
| 701 |
+
min_pixels: Optional[int] = None,
|
| 702 |
+
max_pixels: Optional[int] = None
|
| 703 |
+
):
|
| 704 |
+
patch_size = self.vit.config.patch_size
|
| 705 |
+
temporal_patch_size = self.vit.config.temporal_patch_size
|
| 706 |
+
hidden_stride = self.vit.config.hidden_stride
|
| 707 |
+
assert (image is None) ^ (video is None), "Invalid input: expect either image or video"
|
| 708 |
+
if image is not None:
|
| 709 |
+
images = [image]
|
| 710 |
+
else:
|
| 711 |
+
images = video
|
| 712 |
+
images = [image.convert("RGB") if image.mode != 'RGB' else image for image in images]
|
| 713 |
+
width, height = images[0].size
|
| 714 |
+
processed_images = []
|
| 715 |
+
for image in images:
|
| 716 |
+
resized_height, resized_width = self.smart_resize(
|
| 717 |
+
height,
|
| 718 |
+
width,
|
| 719 |
+
factor=patch_size * hidden_stride,
|
| 720 |
+
min_pixels=min_pixels,
|
| 721 |
+
max_pixels=max_pixels,
|
| 722 |
+
)
|
| 723 |
+
new_size = dict(height=resized_height, width=resized_width)
|
| 724 |
+
new_image = self.image_processor.preprocess(image, size=new_size, return_tensors="np")['pixel_values'][0]
|
| 725 |
+
processed_images.append(new_image)
|
| 726 |
+
|
| 727 |
+
patches = np.array(processed_images)
|
| 728 |
+
if patches.shape[0] % temporal_patch_size != 0:
|
| 729 |
+
repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size - 1, axis=0)
|
| 730 |
+
patches = np.concatenate([patches, repeats], axis=0)
|
| 731 |
+
channel = patches.shape[1]
|
| 732 |
+
grid_t = patches.shape[0] // temporal_patch_size
|
| 733 |
+
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
| 734 |
+
grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
|
| 735 |
+
|
| 736 |
+
patches = patches.reshape(
|
| 737 |
+
grid_t, temporal_patch_size, channel,
|
| 738 |
+
grid_h // hidden_stride, hidden_stride, patch_size,
|
| 739 |
+
grid_w // hidden_stride, hidden_stride, patch_size,
|
| 740 |
+
)
|
| 741 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
| 742 |
+
flatten_patches = patches.reshape(
|
| 743 |
+
grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
|
| 744 |
+
)
|
| 745 |
+
flatten_patches = torch.tensor(flatten_patches)
|
| 746 |
+
|
| 747 |
+
return flatten_patches, grid_thw
|
| 748 |
+
|
| 749 |
+
def forward(
|
| 750 |
+
self, pixel_values, grid_thws
|
| 751 |
+
) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
|
| 752 |
+
features = self._encode(pixel_values, grid_thws)
|
| 753 |
+
logits = self.head(features)
|
| 754 |
+
tokens = torch.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype)
|
| 755 |
+
|
| 756 |
+
token_len, _ = tokens.shape
|
| 757 |
+
padding_tensor = torch.zeros(size=(token_len, len(INDICATOR_IDS)),
|
| 758 |
+
dtype=tokens.dtype,
|
| 759 |
+
device=tokens.device,
|
| 760 |
+
layout=tokens.layout,
|
| 761 |
+
requires_grad=False)
|
| 762 |
+
tokens = torch.cat((tokens, padding_tensor), dim=1)
|
| 763 |
+
return tokens
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class OvisPreTrainedModel(PreTrainedModel):
|
| 767 |
+
config_class = Ovis2_5_Config
|
| 768 |
+
base_model_prefix = "ovis2_5"
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
class Ovis2_5(OvisPreTrainedModel):
|
| 772 |
+
_supports_flash_attn_2 = True
|
| 773 |
+
_supports_flash_attn_3 = True
|
| 774 |
+
|
| 775 |
+
def __init__(self, config: Ovis2_5_Config, *inputs, **kwargs):
|
| 776 |
+
super().__init__(config, *inputs, **kwargs)
|
| 777 |
+
attn_implementation = get_attention_type_by_system()
|
| 778 |
+
print(f"Use {attn_implementation} for LLM!")
|
| 779 |
+
self.llm = AutoModelForCausalLM.from_config(
|
| 780 |
+
self.config.llm_config,
|
| 781 |
+
attn_implementation=attn_implementation,
|
| 782 |
+
)
|
| 783 |
+
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
|
| 784 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
|
| 785 |
+
self.visual_tokenizer = VisualTokenizer(vit=AutoModel.from_config(self.config.vit_config),
|
| 786 |
+
visual_vocab_size=self.config.visual_vocab_size,
|
| 787 |
+
image_processor_name_or_path=self.config.name_or_path)
|
| 788 |
+
|
| 789 |
+
self.vte = VisualEmbedding(self.config.visual_vocab_size, self.config.hidden_size,
|
| 790 |
+
device=self.visual_tokenizer.vit.device, dtype=self.visual_tokenizer.vit.dtype)
|
| 791 |
+
indicator_token_indices = torch.arange(
|
| 792 |
+
self.config.visual_vocab_size - len(INDICATOR_IDS),
|
| 793 |
+
self.config.visual_vocab_size,
|
| 794 |
+
dtype=torch.long
|
| 795 |
+
)
|
| 796 |
+
self.register_buffer("indicator_token_indices", indicator_token_indices, persistent=False)
|
| 797 |
+
|
| 798 |
+
def _merge_modules(modules_list: tuple):
|
| 799 |
+
merged_modules = []
|
| 800 |
+
for modules in modules_list:
|
| 801 |
+
merged_modules.extend(modules if modules else [])
|
| 802 |
+
return merged_modules
|
| 803 |
+
|
| 804 |
+
# Standard model configurations for parallelism and device placement
|
| 805 |
+
self._no_split_modules = _merge_modules(
|
| 806 |
+
(self.llm._no_split_modules, self.visual_tokenizer.vit._no_split_modules))
|
| 807 |
+
self._skip_keys_device_placement = self.llm._skip_keys_device_placement
|
| 808 |
+
self._keep_in_fp32_modules = _merge_modules(
|
| 809 |
+
(self.llm._keep_in_fp32_modules, self.visual_tokenizer.vit._keep_in_fp32_modules))
|
| 810 |
+
self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.vit.is_parallelizable))
|
| 811 |
+
# self.supports_gradient_checkpointing = True
|
| 812 |
+
self.supports_gradient_checkpointing = False
|
| 813 |
+
|
| 814 |
+
def tie_weights(self):
|
| 815 |
+
self.llm.tie_weights()
|
| 816 |
+
|
| 817 |
+
def get_wte(self):
|
| 818 |
+
return self.llm.get_input_embeddings()
|
| 819 |
+
|
| 820 |
+
def forward(
|
| 821 |
+
self,
|
| 822 |
+
input_ids: torch.Tensor,
|
| 823 |
+
attention_mask: torch.Tensor,
|
| 824 |
+
pixel_values: Optional[torch.Tensor],
|
| 825 |
+
grid_thws: Optional[torch.Tensor],
|
| 826 |
+
labels: Optional[torch.Tensor] = None,
|
| 827 |
+
**kwargs
|
| 828 |
+
):
|
| 829 |
+
inputs_embeds = self.merge_multimodal(
|
| 830 |
+
input_ids=input_ids,
|
| 831 |
+
pixel_values=pixel_values,
|
| 832 |
+
grid_thws=grid_thws,
|
| 833 |
+
)
|
| 834 |
+
return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)
|
| 835 |
+
|
| 836 |
+
def merge_multimodal(
|
| 837 |
+
self,
|
| 838 |
+
input_ids: torch.Tensor,
|
| 839 |
+
pixel_values: Optional[torch.Tensor],
|
| 840 |
+
grid_thws: Optional[torch.Tensor],
|
| 841 |
+
):
|
| 842 |
+
placeholder_token_mask = torch.lt(input_ids, 0)
|
| 843 |
+
multimodal_embeds = self.get_wte()(torch.masked_fill(input_ids, placeholder_token_mask, 0))
|
| 844 |
+
|
| 845 |
+
if pixel_values is not None:
|
| 846 |
+
visual_indicator_embeds = self.vte(self.indicator_token_indices).to(
|
| 847 |
+
dtype=multimodal_embeds.dtype, device=multimodal_embeds.device
|
| 848 |
+
)
|
| 849 |
+
visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
|
| 850 |
+
visual_embeds = self.vte(visual_tokens).to(dtype=multimodal_embeds.dtype, device=multimodal_embeds.device)
|
| 851 |
+
|
| 852 |
+
for i, indicator_id in enumerate(INDICATOR_IDS):
|
| 853 |
+
multimodal_embeds[input_ids == indicator_id] = visual_indicator_embeds[i]
|
| 854 |
+
multimodal_embeds[input_ids == VISUAL_ATOM_ID] = visual_embeds
|
| 855 |
+
|
| 856 |
+
return multimodal_embeds
|
| 857 |
+
|
| 858 |
+
def _merge_inputs(
|
| 859 |
+
self, raw_input_ids, placeholder_id, grid_thws, indicator_begin_id, indicator_end_id
|
| 860 |
+
):
|
| 861 |
+
input_ids = []
|
| 862 |
+
prev_index = 0
|
| 863 |
+
placeholder_indexes = [i for i, v in enumerate(raw_input_ids) if v == placeholder_id]
|
| 864 |
+
for placeholder_index, grid_thw in zip(placeholder_indexes, grid_thws):
|
| 865 |
+
input_ids.extend(raw_input_ids[prev_index:placeholder_index])
|
| 866 |
+
num_image_atoms = grid_thw.prod().item()
|
| 867 |
+
num_image_atoms //= self.visual_tokenizer.vit.config.hidden_stride ** 2
|
| 868 |
+
num_image_atoms //= self.visual_tokenizer.vit.config.temporal_patch_size
|
| 869 |
+
input_ids.extend([indicator_begin_id] + [VISUAL_ATOM_ID] * num_image_atoms + [indicator_end_id])
|
| 870 |
+
prev_index = placeholder_index + 1
|
| 871 |
+
input_ids.extend(raw_input_ids[prev_index:])
|
| 872 |
+
return input_ids
|
| 873 |
+
|
| 874 |
+
def _tokenize_with_visual_placeholder(self, text):
|
| 875 |
+
placeholder = VIDEO_PLACEHOLDER if VIDEO_PLACEHOLDER in text else IMAGE_PLACEHOLDER
|
| 876 |
+
placeholder_id = VIDEO_PLACEHOLDER_ID if VIDEO_PLACEHOLDER in text else IMAGE_PLACEHOLDER_ID
|
| 877 |
+
chunks = [self.text_tokenizer(chunk, add_special_tokens=False).input_ids for chunk in text.split(placeholder)]
|
| 878 |
+
input_ids = chunks[0]
|
| 879 |
+
for chunk in chunks[1:]:
|
| 880 |
+
input_ids.append(placeholder_id)
|
| 881 |
+
input_ids.extend(chunk)
|
| 882 |
+
return input_ids
|
| 883 |
+
|
| 884 |
+
def preprocess_inputs(
|
| 885 |
+
self,
|
| 886 |
+
messages: List[Union[str, Dict]],
|
| 887 |
+
min_pixels=448 * 448,
|
| 888 |
+
max_pixels=1344 * 1792,
|
| 889 |
+
add_generation_prompt=True,
|
| 890 |
+
enable_thinking=False
|
| 891 |
+
):
|
| 892 |
+
text = self.text_tokenizer.apply_chat_template(
|
| 893 |
+
messages,
|
| 894 |
+
tokenize=False,
|
| 895 |
+
add_generation_prompt=add_generation_prompt,
|
| 896 |
+
enable_thinking=enable_thinking
|
| 897 |
+
)
|
| 898 |
+
input_ids = self._tokenize_with_visual_placeholder(text)
|
| 899 |
+
images = []
|
| 900 |
+
videos = []
|
| 901 |
+
for message in messages:
|
| 902 |
+
content = message["content"]
|
| 903 |
+
if isinstance(content, list):
|
| 904 |
+
images.extend([item["image"] for item in content if item.get("image") is not None])
|
| 905 |
+
videos.extend([item["video"] for item in content if item.get("video") is not None])
|
| 906 |
+
if images and videos:
|
| 907 |
+
raise ValueError(
|
| 908 |
+
"Multiple visual input data types detected (both image and video provided). "
|
| 909 |
+
"This model supports only one type of visual input data at a time. "
|
| 910 |
+
"Please provide either image or video, but not both."
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
pixel_values, grid_thws = None, None
|
| 914 |
+
if images:
|
| 915 |
+
pixel_values, grid_thws = zip(
|
| 916 |
+
*(self.visual_tokenizer.preprocess(image=image, min_pixels=min_pixels, max_pixels=max_pixels)
|
| 917 |
+
for image in images)
|
| 918 |
+
)
|
| 919 |
+
input_ids = self._merge_inputs(
|
| 920 |
+
input_ids, IMAGE_PLACEHOLDER_ID, grid_thws, INDICATOR_IDS[0], INDICATOR_IDS[1]
|
| 921 |
+
)
|
| 922 |
+
pixel_values = torch.cat(pixel_values, dim=0)
|
| 923 |
+
grid_thws = torch.cat(grid_thws, dim=0)
|
| 924 |
+
elif videos:
|
| 925 |
+
assert len(videos) == 1, "only support single video"
|
| 926 |
+
pixel_values, grid_thws = self.visual_tokenizer.preprocess(
|
| 927 |
+
video=videos[0], min_pixels=min_pixels, max_pixels=max_pixels
|
| 928 |
+
)
|
| 929 |
+
input_ids = self._merge_inputs(
|
| 930 |
+
input_ids, VIDEO_PLACEHOLDER_ID, grid_thws, INDICATOR_IDS[2], INDICATOR_IDS[3]
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
|
| 934 |
+
|
| 935 |
+
return input_ids, pixel_values, grid_thws
|
| 936 |
+
|
| 937 |
+
def generate(
|
| 938 |
+
self,
|
| 939 |
+
inputs: Optional[torch.Tensor] = None,
|
| 940 |
+
**kwargs,
|
| 941 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 942 |
+
attention_mask = torch.ne(inputs, self.text_tokenizer.pad_token_id).to(device=inputs.device)
|
| 943 |
+
inputs_embeds = self.merge_multimodal(
|
| 944 |
+
input_ids=inputs,
|
| 945 |
+
pixel_values=kwargs.pop('pixel_values', None),
|
| 946 |
+
grid_thws=kwargs.pop('grid_thws', None)
|
| 947 |
+
)
|
| 948 |
+
enable_thinking = kwargs.pop('enable_thinking', False)
|
| 949 |
+
enable_thinking_budget = kwargs.pop('enable_thinking_budget', False)
|
| 950 |
+
thinking_budget = kwargs.pop('thinking_budget', 1024)
|
| 951 |
+
|
| 952 |
+
if enable_thinking and enable_thinking_budget:
|
| 953 |
+
actual_max_new_tokens = kwargs['max_new_tokens']
|
| 954 |
+
kwargs['max_new_tokens'] = thinking_budget
|
| 955 |
+
generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
|
| 956 |
+
output_ids = generated_ids
|
| 957 |
+
output_ids_list = generated_ids[0]
|
| 958 |
+
|
| 959 |
+
# check if the generation has already finished (151645 is <|im_end|>)
|
| 960 |
+
if 151645 not in output_ids_list:
|
| 961 |
+
# check if the thinking process has finished (151668 is </think>)
|
| 962 |
+
# and prepare the second model input
|
| 963 |
+
if 151668 not in output_ids_list:
|
| 964 |
+
early_stopping_text = "\n\nConsidering the limited time by the user, I have to give the solution based on the thinking directly now.\n</think>\n\n"
|
| 965 |
+
early_stopping_ids = self.text_tokenizer(early_stopping_text, return_tensors="pt", return_attention_mask=False).input_ids.to(inputs.device)
|
| 966 |
+
input_ids_appendent = torch.cat([output_ids, early_stopping_ids], dim=-1)
|
| 967 |
+
kwargs['streamer'].put(early_stopping_ids) if 'streamer' in kwargs else None
|
| 968 |
+
else:
|
| 969 |
+
input_ids_appendent = output_ids
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
# second generation
|
| 973 |
+
new_inputs = torch.cat([inputs, input_ids_appendent], dim=-1)
|
| 974 |
+
attention_mask = torch.ne(new_inputs, self.text_tokenizer.pad_token_id).to(device=inputs.device)
|
| 975 |
+
inputs_embeds_appendent = self.merge_multimodal(
|
| 976 |
+
input_ids=input_ids_appendent,
|
| 977 |
+
pixel_values=None,
|
| 978 |
+
grid_thws=None
|
| 979 |
+
)
|
| 980 |
+
new_inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_appendent], dim=-2)
|
| 981 |
+
|
| 982 |
+
kwargs['max_new_tokens'] = inputs_embeds.size(-2) + actual_max_new_tokens - new_inputs_embeds.size(-2)
|
| 983 |
+
generated_ids2 = self.llm.generate(inputs=None, inputs_embeds=new_inputs_embeds, attention_mask=attention_mask, **kwargs)
|
| 984 |
+
kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
|
| 985 |
+
return torch.cat([input_ids_appendent, generated_ids2], dim=-1)
|
| 986 |
+
|
| 987 |
+
else:
|
| 988 |
+
kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
|
| 989 |
+
return generated_ids
|
| 990 |
+
|
| 991 |
+
else:
|
| 992 |
+
generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
|
| 993 |
+
kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
|
| 994 |
+
return generated_ids
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
|
| 998 |
+
AutoModel.register(Siglip2NavitConfig, Siglip2NavitModel)
|
| 999 |
+
AutoConfig.register("ovis2_5", Ovis2_5_Config)
|
| 1000 |
+
AutoModelForCausalLM.register(Ovis2_5_Config, Ovis2_5)
|
ovis_image/model/tokenizer.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OvisTokenizer:
|
| 13 |
+
"""
|
| 14 |
+
Tokenizing and encoding/decoding text using the Ovis tokenizer.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
model_path (str): Path to the tokenzier from hugging face.
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
model_path: str = "Ovis2.5-2B",
|
| 24 |
+
max_length: int = 256,
|
| 25 |
+
**hf_kwargs
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 29 |
+
self.system_prompt = "Describe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: "
|
| 30 |
+
self.user_prompt_begin_id = 28
|
| 31 |
+
self._max_length = max_length + self.user_prompt_begin_id
|
| 32 |
+
|
| 33 |
+
def encode(
|
| 34 |
+
self,
|
| 35 |
+
s: str,
|
| 36 |
+
system_prompt = ""
|
| 37 |
+
) -> torch.Tensor:
|
| 38 |
+
"""
|
| 39 |
+
Encode the prompt text into tokens.
|
| 40 |
+
"""
|
| 41 |
+
if len(system_prompt) == 0:
|
| 42 |
+
system_prompt = self.system_prompt
|
| 43 |
+
messages = [{
|
| 44 |
+
"role": "user",
|
| 45 |
+
"content": system_prompt + s,
|
| 46 |
+
}]
|
| 47 |
+
text = self._tokenizer.apply_chat_template(
|
| 48 |
+
messages,
|
| 49 |
+
tokenize=False,
|
| 50 |
+
add_generation_prompt=True,
|
| 51 |
+
enable_thinking=False
|
| 52 |
+
)
|
| 53 |
+
tokens = self._tokenizer(
|
| 54 |
+
text,
|
| 55 |
+
padding="max_length",
|
| 56 |
+
truncation=True,
|
| 57 |
+
max_length=self._max_length,
|
| 58 |
+
return_tensors="pt",
|
| 59 |
+
add_special_tokens=False,
|
| 60 |
+
)
|
| 61 |
+
return tokens.input_ids, tokens.attention_mask
|
| 62 |
+
|
| 63 |
+
def decode(self, t: List[int]) -> str:
|
| 64 |
+
return self._tokenizer.decode(t, skip_special_tokens=False)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_ovis_tokenizer(tokenizer_path):
|
| 68 |
+
max_ovis_encoding_len = 256
|
| 69 |
+
ovis_tokenizer = OvisTokenizer(
|
| 70 |
+
tokenizer_path,
|
| 71 |
+
max_length=max_ovis_encoding_len,
|
| 72 |
+
)
|
| 73 |
+
return ovis_tokenizer
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
ovis_path = "/mnt/workspace/cv_multimodal/aigc/huggingface/Ovis2.5-2B"
|
| 78 |
+
text = "a cute cat"
|
| 79 |
+
ovis_tokenizer = OvisTokenizer(ovis_path, max_length=256)
|
| 80 |
+
ovis_token = ovis_tokenizer.encode(text)
|
| 81 |
+
import pdb
|
| 82 |
+
pdb.set_trace()
|
ovis_image/sampling.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
from PIL import ExifTags, Image
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
|
| 13 |
+
from ovis_image.dataset.image_util import build_img_ids
|
| 14 |
+
from ovis_image.model.autoencoder import AutoEncoder
|
| 15 |
+
from ovis_image.model.hf_embedder import OvisEmbedder
|
| 16 |
+
from ovis_image.model.model import OvisImageModel
|
| 17 |
+
from ovis_image.utils import (
|
| 18 |
+
generate_noise_latent,
|
| 19 |
+
pack_latents,
|
| 20 |
+
unpack_latents,
|
| 21 |
+
generate_txt_ids,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
| 27 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_lin_function(
|
| 31 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 32 |
+
) -> Callable[[float], float]:
|
| 33 |
+
m = (y2 - y1) / (x2 - x1)
|
| 34 |
+
b = y1 - m * x1
|
| 35 |
+
return lambda x: m * x + b
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def sample_timesteps(batch_size, image_seq_len=None, base_shift=None, max_shift=None):
|
| 39 |
+
if image_seq_len is None or base_shift is None or max_shift is None:
|
| 40 |
+
logit_mean = 0
|
| 41 |
+
else:
|
| 42 |
+
logit_mean = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 43 |
+
logit_std = 1.0
|
| 44 |
+
timesteps = torch.normal(
|
| 45 |
+
mean=logit_mean, std=logit_std, size=(batch_size,)
|
| 46 |
+
)
|
| 47 |
+
timesteps = torch.nn.functional.sigmoid(timesteps)
|
| 48 |
+
return timesteps
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_schedule(
|
| 52 |
+
num_steps: int,
|
| 53 |
+
image_seq_len: int,
|
| 54 |
+
base_shift: float = 0.5,
|
| 55 |
+
max_shift: float = 1.15,
|
| 56 |
+
shift: bool = True,
|
| 57 |
+
) -> list[float]:
|
| 58 |
+
# extra step for zero
|
| 59 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 60 |
+
|
| 61 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
| 62 |
+
if shift:
|
| 63 |
+
# estimate mu based on linear estimation between two points
|
| 64 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 65 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
| 66 |
+
|
| 67 |
+
return timesteps.tolist()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def generate_image(
|
| 71 |
+
device: torch.device,
|
| 72 |
+
dtype: torch.dtype,
|
| 73 |
+
model: OvisImageModel,
|
| 74 |
+
prompt: str,
|
| 75 |
+
autoencoder: AutoEncoder,
|
| 76 |
+
ovis_tokenizer,
|
| 77 |
+
ovis_encoder: OvisEmbedder,
|
| 78 |
+
img_height: int = 256,
|
| 79 |
+
img_width: int = 256,
|
| 80 |
+
denoising_steps: int = 50,
|
| 81 |
+
cfg_scale: float = 5.0,
|
| 82 |
+
seed: int = 42,
|
| 83 |
+
) -> torch.Tensor:
|
| 84 |
+
"""
|
| 85 |
+
Sampling and save a single images from noise using a given prompt.
|
| 86 |
+
For randomized noise generation, the random seend should already be set at the begining of training.
|
| 87 |
+
Since we will always use the local random seed on this rank, we don't need to pass in the seed again.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
# allow for packing and conversion to latent space. Use the same resolution as training time.
|
| 91 |
+
img_height = 16 * (img_height // 16)
|
| 92 |
+
img_width = 16 * (img_width // 16)
|
| 93 |
+
|
| 94 |
+
enable_classifier_free_guidance = True
|
| 95 |
+
|
| 96 |
+
# Tokenize the prompt. Unsqueeze to add a batch dimension.
|
| 97 |
+
ovis_token_ids, ovis_token_mask = ovis_tokenizer.encode(prompt)
|
| 98 |
+
ovis_encodings = ovis_encoder(
|
| 99 |
+
ovis_token_ids.to(device=device), ovis_token_mask.to(device=device)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if enable_classifier_free_guidance:
|
| 103 |
+
empty_ovis_token_ids, empty_ovis_token_mask = ovis_tokenizer.encode("")
|
| 104 |
+
empty_ovis_encodings = ovis_encoder(
|
| 105 |
+
empty_ovis_token_ids.to(device=device), empty_ovis_token_mask.to(device=device)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
latents = generate_noise_latent(
|
| 109 |
+
ovis_token_ids.shape[0],
|
| 110 |
+
img_height, img_width, device, dtype, seed=seed,
|
| 111 |
+
latent_channel=autoencoder.params.z_channels)
|
| 112 |
+
|
| 113 |
+
img = denoise(
|
| 114 |
+
device=device,
|
| 115 |
+
dtype=dtype,
|
| 116 |
+
model=model,
|
| 117 |
+
latents=latents,
|
| 118 |
+
denoising_steps=denoising_steps,
|
| 119 |
+
ovis_encodings=ovis_encodings,
|
| 120 |
+
enable_classifier_free_guidance=enable_classifier_free_guidance,
|
| 121 |
+
empty_ovis_encodings=(
|
| 122 |
+
empty_ovis_encodings if enable_classifier_free_guidance else None
|
| 123 |
+
),
|
| 124 |
+
classifier_free_guidance_scale=cfg_scale,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
img = autoencoder.decode(img)
|
| 128 |
+
return img
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def denoise(
|
| 132 |
+
device: torch.device,
|
| 133 |
+
dtype: torch.dtype,
|
| 134 |
+
model: OvisImageModel,
|
| 135 |
+
latents: torch.Tensor,
|
| 136 |
+
denoising_steps: int,
|
| 137 |
+
ovis_encodings: torch.Tensor,
|
| 138 |
+
enable_classifier_free_guidance: bool = False,
|
| 139 |
+
empty_ovis_encodings: torch.Tensor | None = None,
|
| 140 |
+
classifier_free_guidance_scale: float | None = None,
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
"""
|
| 143 |
+
Sampling images from noise using a given prompt, by running inference with trained model.
|
| 144 |
+
Save the generated images to the given output path.
|
| 145 |
+
"""
|
| 146 |
+
bsz = ovis_encodings.shape[0]
|
| 147 |
+
_, latent_channels, latent_height, latent_width = latents.shape
|
| 148 |
+
|
| 149 |
+
# create denoising schedule
|
| 150 |
+
timesteps = get_schedule(denoising_steps, latent_height * latent_width, shift=True)
|
| 151 |
+
|
| 152 |
+
# create positional encodings
|
| 153 |
+
|
| 154 |
+
latent_pos_enc = build_img_ids(
|
| 155 |
+
latent_height // 2, latent_width // 2,
|
| 156 |
+
).to(latents)
|
| 157 |
+
latent_pos_enc = repeat(latent_pos_enc, 'l c -> bsz l c', bsz=bsz)
|
| 158 |
+
ovis_txt_ids = generate_txt_ids(ovis_encodings, time_id=0).to(latents)
|
| 159 |
+
|
| 160 |
+
if enable_classifier_free_guidance:
|
| 161 |
+
ovis_encodings = torch.cat([empty_ovis_encodings, ovis_encodings], dim=0)
|
| 162 |
+
latent_pos_enc = torch.cat([latent_pos_enc, latent_pos_enc], dim=0)
|
| 163 |
+
ovis_txt_ids = torch.cat([ovis_txt_ids, ovis_txt_ids], dim=0)
|
| 164 |
+
|
| 165 |
+
# convert img-like latents into sequences of patches
|
| 166 |
+
latents = pack_latents(latents)
|
| 167 |
+
|
| 168 |
+
# this is ignored for schnell
|
| 169 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
| 170 |
+
if enable_classifier_free_guidance:
|
| 171 |
+
img = torch.cat([latents, latents], dim=0)
|
| 172 |
+
t_vec = torch.full((bsz * 2,), t_curr, dtype=dtype, device=device)
|
| 173 |
+
else:
|
| 174 |
+
img = latents
|
| 175 |
+
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
|
| 176 |
+
model_pred = model(
|
| 177 |
+
img=img,
|
| 178 |
+
img_ids=latent_pos_enc,
|
| 179 |
+
txt=ovis_encodings,
|
| 180 |
+
txt_ids=ovis_txt_ids,
|
| 181 |
+
timesteps=t_vec,
|
| 182 |
+
)
|
| 183 |
+
if enable_classifier_free_guidance:
|
| 184 |
+
pred_u, pred_c = model_pred.chunk(2)
|
| 185 |
+
pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u)
|
| 186 |
+
else:
|
| 187 |
+
pred = model_pred
|
| 188 |
+
|
| 189 |
+
latents = latents + (t_prev - t_curr) * pred
|
| 190 |
+
|
| 191 |
+
# convert sequences of patches into img-like latents
|
| 192 |
+
latents = unpack_latents(latents, latent_height, latent_width)
|
| 193 |
+
|
| 194 |
+
return latents
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def save_image(
|
| 199 |
+
name: str,
|
| 200 |
+
output_dir: str,
|
| 201 |
+
x: torch.Tensor,
|
| 202 |
+
add_sampling_metadata: bool,
|
| 203 |
+
prompt: str,
|
| 204 |
+
verbose = True,
|
| 205 |
+
):
|
| 206 |
+
if verbose:
|
| 207 |
+
print(f"Saving image to {output_dir}/{name}")
|
| 208 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 209 |
+
output_name = os.path.join(output_dir, name)
|
| 210 |
+
|
| 211 |
+
# bring into PIL format and save
|
| 212 |
+
x = x.clamp(-1, 1)
|
| 213 |
+
x = rearrange(x[0], "c h w -> h w c")
|
| 214 |
+
|
| 215 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
| 216 |
+
|
| 217 |
+
exif_data = Image.Exif()
|
| 218 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img"
|
| 219 |
+
exif_data[ExifTags.Base.Make] = "Ovis"
|
| 220 |
+
exif_data[ExifTags.Base.Model] = name
|
| 221 |
+
if add_sampling_metadata:
|
| 222 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
| 223 |
+
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
| 224 |
+
|
ovis_image/test.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
|
| 11 |
+
from ovis_image.model.tokenizer import build_ovis_tokenizer
|
| 12 |
+
from ovis_image.model.autoencoder import load_ae
|
| 13 |
+
from ovis_image.model.hf_embedder import OvisEmbedder
|
| 14 |
+
from ovis_image.model.model import OvisImageModel
|
| 15 |
+
from ovis_image.sampling import generate_image, save_image
|
| 16 |
+
from ovis_image import ovis_image_configs
|
| 17 |
+
|
| 18 |
+
def parse_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument('--model_path', type=str, required=True)
|
| 21 |
+
parser.add_argument('--ovis_path', type=str, default="")
|
| 22 |
+
parser.add_argument('--vae_path', type=str, default="")
|
| 23 |
+
parser.add_argument('--prompt', type=str, default="")
|
| 24 |
+
parser.add_argument('--image_size', type=int, default=1024)
|
| 25 |
+
parser.add_argument('--denoising_steps', type=int, default=50)
|
| 26 |
+
parser.add_argument('--cfg_scale', type=float, default=5.0)
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
return args
|
| 29 |
+
|
| 30 |
+
def load_model_weight(model, model_path):
|
| 31 |
+
model_state_dict = load_file(model_path)
|
| 32 |
+
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict)
|
| 33 |
+
print(f"Load Missing Keys {missing_keys}")
|
| 34 |
+
print(f"Load Unexpected Keys {unexpected_keys}")
|
| 35 |
+
return model
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
args = parse_args()
|
| 40 |
+
model_config = ovis_image_configs["ovis-image-7b"]
|
| 41 |
+
device = "cuda"
|
| 42 |
+
_dtype = torch.bfloat16
|
| 43 |
+
print(f"dtype: {_dtype}")
|
| 44 |
+
ovis_image = OvisImageModel(model_config)
|
| 45 |
+
ovis_image = load_model_weight(ovis_image, args.model_path)
|
| 46 |
+
ovis_image = ovis_image.to(device=device, dtype=_dtype)
|
| 47 |
+
ovis_image.eval()
|
| 48 |
+
|
| 49 |
+
ovis_tokenizer = build_ovis_tokenizer(args.ovis_path)
|
| 50 |
+
autoencoder = load_ae(
|
| 51 |
+
args.vae_path,
|
| 52 |
+
model_config.autoencoder_params,
|
| 53 |
+
device=device,
|
| 54 |
+
dtype=_dtype,
|
| 55 |
+
random_init=False,
|
| 56 |
+
)
|
| 57 |
+
autoencoder.eval()
|
| 58 |
+
ovis_encoder = OvisEmbedder(
|
| 59 |
+
model_path=args.ovis_path,
|
| 60 |
+
random_init=False,
|
| 61 |
+
low_cpu_mem_usage=True,
|
| 62 |
+
torch_dtype=torch.bfloat16,
|
| 63 |
+
).to(device=device, dtype=_dtype)
|
| 64 |
+
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
image = generate_image(
|
| 67 |
+
device=device,
|
| 68 |
+
dtype=_dtype,
|
| 69 |
+
model=ovis_image,
|
| 70 |
+
prompt=args.prompt,
|
| 71 |
+
autoencoder=autoencoder,
|
| 72 |
+
ovis_tokenizer=ovis_tokenizer,
|
| 73 |
+
ovis_encoder=ovis_encoder,
|
| 74 |
+
img_height=args.image_size,
|
| 75 |
+
img_width=args.image_size,
|
| 76 |
+
denoising_steps=args.denoising_steps,
|
| 77 |
+
cfg_scale=args.cfg_scale,
|
| 78 |
+
seed=42,
|
| 79 |
+
)
|
| 80 |
+
image_name = f"ovis_image.png"
|
| 81 |
+
save_image(
|
| 82 |
+
name=image_name,
|
| 83 |
+
output_dir="outputs",
|
| 84 |
+
x=image,
|
| 85 |
+
add_sampling_metadata=True,
|
| 86 |
+
prompt=args.prompt,
|
| 87 |
+
verbose=False,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|
ovis_image/utils.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 AIDC-AI
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
def generate_txt_ids(encodings, time_id=0):
|
| 10 |
+
txt_ids = torch.zeros(encodings.shape[0], encodings.shape[1], 3)
|
| 11 |
+
txt_ids[..., 1] = txt_ids[..., 1] + torch.arange(encodings.shape[1])[None, :]
|
| 12 |
+
txt_ids[..., 2] = txt_ids[..., 2] + torch.arange(encodings.shape[1])[None, :]
|
| 13 |
+
txt_ids[..., 0] = time_id
|
| 14 |
+
return txt_ids
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def generate_noise_latent(
|
| 18 |
+
bsz: int,
|
| 19 |
+
height: int,
|
| 20 |
+
width: int,
|
| 21 |
+
device: str | torch.device,
|
| 22 |
+
dtype: torch.dtype,
|
| 23 |
+
seed: int | None = None,
|
| 24 |
+
latent_channel = None,
|
| 25 |
+
) -> Tensor:
|
| 26 |
+
"""Generate noise latents for the flow model. The random seed will be set at the begining of training.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
bsz (int): batch_size.
|
| 30 |
+
height (int): The height of the image.
|
| 31 |
+
width (int): The width of the image.
|
| 32 |
+
device (str | torch.device): The device to use.
|
| 33 |
+
dtype (torch.dtype): The dtype to use.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Tensor: The noise latents.
|
| 37 |
+
Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO]
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8
|
| 41 |
+
if latent_channel is not None:
|
| 42 |
+
LATENT_CHANNELS = latent_channel
|
| 43 |
+
return torch.randn(
|
| 44 |
+
bsz,
|
| 45 |
+
LATENT_CHANNELS,
|
| 46 |
+
height // IMAGE_LATENT_SIZE_RATIO,
|
| 47 |
+
width // IMAGE_LATENT_SIZE_RATIO,
|
| 48 |
+
dtype=dtype,
|
| 49 |
+
generator=torch.Generator().manual_seed(seed),
|
| 50 |
+
).to(device)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def pack_latents(x: Tensor) -> Tensor:
|
| 54 |
+
"""
|
| 55 |
+
Rearrange latents from an image-like format into a sequence of patches.
|
| 56 |
+
Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
x (Tensor): The unpacked latents.
|
| 60 |
+
Shape: [bsz, ch, latent height, latent width]
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Tensor: The packed latents.
|
| 64 |
+
Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
|
| 65 |
+
"""
|
| 66 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 67 |
+
|
| 68 |
+
b, c, latent_height, latent_width = x.shape
|
| 69 |
+
h = latent_height // PATCH_HEIGHT
|
| 70 |
+
w = latent_width // PATCH_WIDTH
|
| 71 |
+
|
| 72 |
+
# [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw]
|
| 73 |
+
x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH)
|
| 74 |
+
|
| 75 |
+
# [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW]
|
| 76 |
+
x = x.permute(0, 2, 3, 1, 4, 5)
|
| 77 |
+
|
| 78 |
+
# [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW]
|
| 79 |
+
return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Rearrange latents from a sequence of patches into an image-like format.
|
| 85 |
+
Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
x (Tensor): The packed latents.
|
| 89 |
+
Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw)
|
| 90 |
+
latent_height (int): The height of the unpacked latents.
|
| 91 |
+
latent_width (int): The width of the unpacked latents.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Tensor: The unpacked latents.
|
| 95 |
+
Shape: [bsz, ch, latent height, latent width]
|
| 96 |
+
"""
|
| 97 |
+
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
|
| 98 |
+
|
| 99 |
+
b, _, c_ph_pw = x.shape
|
| 100 |
+
h = latent_height // PATCH_HEIGHT
|
| 101 |
+
w = latent_width // PATCH_WIDTH
|
| 102 |
+
c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH)
|
| 103 |
+
|
| 104 |
+
# [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw]
|
| 105 |
+
x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH)
|
| 106 |
+
|
| 107 |
+
# [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw]
|
| 108 |
+
x = x.permute(0, 3, 1, 4, 2, 5)
|
| 109 |
+
|
| 110 |
+
# [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw]
|
| 111 |
+
return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH)
|
| 112 |
+
|