Image_Classification / data_utils.py
functionNormally
Restaurer les paramètres CNN qui fonctionnaient + epoch max à 50
e8074db
import random
from collections import Counter
from typing import Dict, List, Tuple
import pandas as pd
import torch
from PIL import Image
from datasets import load_dataset, DatasetDict
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED
_RAW_DATASET = None
_CLASS_NAMES = None
_SPLITS = None
class HFDatasetWrapper(Dataset):
def __init__(self, hf_dataset, transform):
self.dataset = hf_dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item["image"]
if not isinstance(image, Image.Image):
image = Image.open(image)
image = image.convert("RGB")
label = int(item["label"])
if self.transform:
image = self.transform(image)
return image, label
def get_train_transform():
return transforms.Compose(
[
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation(degrees=5),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
]
)
def get_eval_transform():
return transforms.Compose(
[
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
]
)
def load_raw_dataset():
global _RAW_DATASET, _CLASS_NAMES
if _RAW_DATASET is not None:
return _RAW_DATASET, _CLASS_NAMES
if not HF_TOKEN:
raise RuntimeError(
"HF_TOKEN est manquant. Ajoutez-le dans les Secrets du Space Hugging Face."
)
raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
if "train" not in raw:
raise RuntimeError("Le dataset Hugging Face doit contenir au moins un split 'train'.")
label_feature = raw["train"].features["label"]
if hasattr(label_feature, "names") and label_feature.names:
class_names = label_feature.names
else:
labels = list(raw["train"]["label"])
class_names = [str(x) for x in sorted(set(labels))]
_RAW_DATASET = raw
_CLASS_NAMES = class_names
return _RAW_DATASET, _CLASS_NAMES
def prepare_splits(
train_ratio: float = 0.70,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
):
global _SPLITS
if _SPLITS is not None:
return _SPLITS
raw, class_names = load_raw_dataset()
if "validation" in raw and "test" in raw:
_SPLITS = {
"train": raw["train"],
"validation": raw["validation"],
"test": raw["test"],
}
return _SPLITS
if "test" in raw:
train_val = raw["train"]
test = raw["test"]
relative_val_ratio = val_ratio / (train_ratio + val_ratio)
try:
split_train_val = train_val.train_test_split(
test_size=relative_val_ratio,
seed=RANDOM_SEED,
stratify_by_column="label",
)
except Exception:
split_train_val = train_val.train_test_split(
test_size=relative_val_ratio,
seed=RANDOM_SEED,
)
_SPLITS = {
"train": split_train_val["train"],
"validation": split_train_val["test"],
"test": test,
}
return _SPLITS
full = raw["train"]
try:
first_split = full.train_test_split(
test_size=(val_ratio + test_ratio),
seed=RANDOM_SEED,
stratify_by_column="label",
)
except Exception:
first_split = full.train_test_split(
test_size=(val_ratio + test_ratio),
seed=RANDOM_SEED,
)
temp = first_split["test"]
relative_test_ratio = test_ratio / (val_ratio + test_ratio)
try:
second_split = temp.train_test_split(
test_size=relative_test_ratio,
seed=RANDOM_SEED,
stratify_by_column="label",
)
except Exception:
second_split = temp.train_test_split(
test_size=relative_test_ratio,
seed=RANDOM_SEED,
)
_SPLITS = {
"train": first_split["train"],
"validation": second_split["train"],
"test": second_split["test"],
}
return _SPLITS
def get_class_names() -> List[str]:
_, class_names = load_raw_dataset()
return class_names
def make_loaders(batch_size: int):
splits = prepare_splits()
class_names = get_class_names()
train_dataset = HFDatasetWrapper(splits["train"], get_train_transform())
val_dataset = HFDatasetWrapper(splits["validation"], get_eval_transform())
test_dataset = HFDatasetWrapper(splits["test"], get_eval_transform())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, test_loader, class_names
def dataset_overview() -> Tuple[dict, pd.DataFrame]:
splits = prepare_splits()
class_names = get_class_names()
rows = []
total = 0
for split_name, split_data in splits.items():
labels = list(split_data["label"])
counter = Counter(labels)
split_total = len(labels)
total += split_total
for label_id, count in sorted(counter.items()):
rows.append(
{
"split": split_name,
"classe": class_names[int(label_id)],
"nombre_images": count,
}
)
df = pd.DataFrame(rows)
summary = {
"dataset": HF_DATASET_REPO,
"nombre_total_images": total,
"nombre_classes": len(class_names),
"train": len(splits["train"]),
"validation": len(splits["validation"]),
"test": len(splits["test"]),
}
return summary, df
def get_images_for_gallery(split_name: str, class_name: str, max_images: int = 24):
splits = prepare_splits()
class_names = get_class_names()
if split_name not in splits:
split_name = "train"
dataset = splits[split_name]
if class_name and class_name != "Toutes les classes":
class_id = class_names.index(class_name)
indices = [i for i, x in enumerate(dataset["label"]) if int(x) == class_id]
else:
indices = list(range(len(dataset)))
if not indices:
return []
sample_indices = random.sample(indices, min(max_images, len(indices)))
gallery = []
for idx in sample_indices:
item = dataset[idx]
image = item["image"]
if not isinstance(image, Image.Image):
image = Image.open(image)
image = image.convert("RGB")
label_id = int(item["label"])
label_name = class_names[label_id]
gallery.append((image, label_name))
return gallery