Spaces:
Runtime error
Runtime error
Commit
·
1f5d8b4
1
Parent(s):
ede1f67
Add DocString
Browse files
app.py
CHANGED
|
@@ -37,9 +37,9 @@ id2label = {
|
|
| 37 |
}
|
| 38 |
label2id = {v: k for k, v in id2label.items()}
|
| 39 |
num_labels = len(id2label)
|
| 40 |
-
checkpoint = "nvidia/segformer-
|
| 41 |
-
image_processor = SegformerImageProcessor()
|
| 42 |
-
state_dict_path = f"runs/{checkpoint}
|
| 43 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 44 |
checkpoint,
|
| 45 |
num_labels=num_labels,
|
|
@@ -58,6 +58,17 @@ model.eval()
|
|
| 58 |
|
| 59 |
|
| 60 |
def load_and_prepare_images(image_name, segformer=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
image_path = os.path.join(data_folder, "images", image_name)
|
| 62 |
mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png")
|
| 63 |
mask_path = os.path.join(data_folder, "masks", mask_name)
|
|
@@ -82,35 +93,47 @@ def load_and_prepare_images(image_name, segformer=False):
|
|
| 82 |
|
| 83 |
|
| 84 |
def predict_segmentation(image):
|
| 85 |
-
|
| 86 |
-
|
| 87 |
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 90 |
model.to(device)
|
| 91 |
-
|
| 92 |
-
# Déplacer les inputs sur le bon device et faire la prédiction
|
| 93 |
pixel_values = inputs.pixel_values.to(device)
|
| 94 |
|
| 95 |
-
with torch.no_grad():
|
| 96 |
outputs = model(pixel_values=pixel_values)
|
| 97 |
logits = outputs.logits
|
| 98 |
|
| 99 |
-
# Redimensionner les logits à la taille de l'image d'origine
|
| 100 |
upsampled_logits = nn.functional.interpolate(
|
| 101 |
logits,
|
| 102 |
size=image.size[::-1], # (height, width)
|
| 103 |
mode="bilinear",
|
| 104 |
align_corners=False,
|
| 105 |
)
|
| 106 |
-
|
| 107 |
-
# Obtenir la prédiction finale
|
| 108 |
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
|
| 109 |
|
| 110 |
return pred_seg
|
| 111 |
|
| 112 |
|
| 113 |
def process_image(image_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images(
|
| 115 |
image_name, segformer=True
|
| 116 |
)
|
|
@@ -131,6 +154,12 @@ def process_image(image_name):
|
|
| 131 |
|
| 132 |
|
| 133 |
def create_cityscapes_label_colormap():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
colormap = np.zeros((256, 3), dtype=np.uint8)
|
| 135 |
colormap[0] = [78, 82, 110]
|
| 136 |
colormap[1] = [128, 64, 128]
|
|
@@ -147,68 +176,43 @@ def create_cityscapes_label_colormap():
|
|
| 147 |
cityscapes_colormap = create_cityscapes_label_colormap()
|
| 148 |
|
| 149 |
|
| 150 |
-
def blend_images(original_image, colored_segmentation, alpha=0.6):
|
| 151 |
-
blended_image = Image.blend(original_image, colored_segmentation, alpha)
|
| 152 |
-
return blended_image
|
| 153 |
-
|
| 154 |
-
|
| 155 |
def colorize_mask(mask):
|
| 156 |
return cityscapes_colormap[mask]
|
| 157 |
|
| 158 |
|
| 159 |
# ---- Fin Partie Segmentation
|
| 160 |
|
| 161 |
-
# def compare_masks(real_mask, fpn_mask, segformer_mask):
|
| 162 |
-
# """
|
| 163 |
-
# Compare les masques prédits par FPN et SegFormer avec le masque réel.
|
| 164 |
-
# Retourne un score IoU et une précision pixel par pixel pour chaque modèle.
|
| 165 |
-
|
| 166 |
-
# Args:
|
| 167 |
-
# real_mask (np.array): Le masque réel de référence
|
| 168 |
-
# fpn_mask (np.array): Le masque prédit par le modèle FPN
|
| 169 |
-
# segformer_mask (np.array): Le masque prédit par le modèle SegFormer
|
| 170 |
-
|
| 171 |
-
# Returns:
|
| 172 |
-
# dict: Dictionnaire contenant les scores IoU et les précisions pour chaque modèle
|
| 173 |
-
# """
|
| 174 |
-
|
| 175 |
-
# assert real_mask.shape == fpn_mask.shape == segformer_mask.shape, "Les masques doivent avoir la même forme"
|
| 176 |
-
|
| 177 |
-
# real_flat = real_mask.flatten()
|
| 178 |
-
# fpn_flat = fpn_mask.flatten()
|
| 179 |
-
# segformer_flat = segformer_mask.flatten()
|
| 180 |
-
|
| 181 |
-
# # Calcul du score de Jaccard (IoU)
|
| 182 |
-
# iou_fpn = jaccard_score(real_flat, fpn_flat, average='weighted')
|
| 183 |
-
# iou_segformer = jaccard_score(real_flat, segformer_flat, average='weighted')
|
| 184 |
-
|
| 185 |
-
# # Calcul de la précision pixel par pixel
|
| 186 |
-
# accuracy_fpn = accuracy_score(real_flat, fpn_flat)
|
| 187 |
-
# accuracy_segformer = accuracy_score(real_flat, segformer_flat)
|
| 188 |
-
|
| 189 |
-
# return {
|
| 190 |
-
# 'FPN': {'IoU': iou_fpn, 'Precision': accuracy_fpn},
|
| 191 |
-
# 'SegFormer': {'IoU': iou_segformer, 'Precision': accuracy_segformer}
|
| 192 |
-
# }
|
| 193 |
-
|
| 194 |
# ---- Partie EDA
|
| 195 |
|
| 196 |
|
| 197 |
def analyse_mask(real_mask, num_labels):
|
| 198 |
-
|
| 199 |
-
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
|
|
|
| 203 |
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
class_proportions = counts / total_pixels
|
| 206 |
-
|
| 207 |
-
# Créer un dictionnaire avec les proportions
|
| 208 |
return dict(enumerate(class_proportions))
|
| 209 |
|
| 210 |
|
| 211 |
def show_eda(image_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
original_image, true_mask, _ = load_and_prepare_images(image_name)
|
| 213 |
class_proportions = analyse_mask(true_mask, num_labels)
|
| 214 |
cityscapes_colormap = create_cityscapes_label_colormap()
|
|
@@ -266,17 +270,54 @@ def show_eda(image_name):
|
|
| 266 |
|
| 267 |
|
| 268 |
class SegformerWrapper(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
def __init__(self, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
super().__init__()
|
| 271 |
self.model = model
|
| 272 |
|
| 273 |
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
output = self.model(x)
|
| 275 |
return output.logits
|
| 276 |
|
| 277 |
|
| 278 |
class SemanticSegmentationTarget:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
def __init__(self, category, mask):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
self.category = category
|
| 281 |
self.mask = torch.from_numpy(mask)
|
| 282 |
if torch.cuda.is_available():
|
|
@@ -305,12 +346,33 @@ class SemanticSegmentationTarget:
|
|
| 305 |
|
| 306 |
|
| 307 |
def segformer_reshape_transform_huggingface(tensor, width, height):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
|
| 309 |
result = result.transpose(2, 3).transpose(1, 2)
|
| 310 |
return result
|
| 311 |
|
| 312 |
|
| 313 |
def explain_model(image_name, category_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
original_image, _, _ = load_and_prepare_images(image_name)
|
| 315 |
rgb_img = np.float32(original_image) / 255
|
| 316 |
img_tensor = transforms.ToTensor()(rgb_img)
|
|
@@ -379,6 +441,12 @@ import random
|
|
| 379 |
|
| 380 |
|
| 381 |
def change_image():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
image_dir = (
|
| 383 |
"data_sample/images" # Remplacez par le chemin de votre dossier d'images
|
| 384 |
)
|
|
@@ -388,6 +456,16 @@ def change_image():
|
|
| 388 |
|
| 389 |
|
| 390 |
def apply_augmentation(image, augmentation_names):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
augmentations = {
|
| 392 |
"Horizontal Flip": A.HorizontalFlip(p=1),
|
| 393 |
"Shift Scale Rotate": A.ShiftScaleRotate(p=1),
|
|
@@ -541,4 +619,4 @@ with gr.Blocks(title="Preuve de concept", theme=my_theme) as demo:
|
|
| 541 |
|
| 542 |
|
| 543 |
# Lancer l'application
|
| 544 |
-
demo.launch(favicon_path="favicon.ico"
|
|
|
|
| 37 |
}
|
| 38 |
label2id = {v: k for k, v in id2label.items()}
|
| 39 |
num_labels = len(id2label)
|
| 40 |
+
checkpoint = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
|
| 41 |
+
image_processor = SegformerImageProcessor(do_resize=False)
|
| 42 |
+
state_dict_path = f"runs/{checkpoint}/best_model.pt"
|
| 43 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 44 |
checkpoint,
|
| 45 |
num_labels=num_labels,
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def load_and_prepare_images(image_name, segformer=False):
|
| 61 |
+
"""
|
| 62 |
+
Charge et prépare les images, les masques et les prédictions associées pour une image donnée.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
image_name (str): Le nom du fichier de l'image à charger.
|
| 66 |
+
segformer (bool, optional): Si True, prédit également le masque avec SegFormer. Par défaut False.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
tuple: Contient l'image originale redimensionnée, le masque réel, la prédiction FPN,
|
| 70 |
+
et la prédiction SegFormer si `segformer` est True.
|
| 71 |
+
"""
|
| 72 |
image_path = os.path.join(data_folder, "images", image_name)
|
| 73 |
mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png")
|
| 74 |
mask_path = os.path.join(data_folder, "masks", mask_name)
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
def predict_segmentation(image):
|
| 96 |
+
"""
|
| 97 |
+
Prédit la segmentation d'une image donnée à l'aide d'un modèle pré-entraîné.
|
| 98 |
|
| 99 |
+
Args:
|
| 100 |
+
image (PIL.Image.Image): L'image à segmenter.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
numpy.ndarray: La carte de segmentation prédite.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
| 107 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 108 |
model.to(device)
|
|
|
|
|
|
|
| 109 |
pixel_values = inputs.pixel_values.to(device)
|
| 110 |
|
| 111 |
+
with torch.no_grad():
|
| 112 |
outputs = model(pixel_values=pixel_values)
|
| 113 |
logits = outputs.logits
|
| 114 |
|
|
|
|
| 115 |
upsampled_logits = nn.functional.interpolate(
|
| 116 |
logits,
|
| 117 |
size=image.size[::-1], # (height, width)
|
| 118 |
mode="bilinear",
|
| 119 |
align_corners=False,
|
| 120 |
)
|
|
|
|
|
|
|
| 121 |
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
|
| 122 |
|
| 123 |
return pred_seg
|
| 124 |
|
| 125 |
|
| 126 |
def process_image(image_name):
|
| 127 |
+
"""
|
| 128 |
+
Traite une image en chargeant l'image originale, le masque réel, et les prédictions de masques.
|
| 129 |
+
Envoie la liste de tuple à l'interface "Predictions" de Gradio
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
image_name (str): Le nom de l'image à traiter.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
list: Une liste de tuples contenant l'image et son titre associé.
|
| 136 |
+
"""
|
| 137 |
original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images(
|
| 138 |
image_name, segformer=True
|
| 139 |
)
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
def create_cityscapes_label_colormap():
|
| 157 |
+
"""
|
| 158 |
+
Crée une colormap pour les labels Cityscapes.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
numpy.ndarray: Un tableau 2D où chaque ligne représente la couleur RGB d'un label.
|
| 162 |
+
"""
|
| 163 |
colormap = np.zeros((256, 3), dtype=np.uint8)
|
| 164 |
colormap[0] = [78, 82, 110]
|
| 165 |
colormap[1] = [128, 64, 128]
|
|
|
|
| 176 |
cityscapes_colormap = create_cityscapes_label_colormap()
|
| 177 |
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
def colorize_mask(mask):
|
| 180 |
return cityscapes_colormap[mask]
|
| 181 |
|
| 182 |
|
| 183 |
# ---- Fin Partie Segmentation
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
# ---- Partie EDA
|
| 186 |
|
| 187 |
|
| 188 |
def analyse_mask(real_mask, num_labels):
|
| 189 |
+
"""
|
| 190 |
+
Analyse la distribution des classes dans un masque réel.
|
| 191 |
|
| 192 |
+
Args:
|
| 193 |
+
real_mask (numpy.ndarray): Le masque de labels réels.
|
| 194 |
+
num_labels (int): Le nombre total de classes.
|
| 195 |
|
| 196 |
+
Returns:
|
| 197 |
+
dict: Un dictionnaire contenant les proportions des classes dans le masque.
|
| 198 |
+
"""
|
| 199 |
+
counts = np.bincount(real_mask.ravel(), minlength=num_labels)
|
| 200 |
+
total_pixels = real_mask.size
|
| 201 |
class_proportions = counts / total_pixels
|
|
|
|
|
|
|
| 202 |
return dict(enumerate(class_proportions))
|
| 203 |
|
| 204 |
|
| 205 |
def show_eda(image_name):
|
| 206 |
+
"""
|
| 207 |
+
Affiche une analyse exploratoire de la distribution des classes pour une image et son masque associé.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
image_name (str): Le nom de l'image à analyser.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
tuple: Contient l'image originale, le masque réel coloré et une figure Plotly représentant
|
| 214 |
+
la distribution des classes.
|
| 215 |
+
"""
|
| 216 |
original_image, true_mask, _ = load_and_prepare_images(image_name)
|
| 217 |
class_proportions = analyse_mask(true_mask, num_labels)
|
| 218 |
cityscapes_colormap = create_cityscapes_label_colormap()
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
class SegformerWrapper(nn.Module):
|
| 273 |
+
"""
|
| 274 |
+
Un wrapper pour le modèle SegFormer qui renvoie uniquement les logits en sortie.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
model (torch.nn.Module): Le modèle SegFormer pré-entraîné.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
def __init__(self, model):
|
| 281 |
+
"""
|
| 282 |
+
Initialise le SegformerWrapper.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
model (torch.nn.Module): Le modèle SegFormer pré-entraîné.
|
| 286 |
+
"""
|
| 287 |
super().__init__()
|
| 288 |
self.model = model
|
| 289 |
|
| 290 |
def forward(self, x):
|
| 291 |
+
"""
|
| 292 |
+
Renvoie les logits du modèle au lieu de renvoyer un dictionnaire.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
x (torch.Tensor): Les entrées du modèle.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
torch.Tensor: Les logits du modèle.
|
| 299 |
+
"""
|
| 300 |
output = self.model(x)
|
| 301 |
return output.logits
|
| 302 |
|
| 303 |
|
| 304 |
class SemanticSegmentationTarget:
|
| 305 |
+
"""
|
| 306 |
+
Représente une classe cible pour la segmentation sémantique utilisée dans GradCAM.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
category (int): L'index de la catégorie cible.
|
| 310 |
+
mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
def __init__(self, category, mask):
|
| 314 |
+
"""
|
| 315 |
+
Initialise la cible de segmentation sémantique.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
category (int): L'index de la catégorie cible.
|
| 319 |
+
mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt.
|
| 320 |
+
"""
|
| 321 |
self.category = category
|
| 322 |
self.mask = torch.from_numpy(mask)
|
| 323 |
if torch.cuda.is_available():
|
|
|
|
| 346 |
|
| 347 |
|
| 348 |
def segformer_reshape_transform_huggingface(tensor, width, height):
|
| 349 |
+
"""
|
| 350 |
+
Réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
tensor (torch.Tensor): Le tenseur à réorganiser.
|
| 354 |
+
width (int): La nouvelle largeur.
|
| 355 |
+
height (int): La nouvelle hauteur.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
torch.Tensor: Le tenseur réorganisé.
|
| 359 |
+
"""
|
| 360 |
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
|
| 361 |
result = result.transpose(2, 3).transpose(1, 2)
|
| 362 |
return result
|
| 363 |
|
| 364 |
|
| 365 |
def explain_model(image_name, category_name):
|
| 366 |
+
"""
|
| 367 |
+
Explique les prédictions du modèle SegFormer en utilisant GradCAM pour une image et une catégorie données.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
image_name (str): Le nom de l'image à expliquer.
|
| 371 |
+
category_name (str): Le nom de la catégorie cible.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
matplotlib.figure.Figure: Une figure matplotlib contenant la carte de chaleur GradCAM superposée sur l'image originale.
|
| 375 |
+
"""
|
| 376 |
original_image, _, _ = load_and_prepare_images(image_name)
|
| 377 |
rgb_img = np.float32(original_image) / 255
|
| 378 |
img_tensor = transforms.ToTensor()(rgb_img)
|
|
|
|
| 441 |
|
| 442 |
|
| 443 |
def change_image():
|
| 444 |
+
"""
|
| 445 |
+
Sélectionne et charge aléatoirement une image depuis un dossier spécifié.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
PIL.Image.Image: L'image sélectionnée.
|
| 449 |
+
"""
|
| 450 |
image_dir = (
|
| 451 |
"data_sample/images" # Remplacez par le chemin de votre dossier d'images
|
| 452 |
)
|
|
|
|
| 456 |
|
| 457 |
|
| 458 |
def apply_augmentation(image, augmentation_names):
|
| 459 |
+
"""
|
| 460 |
+
Applique une ou plusieurs augmentations à une image.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
image (PIL.Image.Image): L'image à augmenter.
|
| 464 |
+
augmentation_names (list of str): Les noms des augmentations à appliquer.
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
PIL.Image.Image: L'image augmentée.
|
| 468 |
+
"""
|
| 469 |
augmentations = {
|
| 470 |
"Horizontal Flip": A.HorizontalFlip(p=1),
|
| 471 |
"Shift Scale Rotate": A.ShiftScaleRotate(p=1),
|
|
|
|
| 619 |
|
| 620 |
|
| 621 |
# Lancer l'application
|
| 622 |
+
demo.launch(favicon_path="favicon.ico")
|