File size: 4,549 Bytes
3cf4fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import re
from dataclasses import dataclass
from logging import getLogger
from typing import Any, Dict, List, Optional
import torch
from core.distributed import get_is_master
logger = getLogger()
@dataclass
class MLLMBatch:
x: torch.LongTensor
y: torch.LongTensor
mask: Optional[torch.BoolTensor] = None
image_pos_index: Optional[torch.LongTensor] = None
images: Optional[torch.Tensor] = None
media_type: Optional[List[str]] = (["text"],)
num_image_chunks: Optional[List[int]] = None
def __post_init__(self):
assert self.x.dim() == 2, "{} != 2".format(self.x.dim())
assert self.x.shape == self.y.shape
assert self.x.dtype == torch.int64
assert self.y.dtype == torch.int64
assert self.mask is None or self.mask.shape == self.x.shape
class BaseCollator:
def __init__(
self,
tokenizer,
show_first_batch: bool = False,
) -> None:
self.tokenizer = tokenizer
self.first_batch = show_first_batch
def __call__(self, features: List[Dict[str, Any]]):
raise NotImplementedError
class MllmPaddingCollator(BaseCollator):
def prettify_decoded_text(self, texts: List[str]) -> List[str]:
"""
Prettify the decoded text by replacing consecutive <|image|> tokens with a shortened form using regex.
"""
prettified = []
special_tokens = ["<|end_of_text|>", "<|image|>"]
for text in texts:
for token in special_tokens:
# Regex to find consecutive occurrences of the token
pattern = f"({re.escape(token)})\\1+" # Captures repeating groups of the token
def replace_consecutive(match):
count = len(match.group(0)) // len(token)
return f"{token}..x{count}"
text = re.sub(pattern, replace_consecutive, text)
prettified.append(text)
return prettified
def __call__(self, features: List[Dict[str, Any]]) -> MLLMBatch:
text = []
images = []
media_type = []
response_pos = []
image_pos = []
num_image_chunks = []
for b in features:
text.append(b["text_ids"])
images.append(b["media"])
response_pos.append(b["response_pos"])
image_pos.append(b["image_pos"])
num_image_chunks.append(b["num_image_chunks"])
media_type.append(b["media_type"])
images = [img for img in images if img is not None]
images = torch.cat(images) if images else None
# max_text_len = max([len(x) for x in text]) - 1
bsz = len(text)
input_ids = torch.full(
(bsz, self.tokenizer.seq_len), self.tokenizer.pad_token_id
)
label_ids = torch.full(
(bsz, self.tokenizer.seq_len), self.tokenizer.pad_token_id
)
image_pos_index = torch.full((bsz, self.tokenizer.seq_len), -1)
for i in range(bsz):
# Shift labels (list of lists) to train next token prediction
for j in response_pos[i]:
label_ids[i][j - 1] = text[i][j]
# Remove last token for input
text_len = len(text[i]) - 1
input_ids[i][:text_len] = torch.tensor(text[i][:-1])
# Fill image_pos_index
if image_pos[i]:
image_indices = torch.arange(len(image_pos[i]))
image_pos_index[i, image_pos[i]] = image_indices
mask = label_ids.ne(self.tokenizer.pad_token_id)
# Replace all pad tokens with eos tokens
input_ids[input_ids == self.tokenizer.pad_token_id] = (
self.tokenizer.eos_token_id
)
label_ids[label_ids == self.tokenizer.pad_token_id] = (
self.tokenizer.eos_token_id
)
if self.first_batch and get_is_master():
input_decoded = self.tokenizer.decode_batch(input_ids)
label_decoded = self.tokenizer.decode_batch(label_ids)
logger.info(f"Input text: \n{self.prettify_decoded_text(input_decoded)}")
logger.info(f"Label text: \n{self.prettify_decoded_text(label_decoded)}")
self.first_batch = False
return MLLMBatch(
x=input_ids,
y=label_ids,
mask=mask,
image_pos_index=image_pos_index,
images=images,
media_type=media_type,
num_image_chunks=num_image_chunks,
)
|