File size: 4,549 Bytes
3cf4fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) Meta Platforms, Inc. and affiliates.

import re
from dataclasses import dataclass
from logging import getLogger
from typing import Any, Dict, List, Optional

import torch

from core.distributed import get_is_master

logger = getLogger()


@dataclass
class MLLMBatch:
    x: torch.LongTensor
    y: torch.LongTensor
    mask: Optional[torch.BoolTensor] = None
    image_pos_index: Optional[torch.LongTensor] = None
    images: Optional[torch.Tensor] = None
    media_type: Optional[List[str]] = (["text"],)
    num_image_chunks: Optional[List[int]] = None

    def __post_init__(self):
        assert self.x.dim() == 2, "{} != 2".format(self.x.dim())
        assert self.x.shape == self.y.shape
        assert self.x.dtype == torch.int64
        assert self.y.dtype == torch.int64
        assert self.mask is None or self.mask.shape == self.x.shape


class BaseCollator:
    def __init__(
        self,
        tokenizer,
        show_first_batch: bool = False,
    ) -> None:
        self.tokenizer = tokenizer
        self.first_batch = show_first_batch

    def __call__(self, features: List[Dict[str, Any]]):
        raise NotImplementedError


class MllmPaddingCollator(BaseCollator):

    def prettify_decoded_text(self, texts: List[str]) -> List[str]:
        """
        Prettify the decoded text by replacing consecutive <|image|> tokens with a shortened form using regex.
        """
        prettified = []
        special_tokens = ["<|end_of_text|>", "<|image|>"]
        for text in texts:
            for token in special_tokens:
                # Regex to find consecutive occurrences of the token
                pattern = f"({re.escape(token)})\\1+"  # Captures repeating groups of the token

                def replace_consecutive(match):
                    count = len(match.group(0)) // len(token)
                    return f"{token}..x{count}"

                text = re.sub(pattern, replace_consecutive, text)
            prettified.append(text)
        return prettified

    def __call__(self, features: List[Dict[str, Any]]) -> MLLMBatch:
        text = []
        images = []
        media_type = []
        response_pos = []
        image_pos = []
        num_image_chunks = []
        for b in features:
            text.append(b["text_ids"])
            images.append(b["media"])
            response_pos.append(b["response_pos"])
            image_pos.append(b["image_pos"])
            num_image_chunks.append(b["num_image_chunks"])
            media_type.append(b["media_type"])

        images = [img for img in images if img is not None]
        images = torch.cat(images) if images else None

        # max_text_len = max([len(x) for x in text]) - 1
        bsz = len(text)
        input_ids = torch.full(
            (bsz, self.tokenizer.seq_len), self.tokenizer.pad_token_id
        )
        label_ids = torch.full(
            (bsz, self.tokenizer.seq_len), self.tokenizer.pad_token_id
        )
        image_pos_index = torch.full((bsz, self.tokenizer.seq_len), -1)

        for i in range(bsz):
            # Shift labels (list of lists) to train next token prediction
            for j in response_pos[i]:
                label_ids[i][j - 1] = text[i][j]
            # Remove last token for input
            text_len = len(text[i]) - 1
            input_ids[i][:text_len] = torch.tensor(text[i][:-1])
            # Fill image_pos_index
            if image_pos[i]:
                image_indices = torch.arange(len(image_pos[i]))
                image_pos_index[i, image_pos[i]] = image_indices

        mask = label_ids.ne(self.tokenizer.pad_token_id)

        # Replace all pad tokens with eos tokens
        input_ids[input_ids == self.tokenizer.pad_token_id] = (
            self.tokenizer.eos_token_id
        )
        label_ids[label_ids == self.tokenizer.pad_token_id] = (
            self.tokenizer.eos_token_id
        )

        if self.first_batch and get_is_master():
            input_decoded = self.tokenizer.decode_batch(input_ids)
            label_decoded = self.tokenizer.decode_batch(label_ids)
            logger.info(f"Input text: \n{self.prettify_decoded_text(input_decoded)}")
            logger.info(f"Label text: \n{self.prettify_decoded_text(label_decoded)}")
            self.first_batch = False

        return MLLMBatch(
            x=input_ids,
            y=label_ids,
            mask=mask,
            image_pos_index=image_pos_index,
            images=images,
            media_type=media_type,
            num_image_chunks=num_image_chunks,
        )