File size: 2,536 Bytes
9b57ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch.utils.data import Dataset, DataLoader
import random
import json
import os
from tqdm import tqdm

class PairwiseOriginalDataset(Dataset):
    def __init__(
        self,
        json_list,
        soft_label=False,
        confidence_threshold=None,
    ):
        self.samples = []
        for json_file in json_list:
            with open(json_file, "r") as f:
                data = json.load(f)
            self.samples.extend(data)

        self.soft_label = soft_label
        self.confidence_threshold = confidence_threshold

        if confidence_threshold is not None:
            new_samples = []
            for sample in tqdm(
                self.samples, desc="Filtering samples according to confidence threshold"
            ):
                if sample.get("confidence", float("inf")) >= confidence_threshold:
                    new_samples.append(sample)
            self.samples = new_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        while True:
            index = idx
            try:
                return self.get_single_item(index)
            except Exception as e:
                print(f"Error processing sample at index {idx}: {e}")
                import traceback
                traceback.print_exc()
                index = random.randint(0, len(self.samples) - 1)
                if index == idx:
                    continue
                idx = index

    def get_single_item(self, idx):
        sample = self.samples[idx]
        # Load image paths
        image_1 = sample["path1"]
        image_2 = sample["path2"]
        assert os.path.exists(image_1) and os.path.exists(image_2), f'{image_1} or {image_2}'
        text_1 = sample["prompt"]
        text_2 = sample["prompt"]

        # Process Label
        if self.soft_label:
            choice_dist = sorted(sample["choice_dist"], reverse=True)
            assert (
                torch.sum(torch.tensor(choice_dist)) > 0
            ), "Choice distribution cannot be zero."
            label = torch.tensor(choice_dist[0]) / torch.sum(torch.tensor(choice_dist))
        else:
            label = torch.tensor(1).float()
        # breakpoint()
        return {
            "image_1": image_1,
            "image_2": image_2,
            "text_1": text_1,
            "text_2": text_2,
            "label": label,
            "confidence": sample.get("confidence", 1.0),
            "choice_dist": torch.tensor(sample.get("choice_dist", [1.0, 0.0])),
        }