File size: 26,595 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 | import os
import copy
import json
import random
import logging
import re
import time
import math
import itertools
import ast
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, List, Tuple
from io import BytesIO
import base64
from collections.abc import Sequence
from types import SimpleNamespace
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from decord import VideoReader
import transformers
from omegaconf import OmegaConf
from starVLA.dataloader.qwenvl_llavajson.qwen_data_config import data_list
from starVLA.dataloader.qwenvl_llavajson.rope2d import get_rope_index_25, get_rope_index_2
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = 151655
VIDEO_TOKEN_INDEX = 151656
DEFAULT_IMAGE_TOKEN = "<image>\n"
DEFAULT_VIDEO_TOKEN = "<video>\n"
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def read_jsonl(path):
with open(path, "r") as f:
return [json.loads(line) for line in f]
def preprocess_qwen_2_visual(
sources,
tokenizer: transformers.PreTrainedTokenizer,
grid_thw: List = [],
visual_type: str = "image",
) -> Dict:
roles = {"human": "user", "gpt": "assistant"}
system_message = "You are a helpful assistant."
if visual_type not in ["image", "video"]:
raise ValueError("visual_type must be either 'image' or 'video'")
tokenizer = copy.deepcopy(tokenizer)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = chat_template
visual_replicate_index = 0
input_ids, targets = [], []
for i, source in enumerate(sources):
try:
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
except:
print(sources)
input_id, target = [], []
input_id += tokenizer.apply_chat_template([{"role": "system", "content": system_message}])
target += [IGNORE_INDEX] * len(input_id)
for conv in source:
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
if role == "user":
visual_tag = f"<{visual_type}>"
if visual_tag in content:
parts = content.split(visual_tag)
new_parts = []
for i in range(len(parts) - 1):
new_parts.append(parts[i])
replacement = (
"<|vision_start|>"
+ f"<|{visual_type}_pad|>" * grid_thw[visual_replicate_index]
+ "<|vision_end|>"
)
new_parts.append(replacement)
visual_replicate_index += 1
new_parts.append(parts[-1])
content = "".join(new_parts)
conv = [{"role": role, "content": content}]
encode_id = tokenizer.apply_chat_template(conv)
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target_mask = encode_id.copy()
target_mask[:3] = [IGNORE_INDEX] * 3
target += target_mask
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids,
labels=targets,
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args):
super(LazySupervisedDataset, self).__init__()
dataset = data_args.dataset_use.split(",")
dataset_list = data_list(dataset)
rank0_print(f"Loading datasets: {dataset_list}")
self.video_max_total_pixels = getattr(data_args, "video_max_total_pixels", 1664 * 28 * 28)
self.video_min_total_pixels = getattr(data_args, "video_min_total_pixels", 256 * 28 * 28)
self.model_type = data_args.model_type
if data_args.model_type == "qwen2.5vl":
self.get_rope_index = get_rope_index_25
else:
self.get_rope_index = get_rope_index_2
list_data_dict = []
for data in dataset_list:
file_format = data["annotation_path"].split(".")[-1]
if file_format == "jsonl":
annotations = read_jsonl(data["annotation_path"])
else:
annotations = json.load(open(data["annotation_path"], "r"))
sampling_rate = data.get("sampling_rate", 1.0)
if sampling_rate < 1.0:
annotations = random.sample(annotations, int(len(annotations) * sampling_rate))
print(f"sampling {len(annotations)} examples from dataset {data}")
else:
rank0_print(f"dataset name: {data}")
for ann in annotations:
if data["data_path"] != "":
ann["data_path"] = data["data_path"]
elif "raw_data" in ann.keys():
ann["data_path"] = ann["raw_data"]["data_root"]
list_data_dict += annotations
list_data_dict = self.pre_filter_long_case(list_data_dict, max_words=tokenizer.max_len_single_sentence)
random.shuffle(list_data_dict) # Randomly shuffle the data for training
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
rank0_print(f"Total training samples: {len(self.list_data_dict)}")
rank0_print("Formatting inputs...Skip in lazy mode")
# self.data_args.image_processor.max_pixels = data_args.max_pixels
# self.data_args.image_processor.min_pixels = data_args.min_pixels
# self.data_args.image_processor.size["longest_edge"] = data_args.max_pixels
# self.data_args.image_processor.size["shortest_edge"] = data_args.min_pixels
def __len__(self):
return len(self.list_data_dict)
def pre_filter_long_case(self, list_data_dict, max_words=1024):
"""filter out conversations with total words exceeding max_words"""
def count_total_words(convs):
total = 0
for entry in convs:
value = entry.get("value", "")
total += len(value.strip().split())
return total
return [item for item in list_data_dict if count_total_words(item.get("conversations", [])) <= max_words]
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if "image" in sample else 0
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
cur_len = cur_len if ("images" in sample) or ("videos" in sample) else -cur_len
length_list.append(cur_len)
return length_list
@property
def pre_calculated_length(self):
if "num_tokens" in self.list_data_dict[0]:
length_list = [sample["num_tokens"] for sample in self.list_data_dict]
return np.array(length_list)
else:
print("No pre-calculated length available.")
return np.array([1] * len(self.list_data_dict))
def process_image_unified(self, image_file):
processor = copy.deepcopy(self.data_args.image_processor)
image = Image.open(image_file).convert("RGB")
# if fix image size?
if getattr(self.data_args, "fix_image_size", None) is not None:
image = image.resize(
self.data_args.fix_image_size,
resample=Image.BICUBIC,
)
visual_processed = processor.preprocess(image, return_tensors="pt")
image_tensor = visual_processed["pixel_values"]
if isinstance(image_tensor, List):
image_tensor = image_tensor[0]
grid_thw = visual_processed["image_grid_thw"][0]
return image_tensor, grid_thw
def process_video(self, video_file):
if not os.path.exists(video_file):
print(f"File not exist: {video_file}")
vr = VideoReader(video_file, num_threads=4)
total_frames = len(vr)
avg_fps = vr.get_avg_fps()
video_length = total_frames / avg_fps
interval = getattr(self.data_args, "base_interval", 4)
num_frames_to_sample = round(video_length / interval)
video_min_frames = getattr(self.data_args, "video_min_frames", 4)
video_max_frames = getattr(self.data_args, "video_max_frames", 8)
target_frames = min(max(num_frames_to_sample, video_min_frames), video_max_frames)
frame_idx = np.linspace(0, total_frames - 1, target_frames, dtype=int)
frame_idx = np.unique(frame_idx)
video = vr.get_batch(frame_idx).asnumpy()
fps = len(frame_idx) / video_length
processor = copy.deepcopy(self.data_args.image_processor)
processor.max_pixels = self.data_args.video_max_frame_pixels
processor.min_pixels = self.data_args.video_min_frame_pixels
processor.size["longest_edge"] = processor.max_pixels
processor.size["shortest_edge"] = processor.min_pixels
video_processed = processor.preprocess(images=None, videos=video, return_tensors="pt")
video_tensor = video_processed["pixel_values_videos"]
grid_thw = video_processed["video_grid_thw"][0]
second_per_grid_ts = [self.data_args.image_processor.temporal_patch_size / fps] * len(grid_thw)
return video_tensor, grid_thw, second_per_grid_ts
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
num_base_retries = 3
num_final_retries = 30
# try the current sample first
for attempt_idx in range(num_base_retries):
try:
sample = self._get_item(i)
return sample
except Exception as e:
# sleep 1s in case it is a cloud disk issue
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
time.sleep(1)
# try other samples, in case it is file corruption issue
for attempt_idx in range(num_base_retries):
try:
next_index = min(i + 1, len(self.list_data_dict) - 1)
# sample_idx = random.choice(range(len(self)))
sample = self._get_item(next_index)
return sample
except Exception as e:
# no need to sleep
print(
f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:",
e,
)
pass
try:
sample = self._get_item(i)
return sample
except Exception as e:
raise e
def _get_item(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
video = None
if "images" in sources[0] and len(sources[0]["images"]):
image_folder = self.list_data_dict[i]["data_path"]
image_file = self.list_data_dict[i]["images"]
if isinstance(image_file, List):
if len(image_file) > 1:
image_file = [os.path.join(image_folder, file) for file in image_file]
results = [self.process_image_unified(file) for file in image_file]
image, grid_thw = zip(*results)
else:
image_file = image_file[0]
image_file = os.path.join(image_folder, image_file)
image, grid_thw = self.process_image_unified(image_file)
image = [image]
else:
image_file = os.path.join(image_folder, image_file)
image, grid_thw = self.process_image_unified(image_file)
image = [image]
grid_thw_merged = copy.deepcopy(grid_thw)
if not isinstance(grid_thw, Sequence):
grid_thw_merged = [grid_thw_merged]
grid_thw = [grid_thw]
grid_thw_merged = [
merged_thw.prod() // self.data_args.image_processor.merge_size**2 for merged_thw in grid_thw_merged
]
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess_qwen_2_visual(sources, self.tokenizer, grid_thw=grid_thw_merged, visual_type="image")
position_ids, _ = self.get_rope_index(
self.data_args.image_processor.merge_size,
data_dict["input_ids"],
torch.stack(grid_thw, dim=0), # (1,16,16)
)
elif "videos" in sources[0] and len(sources[0]["videos"]):
video_file = self.list_data_dict[i]["videos"]
video_folder = self.list_data_dict[i]["data_path"]
if isinstance(video_file, List):
if len(video_file) > 1:
video_file = [os.path.join(video_folder, file) for file in video_file]
results = [self.process_video(file) for file in video_file]
video, grid_thw, second_per_grid_ts = zip(*results)
else:
video_file = video_file[0]
video_file = os.path.join(video_folder, video_file)
video, grid_thw, second_per_grid_ts = self.process_video(video_file)
video = [video]
else:
video_file = os.path.join(video_folder, video_file)
video, grid_thw, second_per_grid_ts = self.process_video(video_file)
video = [video]
grid_thw_merged = copy.deepcopy(grid_thw)
if not isinstance(grid_thw, Sequence):
grid_thw_merged = [grid_thw_merged]
grid_thw = [grid_thw]
grid_thw_merged = [
merged_thw.prod() // self.data_args.image_processor.merge_size**2 for merged_thw in grid_thw_merged
]
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess_qwen_2_visual(sources, self.tokenizer, grid_thw=grid_thw_merged, visual_type="video")
position_ids, _ = self.get_rope_index(
self.data_args.image_processor.merge_size,
data_dict["input_ids"],
video_grid_thw=torch.stack(grid_thw, dim=0),
second_per_grid_ts=second_per_grid_ts,
)
else:
grid_thw_merged = None
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess_qwen_2_visual(sources, self.tokenizer, grid_thw=grid_thw_merged)
position_ids = torch.arange(0, data_dict["input_ids"].size(1)).view(1, -1).unsqueeze(0).expand(3, -1, -1)
if isinstance(i, int):
data_dict = dict(
input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0],
position_ids=position_ids,
)
if "images" in self.list_data_dict[i]:
data_dict["pixel_values"] = image
data_dict["image_grid_thw"] = grid_thw
# video exist in the data
elif "videos" in self.list_data_dict[i]:
data_dict["pixel_values_videos"] = video
data_dict["video_grid_thw"] = grid_thw
max_len = self.tokenizer.max_len_single_sentence
if data_dict["input_ids"].shape[0] > max_len:
data_dict["input_ids"] = data_dict["input_ids"][:max_len]
data_dict["labels"] = data_dict["labels"][:max_len]
data_dict["position_ids"] = position_ids[:, :, :max_len]
return data_dict
def pad_and_cat(tensor_list):
max_length = max(tensor.shape[2] for tensor in tensor_list)
padded_tensors = []
for tensor in tensor_list:
pad_length = max_length - tensor.shape[2]
padded_tensor = torch.nn.functional.pad(tensor, (0, pad_length), "constant", 1)
padded_tensors.append(padded_tensor)
stacked_tensor = torch.cat(padded_tensors, dim=1)
return stacked_tensor
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, position_ids = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels", "position_ids")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
padding_side=self.tokenizer.padding_side,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX, padding_side=self.tokenizer.padding_side
)
position_ids = pad_and_cat(position_ids)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
position_ids = position_ids[..., : self.tokenizer.model_max_length] # 3,bs,length
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
images = list(
itertools.chain(*(instance["pixel_values"] for instance in instances if "pixel_values" in instance))
)
videos = list(
itertools.chain(
*(instance["pixel_values_videos"] for instance in instances if "pixel_values_videos" in instance)
)
)
if len(images) != 0:
concat_images = torch.cat([image for image in images], dim=0)
grid_thw = list(
itertools.chain(*(instance["image_grid_thw"] for instance in instances if "image_grid_thw" in instance))
)
grid_thw = torch.stack(grid_thw, dim=0)
else:
concat_images = None
grid_thw = None
if len(videos) != 0:
concat_videos = torch.cat([video for video in videos], dim=0)
video_grid_thw = list(
itertools.chain(*(instance["video_grid_thw"] for instance in instances if "video_grid_thw" in instance))
)
video_grid_thw = torch.stack(video_grid_thw, dim=0)
else:
concat_videos = None
video_grid_thw = None
batch["pixel_values"] = concat_images
batch["image_grid_thw"] = grid_thw
batch["pixel_values_videos"] = concat_videos
batch["video_grid_thw"] = video_grid_thw
batch["position_ids"] = position_ids
return batch
@dataclass
class FlattenedDataCollatorForSupervisedDataset(DataCollatorForSupervisedDataset):
"""Collate examples into packed sequence with multi-modal support."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, position_ids = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels", "position_ids")
)
seq_lens = torch.tensor([0] + [len(seq) for seq in input_ids], dtype=torch.int32)
cumsum_seq_lens = torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
input_ids = torch.cat(input_ids, dim=0)
labels = torch.cat(labels, dim=0)
position_ids = torch.cat(position_ids, dim=2)
batch = dict(
input_ids=input_ids.unsqueeze(0),
labels=labels.unsqueeze(0),
attention_mask=cumsum_seq_lens,
position_ids=position_ids,
)
images = list(
itertools.chain(*(instance["pixel_values"] for instance in instances if "pixel_values" in instance))
)
videos = list(
itertools.chain(
*(instance["pixel_values_videos"] for instance in instances if "pixel_values_videos" in instance)
)
)
if len(images) != 0:
concat_images = torch.cat([image for image in images], dim=0)
grid_thw = list(
itertools.chain(*(instance["image_grid_thw"] for instance in instances if "image_grid_thw" in instance))
)
grid_thw = torch.stack(grid_thw, dim=0)
else:
concat_images = None
grid_thw = None
if len(videos) != 0:
concat_videos = torch.cat([video for video in videos], dim=0)
video_grid_thw = list(
itertools.chain(*(instance["video_grid_thw"] for instance in instances if "video_grid_thw" in instance))
)
video_grid_thw = torch.stack(video_grid_thw, dim=0)
else:
concat_videos = None
video_grid_thw = None
batch["pixel_values"] = concat_images
batch["image_grid_thw"] = grid_thw
batch["pixel_values_videos"] = concat_videos
batch["video_grid_thw"] = video_grid_thw
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
# load training dataset
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
# load evaluation dataset (if specified eval dataset path)
eval_dataset = None
if hasattr(data_args, "eval_dataset") and data_args.eval_dataset:
eval_data_args = copy.deepcopy(data_args)
eval_data_args.dataset_use = data_args.eval_dataset
eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=eval_data_args)
# select appropriate collator based on whether data needs to be flattened
if data_args.data_flatten:
data_collator = FlattenedDataCollatorForSupervisedDataset(tokenizer=tokenizer)
else:
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
def make_vlm_dataloader(cfg):
data_args = cfg.datasets.vlm_data
image_processor = AutoProcessor.from_pretrained(
cfg.framework.qwenvl.base_vlm,
).image_processor
tokenizer = transformers.AutoTokenizer.from_pretrained(
cfg.framework.qwenvl.base_vlm,
model_max_length=data_args.model_max_length,
padding_side="left", # flash Attention version of Qwen2.5_VL. Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input.
use_fast=False,
)
# avoid processing these in dataset
image_processor.max_pixels = int(data_args.max_pixels)
image_processor.min_pixels = int(data_args.min_pixels)
image_processor.size["longest_edge"] = int(data_args.max_pixels)
image_processor.size["shortest_edge"] = int(data_args.min_pixels)
data_args_ns = SimpleNamespace(**OmegaConf.to_container(data_args, resolve=True))
data_args_ns.image_processor = image_processor # TODO later remove the logic bound to model
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args_ns)
#
train_dataset = data_module["train_dataset"]
data_collator = data_module["data_collator"]
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
train_dataset,
batch_size=cfg.datasets.vlm_data.per_device_batch_size,
collate_fn=data_collator,
num_workers=4,
)
return {
"train_dataloader": train_dataloader,
}
from transformers import AutoTokenizer, AutoProcessor
if __name__ == "__main__":
import debugpy
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="./examples/LIBERO/train_files/starvla_cotrain_libero.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
cfg = OmegaConf.load(args.config_yaml)
data_args = cfg.datasets.vlm_data
image_processor = AutoProcessor.from_pretrained(
cfg.framework.qwenvl.base_vlm,
).image_processor
tokenizer = transformers.AutoTokenizer.from_pretrained(
cfg.framework.qwenvl.base_vlm,
model_max_length=data_args.model_max_length,
padding_side="left",
use_fast=False,
)
# avoid processing these in dataset
image_processor.max_pixels = data_args.max_pixels
image_processor.min_pixels = data_args.min_pixels
image_processor.size["longest_edge"] = data_args.max_pixels
image_processor.size["shortest_edge"] = data_args.min_pixels
data_args_ns = SimpleNamespace(**OmegaConf.to_container(data_args, resolve=True))
data_args_ns.image_processor = image_processor
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args_ns)
#
train_dataset = data_module["train_dataset"]
data_collator = data_module["data_collator"]
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
train_dataset,
batch_size=cfg.datasets.vlm_data.per_device_batch_size,
collate_fn=data_collator,
)
batchs = iter(train_dataloader)
batch_samples = next(batchs)
# skip the first 99 batches, get the 100th batch
from itertools import islice
# batch_samples = next(islice(batchs, 99, 100))
count = 0
while count < 100:
batch_samples = next(batchs) # for debug
print(count)
count += 1
pass
|