File size: 1,544 Bytes
f846a93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from __future__ import annotations

from typing import Any

import pandas as pd
import torch
from torch.utils.data import Dataset, WeightedRandomSampler

from .data_discovery import LABEL_TO_ID
from .preprocessing import load_pil_image


class EggImageDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, transform: Any | None = None) -> None:
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, str]:
        row = self.df.iloc[index]
        image = load_pil_image(row["filepath"], mode="RGB")
        if self.transform:
            image_tensor = self.transform(image)
        else:
            from torchvision import transforms

            image_tensor = transforms.ToTensor()(image)
        label = int(row.get("label_id", LABEL_TO_ID[row["label"]]))
        return image_tensor, torch.tensor(label, dtype=torch.long), str(row["filepath"])


def create_balanced_sampler(dataframe: pd.DataFrame, seed: int) -> WeightedRandomSampler:
    labels = dataframe["label_id"].astype(int).tolist()
    counts = dataframe["label_id"].value_counts().to_dict()
    weights = torch.DoubleTensor([1.0 / counts[label] for label in labels])
    num_samples = int(max(counts.values()) * len(counts))
    generator = torch.Generator()
    generator.manual_seed(seed)
    return WeightedRandomSampler(weights, num_samples=num_samples, replacement=True, generator=generator)