jiuhai commited on
Commit
6858cdd
·
verified ·
1 Parent(s): 6b139fc

Upload 59 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. blip3o/__init__.py +0 -0
  3. blip3o/__pycache__/__init__.cpython-310.pyc +0 -0
  4. blip3o/__pycache__/__init__.cpython-311.pyc +0 -0
  5. blip3o/__pycache__/constants.cpython-310.pyc +0 -0
  6. blip3o/__pycache__/constants.cpython-311.pyc +0 -0
  7. blip3o/__pycache__/utils.cpython-310.pyc +0 -0
  8. blip3o/__pycache__/utils.cpython-311.pyc +0 -0
  9. blip3o/constants.py +7 -0
  10. blip3o/conversation.py +296 -0
  11. blip3o/data/__init__.py +1 -0
  12. blip3o/data/__pycache__/__init__.cpython-310.pyc +0 -0
  13. blip3o/data/__pycache__/__init__.cpython-311.pyc +0 -0
  14. blip3o/data/__pycache__/dataset.cpython-310.pyc +0 -0
  15. blip3o/data/__pycache__/dataset.cpython-311.pyc +0 -0
  16. blip3o/data/dataset.py +371 -0
  17. blip3o/mm_utils.py +65 -0
  18. blip3o/model/__init__.py +3 -0
  19. blip3o/model/__pycache__/__init__.cpython-310.pyc +0 -0
  20. blip3o/model/__pycache__/__init__.cpython-311.pyc +0 -0
  21. blip3o/model/__pycache__/blip3o_arch.cpython-310.pyc +0 -0
  22. blip3o/model/__pycache__/blip3o_arch.cpython-311.pyc +0 -0
  23. blip3o/model/__pycache__/llava_arch.cpython-310.pyc +0 -0
  24. blip3o/model/__pycache__/llava_arch.cpython-311.pyc +0 -0
  25. blip3o/model/blip3o_arch.py +400 -0
  26. blip3o/model/builder.py +44 -0
  27. blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-310.pyc +0 -0
  28. blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-311.pyc +0 -0
  29. blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-310.pyc +0 -0
  30. blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-311.pyc +0 -0
  31. blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-310.pyc +0 -0
  32. blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-311.pyc +0 -0
  33. blip3o/model/language_model/__pycache__/llava_qwen.cpython-310.pyc +0 -0
  34. blip3o/model/language_model/__pycache__/llava_qwen.cpython-311.pyc +0 -0
  35. blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-310.pyc +0 -0
  36. blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-311.pyc +0 -0
  37. blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-310.pyc +0 -0
  38. blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-311.pyc +0 -0
  39. blip3o/model/language_model/blip3o_qwen.py +215 -0
  40. blip3o/model/language_model/blip3o_qwen_grpo.py +255 -0
  41. blip3o/model/language_model/blip3o_qwen_inference.py +241 -0
  42. blip3o/model/multimodal_decoder/__pycache__/builder.cpython-310.pyc +0 -0
  43. blip3o/model/multimodal_decoder/__pycache__/builder.cpython-311.pyc +0 -0
  44. blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-310.pyc +0 -0
  45. blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-311.pyc +0 -0
  46. blip3o/model/multimodal_decoder/builder.py +14 -0
  47. blip3o/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  48. blip3o/model/multimodal_encoder/__pycache__/builder.cpython-311.pyc +0 -0
  49. blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-310.pyc +0 -0
  50. blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ blip3o/train/__pycache__/grpo_trainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
blip3o/__init__.py ADDED
File without changes
blip3o/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
blip3o/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (156 Bytes). View file
 
blip3o/__pycache__/constants.cpython-310.pyc ADDED
Binary file (357 Bytes). View file
 
blip3o/__pycache__/constants.cpython-311.pyc ADDED
Binary file (389 Bytes). View file
 
blip3o/__pycache__/utils.cpython-310.pyc ADDED
Binary file (610 Bytes). View file
 
