Spaces:
Build error
Build error
Upload 59 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- blip3o/__init__.py +0 -0
- blip3o/__pycache__/__init__.cpython-310.pyc +0 -0
- blip3o/__pycache__/__init__.cpython-311.pyc +0 -0
- blip3o/__pycache__/constants.cpython-310.pyc +0 -0
- blip3o/__pycache__/constants.cpython-311.pyc +0 -0
- blip3o/__pycache__/utils.cpython-310.pyc +0 -0
- blip3o/__pycache__/utils.cpython-311.pyc +0 -0
- blip3o/constants.py +7 -0
- blip3o/conversation.py +296 -0
- blip3o/data/__init__.py +1 -0
- blip3o/data/__pycache__/__init__.cpython-310.pyc +0 -0
- blip3o/data/__pycache__/__init__.cpython-311.pyc +0 -0
- blip3o/data/__pycache__/dataset.cpython-310.pyc +0 -0
- blip3o/data/__pycache__/dataset.cpython-311.pyc +0 -0
- blip3o/data/dataset.py +371 -0
- blip3o/mm_utils.py +65 -0
- blip3o/model/__init__.py +3 -0
- blip3o/model/__pycache__/__init__.cpython-310.pyc +0 -0
- blip3o/model/__pycache__/__init__.cpython-311.pyc +0 -0
- blip3o/model/__pycache__/blip3o_arch.cpython-310.pyc +0 -0
- blip3o/model/__pycache__/blip3o_arch.cpython-311.pyc +0 -0
- blip3o/model/__pycache__/llava_arch.cpython-310.pyc +0 -0
- blip3o/model/__pycache__/llava_arch.cpython-311.pyc +0 -0
- blip3o/model/blip3o_arch.py +400 -0
- blip3o/model/builder.py +44 -0
- blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-310.pyc +0 -0
- blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-311.pyc +0 -0
- blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-310.pyc +0 -0
- blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-311.pyc +0 -0
- blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-310.pyc +0 -0
- blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-311.pyc +0 -0
- blip3o/model/language_model/__pycache__/llava_qwen.cpython-310.pyc +0 -0
- blip3o/model/language_model/__pycache__/llava_qwen.cpython-311.pyc +0 -0
- blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-310.pyc +0 -0
- blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-311.pyc +0 -0
- blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-310.pyc +0 -0
- blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-311.pyc +0 -0
- blip3o/model/language_model/blip3o_qwen.py +215 -0
- blip3o/model/language_model/blip3o_qwen_grpo.py +255 -0
- blip3o/model/language_model/blip3o_qwen_inference.py +241 -0
- blip3o/model/multimodal_decoder/__pycache__/builder.cpython-310.pyc +0 -0
- blip3o/model/multimodal_decoder/__pycache__/builder.cpython-311.pyc +0 -0
- blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-310.pyc +0 -0
- blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-311.pyc +0 -0
- blip3o/model/multimodal_decoder/builder.py +14 -0
- blip3o/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
- blip3o/model/multimodal_encoder/__pycache__/builder.cpython-311.pyc +0 -0
- blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-310.pyc +0 -0
- blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
blip3o/train/__pycache__/grpo_trainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
blip3o/__init__.py
ADDED
|
File without changes
|
blip3o/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
blip3o/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
blip3o/__pycache__/constants.cpython-310.pyc
ADDED
|
Binary file (357 Bytes). View file
|
|
|
blip3o/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (389 Bytes). View file
|
|
|
blip3o/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (610 Bytes). View file
|
|
|
blip3o/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
blip3o/constants.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Constants
|
| 2 |
+
IGNORE_INDEX = -100
|
| 3 |
+
IMAGE_TOKEN_INDEX = -200
|
| 4 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 5 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 6 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 7 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
blip3o/conversation.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import dataclasses
|
| 3 |
+
import re
|
| 4 |
+
from enum import Enum, auto
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SeparatorStyle(Enum):
|
| 12 |
+
"""Different separator style."""
|
| 13 |
+
|
| 14 |
+
SINGLE = auto()
|
| 15 |
+
TWO = auto()
|
| 16 |
+
PLAIN = auto()
|
| 17 |
+
CHATML = auto()
|
| 18 |
+
QWEN = auto()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclasses.dataclass
|
| 22 |
+
class Conversation:
|
| 23 |
+
"""A class that keeps all conversation history."""
|
| 24 |
+
|
| 25 |
+
system: str
|
| 26 |
+
roles: List[str]
|
| 27 |
+
messages: List[List[str]]
|
| 28 |
+
offset: int
|
| 29 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 30 |
+
sep: str = "###"
|
| 31 |
+
sep2: str = None
|
| 32 |
+
version: str = "Unknown"
|
| 33 |
+
|
| 34 |
+
tokenizer_id: str = ""
|
| 35 |
+
tokenizer: Any = None
|
| 36 |
+
# Stop criteria (the default one is EOS token)
|
| 37 |
+
stop_str: Union[str, List[str]] = None
|
| 38 |
+
# Stops generation if meeting any token in this list
|
| 39 |
+
stop_token_ids: List[int] = None
|
| 40 |
+
|
| 41 |
+
skip_next: bool = False
|
| 42 |
+
|
| 43 |
+
def get_prompt(self):
|
| 44 |
+
messages = self.messages
|
| 45 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 46 |
+
messages = self.messages.copy()
|
| 47 |
+
init_role, init_msg = messages[0].copy()
|
| 48 |
+
init_msg = init_msg[0]
|
| 49 |
+
if "mmtag" in self.version:
|
| 50 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
| 51 |
+
messages[0] = (init_role, init_msg)
|
| 52 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
| 53 |
+
messages.insert(1, (self.roles[1], "Received."))
|
| 54 |
+
elif not init_msg.startswith("<image>"):
|
| 55 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
| 56 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
| 57 |
+
else:
|
| 58 |
+
messages[0] = (init_role, init_msg)
|
| 59 |
+
|
| 60 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
| 61 |
+
ret = self.system + self.sep
|
| 62 |
+
for role, message in messages:
|
| 63 |
+
if message:
|
| 64 |
+
if type(message) is tuple:
|
| 65 |
+
message, _, _ = message
|
| 66 |
+
ret += role + ": " + message + self.sep
|
| 67 |
+
else:
|
| 68 |
+
ret += role + ":"
|
| 69 |
+
|
| 70 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
| 71 |
+
seps = [self.sep, self.sep2]
|
| 72 |
+
ret = self.system + seps[0]
|
| 73 |
+
for i, (role, message) in enumerate(messages):
|
| 74 |
+
if message:
|
| 75 |
+
if type(message) is tuple:
|
| 76 |
+
message, _, _ = message
|
| 77 |
+
ret += role + ": " + message + seps[i % 2]
|
| 78 |
+
else:
|
| 79 |
+
ret += role + ":"
|
| 80 |
+
|
| 81 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
| 82 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
| 83 |
+
for role, message in messages:
|
| 84 |
+
if message:
|
| 85 |
+
if type(message) is tuple:
|
| 86 |
+
message, images, _ = message
|
| 87 |
+
message = "<image>" * len(images) + message
|
| 88 |
+
ret += role + "\n" + message + self.sep + "\n"
|
| 89 |
+
else:
|
| 90 |
+
ret += role + "\n"
|
| 91 |
+
return ret
|
| 92 |
+
|
| 93 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 94 |
+
seps = [self.sep, self.sep2]
|
| 95 |
+
ret = self.system
|
| 96 |
+
for i, (role, message) in enumerate(messages):
|
| 97 |
+
if message:
|
| 98 |
+
if type(message) is tuple:
|
| 99 |
+
message, _, _ = message
|
| 100 |
+
ret += message + seps[i % 2]
|
| 101 |
+
else:
|
| 102 |
+
ret += ""
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 105 |
+
|
| 106 |
+
return ret
|
| 107 |
+
|
| 108 |
+
def append_message(self, role, message):
|
| 109 |
+
self.messages.append([role, message])
|
| 110 |
+
|
| 111 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
|
| 112 |
+
if image_process_mode == "Pad":
|
| 113 |
+
|
| 114 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
| 115 |
+
width, height = pil_img.size
|
| 116 |
+
if width == height:
|
| 117 |
+
return pil_img
|
| 118 |
+
elif width > height:
|
| 119 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 120 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 121 |
+
return result
|
| 122 |
+
else:
|
| 123 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 124 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 125 |
+
return result
|
| 126 |
+
|
| 127 |
+
image = expand2square(image)
|
| 128 |
+
elif image_process_mode in ["Default", "Crop"]:
|
| 129 |
+
pass
|
| 130 |
+
elif image_process_mode == "Resize":
|
| 131 |
+
image = image.resize((336, 336))
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
| 134 |
+
|
| 135 |
+
if type(image) is not Image.Image:
|
| 136 |
+
image = Image.open(image).convert("RGB")
|
| 137 |
+
|
| 138 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 139 |
+
aspect_ratio = max_hw / min_hw
|
| 140 |
+
max_len, min_len = 672, 448
|
| 141 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 142 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 143 |
+
W, H = image.size
|
| 144 |
+
if H > W:
|
| 145 |
+
H, W = longest_edge, shortest_edge
|
| 146 |
+
else:
|
| 147 |
+
H, W = shortest_edge, longest_edge
|
| 148 |
+
image = image.resize((W, H))
|
| 149 |
+
if return_pil:
|
| 150 |
+
return image
|
| 151 |
+
else:
|
| 152 |
+
buffered = BytesIO()
|
| 153 |
+
image.save(buffered, format=image_format)
|
| 154 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 155 |
+
return img_b64_str
|
| 156 |
+
|
| 157 |
+
def get_images(self, return_pil=False, return_path=False):
|
| 158 |
+
images = []
|
| 159 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 160 |
+
if i % 2 == 0:
|
| 161 |
+
if type(msg) is tuple:
|
| 162 |
+
msg, image, image_process_mode = msg
|
| 163 |
+
if type(image) != list:
|
| 164 |
+
image = [image]
|
| 165 |
+
for img in image:
|
| 166 |
+
if not return_path and self.is_image_file(img):
|
| 167 |
+
img = self.process_image(img, image_process_mode, return_pil=return_pil)
|
| 168 |
+
else:
|
| 169 |
+
images.append(img)
|
| 170 |
+
return images
|
| 171 |
+
|
| 172 |
+
def is_image_file(self, filename):
|
| 173 |
+
image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
|
| 174 |
+
return any(filename.lower().endswith(ext) for ext in image_extensions)
|
| 175 |
+
|
| 176 |
+
def is_video_file(self, filename):
|
| 177 |
+
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
|
| 178 |
+
return any(filename.lower().endswith(ext) for ext in video_extensions)
|
| 179 |
+
|
| 180 |
+
def to_gradio_chatbot(self):
|
| 181 |
+
ret = []
|
| 182 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 183 |
+
if i % 2 == 0:
|
| 184 |
+
if type(msg) is tuple:
|
| 185 |
+
msg, image, image_process_mode = msg
|
| 186 |
+
if type(image) != list:
|
| 187 |
+
image = [image]
|
| 188 |
+
if len(image) == 1:
|
| 189 |
+
msg = "<image>\n" + msg.replace("<image>", "").strip()
|
| 190 |
+
else:
|
| 191 |
+
msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
|
| 192 |
+
|
| 193 |
+
img_str_list = []
|
| 194 |
+
for img in image:
|
| 195 |
+
if self.is_image_file(img):
|
| 196 |
+
img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
|
| 197 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
|
| 198 |
+
img_str_list.append(img_str)
|
| 199 |
+
elif self.is_video_file(img):
|
| 200 |
+
ret.append(((img,), None))
|
| 201 |
+
|
| 202 |
+
msg = msg.strip()
|
| 203 |
+
img_place_holder = ""
|
| 204 |
+
for img_str in img_str_list:
|
| 205 |
+
img_place_holder += f"{img_str}\n\n"
|
| 206 |
+
|
| 207 |
+
if len(img_str_list) > 0:
|
| 208 |
+
msg = f"{img_place_holder}\n\n{msg}"
|
| 209 |
+
|
| 210 |
+
if len(msg) > 0:
|
| 211 |
+
ret.append([msg, None])
|
| 212 |
+
else:
|
| 213 |
+
ret.append([msg, None])
|
| 214 |
+
else:
|
| 215 |
+
ret[-1][-1] = msg
|
| 216 |
+
return ret
|
| 217 |
+
|
| 218 |
+
def copy(self):
|
| 219 |
+
return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
|
| 220 |
+
|
| 221 |
+
def dict(self):
|
| 222 |
+
if len(self.get_images()) > 0:
|
| 223 |
+
return {
|
| 224 |
+
"system": self.system,
|
| 225 |
+
"roles": self.roles,
|
| 226 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
| 227 |
+
"offset": self.offset,
|
| 228 |
+
"sep": self.sep,
|
| 229 |
+
"sep2": self.sep2,
|
| 230 |
+
}
|
| 231 |
+
return {
|
| 232 |
+
"system": self.system,
|
| 233 |
+
"roles": self.roles,
|
| 234 |
+
"messages": self.messages,
|
| 235 |
+
"offset": self.offset,
|
| 236 |
+
"sep": self.sep,
|
| 237 |
+
"sep2": self.sep2,
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
conv_vicuna_v0 = Conversation(
|
| 242 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 243 |
+
roles=("Human", "Assistant"),
|
| 244 |
+
messages=[
|
| 245 |
+
["Human", "What are the key differences between renewable and non-renewable energy sources?"],
|
| 246 |
+
[
|
| 247 |
+
"Assistant",
|
| 248 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
| 249 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
| 250 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
| 251 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
| 252 |
+
"renewable and non-renewable energy sources:\n"
|
| 253 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
| 254 |
+
"energy sources are finite and will eventually run out.\n"
|
| 255 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
| 256 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
| 257 |
+
"and other negative effects.\n"
|
| 258 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
| 259 |
+
"have lower operational costs than non-renewable sources.\n"
|
| 260 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
| 261 |
+
"locations than non-renewable sources.\n"
|
| 262 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
| 263 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
| 264 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
| 265 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
| 266 |
+
],
|
| 267 |
+
],
|
| 268 |
+
offset=2,
|
| 269 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 270 |
+
sep="###",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
conv_qwen = Conversation(
|
| 275 |
+
system="""<|im_start|>system
|
| 276 |
+
You are a helpful assistant.""",
|
| 277 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
| 278 |
+
version="qwen",
|
| 279 |
+
messages=[],
|
| 280 |
+
offset=0,
|
| 281 |
+
sep_style=SeparatorStyle.CHATML,
|
| 282 |
+
sep="<|im_end|>",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
default_conversation = conv_vicuna_v0
|
| 287 |
+
conv_templates = {
|
| 288 |
+
"default": conv_vicuna_v0,
|
| 289 |
+
"v0": conv_vicuna_v0,
|
| 290 |
+
"qwen_1_5": conv_qwen,
|
| 291 |
+
"qwen_2": conv_qwen,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
print(default_conversation.get_prompt())
|
blip3o/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .dataset import *
|
blip3o/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
blip3o/data/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
blip3o/data/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
blip3o/data/__pycache__/dataset.cpython-311.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
blip3o/data/dataset.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import glob
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Dict, List, Optional, Sequence
|
| 11 |
+
import pyarrow.parquet as pq
|
| 12 |
+
import torch
|
| 13 |
+
import transformers
|
| 14 |
+
import yaml
|
| 15 |
+
from PIL import Image, ImageFile
|
| 16 |
+
from torch.utils.data import Dataset
|
| 17 |
+
from torchvision.transforms import v2
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
from datasets import load_dataset, concatenate_datasets
|
| 20 |
+
from blip3o.constants import (
|
| 21 |
+
DEFAULT_IM_END_TOKEN,
|
| 22 |
+
DEFAULT_IM_START_TOKEN,
|
| 23 |
+
DEFAULT_IMAGE_TOKEN,
|
| 24 |
+
IGNORE_INDEX,
|
| 25 |
+
IMAGE_TOKEN_INDEX,
|
| 26 |
+
)
|
| 27 |
+
from blip3o.utils import rank0_print
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
## target transform for sana
|
| 34 |
+
target_transform = v2.Compose(
|
| 35 |
+
[
|
| 36 |
+
v2.Resize(1024),
|
| 37 |
+
v2.CenterCrop(1024),
|
| 38 |
+
v2.ToImage(),
|
| 39 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 40 |
+
v2.Normalize([0.5], [0.5]),
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def expand2square(pil_img, background_color):
|
| 46 |
+
width, height = pil_img.size
|
| 47 |
+
if width == height:
|
| 48 |
+
return pil_img
|
| 49 |
+
elif width > height:
|
| 50 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 51 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 52 |
+
return result
|
| 53 |
+
else:
|
| 54 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 55 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def preprocess_multimodal(sources: Sequence[str], data_args) -> Dict:
|
| 60 |
+
is_multimodal = data_args.is_multimodal
|
| 61 |
+
if not is_multimodal:
|
| 62 |
+
return sources
|
| 63 |
+
|
| 64 |
+
for source in sources:
|
| 65 |
+
for sentence in source:
|
| 66 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
| 67 |
+
# NOTE: only add im_start_end when image generation
|
| 68 |
+
if data_args.mm_use_im_start_end and sentence['from'] == 'gpt':
|
| 69 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 70 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 71 |
+
|
| 72 |
+
# For videoInstruct-100k noisy_data. TODO: Ask Yuanhan to clean the data instead of leaving the noise code here.
|
| 73 |
+
sentence["value"] = sentence["value"].replace("QA_GT_caption_based_noisy", "")
|
| 74 |
+
|
| 75 |
+
return sources
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
|
| 79 |
+
# roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
|
| 80 |
+
roles = {"human": "user", "gpt": "assistant"}
|
| 81 |
+
|
| 82 |
+
#tokenizer = copy.deepcopy(tokenizer)
|
| 83 |
+
# When there is actually an image, we add the image tokens as a special token
|
| 84 |
+
if 'image_token_index' not in globals():
|
| 85 |
+
tokenizer.add_tokens(["<image>"], special_tokens=True)
|
| 86 |
+
global image_token_index
|
| 87 |
+
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
|
| 88 |
+
# if has_image:
|
| 89 |
+
# tokenizer.add_tokens(["<image>"], special_tokens=True)
|
| 90 |
+
|
| 91 |
+
# image_token_index = tokenizer.convert_tokens_to_ids("<image>")
|
| 92 |
+
im_start, im_end = tokenizer.additional_special_tokens_ids[:2]
|
| 93 |
+
# unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"]
|
| 94 |
+
unmask_tokens_idx = [198, im_start, im_end]
|
| 95 |
+
# nl_tokens = tokenizer("\n").input_ids
|
| 96 |
+
|
| 97 |
+
# Reset Qwen chat templates so that it won't include system message every time we apply
|
| 98 |
+
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
| 99 |
+
tokenizer.chat_template = chat_template
|
| 100 |
+
|
| 101 |
+
# _system = tokenizer("system").input_ids + nl_tokens
|
| 102 |
+
# _user = tokenizer("user").input_ids + nl_tokens
|
| 103 |
+
# _assistant = tokenizer("assistant").input_ids + nl_tokens
|
| 104 |
+
|
| 105 |
+
# Apply prompt templates
|
| 106 |
+
input_ids, targets = [], []
|
| 107 |
+
for i, source in enumerate(sources):
|
| 108 |
+
if roles[source[0]["from"]] != roles["human"]:
|
| 109 |
+
source = source[1:]
|
| 110 |
+
|
| 111 |
+
input_id, target = [], []
|
| 112 |
+
|
| 113 |
+
# New version, use apply chat template
|
| 114 |
+
# Build system message for each sentence
|
| 115 |
+
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# target += [IGNORE_INDEX] * len(input_id)
|
| 119 |
+
target += input_id
|
| 120 |
+
|
| 121 |
+
for conv in source:
|
| 122 |
+
# Make sure blip3o data can load
|
| 123 |
+
try:
|
| 124 |
+
role = conv["role"]
|
| 125 |
+
content = conv["content"]
|
| 126 |
+
except:
|
| 127 |
+
role = conv["from"]
|
| 128 |
+
content = conv["value"]
|
| 129 |
+
|
| 130 |
+
role = roles.get(role, role)
|
| 131 |
+
|
| 132 |
+
conv = [{"role" : role, "content" : content}]
|
| 133 |
+
encode_id = tokenizer.apply_chat_template(conv)
|
| 134 |
+
input_id += encode_id
|
| 135 |
+
if role in ["user", "system"]:
|
| 136 |
+
# target += [IGNORE_INDEX] * len(encode_id)
|
| 137 |
+
target += encode_id
|
| 138 |
+
|
| 139 |
+
else:
|
| 140 |
+
target += encode_id
|
| 141 |
+
|
| 142 |
+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
|
| 143 |
+
for idx, encode_id in enumerate(input_id):
|
| 144 |
+
if encode_id in unmask_tokens_idx:
|
| 145 |
+
target[idx] = encode_id
|
| 146 |
+
if encode_id == image_token_index:
|
| 147 |
+
input_id[idx] = IMAGE_TOKEN_INDEX
|
| 148 |
+
input_ids.append(input_id)
|
| 149 |
+
targets.append(target)
|
| 150 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| 151 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
| 152 |
+
|
| 153 |
+
return dict(
|
| 154 |
+
input_ids=input_ids,
|
| 155 |
+
labels=targets,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class LazySupervisedMixDataset(Dataset):
|
| 161 |
+
"""Dataset for supervised fine-tuning."""
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 166 |
+
data_path: str,
|
| 167 |
+
data_args
|
| 168 |
+
):
|
| 169 |
+
super(LazySupervisedMixDataset, self).__init__()
|
| 170 |
+
|
| 171 |
+
self.data_args = data_args
|
| 172 |
+
list_data_dict = []
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
data_files = glob.glob('/fsx/sfr/data/jiuhai/hub/datasets--BLIP3o--BLIP3o-60k/snapshots/f7316b0aa446338ee1707484924aa59457b4bbf3/*.tar')
|
| 176 |
+
data_files.sort()
|
| 177 |
+
train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=1, cache_dir='/fsx/sfr/data/jiuhai/webdataset')
|
| 178 |
+
train_dataset = train_dataset.rename_column("jpg", "image")
|
| 179 |
+
train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I'])
|
| 180 |
+
train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (
|
| 181 |
+
["image", "txt", "type"])])
|
| 182 |
+
print(f"finish loading image {len(train_dataset)}")
|
| 183 |
+
list_data_dict.append(train_dataset)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if len(list_data_dict) > 1:
|
| 188 |
+
list_data_dict = concatenate_datasets(list_data_dict)
|
| 189 |
+
else:
|
| 190 |
+
list_data_dict = list_data_dict[0]
|
| 191 |
+
list_data_dict = list_data_dict.shuffle(seed=42)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
rank0_print(f"Totoal number of training instance: {len(list_data_dict)}")
|
| 195 |
+
self.tokenizer = tokenizer
|
| 196 |
+
self.list_data_dict = list_data_dict
|
| 197 |
+
self.modality = torch.tensor(0) # 0 is for und task, 1 is for gen task
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def __len__(self):
|
| 201 |
+
return len(self.list_data_dict)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def process_image(self, image):
|
| 205 |
+
processor = self.data_args.image_processor
|
| 206 |
+
image_size = image.size
|
| 207 |
+
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
| 208 |
+
return image, image_size, self.modality
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def process_target_image(self, image):
|
| 212 |
+
image = target_transform(image)
|
| 213 |
+
return image
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def lengths(self):
|
| 218 |
+
length_list = []
|
| 219 |
+
for sample in self.list_data_dict:
|
| 220 |
+
img_tokens = 128 if "image" in sample else 0
|
| 221 |
+
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
|
| 222 |
+
return length_list
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def modality_lengths(self):
|
| 226 |
+
length_list = []
|
| 227 |
+
for sample in self.list_data_dict:
|
| 228 |
+
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
|
| 229 |
+
cur_len = cur_len if "image" in sample else -cur_len
|
| 230 |
+
length_list.append(cur_len)
|
| 231 |
+
return length_list
|
| 232 |
+
|
| 233 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 234 |
+
|
| 235 |
+
while True:
|
| 236 |
+
sources = self.list_data_dict[i]
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if sources["type"] == "T2I":
|
| 240 |
+
|
| 241 |
+
sources["conversations"] = [
|
| 242 |
+
{"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"},
|
| 243 |
+
{"from": "gpt", "value": "<image>"},
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
elif sources["type"] == "I2I":
|
| 248 |
+
sources["conversations"] = [
|
| 249 |
+
{
|
| 250 |
+
"from": "human",
|
| 251 |
+
"value": f"<image>\nPlease reconstruct the given image.",
|
| 252 |
+
},
|
| 253 |
+
{"from": "gpt", "value": ""},
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
|
| 258 |
+
|
| 259 |
+
if "image" in sources:
|
| 260 |
+
|
| 261 |
+
if sources["type"] == "T2I" or sources["type"] == "I2I":
|
| 262 |
+
image_files = self.list_data_dict[i]["image"]
|
| 263 |
+
|
| 264 |
+
if not isinstance(image_files, list):
|
| 265 |
+
image_files = [image_files]
|
| 266 |
+
|
| 267 |
+
images = []
|
| 268 |
+
|
| 269 |
+
for img in image_files:
|
| 270 |
+
try:
|
| 271 |
+
if sources["type"] == "T2I" or sources["type"] == "I2I":
|
| 272 |
+
img = img.convert("RGB")
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
|
| 275 |
+
images.append(img)
|
| 276 |
+
except Exception as e:
|
| 277 |
+
print(f"Error opening image {img}: {e}")
|
| 278 |
+
images = None
|
| 279 |
+
break # Skip to the next image if there's an error
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
## test if can apply img_process
|
| 283 |
+
if not images is None:
|
| 284 |
+
try:
|
| 285 |
+
process_images = [self.process_image(f) for f in images]
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"Error wrong number of channels: {e}")
|
| 288 |
+
images = None
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# If no valid images were found, randomly pick another item
|
| 292 |
+
if images is None:
|
| 293 |
+
print(sources)
|
| 294 |
+
print(f"warning false image!!!!!!")
|
| 295 |
+
i = random.randint(0, len(self.list_data_dict) - 1)
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
sources = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args)
|
| 299 |
+
else:
|
| 300 |
+
sources = copy.deepcopy([sources["conversations"]])
|
| 301 |
+
|
| 302 |
+
data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i]))
|
| 303 |
+
if isinstance(i, int):
|
| 304 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# image exist in the data
|
| 308 |
+
if "image" in self.list_data_dict[i]:
|
| 309 |
+
data_dict["image"] = process_images
|
| 310 |
+
data_dict["target_image"] = [self.process_target_image(f) for f in images]
|
| 311 |
+
|
| 312 |
+
data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk"
|
| 313 |
+
return data_dict
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@dataclass
|
| 318 |
+
class DataCollatorForSupervisedDataset(object):
|
| 319 |
+
"""Collate examples for supervised fine-tuning."""
|
| 320 |
+
|
| 321 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 322 |
+
|
| 323 |
+
def pad_sequence(self, input_ids, batch_first, padding_value):
|
| 324 |
+
if self.tokenizer.padding_side == "left":
|
| 325 |
+
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
|
| 326 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
|
| 327 |
+
if self.tokenizer.padding_side == "left":
|
| 328 |
+
input_ids = torch.flip(input_ids, [1])
|
| 329 |
+
return input_ids
|
| 330 |
+
|
| 331 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 332 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 333 |
+
input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
|
| 334 |
+
labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
|
| 335 |
+
if self.tokenizer.pad_token_id is None:
|
| 336 |
+
self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
|
| 337 |
+
input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 338 |
+
labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 339 |
+
batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
|
| 340 |
+
if "image" in instances[0]:
|
| 341 |
+
images = [instance["image"] for instance in instances]
|
| 342 |
+
|
| 343 |
+
batch["image_sizes"] = [im[1] for im_list in images for im in im_list]
|
| 344 |
+
batch["modalities"] = [im[2] for im_list in images for im in im_list]
|
| 345 |
+
images = [im[0] for im_list in images for im in im_list]
|
| 346 |
+
|
| 347 |
+
batch["images"] = images
|
| 348 |
+
|
| 349 |
+
target_images = [instance["target_image"][0] for instance in instances]
|
| 350 |
+
target_images = torch.stack(target_images, dim=0) if target_images else None
|
| 351 |
+
batch["target_images"] = target_images
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if "prompt" in instances[0]:
|
| 355 |
+
batch["prompts"] = [instance["prompt"] for instance in instances]
|
| 356 |
+
return batch
|
| 357 |
+
|
| 358 |
+
def get_dataset_cls(name):
|
| 359 |
+
|
| 360 |
+
if name == 'mix':
|
| 361 |
+
dataset_cls = LazySupervisedMixDataset
|
| 362 |
+
else:
|
| 363 |
+
raise ValueError(f'Unknown dataset class {name}')
|
| 364 |
+
return dataset_cls
|
| 365 |
+
|
| 366 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
| 367 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 368 |
+
dataset_cls = get_dataset_cls(data_args.dataset_cls)
|
| 369 |
+
train_dataset = dataset_cls(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
|
| 370 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 371 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
blip3o/mm_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import StoppingCriteria
|
| 3 |
+
|
| 4 |
+
from blip3o.constants import IMAGE_TOKEN_INDEX
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def process_images(images, image_processor, model_cfg):
|
| 8 |
+
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
| 12 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 13 |
+
|
| 14 |
+
def insert_separator(X, sep):
|
| 15 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
| 16 |
+
|
| 17 |
+
input_ids = []
|
| 18 |
+
offset = 0
|
| 19 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 20 |
+
offset = 1
|
| 21 |
+
input_ids.append(prompt_chunks[0][0])
|
| 22 |
+
|
| 23 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 24 |
+
input_ids.extend(x[offset:])
|
| 25 |
+
|
| 26 |
+
if return_tensors is not None:
|
| 27 |
+
if return_tensors == "pt":
|
| 28 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 29 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
| 30 |
+
return input_ids
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_model_name_from_path(model_path):
|
| 34 |
+
model_path = model_path.strip("/")
|
| 35 |
+
model_paths = model_path.split("/")
|
| 36 |
+
if model_paths[-1].startswith("checkpoint-"):
|
| 37 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 38 |
+
else:
|
| 39 |
+
return model_paths[-1]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 43 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 44 |
+
self.keywords = keywords
|
| 45 |
+
self.keyword_ids = []
|
| 46 |
+
for keyword in keywords:
|
| 47 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 48 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 49 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 50 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 51 |
+
self.tokenizer = tokenizer
|
| 52 |
+
self.start_len = input_ids.shape[1]
|
| 53 |
+
|
| 54 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 55 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
| 56 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
| 57 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 58 |
+
for keyword_id in self.keyword_ids:
|
| 59 |
+
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
|
| 60 |
+
return True
|
| 61 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 62 |
+
for keyword in self.keywords:
|
| 63 |
+
if keyword in outputs:
|
| 64 |
+
return True
|
| 65 |
+
return False
|
blip3o/model/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from blip3o.model.language_model.blip3o_qwen import blip3oQwenConfig, blip3oQwenForCausalLM
|
| 2 |
+
from blip3o.model.language_model.blip3o_qwen_inference import blip3oQwenForInferenceLM
|
| 3 |
+
from blip3o.model.language_model.blip3o_qwen_grpo import blip3oQwenForGRPOLM
|
blip3o/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (449 Bytes). View file
|
|
|
blip3o/model/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (512 Bytes). View file
|
|
|
blip3o/model/__pycache__/blip3o_arch.cpython-310.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
blip3o/model/__pycache__/blip3o_arch.cpython-311.pyc
ADDED
|
Binary file (26.1 kB). View file
|
|
|
blip3o/model/__pycache__/llava_arch.cpython-310.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
blip3o/model/__pycache__/llava_arch.cpython-311.pyc
ADDED
|
Binary file (26 kB). View file
|
|
|
blip3o/model/blip3o_arch.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from blip3o.constants import (
|
| 9 |
+
DEFAULT_IM_END_TOKEN,
|
| 10 |
+
DEFAULT_IM_START_TOKEN,
|
| 11 |
+
IGNORE_INDEX,
|
| 12 |
+
IMAGE_TOKEN_INDEX,
|
| 13 |
+
)
|
| 14 |
+
from blip3o.utils import rank0_print
|
| 15 |
+
from .multimodal_encoder.builder import build_vision_tower
|
| 16 |
+
from .multimodal_decoder.builder import build_sana, build_vae
|
| 17 |
+
from diffusers.models.normalization import RMSNorm
|
| 18 |
+
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaTransformer2DModel
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
class blip3oMetaModel:
|
| 22 |
+
|
| 23 |
+
def __init__(self, config):
|
| 24 |
+
super(blip3oMetaModel, self).__init__(config)
|
| 25 |
+
|
| 26 |
+
if hasattr(config, "mm_vision_tower"):
|
| 27 |
+
delay_load = getattr(config, "delay_load", False)
|
| 28 |
+
self.vision_tower = build_vision_tower(config, delay_load=delay_load)
|
| 29 |
+
|
| 30 |
+
self.sana = build_sana(config)
|
| 31 |
+
self.sana_vae = build_vae(config)
|
| 32 |
+
norm = RMSNorm(2304, eps=1e-5, elementwise_affine=True)
|
| 33 |
+
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
norm.weight.fill_(math.sqrt(5.5))
|
| 36 |
+
self.diffusion_connector = nn.Sequential(
|
| 37 |
+
nn.Linear(config.hidden_size, 2304),
|
| 38 |
+
nn.GELU(approximate="tanh"),
|
| 39 |
+
nn.Linear(2304, 2304),
|
| 40 |
+
norm,
|
| 41 |
+
)
|
| 42 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.diffusion_name_or_path, subfolder="scheduler")
|
| 43 |
+
|
| 44 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.diffusion_name_or_path, subfolder="scheduler")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_vision_tower(self):
|
| 48 |
+
vision_tower = getattr(self, "vision_tower", None)
|
| 49 |
+
if type(vision_tower) is list:
|
| 50 |
+
vision_tower = vision_tower[0]
|
| 51 |
+
return vision_tower
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_sana(self):
|
| 55 |
+
sana = getattr(self, 'sana', None)
|
| 56 |
+
if type(sana) is list:
|
| 57 |
+
sana = sana[0]
|
| 58 |
+
if sana is not None:
|
| 59 |
+
sana.to(self.device)
|
| 60 |
+
return sana
|
| 61 |
+
|
| 62 |
+
def get_sana_vae(self):
|
| 63 |
+
sana_vae = getattr(self, 'sana_vae', None)
|
| 64 |
+
if type(sana_vae) is list:
|
| 65 |
+
sana_vae = sana_vae[0]
|
| 66 |
+
if sana_vae is not None:
|
| 67 |
+
sana_vae.to(self.device)
|
| 68 |
+
return sana_vae
|
| 69 |
+
|
| 70 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
| 71 |
+
vision_tower = model_args.vision_tower
|
| 72 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
| 73 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
| 74 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
| 75 |
+
|
| 76 |
+
self.config.mm_vision_tower = vision_tower
|
| 77 |
+
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
| 78 |
+
|
| 79 |
+
if self.get_vision_tower() is None:
|
| 80 |
+
vision_tower = build_vision_tower(model_args)
|
| 81 |
+
|
| 82 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 83 |
+
self.vision_tower = [vision_tower]
|
| 84 |
+
else:
|
| 85 |
+
self.vision_tower = vision_tower
|
| 86 |
+
else:
|
| 87 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 88 |
+
vision_tower = self.vision_tower[0]
|
| 89 |
+
else:
|
| 90 |
+
vision_tower = self.vision_tower
|
| 91 |
+
vision_tower.load_model()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if self.get_sana() is None:
|
| 95 |
+
sana = build_sana(model_args)
|
| 96 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler"
|
| 97 |
+
)
|
| 98 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler")
|
| 99 |
+
|
| 100 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 101 |
+
self.sana = [sana]
|
| 102 |
+
else:
|
| 103 |
+
self.sana = sana
|
| 104 |
+
else:
|
| 105 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 106 |
+
sana = self.sana[0]
|
| 107 |
+
else:
|
| 108 |
+
sana = self.sana
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if self.get_sana_vae() is None:
|
| 112 |
+
sana_vae = build_vae(model_args)
|
| 113 |
+
|
| 114 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 115 |
+
self.sana_vae = [sana_vae]
|
| 116 |
+
else:
|
| 117 |
+
self.sana_vae = sana_vae
|
| 118 |
+
else:
|
| 119 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 120 |
+
sana_vae = self.sana_vae[0]
|
| 121 |
+
else:
|
| 122 |
+
sana_vae = self.sana_vae
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if getattr(self, 'diffusion_connector', None) is None:
|
| 126 |
+
norm = RMSNorm(2304, eps=1e-5, elementwise_affine=True)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
norm.weight.fill_(math.sqrt(5.5))
|
| 129 |
+
self.diffusion_connector = nn.Sequential(
|
| 130 |
+
nn.Linear(self.config.hidden_size, 2304),
|
| 131 |
+
nn.GELU(approximate="tanh"),
|
| 132 |
+
nn.Linear(2304, 2304),
|
| 133 |
+
norm,
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
for p in self.diffusion_connector.parameters():
|
| 137 |
+
p.requires_grad = True
|
| 138 |
+
|
| 139 |
+
self.config.use_mm_proj = True
|
| 140 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
| 141 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
| 142 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
| 143 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class blip3oMetaForCausalLM(ABC):
|
| 147 |
+
|
| 148 |
+
@abstractmethod
|
| 149 |
+
def get_model(self):
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
def get_vision_tower(self):
|
| 153 |
+
return self.get_model().get_vision_tower()
|
| 154 |
+
|
| 155 |
+
def encode_images(self, images, modalities, pool_scale=None):
|
| 156 |
+
image_features = self.get_model().get_vision_tower()(images, pool_scale=pool_scale)
|
| 157 |
+
|
| 158 |
+
assert 'tokens' in image_features
|
| 159 |
+
image_tokens = image_features['tokens']
|
| 160 |
+
|
| 161 |
+
# discrete features for gen related tasks
|
| 162 |
+
image_tokens = image_tokens + self.config.image_start_token_id
|
| 163 |
+
image_features = self.get_model().embed_tokens(image_tokens)
|
| 164 |
+
|
| 165 |
+
return {'image_features': image_features, 'image_tokens': image_tokens}
|
| 166 |
+
|
| 167 |
+
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=None, image_sizes=None):
|
| 168 |
+
vision_tower = self.get_vision_tower()
|
| 169 |
+
|
| 170 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| 171 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
| 172 |
+
|
| 173 |
+
if not isinstance(modalities, list):
|
| 174 |
+
modalities = [modalities]
|
| 175 |
+
|
| 176 |
+
# random scale for training, but scale 1 for understanding evaluation
|
| 177 |
+
if self.training:
|
| 178 |
+
pool_scale = random.choice(vision_tower.pool_scales)
|
| 179 |
+
else:
|
| 180 |
+
pool_scale = 1
|
| 181 |
+
|
| 182 |
+
if type(images) is list or images.ndim == 5:
|
| 183 |
+
if type(images) is list:
|
| 184 |
+
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
| 185 |
+
|
| 186 |
+
images_list = []
|
| 187 |
+
for image in images:
|
| 188 |
+
if image.ndim == 4:
|
| 189 |
+
images_list.append(image)
|
| 190 |
+
else:
|
| 191 |
+
images_list.append(image.unsqueeze(0))
|
| 192 |
+
|
| 193 |
+
concat_images = torch.cat([image for image in images_list], dim=0)
|
| 194 |
+
split_sizes = [image.shape[0] for image in images_list]
|
| 195 |
+
encoded_image_features = self.encode_images(concat_images, modalities, pool_scale=pool_scale)
|
| 196 |
+
image_tokens = encoded_image_features['image_tokens']
|
| 197 |
+
encoded_image_features = encoded_image_features['image_features']
|
| 198 |
+
|
| 199 |
+
# This is a list, each element is [num_images, patch * patch, dim]
|
| 200 |
+
encoded_image_features = torch.split(encoded_image_features, split_sizes)
|
| 201 |
+
if image_tokens is not None:
|
| 202 |
+
image_tokens = torch.split(image_tokens, split_sizes)
|
| 203 |
+
image_features = []
|
| 204 |
+
for idx, image_feat in enumerate(encoded_image_features):
|
| 205 |
+
image_features.append(image_feat)
|
| 206 |
+
|
| 207 |
+
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
| 208 |
+
|
| 209 |
+
if mm_patch_merge_type == "flat":
|
| 210 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
| 211 |
+
if image_tokens is not None:
|
| 212 |
+
image_tokens = [x.flatten(0, 1) for x in image_tokens]
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
| 215 |
+
else:
|
| 216 |
+
image_features = self.encode_images(images, modalities, pool_scale=pool_scale)
|
| 217 |
+
image_tokens = image_features['image_tokens']
|
| 218 |
+
image_features = image_features['image_features']
|
| 219 |
+
# Let's just add dummy tensors if they do not exist,
|
| 220 |
+
# it is a headache to deal with None all the time.
|
| 221 |
+
# But it is not ideal, and if you have a better idea,
|
| 222 |
+
# please open an issue / submit a PR, thanks.
|
| 223 |
+
breakpoint()
|
| 224 |
+
_labels = labels
|
| 225 |
+
_position_ids = position_ids
|
| 226 |
+
_attention_mask = attention_mask
|
| 227 |
+
if attention_mask is None:
|
| 228 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 229 |
+
else:
|
| 230 |
+
attention_mask = attention_mask.bool()
|
| 231 |
+
if position_ids is None:
|
| 232 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
| 233 |
+
if labels is None:
|
| 234 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 235 |
+
|
| 236 |
+
# remove the padding using attention_mask -- FIXME
|
| 237 |
+
_input_ids = input_ids
|
| 238 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
| 239 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
| 240 |
+
|
| 241 |
+
new_input_embeds = []
|
| 242 |
+
new_labels = []
|
| 243 |
+
cur_image_idx = 0
|
| 244 |
+
# rank_print("Inserting Images embedding")
|
| 245 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
| 246 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
| 247 |
+
# rank0_print(num_images)
|
| 248 |
+
if num_images == 0:
|
| 249 |
+
# cur_image_features = image_features[cur_image_idx]
|
| 250 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
| 251 |
+
# cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
| 252 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_input_embeds_1[0:0]], dim=0)
|
| 253 |
+
new_input_embeds.append(cur_input_embeds)
|
| 254 |
+
new_labels.append(labels[batch_idx])
|
| 255 |
+
cur_image_idx += 1
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
| 259 |
+
cur_input_ids_noim = []
|
| 260 |
+
cur_labels = labels[batch_idx]
|
| 261 |
+
cur_labels_noim = []
|
| 262 |
+
for i in range(len(image_token_indices) - 1):
|
| 263 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
| 264 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
| 265 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
| 266 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
| 267 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
| 268 |
+
cur_new_input_embeds = []
|
| 269 |
+
cur_new_labels = []
|
| 270 |
+
|
| 271 |
+
for i in range(num_images + 1):
|
| 272 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
| 273 |
+
cur_new_labels.append(cur_labels_noim[i])
|
| 274 |
+
if i < num_images:
|
| 275 |
+
try:
|
| 276 |
+
cur_image_features = image_features[cur_image_idx]
|
| 277 |
+
except IndexError:
|
| 278 |
+
rank0_print("Error image_features[cur_image_idx]!")
|
| 279 |
+
break
|
| 280 |
+
# [Assisant\n<start_image><image><end_image>]
|
| 281 |
+
if self.config.image_start_tag_id == cur_labels_noim[i][-1] and image_tokens is not None:
|
| 282 |
+
cur_image_tokens = image_tokens[cur_image_idx]
|
| 283 |
+
if pool_scale is not None:
|
| 284 |
+
pool_token = self.config.scale_start_token_id + pool_scale - 1
|
| 285 |
+
pool_token = torch.tensor([pool_token], dtype=torch.long, device=cur_image_tokens.device)
|
| 286 |
+
cur_image_tokens = torch.cat([pool_token, cur_image_tokens])
|
| 287 |
+
pool_embed = self.get_model().embed_tokens(pool_token)
|
| 288 |
+
cur_image_features = torch.cat([pool_embed, cur_image_features])
|
| 289 |
+
else:
|
| 290 |
+
cur_image_tokens = torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)
|
| 291 |
+
cur_image_idx += 1
|
| 292 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 293 |
+
cur_new_labels.append(cur_image_tokens)
|
| 294 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
| 295 |
+
|
| 296 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
| 297 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
| 298 |
+
|
| 299 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 300 |
+
new_labels.append(cur_new_labels)
|
| 301 |
+
|
| 302 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
| 303 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
| 304 |
+
|
| 305 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
| 306 |
+
new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
| 307 |
+
|
| 308 |
+
# Combine them
|
| 309 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
| 310 |
+
batch_size = len(new_input_embeds)
|
| 311 |
+
|
| 312 |
+
new_input_embeds_padded = []
|
| 313 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
| 314 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 315 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
| 316 |
+
|
| 317 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
| 318 |
+
cur_len = cur_new_embed.shape[0]
|
| 319 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
| 320 |
+
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
|
| 321 |
+
if cur_len > 0:
|
| 322 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
| 323 |
+
attention_mask[i, -cur_len:] = True
|
| 324 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 325 |
+
else:
|
| 326 |
+
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
|
| 327 |
+
if cur_len > 0:
|
| 328 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
| 329 |
+
attention_mask[i, :cur_len] = True
|
| 330 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 331 |
+
|
| 332 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
| 333 |
+
|
| 334 |
+
if _labels is None:
|
| 335 |
+
new_labels = None
|
| 336 |
+
else:
|
| 337 |
+
new_labels = new_labels_padded
|
| 338 |
+
|
| 339 |
+
if _attention_mask is None:
|
| 340 |
+
attention_mask = None
|
| 341 |
+
else:
|
| 342 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
| 343 |
+
|
| 344 |
+
if _position_ids is None:
|
| 345 |
+
position_ids = None
|
| 346 |
+
if getattr(self.config, "use_pos_skipping", False) and self.training:
|
| 347 |
+
position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
|
| 348 |
+
split_position = random.randint(0, new_input_embeds.size(1))
|
| 349 |
+
left_add = random.randint(0, self.config.pos_skipping_range)
|
| 350 |
+
right_add = random.randint(left_add, self.config.pos_skipping_range)
|
| 351 |
+
position_ids[:, :split_position] += left_add
|
| 352 |
+
position_ids[:, split_position:] += right_add
|
| 353 |
+
|
| 354 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
| 355 |
+
|
| 356 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
| 357 |
+
total_num_new_tokens = 0
|
| 358 |
+
vocab_size = len(tokenizer)
|
| 359 |
+
if model_args.mm_use_im_start_end:
|
| 360 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
| 361 |
+
self.config.image_start_tag_id = tokenizer.convert_tokens_to_ids(DEFAULT_IM_START_TOKEN)
|
| 362 |
+
self.config.image_end_tag_id = tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN)
|
| 363 |
+
total_num_new_tokens += num_new_tokens
|
| 364 |
+
self.resize_token_embeddings(vocab_size + total_num_new_tokens)
|
| 365 |
+
|
| 366 |
+
if model_args.num_scale_tokens > 0:
|
| 367 |
+
scale_tokens = [model_args.scale_token_format.format(str(i)) for i in range(model_args.num_scale_tokens)]
|
| 368 |
+
num_new_tokens = tokenizer.add_tokens(scale_tokens, special_tokens=False)
|
| 369 |
+
self.config.scale_start_token_id = tokenizer.convert_tokens_to_ids(scale_tokens[0])
|
| 370 |
+
self.config.scale_end_token_id = tokenizer.convert_tokens_to_ids(scale_tokens[-1])
|
| 371 |
+
self.config.num_scale_tokens = model_args.num_scale_tokens
|
| 372 |
+
total_num_new_tokens += num_new_tokens
|
| 373 |
+
self.resize_token_embeddings(vocab_size + total_num_new_tokens)
|
| 374 |
+
|
| 375 |
+
if model_args.num_image_tokens > 0:
|
| 376 |
+
image_tokens = [model_args.image_token_format.format(str(i)) for i in range(model_args.num_image_tokens)]
|
| 377 |
+
num_new_tokens = tokenizer.add_tokens(image_tokens, special_tokens=False)
|
| 378 |
+
self.config.image_start_token_id = tokenizer.convert_tokens_to_ids(image_tokens[0])
|
| 379 |
+
self.config.image_end_token_id = tokenizer.convert_tokens_to_ids(image_tokens[-1])
|
| 380 |
+
self.config.num_image_tokens = model_args.num_image_tokens
|
| 381 |
+
|
| 382 |
+
total_num_new_tokens += num_new_tokens
|
| 383 |
+
self.resize_token_embeddings(vocab_size + total_num_new_tokens)
|
| 384 |
+
if num_new_tokens > 0:
|
| 385 |
+
self.config.num_new_tokens = num_new_tokens
|
| 386 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
| 387 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
| 388 |
+
|
| 389 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 390 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 391 |
+
|
| 392 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 393 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 394 |
+
|
| 395 |
+
vision_tower = self.get_vision_tower()
|
| 396 |
+
if model_args.load_embeddings_from_vision and vision_tower is not None:
|
| 397 |
+
vision_embeddings = vision_tower.get_embedding()
|
| 398 |
+
if model_args.num_image_tokens == vision_embeddings.shape[0] and input_embeddings.shape[1] == vision_embeddings.shape[1]:
|
| 399 |
+
rank0_print("Load vision embeddings from vision tower.")
|
| 400 |
+
input_embeddings[self.config.image_start_token_id:self.config.image_end_token_id+1] = vision_embeddings
|
blip3o/model/builder.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
|
| 4 |
+
from blip3o.model import blip3oQwenForCausalLM
|
| 5 |
+
from blip3o.utils import rank0_print
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
|
| 9 |
+
kwargs["device_map"] = device_map
|
| 10 |
+
kwargs.pop("multimodal")
|
| 11 |
+
|
| 12 |
+
if customized_config is not None:
|
| 13 |
+
kwargs["config"] = customized_config
|
| 14 |
+
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 16 |
+
from blip3o.model.language_model.blip3o_qwen import blip3oQwenConfig
|
| 17 |
+
|
| 18 |
+
breakpoint()
|
| 19 |
+
if overwrite_config is not None:
|
| 20 |
+
blip3o_cfg = blip3oQwenConfig.from_pretrained(model_path)
|
| 21 |
+
rank0_print(f"Overwriting config with {overwrite_config}")
|
| 22 |
+
for k, v in overwrite_config.items():
|
| 23 |
+
setattr(blip3o_cfg, k, v)
|
| 24 |
+
model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=blip3o_cfg, **kwargs)
|
| 25 |
+
else:
|
| 26 |
+
model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
|
| 27 |
+
|
| 28 |
+
vision_tower = model.get_vision_tower()
|
| 29 |
+
if not vision_tower.is_loaded:
|
| 30 |
+
vision_tower.load_model(device_map=device_map)
|
| 31 |
+
if device_map != "auto":
|
| 32 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
| 33 |
+
image_processor = vision_tower.image_processor
|
| 34 |
+
|
| 35 |
+
if hasattr(model.config, "max_sequence_length"):
|
| 36 |
+
context_len = model.config.max_sequence_length
|
| 37 |
+
elif hasattr(model.config, "max_position_embeddings"):
|
| 38 |
+
context_len = model.config.max_position_embeddings
|
| 39 |
+
elif hasattr(model.config, "tokenizer_model_max_length"):
|
| 40 |
+
context_len = model.config.tokenizer_model_max_length
|
| 41 |
+
else:
|
| 42 |
+
context_len = 2048
|
| 43 |
+
|
| 44 |
+
return tokenizer, model, image_processor, context_len
|
blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-310.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-311.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-310.pyc
ADDED
|
Binary file (7.82 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-311.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-310.pyc
ADDED
|
Binary file (7.12 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-311.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/llava_qwen.cpython-310.pyc
ADDED
|
Binary file (6.82 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/llava_qwen.cpython-311.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-310.pyc
ADDED
|
Binary file (7.79 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-310.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-311.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
blip3o/model/language_model/blip3o_qwen.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoConfig,
|
| 7 |
+
AutoModelForCausalLM,
|
| 8 |
+
Qwen3Config,
|
| 9 |
+
Qwen3ForCausalLM,
|
| 10 |
+
Qwen3Model,
|
| 11 |
+
)
|
| 12 |
+
from transformers.generation.utils import GenerateOutput
|
| 13 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 14 |
+
|
| 15 |
+
from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel
|
| 16 |
+
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
|
| 17 |
+
from blip3o.utils import rank0_print
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class blip3oQwenConfig(Qwen3Config):
|
| 21 |
+
model_type = "blip3o_qwen"
|
| 22 |
+
|
| 23 |
+
class blip3oQwenModel(blip3oMetaModel, Qwen3Model):
|
| 24 |
+
config_class = blip3oQwenConfig
|
| 25 |
+
|
| 26 |
+
def __init__(self, config: Qwen3Config):
|
| 27 |
+
super(blip3oQwenModel, self).__init__(config)
|
| 28 |
+
|
| 29 |
+
class blip3oQwenForCausalLM(Qwen3ForCausalLM, blip3oMetaForCausalLM):
|
| 30 |
+
config_class = blip3oQwenConfig
|
| 31 |
+
|
| 32 |
+
def __init__(self, config):
|
| 33 |
+
Qwen3ForCausalLM.__init__(self, config)
|
| 34 |
+
config.model_type = "blip3o_qwen"
|
| 35 |
+
config.rope_scaling = None
|
| 36 |
+
|
| 37 |
+
self.model = blip3oQwenModel(config)
|
| 38 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 39 |
+
|
| 40 |
+
# Initialize weights and apply final processing
|
| 41 |
+
self.post_init()
|
| 42 |
+
|
| 43 |
+
def get_model(self):
|
| 44 |
+
return self.model
|
| 45 |
+
|
| 46 |
+
def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
|
| 47 |
+
sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
| 48 |
+
schedule_timesteps = self.model.noise_scheduler.timesteps.to(device)
|
| 49 |
+
timesteps = timesteps.to(device)
|
| 50 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 51 |
+
|
| 52 |
+
sigma = sigmas[step_indices].flatten()
|
| 53 |
+
while len(sigma.shape) < n_dim:
|
| 54 |
+
sigma = sigma.unsqueeze(-1)
|
| 55 |
+
return sigma
|
| 56 |
+
|
| 57 |
+
def mask_drop(self, latents, drop_prob=0.1):
|
| 58 |
+
if drop_prob <= 0:
|
| 59 |
+
return latents
|
| 60 |
+
mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
|
| 61 |
+
while len(mask.shape) < len(latents.shape):
|
| 62 |
+
mask = mask.unsqueeze(-1)
|
| 63 |
+
mask = 1 - mask # need to flip 0 <-> 1
|
| 64 |
+
return latents * mask
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def forward(
|
| 68 |
+
self,
|
| 69 |
+
input_ids: torch.LongTensor = None,
|
| 70 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 71 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 72 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 73 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 74 |
+
labels: Optional[torch.LongTensor] = None,
|
| 75 |
+
use_cache: Optional[bool] = None,
|
| 76 |
+
output_attentions: Optional[bool] = None,
|
| 77 |
+
output_hidden_states: Optional[bool] = None,
|
| 78 |
+
images: Optional[torch.FloatTensor] = None,
|
| 79 |
+
target_images: Optional[torch.FloatTensor] = None,
|
| 80 |
+
image_sizes: Optional[List[List[int]]] = None,
|
| 81 |
+
return_dict: Optional[bool] = None,
|
| 82 |
+
modalities: Optional[List[str]] = ["image"],
|
| 83 |
+
dpo_forward: Optional[bool] = False,
|
| 84 |
+
cache_position=None,
|
| 85 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if inputs_embeds is None:
|
| 89 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
|
| 90 |
+
outputs = self.model(
|
| 91 |
+
input_ids=input_ids,
|
| 92 |
+
attention_mask=attention_mask,
|
| 93 |
+
position_ids=position_ids,
|
| 94 |
+
past_key_values=past_key_values,
|
| 95 |
+
inputs_embeds=inputs_embeds,
|
| 96 |
+
use_cache=use_cache,
|
| 97 |
+
output_attentions=output_attentions,
|
| 98 |
+
output_hidden_states=output_hidden_states,
|
| 99 |
+
return_dict=return_dict,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
hidden_states = outputs[0]
|
| 103 |
+
logits = self.lm_head(hidden_states)
|
| 104 |
+
if labels is not None:
|
| 105 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 106 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 107 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
| 108 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 109 |
+
shift_labels = shift_labels.view(-1)
|
| 110 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 111 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if target_images is not None:
|
| 116 |
+
vae = self.model.get_sana_vae()
|
| 117 |
+
latents = vae.encode(target_images).latent
|
| 118 |
+
if "shift_factor" in vae.config and vae.config.shift_factor is not None:
|
| 119 |
+
latents = latents - vae.config.shift_factor
|
| 120 |
+
latents = latents * vae.config.scaling_factor
|
| 121 |
+
noise = torch.randn_like(latents, device=latents.device)
|
| 122 |
+
weighting_scheme = "uniform"
|
| 123 |
+
u = compute_density_for_timestep_sampling(
|
| 124 |
+
weighting_scheme=weighting_scheme,
|
| 125 |
+
batch_size=latents.shape[0],
|
| 126 |
+
logit_mean=0.0,
|
| 127 |
+
logit_std=1.0,
|
| 128 |
+
mode_scale=1.29,
|
| 129 |
+
)
|
| 130 |
+
indices = (u * self.model.noise_scheduler.config.num_train_timesteps).long()
|
| 131 |
+
timesteps = self.model.noise_scheduler.timesteps[indices].to(device=latents.device)
|
| 132 |
+
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
|
| 133 |
+
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
|
| 134 |
+
|
| 135 |
+
sana = self.model.get_sana()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
start_pos = (labels == self.config.image_start_tag_id).float().argmax(dim=1)
|
| 139 |
+
end_pos = (labels == self.config.image_end_tag_id).float().argmax(dim=1)
|
| 140 |
+
|
| 141 |
+
breakpoint()
|
| 142 |
+
selected_hidden_states = []
|
| 143 |
+
for b in range(hidden_states.size(0)):
|
| 144 |
+
start = start_pos[b].item() + 1
|
| 145 |
+
end = end_pos[b].item()
|
| 146 |
+
hidden_states_filter = hidden_states[b, start:end, :]
|
| 147 |
+
if hidden_states_filter.size(1) != 730:
|
| 148 |
+
hidden_states_filter = hidden_states[b, -730:, :]
|
| 149 |
+
selected_hidden_states.append(hidden_states_filter)
|
| 150 |
+
|
| 151 |
+
selected_hidden_states = torch.stack(selected_hidden_states, dim=0)
|
| 152 |
+
diffusion_pred = sana(
|
| 153 |
+
hidden_states=noisy_latents,
|
| 154 |
+
timestep=timesteps,
|
| 155 |
+
encoder_hidden_states=self.model.diffusion_connector(self.mask_drop(selected_hidden_states)),
|
| 156 |
+
encoder_attention_mask=None,
|
| 157 |
+
).sample
|
| 158 |
+
|
| 159 |
+
target = noise - latents
|
| 160 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
|
| 161 |
+
diff_loss = torch.mean(
|
| 162 |
+
(weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 163 |
+
1,
|
| 164 |
+
)
|
| 165 |
+
diff_loss = diff_loss.mean()
|
| 166 |
+
rank0_print(f" Cross-entropy loss {loss}, Diffusion loss {diff_loss} ")
|
| 167 |
+
loss += diff_loss
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
return CausalLMOutputWithPast(
|
| 173 |
+
loss=loss,
|
| 174 |
+
logits=logits,
|
| 175 |
+
past_key_values=outputs.past_key_values,
|
| 176 |
+
hidden_states=outputs.hidden_states,
|
| 177 |
+
attentions=outputs.attentions,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
def generate(
|
| 183 |
+
self,
|
| 184 |
+
inputs: Optional[torch.Tensor] = None,
|
| 185 |
+
images: Optional[torch.Tensor] = None,
|
| 186 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 187 |
+
modalities: Optional[List[str]] = ["image"],
|
| 188 |
+
**kwargs,
|
| 189 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 190 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 191 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 192 |
+
if "inputs_embeds" in kwargs:
|
| 193 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 194 |
+
|
| 195 |
+
if images is not None:
|
| 196 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
| 197 |
+
else:
|
| 198 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
| 199 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 204 |
+
images = kwargs.pop("images", None)
|
| 205 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
| 206 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
| 207 |
+
if images is not None:
|
| 208 |
+
inputs["images"] = images
|
| 209 |
+
if image_sizes is not None:
|
| 210 |
+
inputs["image_sizes"] = image_sizes
|
| 211 |
+
return inputs
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
AutoConfig.register("blip3o_qwen", blip3oQwenConfig)
|
| 215 |
+
AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM)
|
blip3o/model/language_model/blip3o_qwen_grpo.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoConfig,
|
| 7 |
+
AutoModelForCausalLM,
|
| 8 |
+
Qwen3Config,
|
| 9 |
+
Qwen3ForCausalLM,
|
| 10 |
+
Qwen3Model,
|
| 11 |
+
)
|
| 12 |
+
from transformers.generation.utils import GenerateOutput
|
| 13 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 14 |
+
|
| 15 |
+
from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel
|
| 16 |
+
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.schedulers import DDPMScheduler, DDIMScheduler, LCMScheduler, FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
|
| 19 |
+
import numpy as np
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import PIL
|
| 22 |
+
from blip3o.utils import rank0_print
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def numpy_to_pil(images: np.ndarray):
|
| 27 |
+
"""
|
| 28 |
+
Convert a NumPy array of shape (batch, height, width, channels) to a list of PIL Images.
|
| 29 |
+
"""
|
| 30 |
+
pil_images = []
|
| 31 |
+
for img in images:
|
| 32 |
+
img_uint8 = (img * 255).round().astype("uint8")
|
| 33 |
+
if img_uint8.shape[2] == 1:
|
| 34 |
+
img_uint8 = img_uint8[..., 0]
|
| 35 |
+
pil_images.append(PIL.Image.fromarray(img_uint8))
|
| 36 |
+
return pil_images
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class blip3oQwenConfig(Qwen3Config):
|
| 41 |
+
model_type = "blip3o_qwen_grpo"
|
| 42 |
+
|
| 43 |
+
class blip3oQwenModel(blip3oMetaModel, Qwen3Model):
|
| 44 |
+
config_class = blip3oQwenConfig
|
| 45 |
+
|
| 46 |
+
def __init__(self, config: Qwen3Config):
|
| 47 |
+
super(blip3oQwenModel, self).__init__(config)
|
| 48 |
+
|
| 49 |
+
class blip3oQwenForGRPOLM(Qwen3ForCausalLM, blip3oMetaForCausalLM):
|
| 50 |
+
config_class = blip3oQwenConfig
|
| 51 |
+
|
| 52 |
+
def __init__(self, config):
|
| 53 |
+
Qwen3ForCausalLM.__init__(self, config)
|
| 54 |
+
config.model_type = "blip3o_qwen"
|
| 55 |
+
config.rope_scaling = None
|
| 56 |
+
|
| 57 |
+
self.model = blip3oQwenModel(config)
|
| 58 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 59 |
+
|
| 60 |
+
# Initialize weights and apply final processing
|
| 61 |
+
self.post_init()
|
| 62 |
+
|
| 63 |
+
def get_model(self):
|
| 64 |
+
return self.model
|
| 65 |
+
|
| 66 |
+
def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
|
| 67 |
+
sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
| 68 |
+
schedule_timesteps = self.model.noise_scheduler.timesteps.to(device)
|
| 69 |
+
timesteps = timesteps.to(device)
|
| 70 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 71 |
+
|
| 72 |
+
sigma = sigmas[step_indices].flatten()
|
| 73 |
+
while len(sigma.shape) < n_dim:
|
| 74 |
+
sigma = sigma.unsqueeze(-1)
|
| 75 |
+
return sigma
|
| 76 |
+
|
| 77 |
+
def mask_drop(self, latents, drop_prob=0.1):
|
| 78 |
+
if drop_prob <= 0:
|
| 79 |
+
return latents
|
| 80 |
+
mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
|
| 81 |
+
while len(mask.shape) < len(latents.shape):
|
| 82 |
+
mask = mask.unsqueeze(-1)
|
| 83 |
+
mask = 1 - mask # need to flip 0 <-> 1
|
| 84 |
+
return latents * mask
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def generate(
|
| 90 |
+
self,
|
| 91 |
+
inputs: Optional[torch.Tensor] = None,
|
| 92 |
+
images: Optional[torch.Tensor] = None,
|
| 93 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 94 |
+
modalities: Optional[List[str]] = ["image"],
|
| 95 |
+
**kwargs,
|
| 96 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 97 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 98 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 99 |
+
if "inputs_embeds" in kwargs:
|
| 100 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 101 |
+
|
| 102 |
+
if images is not None:
|
| 103 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
| 104 |
+
else:
|
| 105 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
| 106 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 111 |
+
images = kwargs.pop("images", None)
|
| 112 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
| 113 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
| 114 |
+
if images is not None:
|
| 115 |
+
inputs["images"] = images
|
| 116 |
+
if image_sizes is not None:
|
| 117 |
+
inputs["image_sizes"] = image_sizes
|
| 118 |
+
return inputs
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@torch.no_grad()
|
| 125 |
+
def decode_latents(self, latents, normalize=True, return_tensor=False):
|
| 126 |
+
if self.model.sana_vae is not None:
|
| 127 |
+
latents = latents / self.model.sana_vae.config.scaling_factor
|
| 128 |
+
if "shift_factor" in self.model.sana_vae.config and self.model.sana_vae.config.shift_factor is not None:
|
| 129 |
+
latents = latents + self.model.sana_vae.config.shift_factor
|
| 130 |
+
samples = self.model.sana_vae.decode(latents).sample
|
| 131 |
+
else:
|
| 132 |
+
samples = latents
|
| 133 |
+
if normalize:
|
| 134 |
+
samples = (samples / 2 + 0.5).clamp(0, 1)
|
| 135 |
+
else:
|
| 136 |
+
samples = samples.clamp(-1, 1)
|
| 137 |
+
if return_tensor:
|
| 138 |
+
return samples
|
| 139 |
+
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 140 |
+
samples = numpy_to_pil(samples)
|
| 141 |
+
return samples
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@torch.no_grad()
|
| 146 |
+
def generate_images(
|
| 147 |
+
self,
|
| 148 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 149 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 150 |
+
max_new_tokens: Optional[torch.Tensor] = None,
|
| 151 |
+
temperature: Optional[torch.Tensor] = None,
|
| 152 |
+
top_p: Optional[torch.Tensor] = None,
|
| 153 |
+
top_k: Optional[torch.Tensor] = None,
|
| 154 |
+
images: Optional[torch.Tensor] = None,
|
| 155 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 156 |
+
modalities: Optional[List[str]] = ["image"],
|
| 157 |
+
guidance_scale: float = 2.0,
|
| 158 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 159 |
+
num_inference_steps: int = 30,
|
| 160 |
+
num_images_per_prompt: int = 1,
|
| 161 |
+
return_tensor=False,
|
| 162 |
+
enable_progress_bar=False,
|
| 163 |
+
**kwargs,
|
| 164 |
+
):
|
| 165 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 166 |
+
# attention_mask = (inputs != -100).long()
|
| 167 |
+
|
| 168 |
+
gen_ids = super(blip3oQwenForGRPOLM, self).generate(
|
| 169 |
+
input_ids,
|
| 170 |
+
max_new_tokens=max_new_tokens,
|
| 171 |
+
do_sample=True,
|
| 172 |
+
temperature=1.0,
|
| 173 |
+
attention_mask=attention_mask,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# breakpoint()
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
outs = self.model(
|
| 179 |
+
input_ids = gen_ids,
|
| 180 |
+
output_hidden_states = True,
|
| 181 |
+
return_dict = True,
|
| 182 |
+
)
|
| 183 |
+
hidden_states = outs.hidden_states[-1]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
start_pos = (gen_ids == self.config.image_start_tag_id).float().argmax(dim=1)
|
| 187 |
+
end_pos = (gen_ids == self.config.image_end_tag_id).float().argmax(dim=1)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
selected_hidden_states = []
|
| 191 |
+
for b in range(hidden_states.size(0)):
|
| 192 |
+
start = start_pos[b].item() + 1
|
| 193 |
+
# end = end_pos[b].item()
|
| 194 |
+
selected_hidden_states.append(hidden_states[b, start:, :])
|
| 195 |
+
pred_latent = torch.stack(selected_hidden_states, dim=0)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
img_hidden_states_null = torch.zeros_like(pred_latent)
|
| 199 |
+
pred_latent = torch.cat([img_hidden_states_null, pred_latent], 0)
|
| 200 |
+
## sample images from here
|
| 201 |
+
device = next(self.parameters()).device
|
| 202 |
+
dtype = next(self.parameters()).dtype
|
| 203 |
+
|
| 204 |
+
bsz = len(pred_latent) // 2
|
| 205 |
+
# latent_size = self.config.input_size
|
| 206 |
+
latent_size = 32
|
| 207 |
+
latent_channels = self.model.sana.config.in_channels
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
latents = randn_tensor(
|
| 211 |
+
shape=(bsz * num_images_per_prompt, latent_channels, latent_size, latent_size),
|
| 212 |
+
generator=None,
|
| 213 |
+
device=device,
|
| 214 |
+
dtype=torch.bfloat16,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# set step values
|
| 218 |
+
if isinstance(self.model.noise_scheduler, FlowMatchEulerDiscreteScheduler):
|
| 219 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 220 |
+
self.model.noise_scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
|
| 221 |
+
else:
|
| 222 |
+
self.model.noise_scheduler.set_timesteps(num_inference_steps)
|
| 223 |
+
|
| 224 |
+
# pred_latent = torch.cat([pred_latent] * 2)
|
| 225 |
+
# Convert to float32 before saving
|
| 226 |
+
for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images", disable=not enable_progress_bar):
|
| 227 |
+
|
| 228 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 229 |
+
latent_model_input = latent_model_input.to(pred_latent.dtype)
|
| 230 |
+
|
| 231 |
+
if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"):
|
| 232 |
+
latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t)
|
| 233 |
+
# predict noise model_output
|
| 234 |
+
noise_pred = self.model.sana(
|
| 235 |
+
hidden_states=latent_model_input,
|
| 236 |
+
encoder_hidden_states=self.model.diffusion_connector(pred_latent),
|
| 237 |
+
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device),
|
| 238 |
+
encoder_attention_mask=None
|
| 239 |
+
).sample
|
| 240 |
+
|
| 241 |
+
noise_pred_uncond, noise_pred= noise_pred.chunk(2)
|
| 242 |
+
|
| 243 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
| 244 |
+
|
| 245 |
+
# compute previous image: x_t -> x_t-1
|
| 246 |
+
latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
|
| 247 |
+
|
| 248 |
+
samples = self.decode_latents(latents.to(self.model.sana_vae.dtype) if self.model.sana_vae is not None else latents, return_tensor=return_tensor)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
return gen_ids, samples
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
AutoConfig.register("blip3o_qwen_grpo", blip3oQwenConfig)
|
| 255 |
+
AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForGRPOLM)
|
blip3o/model/language_model/blip3o_qwen_inference.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoConfig,
|
| 7 |
+
AutoModelForCausalLM,
|
| 8 |
+
Qwen3Config,
|
| 9 |
+
Qwen3ForCausalLM,
|
| 10 |
+
Qwen3Model,
|
| 11 |
+
)
|
| 12 |
+
from transformers.generation.utils import GenerateOutput
|
| 13 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 14 |
+
|
| 15 |
+
from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel
|
| 16 |
+
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.schedulers import DDPMScheduler, DDIMScheduler, LCMScheduler, FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
|
| 19 |
+
import numpy as np
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import PIL
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def numpy_to_pil(images: np.ndarray):
|
| 25 |
+
"""
|
| 26 |
+
Convert a NumPy array of shape (batch, height, width, channels) to a list of PIL Images.
|
| 27 |
+
"""
|
| 28 |
+
pil_images = []
|
| 29 |
+
for img in images:
|
| 30 |
+
img_uint8 = (img * 255).round().astype("uint8")
|
| 31 |
+
if img_uint8.shape[2] == 1:
|
| 32 |
+
img_uint8 = img_uint8[..., 0]
|
| 33 |
+
pil_images.append(PIL.Image.fromarray(img_uint8))
|
| 34 |
+
return pil_images
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class blip3oQwenConfig(Qwen3Config):
|
| 38 |
+
model_type = "blip3o_qwen_inference"
|
| 39 |
+
|
| 40 |
+
class blip3oQwenModel(blip3oMetaModel, Qwen3Model):
|
| 41 |
+
config_class = blip3oQwenConfig
|
| 42 |
+
|
| 43 |
+
def __init__(self, config: Qwen3Config):
|
| 44 |
+
super(blip3oQwenModel, self).__init__(config)
|
| 45 |
+
|
| 46 |
+
class blip3oQwenForInferenceLM(Qwen3ForCausalLM, blip3oMetaForCausalLM):
|
| 47 |
+
config_class = blip3oQwenConfig
|
| 48 |
+
|
| 49 |
+
def __init__(self, config):
|
| 50 |
+
Qwen3ForCausalLM.__init__(self, config)
|
| 51 |
+
config.model_type = "blip3o_qwen"
|
| 52 |
+
config.rope_scaling = None
|
| 53 |
+
|
| 54 |
+
self.model = blip3oQwenModel(config)
|
| 55 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 56 |
+
|
| 57 |
+
# Initialize weights and apply final processing
|
| 58 |
+
self.post_init()
|
| 59 |
+
|
| 60 |
+
def get_model(self):
|
| 61 |
+
return self.model
|
| 62 |
+
|
| 63 |
+
def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
|
| 64 |
+
sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
| 65 |
+
schedule_timesteps = self.model.noise_scheduler.timesteps.to(device)
|
| 66 |
+
timesteps = timesteps.to(device)
|
| 67 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 68 |
+
|
| 69 |
+
sigma = sigmas[step_indices].flatten()
|
| 70 |
+
while len(sigma.shape) < n_dim:
|
| 71 |
+
sigma = sigma.unsqueeze(-1)
|
| 72 |
+
return sigma
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@torch.no_grad()
|
| 77 |
+
def generate(
|
| 78 |
+
self,
|
| 79 |
+
inputs: Optional[torch.Tensor] = None,
|
| 80 |
+
images: Optional[torch.Tensor] = None,
|
| 81 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 82 |
+
modalities: Optional[List[str]] = ["image"],
|
| 83 |
+
**kwargs,
|
| 84 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 85 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 86 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 87 |
+
if "inputs_embeds" in kwargs:
|
| 88 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 89 |
+
|
| 90 |
+
if images is not None:
|
| 91 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
| 92 |
+
else:
|
| 93 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
| 94 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@torch.no_grad()
|
| 101 |
+
def decode_latents(self, latents, normalize=True, return_tensor=False):
|
| 102 |
+
if self.model.sana_vae is not None:
|
| 103 |
+
latents = latents / self.model.sana_vae.config.scaling_factor
|
| 104 |
+
if "shift_factor" in self.model.sana_vae.config and self.model.sana_vae.config.shift_factor is not None:
|
| 105 |
+
latents = latents + self.model.sana_vae.config.shift_factor
|
| 106 |
+
samples = self.model.sana_vae.decode(latents).sample
|
| 107 |
+
else:
|
| 108 |
+
samples = latents
|
| 109 |
+
if normalize:
|
| 110 |
+
samples = (samples / 2 + 0.5).clamp(0, 1)
|
| 111 |
+
else:
|
| 112 |
+
samples = samples.clamp(-1, 1)
|
| 113 |
+
if return_tensor:
|
| 114 |
+
return samples
|
| 115 |
+
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 116 |
+
samples = numpy_to_pil(samples)
|
| 117 |
+
return samples
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def generate_images(
|
| 123 |
+
self,
|
| 124 |
+
inputs: Optional[torch.Tensor] = None,
|
| 125 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 126 |
+
max_new_tokens: Optional[torch.Tensor] = None,
|
| 127 |
+
temperature: Optional[torch.Tensor] = None,
|
| 128 |
+
top_p: Optional[torch.Tensor] = None,
|
| 129 |
+
top_k: Optional[torch.Tensor] = None,
|
| 130 |
+
images: Optional[torch.Tensor] = None,
|
| 131 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 132 |
+
modalities: Optional[List[str]] = ["image"],
|
| 133 |
+
guidance_scale: float = 2.0,
|
| 134 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 135 |
+
num_inference_steps: int = 30,
|
| 136 |
+
num_images_per_prompt: int = 1,
|
| 137 |
+
return_tensor=False,
|
| 138 |
+
enable_progress_bar=False,
|
| 139 |
+
**kwargs,
|
| 140 |
+
):
|
| 141 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 142 |
+
# attention_mask = (inputs != -100).long()
|
| 143 |
+
|
| 144 |
+
gen_ids = super(blip3oQwenForInferenceLM, self).generate(
|
| 145 |
+
inputs,
|
| 146 |
+
max_new_tokens=max_new_tokens,
|
| 147 |
+
do_sample=True,
|
| 148 |
+
temperature=temperature,
|
| 149 |
+
attention_mask=attention_mask,
|
| 150 |
+
top_p=top_p,
|
| 151 |
+
top_k=top_k)
|
| 152 |
+
|
| 153 |
+
# breakpoint()
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
outs = self.model(
|
| 156 |
+
input_ids = gen_ids,
|
| 157 |
+
output_hidden_states = True,
|
| 158 |
+
return_dict = True,
|
| 159 |
+
)
|
| 160 |
+
hidden_states = outs.hidden_states[-1]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
start_pos = (gen_ids == self.config.image_start_tag_id).float().argmax(dim=1)
|
| 164 |
+
end_pos = (gen_ids == self.config.image_end_tag_id).float().argmax(dim=1)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
selected_hidden_states = []
|
| 168 |
+
for b in range(hidden_states.size(0)):
|
| 169 |
+
start = start_pos[b].item() + 1
|
| 170 |
+
# end = end_pos[b].item()
|
| 171 |
+
selected_hidden_states.append(hidden_states[b, start:, :])
|
| 172 |
+
pred_latent = torch.stack(selected_hidden_states, dim=0)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
img_hidden_states_null = torch.zeros_like(pred_latent)
|
| 177 |
+
pred_latent = torch.cat([img_hidden_states_null, pred_latent], 0)
|
| 178 |
+
## sample images from here
|
| 179 |
+
device = next(self.parameters()).device
|
| 180 |
+
dtype = next(self.parameters()).dtype
|
| 181 |
+
|
| 182 |
+
bsz = len(pred_latent) // 2
|
| 183 |
+
# latent_size = self.config.input_size
|
| 184 |
+
latent_size = 32
|
| 185 |
+
latent_channels = self.model.sana.config.in_channels
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
latents = randn_tensor(
|
| 189 |
+
shape=(bsz * num_images_per_prompt, latent_channels, latent_size, latent_size),
|
| 190 |
+
generator=None,
|
| 191 |
+
device=device,
|
| 192 |
+
dtype=torch.bfloat16,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# set step values
|
| 196 |
+
if isinstance(self.model.noise_scheduler, FlowMatchEulerDiscreteScheduler):
|
| 197 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 198 |
+
self.model.noise_scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
|
| 199 |
+
else:
|
| 200 |
+
self.model.noise_scheduler.set_timesteps(num_inference_steps)
|
| 201 |
+
|
| 202 |
+
# pred_latent = torch.cat([pred_latent] * 2)
|
| 203 |
+
# Convert to float32 before saving
|
| 204 |
+
for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images", disable=not enable_progress_bar):
|
| 205 |
+
|
| 206 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 207 |
+
latent_model_input = latent_model_input.to(pred_latent.dtype)
|
| 208 |
+
|
| 209 |
+
if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"):
|
| 210 |
+
latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t)
|
| 211 |
+
# predict noise model_output
|
| 212 |
+
noise_pred = self.model.sana(
|
| 213 |
+
hidden_states=latent_model_input,
|
| 214 |
+
encoder_hidden_states=self.model.diffusion_connector(pred_latent),
|
| 215 |
+
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device),
|
| 216 |
+
encoder_attention_mask=None
|
| 217 |
+
).sample
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
noise_pred_uncond, noise_pred= noise_pred.chunk(2)
|
| 221 |
+
|
| 222 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
| 223 |
+
|
| 224 |
+
# compute previous image: x_t -> x_t-1
|
| 225 |
+
latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
|
| 226 |
+
|
| 227 |
+
samples = self.decode_latents(latents.to(self.model.sana_vae.dtype) if self.model.sana_vae is not None else latents, return_tensor=return_tensor)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
return gen_ids, samples
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
AutoConfig.register("blip3o_qwen_inference", blip3oQwenConfig)
|
| 240 |
+
AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForInferenceLM)
|
| 241 |
+
|
blip3o/model/multimodal_decoder/__pycache__/builder.cpython-310.pyc
ADDED
|
Binary file (661 Bytes). View file
|
|
|
blip3o/model/multimodal_decoder/__pycache__/builder.cpython-311.pyc
ADDED
|
Binary file (954 Bytes). View file
|
|
|
blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-310.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-311.pyc
ADDED
|
Binary file (6.73 kB). View file
|
|
|
blip3o/model/multimodal_decoder/builder.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import AutoencoderDC, SanaTransformer2DModel
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def build_sana(vision_tower_cfg, **kwargs):
|
| 6 |
+
sana = SanaTransformer2DModel.from_pretrained(vision_tower_cfg.diffusion_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16)
|
| 7 |
+
return sana
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_vae(vision_tower_cfg, **kwargs):
|
| 11 |
+
vae = AutoencoderDC.from_pretrained(vision_tower_cfg.diffusion_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16)
|
| 12 |
+
return vae
|
| 13 |
+
|
| 14 |
+
|
blip3o/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
ADDED
|
Binary file (472 Bytes). View file
|
|
|
blip3o/model/multimodal_encoder/__pycache__/builder.cpython-311.pyc
ADDED
|
Binary file (639 Bytes). View file
|
|
|
blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-310.pyc
ADDED
|
Binary file (3.72 kB). View file
|
|
|
blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-311.pyc
ADDED
|
Binary file (6.74 kB). View file
|
|
|