| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import random |
| import json |
| import os |
| from tqdm import tqdm |
|
|
| class PairwiseOriginalDataset(Dataset): |
| def __init__( |
| self, |
| json_list, |
| soft_label=False, |
| confidence_threshold=None, |
| ): |
| self.samples = [] |
| for json_file in json_list: |
| with open(json_file, "r") as f: |
| data = json.load(f) |
| self.samples.extend(data) |
|
|
| self.soft_label = soft_label |
| self.confidence_threshold = confidence_threshold |
|
|
| if confidence_threshold is not None: |
| new_samples = [] |
| for sample in tqdm( |
| self.samples, desc="Filtering samples according to confidence threshold" |
| ): |
| if sample.get("confidence", float("inf")) >= confidence_threshold: |
| new_samples.append(sample) |
| self.samples = new_samples |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| while True: |
| index = idx |
| try: |
| return self.get_single_item(index) |
| except Exception as e: |
| print(f"Error processing sample at index {idx}: {e}") |
| import traceback |
| traceback.print_exc() |
| index = random.randint(0, len(self.samples) - 1) |
| if index == idx: |
| continue |
| idx = index |
|
|
| def get_single_item(self, idx): |
| sample = self.samples[idx] |
| |
| image_1 = sample["path1"] |
| image_2 = sample["path2"] |
| assert os.path.exists(image_1) and os.path.exists(image_2), f'{image_1} or {image_2}' |
| text_1 = sample["prompt"] |
| text_2 = sample["prompt"] |
|
|
| |
| if self.soft_label: |
| choice_dist = sorted(sample["choice_dist"], reverse=True) |
| assert ( |
| torch.sum(torch.tensor(choice_dist)) > 0 |
| ), "Choice distribution cannot be zero." |
| label = torch.tensor(choice_dist[0]) / torch.sum(torch.tensor(choice_dist)) |
| else: |
| label = torch.tensor(1).float() |
| |
| return { |
| "image_1": image_1, |
| "image_2": image_2, |
| "text_1": text_1, |
| "text_2": text_2, |
| "label": label, |
| "confidence": sample.get("confidence", 1.0), |
| "choice_dist": torch.tensor(sample.get("choice_dist", [1.0, 0.0])), |
| } |
|
|