blip3o/__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
blip3o/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IGNORE_INDEX = -100
3
+ IMAGE_TOKEN_INDEX = -200
4
+ DEFAULT_IMAGE_TOKEN = "<image>"
5
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
6
+ DEFAULT_IM_START_TOKEN = "<im_start>"
7
+ DEFAULT_IM_END_TOKEN = "<im_end>"
blip3o/conversation.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import dataclasses
3
+ import re
4
+ from enum import Enum, auto
5
+ from io import BytesIO
6
+ from typing import Any, Dict, List, Tuple, Union
7
+
8
+ from PIL import Image
9
+
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+
14
+ SINGLE = auto()
15
+ TWO = auto()
16
+ PLAIN = auto()
17
+ CHATML = auto()
18
+ QWEN = auto()
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class Conversation:
23
+ """A class that keeps all conversation history."""
24
+
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
30
+ sep: str = "###"
31
+ sep2: str = None
32
+ version: str = "Unknown"
33
+
34
+ tokenizer_id: str = ""
35
+ tokenizer: Any = None
36
+ # Stop criteria (the default one is EOS token)
37
+ stop_str: Union[str, List[str]] = None
38
+ # Stops generation if meeting any token in this list
39
+ stop_token_ids: List[int] = None
40
+
41
+ skip_next: bool = False
42
+
43
+ def get_prompt(self):
44
+ messages = self.messages
45
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
46
+ messages = self.messages.copy()
47
+ init_role, init_msg = messages[0].copy()
48
+ init_msg = init_msg[0]
49
+ if "mmtag" in self.version:
50
+ init_msg = init_msg.replace("<image>", "").strip()
51
+ messages[0] = (init_role, init_msg)
52
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
53
+ messages.insert(1, (self.roles[1], "Received."))
54
+ elif not init_msg.startswith("<image>"):
55
+ init_msg = init_msg.replace("<image>", "").strip()
56
+ messages[0] = (init_role, "<image>\n" + init_msg)
57
+ else:
58
+ messages[0] = (init_role, init_msg)
59
+
60
+ if self.sep_style == SeparatorStyle.SINGLE:
61
+ ret = self.system + self.sep
62
+ for role, message in messages:
63
+ if message:
64
+ if type(message) is tuple:
65
+ message, _, _ = message
66
+ ret += role + ": " + message + self.sep
67
+ else:
68
+ ret += role + ":"
69
+
70
+ elif self.sep_style == SeparatorStyle.TWO:
71
+ seps = [self.sep, self.sep2]
72
+ ret = self.system + seps[0]
73
+ for i, (role, message) in enumerate(messages):
74
+ if message:
75
+ if type(message) is tuple:
76
+ message, _, _ = message
77
+ ret += role + ": " + message + seps[i % 2]
78
+ else:
79
+ ret += role + ":"
80
+
81
+ elif self.sep_style == SeparatorStyle.CHATML:
82
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
83
+ for role, message in messages:
84
+ if message:
85
+ if type(message) is tuple:
86
+ message, images, _ = message
87
+ message = "<image>" * len(images) + message
88
+ ret += role + "\n" + message + self.sep + "\n"
89
+ else:
90
+ ret += role + "\n"
91
+ return ret
92
+
93
+ elif self.sep_style == SeparatorStyle.PLAIN:
94
+ seps = [self.sep, self.sep2]
95
+ ret = self.system
96
+ for i, (role, message) in enumerate(messages):
97
+ if message:
98
+ if type(message) is tuple:
99
+ message, _, _ = message
100
+ ret += message + seps[i % 2]
101
+ else:
102
+ ret += ""
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ return ret
107
+
108
+ def append_message(self, role, message):
109
+ self.messages.append([role, message])
110
+
111
+ def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
112
+ if image_process_mode == "Pad":
113
+
114
+ def expand2square(pil_img, background_color=(122, 116, 104)):
115
+ width, height = pil_img.size
116
+ if width == height:
117
+ return pil_img
118
+ elif width > height:
119
+ result = Image.new(pil_img.mode, (width, width), background_color)
120
+ result.paste(pil_img, (0, (width - height) // 2))
121
+ return result
122
+ else:
123
+ result = Image.new(pil_img.mode, (height, height), background_color)
124
+ result.paste(pil_img, ((height - width) // 2, 0))
125
+ return result
126
+
127
+ image = expand2square(image)
128
+ elif image_process_mode in ["Default", "Crop"]:
129
+ pass
130
+ elif image_process_mode == "Resize":
131
+ image = image.resize((336, 336))
132
+ else:
133
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
134
+
135
+ if type(image) is not Image.Image:
136
+ image = Image.open(image).convert("RGB")
137
+
138
+ max_hw, min_hw = max(image.size), min(image.size)
139
+ aspect_ratio = max_hw / min_hw
140
+ max_len, min_len = 672, 448
141
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
142
+ longest_edge = int(shortest_edge * aspect_ratio)
143
+ W, H = image.size
144
+ if H > W:
145
+ H, W = longest_edge, shortest_edge
146
+ else:
147
+ H, W = shortest_edge, longest_edge
148
+ image = image.resize((W, H))
149
+ if return_pil:
150
+ return image
151
+ else:
152
+ buffered = BytesIO()
153
+ image.save(buffered, format=image_format)
154
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
155
+ return img_b64_str
156
+
157
+ def get_images(self, return_pil=False, return_path=False):
158
+ images = []
159
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
160
+ if i % 2 == 0:
161
+ if type(msg) is tuple:
162
+ msg, image, image_process_mode = msg
163
+ if type(image) != list:
164
+ image = [image]
165
+ for img in image:
166
+ if not return_path and self.is_image_file(img):
167
+ img = self.process_image(img, image_process_mode, return_pil=return_pil)
168
+ else:
169
+ images.append(img)
170
+ return images
171
+
172
+ def is_image_file(self, filename):
173
+ image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
174
+ return any(filename.lower().endswith(ext) for ext in image_extensions)
175
+
176
+ def is_video_file(self, filename):
177
+ video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
178
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
179
+
180
+ def to_gradio_chatbot(self):
181
+ ret = []
182
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
183
+ if i % 2 == 0:
184
+ if type(msg) is tuple:
185
+ msg, image, image_process_mode = msg
186
+ if type(image) != list:
187
+ image = [image]
188
+ if len(image) == 1:
189
+ msg = "<image>\n" + msg.replace("<image>", "").strip()
190
+ else:
191
+ msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
192
+
193
+ img_str_list = []
194
+ for img in image:
195
+ if self.is_image_file(img):
196
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
197
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
198
+ img_str_list.append(img_str)
199
+ elif self.is_video_file(img):
200
+ ret.append(((img,), None))
201
+
202
+ msg = msg.strip()
203
+ img_place_holder = ""
204
+ for img_str in img_str_list:
205
+ img_place_holder += f"{img_str}\n\n"
206
+
207
+ if len(img_str_list) > 0:
208
+ msg = f"{img_place_holder}\n\n{msg}"
209
+
210
+ if len(msg) > 0:
211
+ ret.append([msg, None])
212
+ else:
213
+ ret.append([msg, None])
214
+ else:
215
+ ret[-1][-1] = msg
216
+ return ret
217
+
218
+ def copy(self):
219
+ return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
220
+
221
+ def dict(self):
222
+ if len(self.get_images()) > 0:
223
+ return {
224
+ "system": self.system,
225
+ "roles": self.roles,
226
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
227
+ "offset": self.offset,
228
+ "sep": self.sep,
229
+ "sep2": self.sep2,
230
+ }
231
+ return {
232
+ "system": self.system,
233
+ "roles": self.roles,
234
+ "messages": self.messages,
235
+ "offset": self.offset,
236
+ "sep": self.sep,
237
+ "sep2": self.sep2,
238
+ }
239
+
240
+
241
+ conv_vicuna_v0 = Conversation(
242
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
243
+ roles=("Human", "Assistant"),
244
+ messages=[
245
+ ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
246
+ [
247
+ "Assistant",
248
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
249
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
250
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
251
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
252
+ "renewable and non-renewable energy sources:\n"
253
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
254
+ "energy sources are finite and will eventually run out.\n"
255
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
256
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
257
+ "and other negative effects.\n"
258
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
259
+ "have lower operational costs than non-renewable sources.\n"
260
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
261
+ "locations than non-renewable sources.\n"
262
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
263
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
264
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
265
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
266
+ ],
267
+ ],
268
+ offset=2,
269
+ sep_style=SeparatorStyle.SINGLE,
270
+ sep="###",
271
+ )
272
+
273
+
274
+ conv_qwen = Conversation(
275
+ system="""<|im_start|>system
276
+ You are a helpful assistant.""",
277
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
278
+ version="qwen",
279
+ messages=[],
280
+ offset=0,
281
+ sep_style=SeparatorStyle.CHATML,
282
+ sep="<|im_end|>",
283
+ )
284
+
285
+
286
+ default_conversation = conv_vicuna_v0
287
+ conv_templates = {
288
+ "default": conv_vicuna_v0,
289
+ "v0": conv_vicuna_v0,
290
+ "qwen_1_5": conv_qwen,
291
+ "qwen_2": conv_qwen,
292
+ }
293
+
294
+
295
+ if __name__ == "__main__":
296
+ print(default_conversation.get_prompt())
blip3o/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import *
blip3o/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (174 Bytes). View file
 
blip3o/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
blip3o/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
blip3o/data/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
blip3o/data/dataset.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import glob
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ import re
9
+ from dataclasses import dataclass
10
+ from typing import Dict, List, Optional, Sequence
11
+ import pyarrow.parquet as pq
12
+ import torch
13
+ import transformers
14
+ import yaml
15
+ from PIL import Image, ImageFile
16
+ from torch.utils.data import Dataset
17
+ from torchvision.transforms import v2
18
+ from torchvision import transforms
19
+ from datasets import load_dataset, concatenate_datasets
20
+ from blip3o.constants import (
21
+ DEFAULT_IM_END_TOKEN,
22
+ DEFAULT_IM_START_TOKEN,
23
+ DEFAULT_IMAGE_TOKEN,
24
+ IGNORE_INDEX,
25
+ IMAGE_TOKEN_INDEX,
26
+ )
27
+ from blip3o.utils import rank0_print
28
+
29
+
30
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
31
+
32
+
33
+ ## target transform for sana
34
+ target_transform = v2.Compose(
35
+ [
36
+ v2.Resize(1024),
37
+ v2.CenterCrop(1024),
38
+ v2.ToImage(),
39
+ v2.ToDtype(torch.float32, scale=True),
40
+ v2.Normalize([0.5], [0.5]),
41
+ ]
42
+ )
43
+
44
+
45
+ def expand2square(pil_img, background_color):
46
+ width, height = pil_img.size
47
+ if width == height:
48
+ return pil_img
49
+ elif width > height:
50
+ result = Image.new(pil_img.mode, (width, width), background_color)
51
+ result.paste(pil_img, (0, (width - height) // 2))
52
+ return result
53
+ else:
54
+ result = Image.new(pil_img.mode, (height, height), background_color)
55
+ result.paste(pil_img, ((height - width) // 2, 0))
56
+ return result
57
+
58
+
59
+ def preprocess_multimodal(sources: Sequence[str], data_args) -> Dict:
60
+ is_multimodal = data_args.is_multimodal
61
+ if not is_multimodal:
62
+ return sources
63
+
64
+ for source in sources:
65
+ for sentence in source:
66
+ replace_token = DEFAULT_IMAGE_TOKEN
67
+ # NOTE: only add im_start_end when image generation
68
+ if data_args.mm_use_im_start_end and sentence['from'] == 'gpt':
69
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
70
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
71
+
72
+ # For videoInstruct-100k noisy_data. TODO: Ask Yuanhan to clean the data instead of leaving the noise code here.
73
+ sentence["value"] = sentence["value"].replace("QA_GT_caption_based_noisy", "")
74
+
75
+ return sources
76
+
77
+
78
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
79
+ # roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
80
+ roles = {"human": "user", "gpt": "assistant"}
81
+
82
+ #tokenizer = copy.deepcopy(tokenizer)
83
+ # When there is actually an image, we add the image tokens as a special token
84
+ if 'image_token_index' not in globals():
85
+ tokenizer.add_tokens(["<image>"], special_tokens=True)
86
+ global image_token_index
87
+ image_token_index = tokenizer.convert_tokens_to_ids("<image>")
88
+ # if has_image:
89
+ # tokenizer.add_tokens(["<image>"], special_tokens=True)
90
+
91
+ # image_token_index = tokenizer.convert_tokens_to_ids("<image>")
92
+ im_start, im_end = tokenizer.additional_special_tokens_ids[:2]
93
+ # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"]
94
+ unmask_tokens_idx = [198, im_start, im_end]
95
+ # nl_tokens = tokenizer("\n").input_ids
96
+
97
+ # Reset Qwen chat templates so that it won't include system message every time we apply
98
+ chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
99
+ tokenizer.chat_template = chat_template
100
+
101
+ # _system = tokenizer("system").input_ids + nl_tokens
102
+ # _user = tokenizer("user").input_ids + nl_tokens
103
+ # _assistant = tokenizer("assistant").input_ids + nl_tokens
104
+
105
+ # Apply prompt templates
106
+ input_ids, targets = [], []
107
+ for i, source in enumerate(sources):
108
+ if roles[source[0]["from"]] != roles["human"]:
109
+ source = source[1:]
110
+
111
+ input_id, target = [], []
112
+
113
+ # New version, use apply chat template
114
+ # Build system message for each sentence
115
+ input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
116
+
117
+
118
+ # target += [IGNORE_INDEX] * len(input_id)
119
+ target += input_id
120
+
121
+ for conv in source:
122
+ # Make sure blip3o data can load
123
+ try:
124
+ role = conv["role"]
125
+ content = conv["content"]
126
+ except:
127
+ role = conv["from"]
128
+ content = conv["value"]
129
+
130
+ role = roles.get(role, role)
131
+
132
+ conv = [{"role" : role, "content" : content}]
133
+ encode_id = tokenizer.apply_chat_template(conv)
134
+ input_id += encode_id
135
+ if role in ["user", "system"]:
136
+ # target += [IGNORE_INDEX] * len(encode_id)
137
+ target += encode_id
138
+
139
+ else:
140
+ target += encode_id
141
+
142
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
143
+ for idx, encode_id in enumerate(input_id):
144
+ if encode_id in unmask_tokens_idx:
145
+ target[idx] = encode_id
146
+ if encode_id == image_token_index:
147
+ input_id[idx] = IMAGE_TOKEN_INDEX
148
+ input_ids.append(input_id)
149
+ targets.append(target)
150
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
151
+ targets = torch.tensor(targets, dtype=torch.long)
152
+
153
+ return dict(
154
+ input_ids=input_ids,
155
+ labels=targets,
156
+ )
157
+
158
+
159
+
160
+ class LazySupervisedMixDataset(Dataset):
161
+ """Dataset for supervised fine-tuning."""
162
+
163
+ def __init__(
164
+ self,
165
+ tokenizer: transformers.PreTrainedTokenizer,
166
+ data_path: str,
167
+ data_args
168
+ ):
169
+ super(LazySupervisedMixDataset, self).__init__()
170
+
171
+ self.data_args = data_args
172
+ list_data_dict = []
173
+
174
+
175
+ data_files = glob.glob('/fsx/sfr/data/jiuhai/hub/datasets--BLIP3o--BLIP3o-60k/snapshots/f7316b0aa446338ee1707484924aa59457b4bbf3/*.tar')
176
+ data_files.sort()
177
+ train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=1, cache_dir='/fsx/sfr/data/jiuhai/webdataset')
178
+ train_dataset = train_dataset.rename_column("jpg", "image")
179
+ train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I'])
180
+ train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (
181
+ ["image", "txt", "type"])])
182
+ print(f"finish loading image {len(train_dataset)}")
183
+ list_data_dict.append(train_dataset)
184
+
185
+
186
+
187
+ if len(list_data_dict) > 1:
188
+ list_data_dict = concatenate_datasets(list_data_dict)
189
+ else:
190
+ list_data_dict = list_data_dict[0]
191
+ list_data_dict = list_data_dict.shuffle(seed=42)
192
+
193
+
194
+ rank0_print(f"Totoal number of training instance: {len(list_data_dict)}")
195
+ self.tokenizer = tokenizer
196
+ self.list_data_dict = list_data_dict
197
+ self.modality = torch.tensor(0) # 0 is for und task, 1 is for gen task
198
+
199
+
200
+ def __len__(self):
201
+ return len(self.list_data_dict)
202
+
203
+
204
+ def process_image(self, image):
205
+ processor = self.data_args.image_processor
206
+ image_size = image.size
207
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
208
+ return image, image_size, self.modality
209
+
210
+
211
+ def process_target_image(self, image):
212
+ image = target_transform(image)
213
+ return image
214
+
215
+
216
+ @property
217
+ def lengths(self):
218
+ length_list = []
219
+ for sample in self.list_data_dict:
220
+ img_tokens = 128 if "image" in sample else 0
221
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
222
+ return length_list
223
+
224
+ @property
225
+ def modality_lengths(self):
226
+ length_list = []
227
+ for sample in self.list_data_dict:
228
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
229
+ cur_len = cur_len if "image" in sample else -cur_len
230
+ length_list.append(cur_len)
231
+ return length_list
232
+
233
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
234
+
235
+ while True:
236
+ sources = self.list_data_dict[i]
237
+
238
+
239
+ if sources["type"] == "T2I":
240
+
241
+ sources["conversations"] = [
242
+ {"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"},
243
+ {"from": "gpt", "value": "<image>"},
244
+ ]
245
+
246
+
247
+ elif sources["type"] == "I2I":
248
+ sources["conversations"] = [
249
+ {
250
+ "from": "human",
251
+ "value": f"<image>\nPlease reconstruct the given image.",
252
+ },
253
+ {"from": "gpt", "value": ""},
254
+ ]
255
+
256
+ else:
257
+ raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
258
+
259
+ if "image" in sources:
260
+
261
+ if sources["type"] == "T2I" or sources["type"] == "I2I":
262
+ image_files = self.list_data_dict[i]["image"]
263
+
264
+ if not isinstance(image_files, list):
265
+ image_files = [image_files]
266
+
267
+ images = []
268
+
269
+ for img in image_files:
270
+ try:
271
+ if sources["type"] == "T2I" or sources["type"] == "I2I":
272
+ img = img.convert("RGB")
273
+ else:
274
+ raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
275
+ images.append(img)
276
+ except Exception as e:
277
+ print(f"Error opening image {img}: {e}")
278
+ images = None
279
+ break # Skip to the next image if there's an error
280
+
281
+
282
+ ## test if can apply img_process
283
+ if not images is None:
284
+ try:
285
+ process_images = [self.process_image(f) for f in images]
286
+ except Exception as e:
287
+ print(f"Error wrong number of channels: {e}")
288
+ images = None
289
+
290
+
291
+ # If no valid images were found, randomly pick another item
292
+ if images is None:
293
+ print(sources)
294
+ print(f"warning false image!!!!!!")
295
+ i = random.randint(0, len(self.list_data_dict) - 1)
296
+ continue
297
+
298
+ sources = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args)
299
+ else:
300
+ sources = copy.deepcopy([sources["conversations"]])
301
+
302
+ data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i]))
303
+ if isinstance(i, int):
304
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
305
+
306
+
307
+ # image exist in the data
308
+ if "image" in self.list_data_dict[i]:
309
+ data_dict["image"] = process_images
310
+ data_dict["target_image"] = [self.process_target_image(f) for f in images]
311
+
312
+ data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk"
313
+ return data_dict
314
+
315
+
316
+
317
+ @dataclass
318
+ class DataCollatorForSupervisedDataset(object):
319
+ """Collate examples for supervised fine-tuning."""
320
+
321
+ tokenizer: transformers.PreTrainedTokenizer
322
+
323
+ def pad_sequence(self, input_ids, batch_first, padding_value):
324
+ if self.tokenizer.padding_side == "left":
325
+ input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
326
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
327
+ if self.tokenizer.padding_side == "left":
328
+ input_ids = torch.flip(input_ids, [1])
329
+ return input_ids
330
+
331
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
332
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
333
+ input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
334
+ labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
335
+ if self.tokenizer.pad_token_id is None:
336
+ self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
337
+ input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
338
+ labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
339
+ batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
340
+ if "image" in instances[0]:
341
+ images = [instance["image"] for instance in instances]
342
+
343
+ batch["image_sizes"] = [im[1] for im_list in images for im in im_list]
344
+ batch["modalities"] = [im[2] for im_list in images for im in im_list]
345
+ images = [im[0] for im_list in images for im in im_list]
346
+
347
+ batch["images"] = images
348
+
349
+ target_images = [instance["target_image"][0] for instance in instances]
350
+ target_images = torch.stack(target_images, dim=0) if target_images else None
351
+ batch["target_images"] = target_images
352
+
353
+
354
+ if "prompt" in instances[0]:
355
+ batch["prompts"] = [instance["prompt"] for instance in instances]
356
+ return batch
357
+
358
+ def get_dataset_cls(name):
359
+
360
+ if name == 'mix':
361
+ dataset_cls = LazySupervisedMixDataset
362
+ else:
363
+ raise ValueError(f'Unknown dataset class {name}')
364
+ return dataset_cls
365
+
366
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
367
+ """Make dataset and collator for supervised fine-tuning."""
368
+ dataset_cls = get_dataset_cls(data_args.dataset_cls)
369
+ train_dataset = dataset_cls(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
370
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
371
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
blip3o/mm_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import StoppingCriteria
3
+
4
+ from blip3o.constants import IMAGE_TOKEN_INDEX
5
+
6
+
7
+ def process_images(images, image_processor, model_cfg):
8
+ return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
9
+
10
+
11
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
12
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
13
+
14
+ def insert_separator(X, sep):
15
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
16
+
17
+ input_ids = []
18
+ offset = 0
19
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
20
+ offset = 1
21
+ input_ids.append(prompt_chunks[0][0])
22
+
23
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
24
+ input_ids.extend(x[offset:])
25
+
26
+ if return_tensors is not None:
27
+ if return_tensors == "pt":
28
+ return torch.tensor(input_ids, dtype=torch.long)
29
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
30
+ return input_ids
31
+
32
+
33
+ def get_model_name_from_path(model_path):
34
+ model_path = model_path.strip("/")
35
+ model_paths = model_path.split("/")
36
+ if model_paths[-1].startswith("checkpoint-"):
37
+ return model_paths[-2] + "_" + model_paths[-1]
38
+ else:
39
+ return model_paths[-1]
40
+
41
+
42
+ class KeywordsStoppingCriteria(StoppingCriteria):
43
+ def __init__(self, keywords, tokenizer, input_ids):
44
+ self.keywords = keywords
45
+ self.keyword_ids = []
46
+ for keyword in keywords:
47
+ cur_keyword_ids = tokenizer(keyword).input_ids
48
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
49
+ cur_keyword_ids = cur_keyword_ids[1:]
50
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
51
+ self.tokenizer = tokenizer
52
+ self.start_len = input_ids.shape[1]
53
+
54
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
55
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
56
+ offset = min(output_ids.shape[1] - self.start_len, 3)
57
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
58
+ for keyword_id in self.keyword_ids:
59
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
60
+ return True
61
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
62
+ for keyword in self.keywords:
63
+ if keyword in outputs:
64
+ return True
65
+ return False
blip3o/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from blip3o.model.language_model.blip3o_qwen import blip3oQwenConfig, blip3oQwenForCausalLM
2
+ from blip3o.model.language_model.blip3o_qwen_inference import blip3oQwenForInferenceLM
3
+ from blip3o.model.language_model.blip3o_qwen_grpo import blip3oQwenForGRPOLM
blip3o/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (449 Bytes). View file
 
blip3o/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (512 Bytes). View file
 
blip3o/model/__pycache__/blip3o_arch.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
blip3o/model/__pycache__/blip3o_arch.cpython-311.pyc ADDED
Binary file (26.1 kB). View file
 
blip3o/model/__pycache__/llava_arch.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
blip3o/model/__pycache__/llava_arch.cpython-311.pyc ADDED
Binary file (26 kB). View file
 
blip3o/model/blip3o_arch.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from blip3o.constants import (
9
+ DEFAULT_IM_END_TOKEN,
10
+ DEFAULT_IM_START_TOKEN,
11
+ IGNORE_INDEX,
12
+ IMAGE_TOKEN_INDEX,
13
+ )
14
+ from blip3o.utils import rank0_print
15
+ from .multimodal_encoder.builder import build_vision_tower
16
+ from .multimodal_decoder.builder import build_sana, build_vae
17
+ from diffusers.models.normalization import RMSNorm
18
+ from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaTransformer2DModel
19
+ import math
20
+
21
+ class blip3oMetaModel:
22
+
23
+ def __init__(self, config):
24
+ super(blip3oMetaModel, self).__init__(config)
25
+
26
+ if hasattr(config, "mm_vision_tower"):
27
+ delay_load = getattr(config, "delay_load", False)
28
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
29
+
30
+ self.sana = build_sana(config)
31
+ self.sana_vae = build_vae(config)
32
+ norm = RMSNorm(2304, eps=1e-5, elementwise_affine=True)
33
+
34
+ with torch.no_grad():
35
+ norm.weight.fill_(math.sqrt(5.5))
36
+ self.diffusion_connector = nn.Sequential(
37
+ nn.Linear(config.hidden_size, 2304),
38
+ nn.GELU(approximate="tanh"),
39
+ nn.Linear(2304, 2304),
40
+ norm,
41
+ )
42
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.diffusion_name_or_path, subfolder="scheduler")
43
+
44
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.diffusion_name_or_path, subfolder="scheduler")
45
+
46
+
47
+ def get_vision_tower(self):
48
+ vision_tower = getattr(self, "vision_tower", None)
49
+ if type(vision_tower) is list:
50
+ vision_tower = vision_tower[0]
51
+ return vision_tower
52
+
53
+
54
+ def get_sana(self):
55
+ sana = getattr(self, 'sana', None)
56
+ if type(sana) is list:
57
+ sana = sana[0]
58
+ if sana is not None:
59
+ sana.to(self.device)
60
+ return sana
61
+
62
+ def get_sana_vae(self):
63
+ sana_vae = getattr(self, 'sana_vae', None)
64
+ if type(sana_vae) is list:
65
+ sana_vae = sana_vae[0]
66
+ if sana_vae is not None:
67
+ sana_vae.to(self.device)
68
+ return sana_vae
69
+
70
+ def initialize_vision_modules(self, model_args, fsdp=None):
71
+ vision_tower = model_args.vision_tower
72
+ mm_vision_select_layer = model_args.mm_vision_select_layer
73
+ mm_vision_select_feature = model_args.mm_vision_select_feature
74
+ mm_patch_merge_type = model_args.mm_patch_merge_type
75
+
76
+ self.config.mm_vision_tower = vision_tower
77
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
78
+
79
+ if self.get_vision_tower() is None:
80
+ vision_tower = build_vision_tower(model_args)
81
+
82
+ if fsdp is not None and len(fsdp) > 0:
83
+ self.vision_tower = [vision_tower]
84
+ else:
85
+ self.vision_tower = vision_tower
86
+ else:
87
+ if fsdp is not None and len(fsdp) > 0:
88
+ vision_tower = self.vision_tower[0]
89
+ else:
90
+ vision_tower = self.vision_tower
91
+ vision_tower.load_model()
92
+
93
+
94
+ if self.get_sana() is None:
95
+ sana = build_sana(model_args)
96
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler"
97
+ )
98
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler")
99
+
100
+ if fsdp is not None and len(fsdp) > 0:
101
+ self.sana = [sana]
102
+ else:
103
+ self.sana = sana
104
+ else:
105
+ if fsdp is not None and len(fsdp) > 0:
106
+ sana = self.sana[0]
107
+ else:
108
+ sana = self.sana
109
+
110
+
111
+ if self.get_sana_vae() is None:
112
+ sana_vae = build_vae(model_args)
113
+
114
+ if fsdp is not None and len(fsdp) > 0:
115
+ self.sana_vae = [sana_vae]
116
+ else:
117
+ self.sana_vae = sana_vae
118
+ else:
119
+ if fsdp is not None and len(fsdp) > 0:
120
+ sana_vae = self.sana_vae[0]
121
+ else:
122
+ sana_vae = self.sana_vae
123
+
124
+
125
+ if getattr(self, 'diffusion_connector', None) is None:
126
+ norm = RMSNorm(2304, eps=1e-5, elementwise_affine=True)
127
+ with torch.no_grad():
128
+ norm.weight.fill_(math.sqrt(5.5))
129
+ self.diffusion_connector = nn.Sequential(
130
+ nn.Linear(self.config.hidden_size, 2304),
131
+ nn.GELU(approximate="tanh"),
132
+ nn.Linear(2304, 2304),
133
+ norm,
134
+ )
135
+ else:
136
+ for p in self.diffusion_connector.parameters():
137
+ p.requires_grad = True
138
+
139
+ self.config.use_mm_proj = True
140
+ self.config.mm_hidden_size = vision_tower.hidden_size
141
+ self.config.mm_vision_select_layer = mm_vision_select_layer
142
+ self.config.mm_vision_select_feature = mm_vision_select_feature
143
+ self.config.mm_patch_merge_type = mm_patch_merge_type
144
+
145
+
146
+ class blip3oMetaForCausalLM(ABC):
147
+
148
+ @abstractmethod
149
+ def get_model(self):
150
+ pass
151
+
152
+ def get_vision_tower(self):
153
+ return self.get_model().get_vision_tower()
154
+
155
+ def encode_images(self, images, modalities, pool_scale=None):
156
+ image_features = self.get_model().get_vision_tower()(images, pool_scale=pool_scale)
157
+
158
+ assert 'tokens' in image_features
159
+ image_tokens = image_features['tokens']
160
+
161
+ # discrete features for gen related tasks
162
+ image_tokens = image_tokens + self.config.image_start_token_id
163
+ image_features = self.get_model().embed_tokens(image_tokens)
164
+
165
+ return {'image_features': image_features, 'image_tokens': image_tokens}
166
+
167
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=None, image_sizes=None):
168
+ vision_tower = self.get_vision_tower()
169
+
170
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
171
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
172
+
173
+ if not isinstance(modalities, list):
174
+ modalities = [modalities]
175
+
176
+ # random scale for training, but scale 1 for understanding evaluation
177
+ if self.training:
178
+ pool_scale = random.choice(vision_tower.pool_scales)
179
+ else:
180
+ pool_scale = 1
181
+
182
+ if type(images) is list or images.ndim == 5:
183
+ if type(images) is list:
184
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
185
+
186
+ images_list = []
187
+ for image in images:
188
+ if image.ndim == 4:
189
+ images_list.append(image)
190
+ else:
191
+ images_list.append(image.unsqueeze(0))
192
+
193
+ concat_images = torch.cat([image for image in images_list], dim=0)
194
+ split_sizes = [image.shape[0] for image in images_list]
195
+ encoded_image_features = self.encode_images(concat_images, modalities, pool_scale=pool_scale)
196
+ image_tokens = encoded_image_features['image_tokens']
197
+ encoded_image_features = encoded_image_features['image_features']
198
+
199
+ # This is a list, each element is [num_images, patch * patch, dim]
200
+ encoded_image_features = torch.split(encoded_image_features, split_sizes)
201
+ if image_tokens is not None:
202
+ image_tokens = torch.split(image_tokens, split_sizes)
203
+ image_features = []
204
+ for idx, image_feat in enumerate(encoded_image_features):
205
+ image_features.append(image_feat)
206
+
207
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
208
+
209
+ if mm_patch_merge_type == "flat":
210
+ image_features = [x.flatten(0, 1) for x in image_features]
211
+ if image_tokens is not None:
212
+ image_tokens = [x.flatten(0, 1) for x in image_tokens]
213
+ else:
214
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
215
+ else:
216
+ image_features = self.encode_images(images, modalities, pool_scale=pool_scale)
217
+ image_tokens = image_features['image_tokens']
218
+ image_features = image_features['image_features']
219
+ # Let's just add dummy tensors if they do not exist,
220
+ # it is a headache to deal with None all the time.
221
+ # But it is not ideal, and if you have a better idea,
222
+ # please open an issue / submit a PR, thanks.
223
+ breakpoint()
224
+ _labels = labels
225
+ _position_ids = position_ids
226
+ _attention_mask = attention_mask
227
+ if attention_mask is None:
228
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
229
+ else:
230
+ attention_mask = attention_mask.bool()
231
+ if position_ids is None:
232
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
233
+ if labels is None:
234
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
235
+
236
+ # remove the padding using attention_mask -- FIXME
237
+ _input_ids = input_ids
238
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
239
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
240
+
241
+ new_input_embeds = []
242
+ new_labels = []
243
+ cur_image_idx = 0
244
+ # rank_print("Inserting Images embedding")
245
+ for batch_idx, cur_input_ids in enumerate(input_ids):
246
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
247
+ # rank0_print(num_images)
248
+ if num_images == 0:
249
+ # cur_image_features = image_features[cur_image_idx]
250
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
251
+ # cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
252
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_input_embeds_1[0:0]], dim=0)
253
+ new_input_embeds.append(cur_input_embeds)
254
+ new_labels.append(labels[batch_idx])
255
+ cur_image_idx += 1
256
+ continue
257
+
258
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
259
+ cur_input_ids_noim = []
260
+ cur_labels = labels[batch_idx]
261
+ cur_labels_noim = []
262
+ for i in range(len(image_token_indices) - 1):
263
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
264
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
265
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
266
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
267
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
268
+ cur_new_input_embeds = []
269
+ cur_new_labels = []
270
+
271
+ for i in range(num_images + 1):
272
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
273
+ cur_new_labels.append(cur_labels_noim[i])
274
+ if i < num_images:
275
+ try:
276
+ cur_image_features = image_features[cur_image_idx]
277
+ except IndexError:
278
+ rank0_print("Error image_features[cur_image_idx]!")
279
+ break
280
+ # [Assisant\n<start_image><image><end_image>]
281
+ if self.config.image_start_tag_id == cur_labels_noim[i][-1] and image_tokens is not None:
282
+ cur_image_tokens = image_tokens[cur_image_idx]
283
+ if pool_scale is not None:
284
+ pool_token = self.config.scale_start_token_id + pool_scale - 1
285
+ pool_token = torch.tensor([pool_token], dtype=torch.long, device=cur_image_tokens.device)
286
+ cur_image_tokens = torch.cat([pool_token, cur_image_tokens])
287
+ pool_embed = self.get_model().embed_tokens(pool_token)
288
+ cur_image_features = torch.cat([pool_embed, cur_image_features])
289
+ else:
290
+ cur_image_tokens = torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)
291
+ cur_image_idx += 1
292
+ cur_new_input_embeds.append(cur_image_features)
293
+ cur_new_labels.append(cur_image_tokens)
294
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
295
+
296
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
297
+ cur_new_labels = torch.cat(cur_new_labels)
298
+
299
+ new_input_embeds.append(cur_new_input_embeds)
300
+ new_labels.append(cur_new_labels)
301
+
302
+ # Truncate sequences to max length as image embeddings can make the sequence longer
303
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
304
+
305
+ new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
306
+ new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
307
+
308
+ # Combine them
309
+ max_len = max(x.shape[0] for x in new_input_embeds)
310
+ batch_size = len(new_input_embeds)
311
+
312
+ new_input_embeds_padded = []
313
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
314
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
315
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
316
+
317
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
318
+ cur_len = cur_new_embed.shape[0]
319
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
320
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
321
+ if cur_len > 0:
322
+ new_labels_padded[i, -cur_len:] = cur_new_labels
323
+ attention_mask[i, -cur_len:] = True
324
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
325
+ else:
326
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
327
+ if cur_len > 0:
328
+ new_labels_padded[i, :cur_len] = cur_new_labels
329
+ attention_mask[i, :cur_len] = True
330
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
331
+
332
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
333
+
334
+ if _labels is None:
335
+ new_labels = None
336
+ else:
337
+ new_labels = new_labels_padded
338
+
339
+ if _attention_mask is None:
340
+ attention_mask = None
341
+ else:
342
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
343
+
344
+ if _position_ids is None:
345
+ position_ids = None
346
+ if getattr(self.config, "use_pos_skipping", False) and self.training:
347
+ position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
348
+ split_position = random.randint(0, new_input_embeds.size(1))
349
+ left_add = random.randint(0, self.config.pos_skipping_range)
350
+ right_add = random.randint(left_add, self.config.pos_skipping_range)
351
+ position_ids[:, :split_position] += left_add
352
+ position_ids[:, split_position:] += right_add
353
+
354
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
355
+
356
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
357
+ total_num_new_tokens = 0
358
+ vocab_size = len(tokenizer)
359
+ if model_args.mm_use_im_start_end:
360
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
361
+ self.config.image_start_tag_id = tokenizer.convert_tokens_to_ids(DEFAULT_IM_START_TOKEN)
362
+ self.config.image_end_tag_id = tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN)
363
+ total_num_new_tokens += num_new_tokens
364
+ self.resize_token_embeddings(vocab_size + total_num_new_tokens)
365
+
366
+ if model_args.num_scale_tokens > 0:
367
+ scale_tokens = [model_args.scale_token_format.format(str(i)) for i in range(model_args.num_scale_tokens)]
368
+ num_new_tokens = tokenizer.add_tokens(scale_tokens, special_tokens=False)
369
+ self.config.scale_start_token_id = tokenizer.convert_tokens_to_ids(scale_tokens[0])
370
+ self.config.scale_end_token_id = tokenizer.convert_tokens_to_ids(scale_tokens[-1])
371
+ self.config.num_scale_tokens = model_args.num_scale_tokens
372
+ total_num_new_tokens += num_new_tokens
373
+ self.resize_token_embeddings(vocab_size + total_num_new_tokens)
374
+
375
+ if model_args.num_image_tokens > 0:
376
+ image_tokens = [model_args.image_token_format.format(str(i)) for i in range(model_args.num_image_tokens)]
377
+ num_new_tokens = tokenizer.add_tokens(image_tokens, special_tokens=False)
378
+ self.config.image_start_token_id = tokenizer.convert_tokens_to_ids(image_tokens[0])
379
+ self.config.image_end_token_id = tokenizer.convert_tokens_to_ids(image_tokens[-1])
380
+ self.config.num_image_tokens = model_args.num_image_tokens
381
+
382
+ total_num_new_tokens += num_new_tokens
383
+ self.resize_token_embeddings(vocab_size + total_num_new_tokens)
384
+ if num_new_tokens > 0:
385
+ self.config.num_new_tokens = num_new_tokens
386
+ input_embeddings = self.get_input_embeddings().weight.data
387
+ output_embeddings = self.get_output_embeddings().weight.data
388
+
389
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
390
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
391
+
392
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
393
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
394
+
395
+ vision_tower = self.get_vision_tower()
396
+ if model_args.load_embeddings_from_vision and vision_tower is not None:
397
+ vision_embeddings = vision_tower.get_embedding()
398
+ if model_args.num_image_tokens == vision_embeddings.shape[0] and input_embeddings.shape[1] == vision_embeddings.shape[1]:
399
+ rank0_print("Load vision embeddings from vision tower.")
400
+ input_embeddings[self.config.image_start_token_id:self.config.image_end_token_id+1] = vision_embeddings
blip3o/model/builder.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+
4
+ from blip3o.model import blip3oQwenForCausalLM
5
+ from blip3o.utils import rank0_print
6
+
7
+
8
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
9
+ kwargs["device_map"] = device_map
10
+ kwargs.pop("multimodal")
11
+
12
+ if customized_config is not None:
13
+ kwargs["config"] = customized_config
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
16
+ from blip3o.model.language_model.blip3o_qwen import blip3oQwenConfig
17
+
18
+ breakpoint()
19
+ if overwrite_config is not None:
20
+ blip3o_cfg = blip3oQwenConfig.from_pretrained(model_path)
21
+ rank0_print(f"Overwriting config with {overwrite_config}")
22
+ for k, v in overwrite_config.items():
23
+ setattr(blip3o_cfg, k, v)
24
+ model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=blip3o_cfg, **kwargs)
25
+ else:
26
+ model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
27
+
28
+ vision_tower = model.get_vision_tower()
29
+ if not vision_tower.is_loaded:
30
+ vision_tower.load_model(device_map=device_map)
31
+ if device_map != "auto":
32
+ vision_tower.to(device="cuda", dtype=torch.float16)
33
+ image_processor = vision_tower.image_processor
34
+
35
+ if hasattr(model.config, "max_sequence_length"):
36
+ context_len = model.config.max_sequence_length
37
+ elif hasattr(model.config, "max_position_embeddings"):
38
+ context_len = model.config.max_position_embeddings
39
+ elif hasattr(model.config, "tokenizer_model_max_length"):
40
+ context_len = model.config.tokenizer_model_max_length
41
+ else:
42
+ context_len = 2048
43
+
44
+ return tokenizer, model, image_processor, context_len
blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-310.pyc ADDED
Binary file (6.85 kB). View file
 
