HPSv3 / hpsv3 /dataset /pairwise_dataset.py
sdsdgwe's picture
update
9b57ce7
import torch
from torch.utils.data import Dataset, DataLoader
import random
import json
import os
from tqdm import tqdm
class PairwiseOriginalDataset(Dataset):
def __init__(
self,
json_list,
soft_label=False,
confidence_threshold=None,
):
self.samples = []
for json_file in json_list:
with open(json_file, "r") as f:
data = json.load(f)
self.samples.extend(data)
self.soft_label = soft_label
self.confidence_threshold = confidence_threshold
if confidence_threshold is not None:
new_samples = []
for sample in tqdm(
self.samples, desc="Filtering samples according to confidence threshold"
):
if sample.get("confidence", float("inf")) >= confidence_threshold:
new_samples.append(sample)
self.samples = new_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
while True:
index = idx
try:
return self.get_single_item(index)
except Exception as e:
print(f"Error processing sample at index {idx}: {e}")
import traceback
traceback.print_exc()
index = random.randint(0, len(self.samples) - 1)
if index == idx:
continue
idx = index
def get_single_item(self, idx):
sample = self.samples[idx]
# Load image paths
image_1 = sample["path1"]
image_2 = sample["path2"]
assert os.path.exists(image_1) and os.path.exists(image_2), f'{image_1} or {image_2}'
text_1 = sample["prompt"]
text_2 = sample["prompt"]
# Process Label
if self.soft_label:
choice_dist = sorted(sample["choice_dist"], reverse=True)
assert (
torch.sum(torch.tensor(choice_dist)) > 0
), "Choice distribution cannot be zero."
label = torch.tensor(choice_dist[0]) / torch.sum(torch.tensor(choice_dist))
else:
label = torch.tensor(1).float()
# breakpoint()
return {
"image_1": image_1,
"image_2": image_2,
"text_1": text_1,
"text_2": text_2,
"label": label,
"confidence": sample.get("confidence", 1.0),
"choice_dist": torch.tensor(sample.get("choice_dist", [1.0, 0.0])),
}