Spaces:
Running on Zero
Running on Zero
File size: 12,135 Bytes
d1f1097 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 | import argparse
import logging
import os
import sys
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Protocol, Union
import torch
import torch.nn.utils.rnn as rnn_utils
# TODO: ykarnati, use absolute import or
# define train_valid_test_dataloaders_provider in here
sys.path.append(
os.path.abspath(
os.path.join(
os.path.dirname(__file__),
os.path.pardir,
os.path.pardir,
os.path.pardir,
"examples/multimodal",
)
)
)
from dataloader_provider import train_valid_test_dataloaders_provider
from transformers import AutoProcessor
from megatron.energon import (
DefaultTaskEncoder,
VQASample,
WorkerConfig,
get_loader,
get_train_dataset,
)
from megatron.energon.task_encoder.base import stateless
from megatron.training import get_args
from megatron.training.tokenizer.multimodal_tokenizer import mistral_custom_template
@dataclass
class ConversationTemplateConfig:
system: str = None
chat_template: str = None
@dataclass
class LlavaConversationTemplateConfig(ConversationTemplateConfig):
"""Default system prompt and chat template for Llava training."""
system: str = None
chat_template: str = None
class ModelType(Enum):
LLAVA_VLM = "llava_vlm"
VIDEO_LLAVA_VLM = "video_llava_vlm"
class VLMTaskEncoder(
DefaultTaskEncoder[
Union[VQASample],
dict,
dict,
dict,
]
):
def __init__(
self,
model_type: ModelType,
processor,
conversation_template_config=None,
):
self.model_type = model_type
self.processor = processor
self.conversation_template_config = conversation_template_config
def apply_prompt_template(self, input_text: VQASample):
"""Create conversation prompt string using HF chat template.
The first user turn always contains an image placeholder, later turns are text-only.
Returns a *prompt string* that can be fed into the processor together with an image.
"""
user_msgs = input_text.context
bot_msgs = input_text.answers
def _ensure_list_type(value):
if isinstance(value, list):
return value
return [value]
user_msgs = _ensure_list_type(user_msgs)
bot_msgs = _ensure_list_type(bot_msgs)
conversation = []
for _, (u_txt, b_txt) in enumerate(zip(user_msgs, bot_msgs)):
conversation.append(
{
"role": "user",
"content": [{"type": "text", "text": u_txt}],
}
)
conversation.append(
{
"role": "assistant",
"content": [{"type": "text", "text": b_txt}],
}
)
# Inject optional system message
if (
self.conversation_template_config
and self.conversation_template_config.system
):
conversation.insert(
0,
{"role": "system", "content": self.conversation_template_config.system},
)
# Select chat template
if (
self.conversation_template_config
and self.conversation_template_config.chat_template
):
self.processor.chat_template = (
self.conversation_template_config.chat_template
)
return self.processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=False,
)
def _find_pattern_indices(
self, template, pattern, start_idx=0, allow_first_mismatch=False
):
template_len = len(template)
pat_len = len(pattern)
for i in range(start_idx, template_len - pat_len + 1):
match = template[i : i + pat_len] == pattern
if torch.all(match) or (allow_first_mismatch and torch.all(match[1:])):
return i, i + pat_len
return -1, -1
@stateless
def encode_sample(self, sample: VQASample):
"""Return tokenised multimodal sample."""
# Build prompt
prompt = self.apply_prompt_template(sample)
logging.debug(f"prompt: {prompt}")
# Process image + prompt
inputs = self.processor(
images=getattr(sample, "image", None),
text=prompt,
add_special_tokens=False,
return_tensors="pt",
do_rescale=False,
)
# HF processor returns a dict with batch dim
# Remove batch dim
for k, v in inputs.items():
inputs[k] = v.squeeze(0)
answers = sample.answers
if answers:
if not isinstance(answers, list):
answers = [answers]
tokenizer = self.processor.tokenizer
inputs["labels"] = torch.full_like(inputs["input_ids"], fill_value=-100)
search_idx = 0
for ans in answers:
answer_tokens = tokenizer.encode(
ans, add_special_tokens=False, return_tensors="pt"
)[0]
s_idx, e_idx = self._find_pattern_indices(
inputs["input_ids"], answer_tokens, search_idx
)
if s_idx == -1:
raise ValueError(f"Answer not found in input_ids: {ans}")
inputs["labels"][s_idx:e_idx] = inputs["input_ids"][s_idx:e_idx]
search_idx = e_idx
# shift inputs and labels by 1
inputs["input_ids"] = inputs["input_ids"][:-1]
inputs["labels"] = inputs["labels"][1:]
inputs["loss_mask"] = (inputs["labels"] != -100).long()
else:
inputs["labels"] = None
inputs["loss_mask"] = None
return inputs
def batch(self, samples: List[Dict]) -> Dict:
"""Pad/stack individual samples into a single batch dict."""
if not samples:
return {}
batched: Dict[str, torch.Tensor] = {}
keys = samples[0].keys()
for key in keys:
values = [s[key] for s in samples if key in s and s[key] is not None]
processor = KEY_PROCESSORS.get(key)
if processor is not None:
batched[key] = processor(values)
continue
# Fallback behaviours if no specific processor is registered.
if isinstance(values[0], torch.Tensor):
batched[key] = torch.stack(values, dim=0)
else:
batched[key] = values
return batched
def encode_batch_vlm_clip_llava(self, batch_data: Dict) -> Dict:
input_ids = batch_data["input_ids"]
labels = batch_data.get("labels")
loss_mask = batch_data.get("loss_mask")
seq_len = input_ids.size(1)
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(input_ids.size(0), 1)
pixel_values = batch_data.get("pixel_values")
output = {
"input_ids": input_ids,
"labels": labels,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
if pixel_values is not None:
output["modality_inputs"] = {
"images": {"clip_encoder": {"pixel_values": pixel_values}}
}
return output
def encode_batch_vlm_clip_llava_video(self, batch_data: Dict) -> Dict:
input_ids = batch_data["input_ids"]
labels = batch_data.get("labels")
loss_mask = batch_data.get("loss_mask")
seq_len = input_ids.size(1)
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(input_ids.size(0), 1)
pixel_values_videos = batch_data.get("pixel_values_videos")
output = {
"input_ids": input_ids,
"labels": labels,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
if pixel_values_videos is not None:
output["modality_inputs"] = {
"images": {"clip_encoder": {"pixel_values": pixel_values_videos}}
}
return output
def encode_batch(self, batch_data: Dict) -> dict:
if self.model_type is ModelType.LLAVA_VLM:
return self.encode_batch_vlm_clip_llava(batch_data)
elif self.model_type is ModelType.VIDEO_LLAVA_VLM:
return self.encode_batch_vlm_clip_llava_video(batch_data)
else:
raise ValueError(f"Model type {self.model_type} not supported")
def llava_vlm_dataloader_provider(train_val_test_num_samples, is_video_input=False):
args = get_args()
tokenizer_model_id = args.tokenizer_model
processor = AutoProcessor.from_pretrained(tokenizer_model_id)
if is_video_input:
model_type = ModelType.VIDEO_LLAVA_VLM
else:
model_type = ModelType.LLAVA_VLM
return train_valid_test_dataloaders_provider(
train_val_test_num_samples,
task_encoder=VLMTaskEncoder(
model_type=model_type,
processor=processor,
conversation_template_config=LlavaConversationTemplateConfig(),
),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_path",
type=str,
required=True,
help="path to the dataset directory in energon format",
)
args = parser.parse_args()
model_name = "llava-hf/llava-1.5-7b-hf"
processor = AutoProcessor.from_pretrained(model_name)
worker_config = WorkerConfig.default_worker_config(0)
train_loader = get_loader(
get_train_dataset(
args.data_path,
batch_size=8,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=VLMTaskEncoder(
model_type=ModelType.LLAVA_VLM,
processor=processor,
conversation_template_config=LlavaConversationTemplateConfig(),
),
worker_config=worker_config,
),
worker_config=worker_config,
)
print(f"data loader length {len(train_loader)}")
for index, each_batch in enumerate(train_loader):
print(
f"batch index {index} tokens {each_batch['input_ids']} images shape \
{each_batch['modality_inputs']['images']['clip_encoder']['pixel_values'].shape}"
)
break
# -----------------------------------------------------------------------------
# Key processing utilities for batching
# -----------------------------------------------------------------------------
class KeyProcessor(Protocol):
"""Callable that aggregates a list of tensors into a single batched tensor."""
def __call__(self, values: List[torch.Tensor]) -> torch.Tensor: # pragma: no cover
...
class StackProcessor:
"""Simply stack tensors along a given dimension."""
def __init__(self, dim: int = 0):
self.dim = dim
def __call__(self, values: List[torch.Tensor]) -> torch.Tensor:
return torch.stack(values, dim=self.dim)
class PaddingProcessor:
"""Pad variable-length sequences to the same length."""
def __init__(self, pad_value: int, batch_first: bool = True):
self.pad_value = pad_value
self.batch_first = batch_first
def __call__(self, values: List[torch.Tensor]) -> torch.Tensor:
return rnn_utils.pad_sequence(
values, batch_first=self.batch_first, padding_value=self.pad_value
)
# Registry mapping sample keys to their corresponding processor.
KEY_PROCESSORS: Dict[str, KeyProcessor] = {
"pixel_values": StackProcessor(),
"pixel_values_videos": StackProcessor(),
"input_ids": PaddingProcessor(pad_value=0),
"attention_mask": PaddingProcessor(pad_value=0),
"loss_mask": PaddingProcessor(pad_value=0),
"labels": PaddingProcessor(pad_value=-100),
}
|