Spaces:
Running on Zero
Running on Zero
| import argparse | |
| import logging | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import Dict, List, Protocol, Union | |
| import torch | |
| import torch.nn.utils.rnn as rnn_utils | |
| # TODO: ykarnati, use absolute import or | |
| # define train_valid_test_dataloaders_provider in here | |
| sys.path.append( | |
| os.path.abspath( | |
| os.path.join( | |
| os.path.dirname(__file__), | |
| os.path.pardir, | |
| os.path.pardir, | |
| os.path.pardir, | |
| "examples/multimodal", | |
| ) | |
| ) | |
| ) | |
| from dataloader_provider import train_valid_test_dataloaders_provider | |
| from transformers import AutoProcessor | |
| from megatron.energon import ( | |
| DefaultTaskEncoder, | |
| VQASample, | |
| WorkerConfig, | |
| get_loader, | |
| get_train_dataset, | |
| ) | |
| from megatron.energon.task_encoder.base import stateless | |
| from megatron.training import get_args | |
| from megatron.training.tokenizer.multimodal_tokenizer import mistral_custom_template | |
| class ConversationTemplateConfig: | |
| system: str = None | |
| chat_template: str = None | |
| class LlavaConversationTemplateConfig(ConversationTemplateConfig): | |
| """Default system prompt and chat template for Llava training.""" | |
| system: str = None | |
| chat_template: str = None | |
| class ModelType(Enum): | |
| LLAVA_VLM = "llava_vlm" | |
| VIDEO_LLAVA_VLM = "video_llava_vlm" | |
| class VLMTaskEncoder( | |
| DefaultTaskEncoder[ | |
| Union[VQASample], | |
| dict, | |
| dict, | |
| dict, | |
| ] | |
| ): | |
| def __init__( | |
| self, | |
| model_type: ModelType, | |
| processor, | |
| conversation_template_config=None, | |
| ): | |
| self.model_type = model_type | |
| self.processor = processor | |
| self.conversation_template_config = conversation_template_config | |
| def apply_prompt_template(self, input_text: VQASample): | |
| """Create conversation prompt string using HF chat template. | |
| The first user turn always contains an image placeholder, later turns are text-only. | |
| Returns a *prompt string* that can be fed into the processor together with an image. | |
| """ | |
| user_msgs = input_text.context | |
| bot_msgs = input_text.answers | |
| def _ensure_list_type(value): | |
| if isinstance(value, list): | |
| return value | |
| return [value] | |
| user_msgs = _ensure_list_type(user_msgs) | |
| bot_msgs = _ensure_list_type(bot_msgs) | |
| conversation = [] | |
| for _, (u_txt, b_txt) in enumerate(zip(user_msgs, bot_msgs)): | |
| conversation.append( | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": u_txt}], | |
| } | |
| ) | |
| conversation.append( | |
| { | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": b_txt}], | |
| } | |
| ) | |
| # Inject optional system message | |
| if ( | |
| self.conversation_template_config | |
| and self.conversation_template_config.system | |
| ): | |
| conversation.insert( | |
| 0, | |
| {"role": "system", "content": self.conversation_template_config.system}, | |
| ) | |
| # Select chat template | |
| if ( | |
| self.conversation_template_config | |
| and self.conversation_template_config.chat_template | |
| ): | |
| self.processor.chat_template = ( | |
| self.conversation_template_config.chat_template | |
| ) | |
| return self.processor.apply_chat_template( | |
| conversation, | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| def _find_pattern_indices( | |
| self, template, pattern, start_idx=0, allow_first_mismatch=False | |
| ): | |
| template_len = len(template) | |
| pat_len = len(pattern) | |
| for i in range(start_idx, template_len - pat_len + 1): | |
| match = template[i : i + pat_len] == pattern | |
| if torch.all(match) or (allow_first_mismatch and torch.all(match[1:])): | |
| return i, i + pat_len | |
| return -1, -1 | |
| def encode_sample(self, sample: VQASample): | |
| """Return tokenised multimodal sample.""" | |
| # Build prompt | |
| prompt = self.apply_prompt_template(sample) | |
| logging.debug(f"prompt: {prompt}") | |
| # Process image + prompt | |
| inputs = self.processor( | |
| images=getattr(sample, "image", None), | |
| text=prompt, | |
| add_special_tokens=False, | |
| return_tensors="pt", | |
| do_rescale=False, | |
| ) | |
| # HF processor returns a dict with batch dim | |
| # Remove batch dim | |
| for k, v in inputs.items(): | |
| inputs[k] = v.squeeze(0) | |
| answers = sample.answers | |
| if answers: | |
| if not isinstance(answers, list): | |
| answers = [answers] | |
| tokenizer = self.processor.tokenizer | |
| inputs["labels"] = torch.full_like(inputs["input_ids"], fill_value=-100) | |
| search_idx = 0 | |
| for ans in answers: | |
| answer_tokens = tokenizer.encode( | |
| ans, add_special_tokens=False, return_tensors="pt" | |
| )[0] | |
| s_idx, e_idx = self._find_pattern_indices( | |
| inputs["input_ids"], answer_tokens, search_idx | |
| ) | |
| if s_idx == -1: | |
| raise ValueError(f"Answer not found in input_ids: {ans}") | |
| inputs["labels"][s_idx:e_idx] = inputs["input_ids"][s_idx:e_idx] | |
| search_idx = e_idx | |
| # shift inputs and labels by 1 | |
| inputs["input_ids"] = inputs["input_ids"][:-1] | |
| inputs["labels"] = inputs["labels"][1:] | |
| inputs["loss_mask"] = (inputs["labels"] != -100).long() | |
| else: | |
| inputs["labels"] = None | |
| inputs["loss_mask"] = None | |
| return inputs | |
| def batch(self, samples: List[Dict]) -> Dict: | |
| """Pad/stack individual samples into a single batch dict.""" | |
| if not samples: | |
| return {} | |
| batched: Dict[str, torch.Tensor] = {} | |
| keys = samples[0].keys() | |
| for key in keys: | |
| values = [s[key] for s in samples if key in s and s[key] is not None] | |
| processor = KEY_PROCESSORS.get(key) | |
| if processor is not None: | |
| batched[key] = processor(values) | |
| continue | |
| # Fallback behaviours if no specific processor is registered. | |
| if isinstance(values[0], torch.Tensor): | |
| batched[key] = torch.stack(values, dim=0) | |
| else: | |
| batched[key] = values | |
| return batched | |
| def encode_batch_vlm_clip_llava(self, batch_data: Dict) -> Dict: | |
| input_ids = batch_data["input_ids"] | |
| labels = batch_data.get("labels") | |
| loss_mask = batch_data.get("loss_mask") | |
| seq_len = input_ids.size(1) | |
| position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) | |
| position_ids = position_ids.unsqueeze(0).repeat(input_ids.size(0), 1) | |
| pixel_values = batch_data.get("pixel_values") | |
| output = { | |
| "input_ids": input_ids, | |
| "labels": labels, | |
| "loss_mask": loss_mask, | |
| "position_ids": position_ids, | |
| } | |
| if pixel_values is not None: | |
| output["modality_inputs"] = { | |
| "images": {"clip_encoder": {"pixel_values": pixel_values}} | |
| } | |
| return output | |
| def encode_batch_vlm_clip_llava_video(self, batch_data: Dict) -> Dict: | |
| input_ids = batch_data["input_ids"] | |
| labels = batch_data.get("labels") | |
| loss_mask = batch_data.get("loss_mask") | |
| seq_len = input_ids.size(1) | |
| position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) | |
| position_ids = position_ids.unsqueeze(0).repeat(input_ids.size(0), 1) | |
| pixel_values_videos = batch_data.get("pixel_values_videos") | |
| output = { | |
| "input_ids": input_ids, | |
| "labels": labels, | |
| "loss_mask": loss_mask, | |
| "position_ids": position_ids, | |
| } | |
| if pixel_values_videos is not None: | |
| output["modality_inputs"] = { | |
| "images": {"clip_encoder": {"pixel_values": pixel_values_videos}} | |
| } | |
| return output | |
| def encode_batch(self, batch_data: Dict) -> dict: | |
| if self.model_type is ModelType.LLAVA_VLM: | |
| return self.encode_batch_vlm_clip_llava(batch_data) | |
| elif self.model_type is ModelType.VIDEO_LLAVA_VLM: | |
| return self.encode_batch_vlm_clip_llava_video(batch_data) | |
| else: | |
| raise ValueError(f"Model type {self.model_type} not supported") | |
| def llava_vlm_dataloader_provider(train_val_test_num_samples, is_video_input=False): | |
| args = get_args() | |
| tokenizer_model_id = args.tokenizer_model | |
| processor = AutoProcessor.from_pretrained(tokenizer_model_id) | |
| if is_video_input: | |
| model_type = ModelType.VIDEO_LLAVA_VLM | |
| else: | |
| model_type = ModelType.LLAVA_VLM | |
| return train_valid_test_dataloaders_provider( | |
| train_val_test_num_samples, | |
| task_encoder=VLMTaskEncoder( | |
| model_type=model_type, | |
| processor=processor, | |
| conversation_template_config=LlavaConversationTemplateConfig(), | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--data_path", | |
| type=str, | |
| required=True, | |
| help="path to the dataset directory in energon format", | |
| ) | |
| args = parser.parse_args() | |
| model_name = "llava-hf/llava-1.5-7b-hf" | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| worker_config = WorkerConfig.default_worker_config(0) | |
| train_loader = get_loader( | |
| get_train_dataset( | |
| args.data_path, | |
| batch_size=8, | |
| shuffle_buffer_size=None, | |
| max_samples_per_sequence=None, | |
| task_encoder=VLMTaskEncoder( | |
| model_type=ModelType.LLAVA_VLM, | |
| processor=processor, | |
| conversation_template_config=LlavaConversationTemplateConfig(), | |
| ), | |
| worker_config=worker_config, | |
| ), | |
| worker_config=worker_config, | |
| ) | |
| print(f"data loader length {len(train_loader)}") | |
| for index, each_batch in enumerate(train_loader): | |
| print( | |
| f"batch index {index} tokens {each_batch['input_ids']} images shape \ | |
| {each_batch['modality_inputs']['images']['clip_encoder']['pixel_values'].shape}" | |
| ) | |
| break | |
| # ----------------------------------------------------------------------------- | |
| # Key processing utilities for batching | |
| # ----------------------------------------------------------------------------- | |
| class KeyProcessor(Protocol): | |
| """Callable that aggregates a list of tensors into a single batched tensor.""" | |
| def __call__(self, values: List[torch.Tensor]) -> torch.Tensor: # pragma: no cover | |
| ... | |
| class StackProcessor: | |
| """Simply stack tensors along a given dimension.""" | |
| def __init__(self, dim: int = 0): | |
| self.dim = dim | |
| def __call__(self, values: List[torch.Tensor]) -> torch.Tensor: | |
| return torch.stack(values, dim=self.dim) | |
| class PaddingProcessor: | |
| """Pad variable-length sequences to the same length.""" | |
| def __init__(self, pad_value: int, batch_first: bool = True): | |
| self.pad_value = pad_value | |
| self.batch_first = batch_first | |
| def __call__(self, values: List[torch.Tensor]) -> torch.Tensor: | |
| return rnn_utils.pad_sequence( | |
| values, batch_first=self.batch_first, padding_value=self.pad_value | |
| ) | |
| # Registry mapping sample keys to their corresponding processor. | |
| KEY_PROCESSORS: Dict[str, KeyProcessor] = { | |
| "pixel_values": StackProcessor(), | |
| "pixel_values_videos": StackProcessor(), | |
| "input_ids": PaddingProcessor(pad_value=0), | |
| "attention_mask": PaddingProcessor(pad_value=0), | |
| "loss_mask": PaddingProcessor(pad_value=0), | |
| "labels": PaddingProcessor(pad_value=-100), | |
| } | |