blip3o/model/language_model/__pycache__/blip3o_qwen.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-310.pyc ADDED
Binary file (7.82 kB). View file
 
blip3o/model/language_model/__pycache__/blip3o_qwen_grpo.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-310.pyc ADDED
Binary file (7.12 kB). View file
 
blip3o/model/language_model/__pycache__/blip3o_qwen_inference.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
blip3o/model/language_model/__pycache__/llava_qwen.cpython-310.pyc ADDED
Binary file (6.82 kB). View file
 
blip3o/model/language_model/__pycache__/llava_qwen.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-310.pyc ADDED
Binary file (7.79 kB). View file
 
blip3o/model/language_model/__pycache__/llava_qwen_grpo.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-310.pyc ADDED
Binary file (7.09 kB). View file
 
blip3o/model/language_model/__pycache__/llava_qwen_inference.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
blip3o/model/language_model/blip3o_qwen.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import (
6
+ AutoConfig,
7
+ AutoModelForCausalLM,
8
+ Qwen3Config,
9
+ Qwen3ForCausalLM,
10
+ Qwen3Model,
11
+ )
12
+ from transformers.generation.utils import GenerateOutput
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+ from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel
16
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
17
+ from blip3o.utils import rank0_print
18
+
19
+
20
+ class blip3oQwenConfig(Qwen3Config):
21
+ model_type = "blip3o_qwen"
22
+
23
+ class blip3oQwenModel(blip3oMetaModel, Qwen3Model):
24
+ config_class = blip3oQwenConfig
25
+
26
+ def __init__(self, config: Qwen3Config):
27
+ super(blip3oQwenModel, self).__init__(config)
28
+
29
+ class blip3oQwenForCausalLM(Qwen3ForCausalLM, blip3oMetaForCausalLM):
30
+ config_class = blip3oQwenConfig
31
+
32
+ def __init__(self, config):
33
+ Qwen3ForCausalLM.__init__(self, config)
34
+ config.model_type = "blip3o_qwen"
35
+ config.rope_scaling = None
36
+
37
+ self.model = blip3oQwenModel(config)
38
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
39
+
40
+ # Initialize weights and apply final processing
41
+ self.post_init()
42
+
43
+ def get_model(self):
44
+ return self.model
45
+
46
+ def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
47
+ sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype)
48
+ schedule_timesteps = self.model.noise_scheduler.timesteps.to(device)
49
+ timesteps = timesteps.to(device)
50
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
51
+
52
+ sigma = sigmas[step_indices].flatten()
53
+ while len(sigma.shape) < n_dim:
54
+ sigma = sigma.unsqueeze(-1)
55
+ return sigma
56
+
57
+ def mask_drop(self, latents, drop_prob=0.1):
58
+ if drop_prob <= 0:
59
+ return latents
60
+ mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
61
+ while len(mask.shape) < len(latents.shape):
62
+ mask = mask.unsqueeze(-1)
63
+ mask = 1 - mask # need to flip 0 <-> 1
64
+ return latents * mask
65
+
66
+
67
+ def forward(
68
+ self,
69
+ input_ids: torch.LongTensor = None,
70
+ attention_mask: Optional[torch.Tensor] = None,
71
+ position_ids: Optional[torch.LongTensor] = None,
72
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
73
+ inputs_embeds: Optional[torch.FloatTensor] = None,
74
+ labels: Optional[torch.LongTensor] = None,
75
+ use_cache: Optional[bool] = None,
76
+ output_attentions: Optional[bool] = None,
77
+ output_hidden_states: Optional[bool] = None,
78
+ images: Optional[torch.FloatTensor] = None,
79
+ target_images: Optional[torch.FloatTensor] = None,
80
+ image_sizes: Optional[List[List[int]]] = None,
81
+ return_dict: Optional[bool] = None,
82
+ modalities: Optional[List[str]] = ["image"],
83
+ dpo_forward: Optional[bool] = False,
84
+ cache_position=None,
85
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
86
+
87
+
88
+ if inputs_embeds is None:
89
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
90
+ outputs = self.model(
91
+ input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ position_ids=position_ids,
94
+ past_key_values=past_key_values,
95
+ inputs_embeds=inputs_embeds,
96
+ use_cache=use_cache,
97
+ output_attentions=output_attentions,
98
+ output_hidden_states=output_hidden_states,
99
+ return_dict=return_dict,
100
+ )
101
+
102
+ hidden_states = outputs[0]
103
+ logits = self.lm_head(hidden_states)
104
+ if labels is not None:
105
+ shift_logits = logits[..., :-1, :].contiguous()
106
+ shift_labels = labels[..., 1:].contiguous()
107
+ loss_fct = torch.nn.CrossEntropyLoss()
108
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
109
+ shift_labels = shift_labels.view(-1)
110
+ shift_labels = shift_labels.to(shift_logits.device)
111
+ loss = loss_fct(shift_logits, shift_labels)
112
+
113
+
114
+
115
+ if target_images is not None:
116
+ vae = self.model.get_sana_vae()
117
+ latents = vae.encode(target_images).latent
118
+ if "shift_factor" in vae.config and vae.config.shift_factor is not None:
119
+ latents = latents - vae.config.shift_factor
120
+ latents = latents * vae.config.scaling_factor
121
+ noise = torch.randn_like(latents, device=latents.device)
122
+ weighting_scheme = "uniform"
123
+ u = compute_density_for_timestep_sampling(
124
+ weighting_scheme=weighting_scheme,
125
+ batch_size=latents.shape[0],
126
+ logit_mean=0.0,
127
+ logit_std=1.0,
128
+ mode_scale=1.29,
129
+ )
130
+ indices = (u * self.model.noise_scheduler.config.num_train_timesteps).long()
131
+ timesteps = self.model.noise_scheduler.timesteps[indices].to(device=latents.device)
132
+ sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
133
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
134
+
135
+ sana = self.model.get_sana()
136
+
137
+
138
+ start_pos = (labels == self.config.image_start_tag_id).float().argmax(dim=1)
139
+ end_pos = (labels == self.config.image_end_tag_id).float().argmax(dim=1)
140
+
141
+ breakpoint()
142
+ selected_hidden_states = []
143
+ for b in range(hidden_states.size(0)):
144
+ start = start_pos[b].item() + 1
145
+ end = end_pos[b].item()
146
+ hidden_states_filter = hidden_states[b, start:end, :]
147
+ if hidden_states_filter.size(1) != 730:
148
+ hidden_states_filter = hidden_states[b, -730:, :]
149
+ selected_hidden_states.append(hidden_states_filter)
150
+
151
+ selected_hidden_states = torch.stack(selected_hidden_states, dim=0)
152
+ diffusion_pred = sana(
153
+ hidden_states=noisy_latents,
154
+ timestep=timesteps,
155
+ encoder_hidden_states=self.model.diffusion_connector(self.mask_drop(selected_hidden_states)),
156
+ encoder_attention_mask=None,
157
+ ).sample
158
+
159
+ target = noise - latents
160
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
161
+ diff_loss = torch.mean(
162
+ (weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
163
+ 1,
164
+ )
165
+ diff_loss = diff_loss.mean()
166
+ rank0_print(f" Cross-entropy loss {loss}, Diffusion loss {diff_loss} ")
167
+ loss += diff_loss
168
+
169
+
170
+
171
+
172
+ return CausalLMOutputWithPast(
173
+ loss=loss,
174
+ logits=logits,
175
+ past_key_values=outputs.past_key_values,
176
+ hidden_states=outputs.hidden_states,
177
+ attentions=outputs.attentions,
178
+ )
179
+
180
+
181
+ @torch.no_grad()
182
+ def generate(
183
+ self,
184
+ inputs: Optional[torch.Tensor] = None,
185
+ images: Optional[torch.Tensor] = None,
186
+ image_sizes: Optional[torch.Tensor] = None,
187
+ modalities: Optional[List[str]] = ["image"],
188
+ **kwargs,
189
+ ) -> Union[GenerateOutput, torch.LongTensor]:
190
+ position_ids = kwargs.pop("position_ids", None)
191
+ attention_mask = kwargs.pop("attention_mask", None)
192
+ if "inputs_embeds" in kwargs:
193
+ raise NotImplementedError("`inputs_embeds` is not supported")
194
+
195
+ if images is not None:
196
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
197
+ else:
198
+ inputs_embeds = self.get_model().embed_tokens(inputs)
199
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
200
+
201
+
202
+
203
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
204
+ images = kwargs.pop("images", None)
205
+ image_sizes = kwargs.pop("image_sizes", None)
206
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
207
+ if images is not None:
208
+ inputs["images"] = images
209
+ if image_sizes is not None:
210
+ inputs["image_sizes"] = image_sizes
211
+ return inputs
212
+
213
+
214
+ AutoConfig.register("blip3o_qwen", blip3oQwenConfig)
215
+ AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM)
blip3o/model/language_model/blip3o_qwen_grpo.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import (
6
+ AutoConfig,
7
+ AutoModelForCausalLM,
8
+ Qwen3Config,
9
+ Qwen3ForCausalLM,
10
+ Qwen3Model,
11
+ )
12
+ from transformers.generation.utils import GenerateOutput
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+ from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel
16
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.schedulers import DDPMScheduler, DDIMScheduler, LCMScheduler, FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
19
+ import numpy as np
20
+ from tqdm import tqdm
21
+ import PIL
22
+ from blip3o.utils import rank0_print
23
+
24
+
25
+
26
+ def numpy_to_pil(images: np.ndarray):
27
+ """
28
+ Convert a NumPy array of shape (batch, height, width, channels) to a list of PIL Images.
29
+ """
30
+ pil_images = []
31
+ for img in images:
32
+ img_uint8 = (img * 255).round().astype("uint8")
33
+ if img_uint8.shape[2] == 1:
34
+ img_uint8 = img_uint8[..., 0]
35
+ pil_images.append(PIL.Image.fromarray(img_uint8))
36
+ return pil_images
37
+
38
+
39
+
40
+ class blip3oQwenConfig(Qwen3Config):
41
+ model_type = "blip3o_qwen_grpo"
42
+
43
+ class blip3oQwenModel(blip3oMetaModel, Qwen3Model):
44
+ config_class = blip3oQwenConfig
45
+
46
+ def __init__(self, config: Qwen3Config):
47
+ super(blip3oQwenModel, self).__init__(config)
48
+
49
+ class blip3oQwenForGRPOLM(Qwen3ForCausalLM, blip3oMetaForCausalLM):
50
+ config_class = blip3oQwenConfig
51
+
52
+ def __init__(self, config):
53
+ Qwen3ForCausalLM.__init__(self, config)
54
+ config.model_type = "blip3o_qwen"
55
+ config.rope_scaling = None
56
+
57
+ self.model = blip3oQwenModel(config)
58
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
59
+
60
+ # Initialize weights and apply final processing
61
+ self.post_init()
62
+
63
+ def get_model(self):
64
+ return self.model
65
+
66
+ def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
67
+ sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype)
68
+ schedule_timesteps = self.model.noise_scheduler.timesteps.to(device)
69
+ timesteps = timesteps.to(device)
70
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
71
+
72
+ sigma = sigmas[step_indices].flatten()
73
+ while len(sigma.shape) < n_dim:
74
+ sigma = sigma.unsqueeze(-1)
75
+ return sigma
76
+
77
+ def mask_drop(self, latents, drop_prob=0.1):
78
+ if drop_prob <= 0:
79
+ return latents
80
+ mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
81
+ while len(mask.shape) < len(latents.shape):
82
+ mask = mask.unsqueeze(-1)
83
+ mask = 1 - mask # need to flip 0 <-> 1
84
+ return latents * mask
85
+
86
+
87
+
88
+ @torch.no_grad()
89
+ def generate(
90
+ self,
91
+ inputs: Optional[torch.Tensor] = None,
92
+ images: Optional[torch.Tensor] = None,
93
+ image_sizes: Optional[torch.Tensor] = None,
94
+ modalities: Optional[List[str]] = ["image"],
95
+ **kwargs,
96
+ ) -> Union[GenerateOutput, torch.LongTensor]:
97
+ position_ids = kwargs.pop("position_ids", None)
98
+ attention_mask = kwargs.pop("attention_mask", None)
99
+ if "inputs_embeds" in kwargs:
100
+ raise NotImplementedError("`inputs_embeds` is not supported")
101
+
102
+ if images is not None:
103
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
104
+ else:
105
+ inputs_embeds = self.get_model().embed_tokens(inputs)
106
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
107
+
108
+
109
+
110
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
111
+ images = kwargs.pop("images", None)
112
+ image_sizes = kwargs.pop("image_sizes", None)
113
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
114
+ if images is not None:
115
+ inputs["images"] = images
116
+ if image_sizes is not None:
117
+ inputs["image_sizes"] = image_sizes
118
+ return inputs
119
+
120
+
121
+
122
+
123
+
124
+ @torch.no_grad()
125
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
126
+ if self.model.sana_vae is not None:
127
+ latents = latents / self.model.sana_vae.config.scaling_factor
128
+ if "shift_factor" in self.model.sana_vae.config and self.model.sana_vae.config.shift_factor is not None:
129
+ latents = latents + self.model.sana_vae.config.shift_factor
130
+ samples = self.model.sana_vae.decode(latents).sample
131
+ else:
132
+ samples = latents
133
+ if normalize:
134
+ samples = (samples / 2 + 0.5).clamp(0, 1)
135
+ else:
136
+ samples = samples.clamp(-1, 1)
137
+ if return_tensor:
138
+ return samples
139
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
140
+ samples = numpy_to_pil(samples)
141
+ return samples
142
+
143
+
144
+
145
+ @torch.no_grad()
146
+ def generate_images(
147
+ self,
148
+ input_ids: Optional[torch.Tensor] = None,
149
+ attention_mask: Optional[torch.Tensor] = None,
150
+ max_new_tokens: Optional[torch.Tensor] = None,
151
+ temperature: Optional[torch.Tensor] = None,
152
+ top_p: Optional[torch.Tensor] = None,
153
+ top_k: Optional[torch.Tensor] = None,
154
+ images: Optional[torch.Tensor] = None,
155
+ image_sizes: Optional[torch.Tensor] = None,
156
+ modalities: Optional[List[str]] = ["image"],
157
+ guidance_scale: float = 2.0,
158
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
159
+ num_inference_steps: int = 30,
160
+ num_images_per_prompt: int = 1,
161
+ return_tensor=False,
162
+ enable_progress_bar=False,
163
+ **kwargs,
164
+ ):
165
+ position_ids = kwargs.pop("position_ids", None)
166
+ # attention_mask = (inputs != -100).long()
167
+
168
+ gen_ids = super(blip3oQwenForGRPOLM, self).generate(
169
+ input_ids,
170
+ max_new_tokens=max_new_tokens,
171
+ do_sample=True,
172
+ temperature=1.0,
173
+ attention_mask=attention_mask,
174
+ )
175
+
176
+ # breakpoint()
177
+ with torch.no_grad():
178
+ outs = self.model(
179
+ input_ids = gen_ids,
180
+ output_hidden_states = True,
181
+ return_dict = True,
182
+ )
183
+ hidden_states = outs.hidden_states[-1]
184
+
185
+
186
+ start_pos = (gen_ids == self.config.image_start_tag_id).float().argmax(dim=1)
187
+ end_pos = (gen_ids == self.config.image_end_tag_id).float().argmax(dim=1)
188
+
189
+
190
+ selected_hidden_states = []
191
+ for b in range(hidden_states.size(0)):
192
+ start = start_pos[b].item() + 1
193
+ # end = end_pos[b].item()
194
+ selected_hidden_states.append(hidden_states[b, start:, :])
195
+ pred_latent = torch.stack(selected_hidden_states, dim=0)
196
+
197
+
198
+ img_hidden_states_null = torch.zeros_like(pred_latent)
199
+ pred_latent = torch.cat([img_hidden_states_null, pred_latent], 0)
200
+ ## sample images from here
201
+ device = next(self.parameters()).device
202
+ dtype = next(self.parameters()).dtype
203
+
204
+ bsz = len(pred_latent) // 2
205
+ # latent_size = self.config.input_size
206
+ latent_size = 32
207
+ latent_channels = self.model.sana.config.in_channels
208
+
209
+
210
+ latents = randn_tensor(
211
+ shape=(bsz * num_images_per_prompt, latent_channels, latent_size, latent_size),
212
+ generator=None,
213
+ device=device,
214
+ dtype=torch.bfloat16,
215
+ )
216
+
217
+ # set step values
218
+ if isinstance(self.model.noise_scheduler, FlowMatchEulerDiscreteScheduler):
219
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
220
+ self.model.noise_scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
221
+ else:
222
+ self.model.noise_scheduler.set_timesteps(num_inference_steps)
223
+
224
+ # pred_latent = torch.cat([pred_latent] * 2)
225
+ # Convert to float32 before saving
226
+ for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images", disable=not enable_progress_bar):
227
+
228
+ latent_model_input = torch.cat([latents] * 2)
229
+ latent_model_input = latent_model_input.to(pred_latent.dtype)
230
+
231
+ if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"):
232
+ latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t)
233
+ # predict noise model_output
234
+ noise_pred = self.model.sana(
235
+ hidden_states=latent_model_input,
236
+ encoder_hidden_states=self.model.diffusion_connector(pred_latent),
237
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device),
238
+ encoder_attention_mask=None
239
+ ).sample
240
+
241
+ noise_pred_uncond, noise_pred= noise_pred.chunk(2)
242
+
243
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
244
+
245
+ # compute previous image: x_t -> x_t-1
246
+ latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
247
+
248
+ samples = self.decode_latents(latents.to(self.model.sana_vae.dtype) if self.model.sana_vae is not None else latents, return_tensor=return_tensor)
249
+
250
+
251
+ return gen_ids, samples
252
+
253
+
254
+ AutoConfig.register("blip3o_qwen_grpo", blip3oQwenConfig)
255
+ AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForGRPOLM)
blip3o/model/language_model/blip3o_qwen_inference.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import (
6
+ AutoConfig,
7
+ AutoModelForCausalLM,
8
+ Qwen3Config,
9
+ Qwen3ForCausalLM,
10
+ Qwen3Model,
11
+ )
12
+ from transformers.generation.utils import GenerateOutput
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+ from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel
16
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.schedulers import DDPMScheduler, DDIMScheduler, LCMScheduler, FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
19
+ import numpy as np
20
+ from tqdm import tqdm
21
+ import PIL
22
+
23
+
24
+ def numpy_to_pil(images: np.ndarray):
25
+ """
26
+ Convert a NumPy array of shape (batch, height, width, channels) to a list of PIL Images.
27
+ """
28
+ pil_images = []
29
+ for img in images:
30
+ img_uint8 = (img * 255).round().astype("uint8")
31
+ if img_uint8.shape[2] == 1:
32
+ img_uint8 = img_uint8[..., 0]
33
+ pil_images.append(PIL.Image.fromarray(img_uint8))
34
+ return pil_images
35
+
36
+
37
+ class blip3oQwenConfig(Qwen3Config):
38
+ model_type = "blip3o_qwen_inference"
39
+
40
+ class blip3oQwenModel(blip3oMetaModel, Qwen3Model):
41
+ config_class = blip3oQwenConfig
42
+
43
+ def __init__(self, config: Qwen3Config):
44
+ super(blip3oQwenModel, self).__init__(config)
45
+
46
+ class blip3oQwenForInferenceLM(Qwen3ForCausalLM, blip3oMetaForCausalLM):
47
+ config_class = blip3oQwenConfig
48
+
49
+ def __init__(self, config):
50
+ Qwen3ForCausalLM.__init__(self, config)
51
+ config.model_type = "blip3o_qwen"
52
+ config.rope_scaling = None
53
+
54
+ self.model = blip3oQwenModel(config)
55
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
56
+
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
64
+ sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype)
65
+ schedule_timesteps = self.model.noise_scheduler.timesteps.to(device)
66
+ timesteps = timesteps.to(device)
67
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
68
+
69
+ sigma = sigmas[step_indices].flatten()
70
+ while len(sigma.shape) < n_dim:
71
+ sigma = sigma.unsqueeze(-1)
72
+ return sigma
73
+
74
+
75
+
76
+ @torch.no_grad()
77
+ def generate(
78
+ self,
79
+ inputs: Optional[torch.Tensor] = None,
80
+ images: Optional[torch.Tensor] = None,
81
+ image_sizes: Optional[torch.Tensor] = None,
82
+ modalities: Optional[List[str]] = ["image"],
83
+ **kwargs,
84
+ ) -> Union[GenerateOutput, torch.LongTensor]:
85
+ position_ids = kwargs.pop("position_ids", None)
86
+ attention_mask = kwargs.pop("attention_mask", None)
87
+ if "inputs_embeds" in kwargs:
88
+ raise NotImplementedError("`inputs_embeds` is not supported")
89
+
90
+ if images is not None:
91
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
92
+ else:
93
+ inputs_embeds = self.get_model().embed_tokens(inputs)
94
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
95
+
96
+
97
+
98
+
99
+
100
+ @torch.no_grad()
101
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
102
+ if self.model.sana_vae is not None:
103
+ latents = latents / self.model.sana_vae.config.scaling_factor
104
+ if "shift_factor" in self.model.sana_vae.config and self.model.sana_vae.config.shift_factor is not None:
105
+ latents = latents + self.model.sana_vae.config.shift_factor
106
+ samples = self.model.sana_vae.decode(latents).sample
107
+ else:
108
+ samples = latents
109
+ if normalize:
110
+ samples = (samples / 2 + 0.5).clamp(0, 1)
111
+ else:
112
+ samples = samples.clamp(-1, 1)
113
+ if return_tensor:
114
+ return samples
115
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
116
+ samples = numpy_to_pil(samples)
117
+ return samples
118
+
119
+
120
+
121
+ @torch.no_grad()
122
+ def generate_images(
123
+ self,
124
+ inputs: Optional[torch.Tensor] = None,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ max_new_tokens: Optional[torch.Tensor] = None,
127
+ temperature: Optional[torch.Tensor] = None,
128
+ top_p: Optional[torch.Tensor] = None,
129
+ top_k: Optional[torch.Tensor] = None,
130
+ images: Optional[torch.Tensor] = None,
131
+ image_sizes: Optional[torch.Tensor] = None,
132
+ modalities: Optional[List[str]] = ["image"],
133
+ guidance_scale: float = 2.0,
134
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
135
+ num_inference_steps: int = 30,
136
+ num_images_per_prompt: int = 1,
137
+ return_tensor=False,
138
+ enable_progress_bar=False,
139
+ **kwargs,
140
+ ):
141
+ position_ids = kwargs.pop("position_ids", None)
142
+ # attention_mask = (inputs != -100).long()
143
+
144
+ gen_ids = super(blip3oQwenForInferenceLM, self).generate(
145
+ inputs,
146
+ max_new_tokens=max_new_tokens,
147
+ do_sample=True,
148
+ temperature=temperature,
149
+ attention_mask=attention_mask,
150
+ top_p=top_p,
151
+ top_k=top_k)
152
+
153
+ # breakpoint()
154
+ with torch.no_grad():
155
+ outs = self.model(
156
+ input_ids = gen_ids,
157
+ output_hidden_states = True,
158
+ return_dict = True,
159
+ )
160
+ hidden_states = outs.hidden_states[-1]
161
+
162
+
163
+ start_pos = (gen_ids == self.config.image_start_tag_id).float().argmax(dim=1)
164
+ end_pos = (gen_ids == self.config.image_end_tag_id).float().argmax(dim=1)
165
+
166
+
167
+ selected_hidden_states = []
168
+ for b in range(hidden_states.size(0)):
169
+ start = start_pos[b].item() + 1
170
+ # end = end_pos[b].item()
171
+ selected_hidden_states.append(hidden_states[b, start:, :])
172
+ pred_latent = torch.stack(selected_hidden_states, dim=0)
173
+
174
+
175
+
176
+ img_hidden_states_null = torch.zeros_like(pred_latent)
177
+ pred_latent = torch.cat([img_hidden_states_null, pred_latent], 0)
178
+ ## sample images from here
179
+ device = next(self.parameters()).device
180
+ dtype = next(self.parameters()).dtype
181
+
182
+ bsz = len(pred_latent) // 2
183
+ # latent_size = self.config.input_size
184
+ latent_size = 32
185
+ latent_channels = self.model.sana.config.in_channels
186
+
187
+
188
+ latents = randn_tensor(
189
+ shape=(bsz * num_images_per_prompt, latent_channels, latent_size, latent_size),
190
+ generator=None,
191
+ device=device,
192
+ dtype=torch.bfloat16,
193
+ )
194
+
195
+ # set step values
196
+ if isinstance(self.model.noise_scheduler, FlowMatchEulerDiscreteScheduler):
197
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
198
+ self.model.noise_scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
199
+ else:
200
+ self.model.noise_scheduler.set_timesteps(num_inference_steps)
201
+
202
+ # pred_latent = torch.cat([pred_latent] * 2)
203
+ # Convert to float32 before saving
204
+ for t in tqdm(self.model.noise_scheduler.timesteps, desc="Sampling images", disable=not enable_progress_bar):
205
+
206
+ latent_model_input = torch.cat([latents] * 2)
207
+ latent_model_input = latent_model_input.to(pred_latent.dtype)
208
+
209
+ if hasattr(self.model.noise_scheduler.timesteps, "scale_model_input"):
210
+ latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t)
211
+ # predict noise model_output
212
+ noise_pred = self.model.sana(
213
+ hidden_states=latent_model_input,
214
+ encoder_hidden_states=self.model.diffusion_connector(pred_latent),
215
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latents.device),
216
+ encoder_attention_mask=None
217
+ ).sample
218
+
219
+
220
+ noise_pred_uncond, noise_pred= noise_pred.chunk(2)
221
+
222
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
223
+
224
+ # compute previous image: x_t -> x_t-1
225
+ latents = self.model.noise_scheduler.step(noise_pred, t, latents).prev_sample
226
+
227
+ samples = self.decode_latents(latents.to(self.model.sana_vae.dtype) if self.model.sana_vae is not None else latents, return_tensor=return_tensor)
228
+
229
+
230
+ return gen_ids, samples
231
+
232
+
233
+
234
+
235
+
236
+
237
+
238
+
239
+ AutoConfig.register("blip3o_qwen_inference", blip3oQwenConfig)
240
+ AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForInferenceLM)
241
+
blip3o/model/multimodal_decoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (661 Bytes). View file
 
