|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import random |
|
|
import json |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
|
|
|
class PairwiseOriginalDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
json_list, |
|
|
soft_label=False, |
|
|
confidence_threshold=None, |
|
|
): |
|
|
self.samples = [] |
|
|
for json_file in json_list: |
|
|
with open(json_file, "r") as f: |
|
|
data = json.load(f) |
|
|
self.samples.extend(data) |
|
|
|
|
|
self.soft_label = soft_label |
|
|
self.confidence_threshold = confidence_threshold |
|
|
|
|
|
if confidence_threshold is not None: |
|
|
new_samples = [] |
|
|
for sample in tqdm( |
|
|
self.samples, desc="Filtering samples according to confidence threshold" |
|
|
): |
|
|
if sample.get("confidence", float("inf")) >= confidence_threshold: |
|
|
new_samples.append(sample) |
|
|
self.samples = new_samples |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
index = idx |
|
|
try: |
|
|
return self.get_single_item(index) |
|
|
except Exception as e: |
|
|
print(f"Error processing sample at index {idx}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
index = random.randint(0, len(self.samples) - 1) |
|
|
if index == idx: |
|
|
continue |
|
|
idx = index |
|
|
|
|
|
def get_single_item(self, idx): |
|
|
sample = self.samples[idx] |
|
|
|
|
|
image_1 = sample["path1"] |
|
|
image_2 = sample["path2"] |
|
|
assert os.path.exists(image_1) and os.path.exists(image_2), f'{image_1} or {image_2}' |
|
|
text_1 = sample["prompt"] |
|
|
text_2 = sample["prompt"] |
|
|
|
|
|
|
|
|
if self.soft_label: |
|
|
choice_dist = sorted(sample["choice_dist"], reverse=True) |
|
|
assert ( |
|
|
torch.sum(torch.tensor(choice_dist)) > 0 |
|
|
), "Choice distribution cannot be zero." |
|
|
label = torch.tensor(choice_dist[0]) / torch.sum(torch.tensor(choice_dist)) |
|
|
else: |
|
|
label = torch.tensor(1).float() |
|
|
|
|
|
return { |
|
|
"image_1": image_1, |
|
|
"image_2": image_2, |
|
|
"text_1": text_1, |
|
|
"text_2": text_2, |
|
|
"label": label, |
|
|
"confidence": sample.get("confidence", 1.0), |
|
|
"choice_dist": torch.tensor(sample.get("choice_dist", [1.0, 0.0])), |
|
|
} |
|
|
|