File size: 8,574 Bytes
9b57ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import pdb
from dataclasses import dataclass, field
from typing import Optional, List, Union
import numpy as np
import pandas as pd
import torch
from hpsv3.dataset.utils import process_vision_info
from torch.utils.data import Dataset
import torchvision.transforms.functional as F

INSTRUCTION = """
You are tasked with evaluating a generated image based on Visual Quality and Text Alignment and give a overall score to estimate the human preference. Please provide a rating from 0 to 10, with 0 being the worst and 10 being the best. 

**Visual Quality:**  
Evaluate the overall visual quality of the image. The following sub-dimensions should be considered:
- **Reasonableness:** The image should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups.
- **Clarity:** Evaluate the sharpness and visibility of the image. The image should be clear and easy to interpret, with no blurring or indistinct areas.
- **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows).
- **Aesthetic and Creativity:** Assess the artistic aspects of the image, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance.
- **Safety:** The image should not contain harmful or inappropriate content, such as political, violent, or adult material. If such content is present, the image quality and satisfaction score should be the lowest possible. 

**Text Alignment:**  
Assess how well the image matches the textual prompt across the following sub-dimensions:
- **Subject Relevance** Evaluate how accurately the subject(s) in the image (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, and behavior.
- **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the image adheres to this style.
- **Contextual Consistency**: Assess whether the background, setting, and surrounding elements in the image logically fit the scenario described in the prompt. The environment should support and enhance the subject without contradictions.
- **Attribute Fidelity**: Check if specific attributes mentioned in the prompt (e.g., colors, clothing, accessories, expressions, actions) are faithfully represented in the image. Minor deviations may be acceptable, but critical attributes should be preserved.
- **Semantic Coherence**: Evaluate whether the overall meaning and intent of the prompt are captured in the image. The generated content should not introduce elements that conflict with or distort the original description.
Textual prompt - {text_prompt}


"""

INSTRUCTION_debug = """
{text_prompt}
"""

prompt_with_special_token = """
Please provide the overall ratings of this image: <|Reward|>

END
"""

prompt_without_special_token = """
Please provide the overall ratings of this image: 
"""


class QWen2VLDataCollator:
    def __init__(
        self,
        processor,
        with_instruction=True,
        max_pixels=256 * 28 * 28,  # Default max pixels
        min_pixels=256 * 28 * 28,  # Default min pixels
        use_special_tokens=True,
    ):
        self.processor = processor
        self.with_instruction = with_instruction
        self.max_pixels = max_pixels
        self.min_pixels = min_pixels
        self.use_special_tokens = use_special_tokens

    def _clean_message(
        self,
        texts,
        images,
        max_pixels=256 * 28 * 28,
        min_pixels=256 * 28 * 28,
        with_instruction=True,
        use_special_tokens=True,
    ):
        """
        remove unnecessary keys from message(very very necessary)
        """
        message_list = []
        for text, image in zip(texts, images):
            out_message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image,
                            "min_pixels": min_pixels,
                            "max_pixels": max_pixels,
                        },
                        {
                            "type": "text",
                            "text": (
                                INSTRUCTION.format(text_prompt=text)
                                + prompt_with_special_token
                                if use_special_tokens
                                else prompt_without_special_token
                            ),
                        },
                    ],
                }
            ]

            message_list.append(out_message)

        return message_list

    def _pad_sequence(self, sequences, attention_mask, max_len, padding_side="right"):
        """
        Pad the sequences to the maximum length.
        """
        assert padding_side in ["right", "left"]
        if sequences.shape[1] >= max_len:
            return sequences, attention_mask

        pad_len = max_len - sequences.shape[1]
        padding = (0, pad_len) if padding_side == "right" else (pad_len, 0)

        sequences_padded = torch.nn.functional.pad(
            sequences, padding, "constant", self.processor.tokenizer.pad_token_id
        )
        attention_mask_padded = torch.nn.functional.pad(
            attention_mask, padding, "constant", 0
        )

        return sequences_padded, attention_mask_padded

    def __call__(self, inputs, with_instruction=True):
        """
        Preprocess inputs to token sequences and return a batch
        """
        images_1, images_2, texts_1, texts_2 = [], [], [], []

        for idx, batch in enumerate(inputs):
            texts_1.append(batch["text_1"])
            texts_2.append(batch["text_2"])
            images_1.append(batch["image_1"])
            images_2.append(batch["image_2"])

        messages_batch_1 = self._clean_message(
            texts_1,
            images_1,
            max_pixels=self.max_pixels,
            min_pixels=self.min_pixels,
            with_instruction=self.with_instruction,
            use_special_tokens=self.use_special_tokens,
        )
        messages_batch_2 = self._clean_message(
            texts_2,
            images_2,
            max_pixels=self.max_pixels,
            min_pixels=self.min_pixels,
            with_instruction=self.with_instruction,
            use_special_tokens=self.use_special_tokens,
        )
        # import pdb; pdb.set_trace()
        image_inputs_1, _ = process_vision_info(messages_batch_1)
        image_inputs_2, _ = process_vision_info(messages_batch_2)
        image_inputs_1 = [
            np.array(image_inputs_1[i]) / 255.0 for i in range(len(image_inputs_1))
        ]
        image_inputs_2 = [
            np.array(image_inputs_2[i]) / 255.0 for i in range(len(image_inputs_2))
        ]
        do_rescale = False

        batch_1 = self.processor(
            text=self.processor.apply_chat_template(
                messages_batch_1, tokenize=False, add_generation_prompt=True
            ),
            images=image_inputs_1,
            videos=None,
            padding=True,
            return_tensors="pt",
            images_kwargs={"do_rescale": do_rescale},
        )
        batch_2 = self.processor(
            text=self.processor.apply_chat_template(
                messages_batch_2, tokenize=False, add_generation_prompt=True
            ),
            images=image_inputs_2,
            videos=None,
            padding=True,
            return_tensors="pt",
            images_kwargs={"do_rescale": do_rescale},
        )

        # pdb.set_trace()
        max_len = max(batch_1["input_ids"].shape[1], batch_2["input_ids"].shape[1])
        batch_1["input_ids"], batch_1["attention_mask"] = self._pad_sequence(
            batch_1["input_ids"], batch_1["attention_mask"], max_len, "right"
        )
        batch_2["input_ids"], batch_2["attention_mask"] = self._pad_sequence(
            batch_2["input_ids"], batch_2["attention_mask"], max_len, "right"
        )

        batch = {
            "batch_1": batch_1,
            "batch_2": batch_2,
            "choice_dist": torch.stack([batch["choice_dist"] for batch in inputs]),
            # Store original text prompts for visualization
            "text_1": texts_1,
            "text_2": texts_2,
            "image_1": image_inputs_1,
            "image_2": image_inputs_2,
        }

        return batch