blip3o/model/multimodal_decoder/__pycache__/builder.cpython-311.pyc ADDED
Binary file (954 Bytes). View file
 
blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-310.pyc ADDED
Binary file (3.71 kB). View file
 
blip3o/model/multimodal_decoder/__pycache__/ta_tok_encoder.cpython-311.pyc ADDED
Binary file (6.73 kB). View file
 
blip3o/model/multimodal_decoder/builder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderDC, SanaTransformer2DModel
2
+ import torch
3
+
4
+
5
+ def build_sana(vision_tower_cfg, **kwargs):
6
+ sana = SanaTransformer2DModel.from_pretrained(vision_tower_cfg.diffusion_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16)
7
+ return sana
8
+
9
+
10
+ def build_vae(vision_tower_cfg, **kwargs):
11
+ vae = AutoencoderDC.from_pretrained(vision_tower_cfg.diffusion_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16)
12
+ return vae
13
+
14
+
blip3o/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (472 Bytes). View file
 
blip3o/model/multimodal_encoder/__pycache__/builder.cpython-311.pyc ADDED
Binary file (639 Bytes). View file
 
blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
blip3o/model/multimodal_encoder/__pycache__/ta_tok_encoder.cpython-311.pyc ADDED
Binary file (6.74 kB). View file