File size: 4,269 Bytes
9894d76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from pathlib import Path
import pandas as pd
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset as TorchDataset
import numpy as np
from PIL import Image
from transformers import CLIPProcessor
DATA_PATH = Path(__file__).parent.parent / "data"
IMAGES_PATH = DATA_PATH / "imgs"
LABELS_CSV = DATA_PATH / "labels.csv"
def load_golden_dataset() -> pd.DataFrame:
df = pd.read_csv(LABELS_CSV)
# Convert image paths from /data/imgs/... to actual file paths
df["image"] = df["image"].apply(
lambda x: str(IMAGES_PATH / Path(x).name)
)
# Verify files exist
existing = df["image"].apply(lambda x: Path(x).exists())
missing_count = (~existing).sum()
if missing_count > 0:
print(f"Warning: {missing_count} image files not found")
df = df[existing].copy()
# Preprocess labels: combine UNCERTAIN with FAMILY_SAFE (0), SUGGESTIVE (1)
# 0 = FAMILY_SAFE/UNCERTAIN (closer to 0 means FAMILY_SAFE/UNCERTAIN)
# 1 = SUGGESTIVE (closer to 1 means SUGGESTIVE)
df["label"] = df["choice"].apply(
lambda x: 0 if x in ["FAMILY_SAFE", "UNCERTAIN"] else 1
)
return df
def create_dataset_splits(
train_size: float = 0.7,
test_size: float = 0.15,
val_size: float = 0.15,
random_state: int = 42
) -> DatasetDict:
# Validate split sizes
assert abs(train_size + test_size + val_size - 1.0) < 1e-6, \
"Split sizes must sum to 1.0"
# Load data
df = load_golden_dataset()
print(f"Loaded {len(df)} golden self-labelled images")
print("Original label distribution:")
print(df["choice"].value_counts())
print("\nBinary label distribution (after preprocessing):")
print(df["label"].value_counts())
print(" (0 = FAMILY_SAFE/UNCERTAIN, 1 = SUGGESTIVE)")
# First split: train vs (test + val)
# Stratify by binary label to maintain distribution
train_df, temp_df = train_test_split(
df,
test_size=(test_size + val_size),
stratify=df["label"],
random_state=random_state
)
# Second split: test vs val
# Adjust test_size for the remaining data
test_proportion = test_size / (test_size + val_size)
test_df, val_df = train_test_split(
temp_df,
test_size=(1 - test_proportion),
stratify=temp_df["label"],
random_state=random_state
)
print("\nSplit sizes:")
print(f" Train: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)")
print(f" Test: {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)")
print(f" val: {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)")
# Convert to HuggingFace Datasets
train_ds = Dataset.from_pandas(train_df)
test_ds = Dataset.from_pandas(test_df)
val_ds = Dataset.from_pandas(val_df)
# Create DatasetDict
dataset_dict = DatasetDict({
"train": train_ds,
"test": test_ds,
"val": val_ds
})
return dataset_dict
def get_dataset(
train_size: float = 0.7,
test_size: float = 0.15,
val_size: float = 0.15,
random_state: int = 42
) -> DatasetDict:
return create_dataset_splits(
train_size=train_size,
test_size=test_size,
val_size=val_size,
random_state=random_state
)
class ImageDataset(TorchDataset):
"""PyTorch Dataset for image classification."""
def __init__(self, image_paths: list[str], labels: np.ndarray, processor: CLIPProcessor):
self.image_paths = image_paths
self.labels = torch.tensor(labels, dtype=torch.long)
self.processor = processor
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple[dict, torch.Tensor]:
# Single item fetching (PyTorch DataLoader handles batching automatically)
img_path = self.image_paths[idx]
image = Image.open(img_path).convert("RGB")
# Process image with CLIP processor
inputs = self.processor(images=image, return_tensors="pt")
# Remove batch dimension from processor output
pixel_values = inputs["pixel_values"].squeeze(0)
label = self.labels[idx]
return {"pixel_values": pixel_values}, label
|