| | from pathlib import Path |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import torchvision |
| | from matplotlib import font_manager |
| | from matplotlib.figure import Figure |
| | from matplotlib.gridspec import GridSpec |
| | from PIL import Image |
| |
|
| | IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| | IMAGENET_STD = [0.229, 0.224, 0.225] |
| | DISTANCE_THRESHOLD_NEW_INDIVIDUAL = 0.7 |
| |
|
| |
|
| | def get_inverse_normalize_transform(mean, std): |
| | return torchvision.transforms.Normalize( |
| | mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std] |
| | ) |
| |
|
| |
|
| | def get_color( |
| | distance: float, |
| | distance_threshold_new_individual: float = DISTANCE_THRESHOLD_NEW_INDIVIDUAL, |
| | margin: float = 0.10, |
| | ) -> str: |
| | threshold_unsure = distance_threshold_new_individual * (1.0 - margin) |
| | threshold_new_individual = distance_threshold_new_individual * (1 + margin) |
| | if distance < threshold_unsure: |
| | return "green" |
| | elif distance < threshold_new_individual: |
| | return "orange" |
| | else: |
| | return "red" |
| |
|
| |
|
| | def draw_extrated_chip(ax, chip_image) -> None: |
| | ax.set_title("Extracted chip") |
| | ax.set_axis_off() |
| | ax.imshow(chip_image) |
| |
|
| |
|
| | def draw_closest_neighbors( |
| | fig: Figure, |
| | gs: GridSpec, |
| | i_start: int, |
| | k_closest_neighbors: int, |
| | indexed_k_nearest_individuals: dict, |
| | ) -> None: |
| | inv_normalize = get_inverse_normalize_transform( |
| | mean=IMAGENET_MEAN, |
| | std=IMAGENET_STD, |
| | ) |
| |
|
| | neighbors = [] |
| | for bear_id, xs in indexed_k_nearest_individuals.items(): |
| | for x in xs: |
| | data = x.copy() |
| | data["bear_id"] = bear_id |
| | neighbors.append(data) |
| |
|
| | nearest_neighbors = sorted( |
| | neighbors, |
| | key=lambda x: x["distance"], |
| | )[:k_closest_neighbors] |
| | for j, neighbor in enumerate(nearest_neighbors): |
| | ax = fig.add_subplot(gs[i_start, j]) |
| | distance = neighbor["distance"] |
| | bear_id = neighbor["bear_id"] |
| | dataset_image = neighbor["dataset_image"] |
| | image = inv_normalize(dataset_image).numpy() |
| | image = np.transpose(image, (1, 2, 0)) |
| | color = get_color(distance=distance) |
| | ax.set_axis_off() |
| | ax.set_title(label=f"{bear_id}: {distance:.2f}", color=color) |
| | ax.imshow(image) |
| |
|
| |
|
| | def draw_top_k_individuals( |
| | fig: Figure, |
| | gs: GridSpec, |
| | i_start: int, |
| | i_end: int, |
| | indexed_k_nearest_individuals: dict, |
| | bear_ids: list[str], |
| | indexed_samples: dict, |
| | ): |
| | inv_normalize = get_inverse_normalize_transform( |
| | mean=IMAGENET_MEAN, |
| | std=IMAGENET_STD, |
| | ) |
| | for i in range(i_start, i_end): |
| | for j in range(len(bear_ids)): |
| | |
| | if i == i_start: |
| | ax = fig.add_subplot(gs[i, j]) |
| | bear_id = bear_ids[j] |
| | nearest_individual = indexed_k_nearest_individuals[bear_id][0] |
| | distance = nearest_individual["distance"] |
| | dataset_image = nearest_individual["dataset_image"] |
| | image = inv_normalize(dataset_image).numpy() |
| | image = np.transpose(image, (1, 2, 0)) |
| | color = get_color(distance=distance) |
| | ax.set_axis_off() |
| | ax.set_title(label=f"{bear_id}: {distance:.2f}", color=color) |
| | ax.imshow(image) |
| |
|
| | |
| | else: |
| | bear_id = bear_ids[j] |
| | idx = i - i_start - 1 |
| | if idx < len(indexed_samples[bear_id]): |
| | filepath = indexed_samples[bear_id][idx] |
| | if filepath: |
| | ax = fig.add_subplot(gs[i, j]) |
| | with Image.open(filepath) as image: |
| | ax.set_axis_off() |
| | ax.imshow(image) |
| |
|
| |
|
| | def bearid_ui( |
| | pil_image_chip: Image.Image, |
| | indexed_k_nearest_individuals: dict, |
| | indexed_samples: dict, |
| | save_filepath: Path, |
| | k_closest_neighbors: int = 5, |
| | ) -> None: |
| | """Main UI for identifying bears.""" |
| | chip_image = pil_image_chip |
| | |
| | |
| | bear_ids = list(indexed_k_nearest_individuals.keys()) |
| |
|
| | |
| | ncols = max(len(bear_ids), k_closest_neighbors) |
| |
|
| | |
| | |
| | |
| | |
| | nrows = max([len(xs) for xs in indexed_samples.values()]) + 3 |
| | figsize = (3 * ncols, 3 * nrows) |
| | fig = plt.figure(constrained_layout=True, figsize=figsize) |
| | gs = GridSpec(nrows=nrows, ncols=ncols, figure=fig) |
| | font_properties_section = font_manager.FontProperties(size=35) |
| | font_properties_title = font_manager.FontProperties(size=40) |
| |
|
| | |
| | i_closest_neighbors = 2 |
| | ax = fig.add_subplot(gs[i_closest_neighbors - 1, :]) |
| | ax.set_axis_off() |
| | ax.text( |
| | y=0.2, |
| | x=0, |
| | s="Closest faces", |
| | font_properties=font_properties_section, |
| | ) |
| | draw_closest_neighbors( |
| | fig=fig, |
| | gs=gs, |
| | i_start=i_closest_neighbors, |
| | k_closest_neighbors=k_closest_neighbors, |
| | indexed_k_nearest_individuals=indexed_k_nearest_individuals, |
| | ) |
| | |
| | i_top_k_individual = 4 |
| | ax = fig.add_subplot(gs[i_top_k_individual - 1, :]) |
| | ax.set_axis_off() |
| | ax.text( |
| | y=0.2, |
| | x=0, |
| | s=f"Closest {len(bear_ids)} individuals", |
| | font_properties=font_properties_section, |
| | ) |
| | draw_top_k_individuals( |
| | fig=fig, |
| | gs=gs, |
| | i_end=nrows, |
| | i_start=i_top_k_individual, |
| | indexed_k_nearest_individuals=indexed_k_nearest_individuals, |
| | bear_ids=bear_ids, |
| | indexed_samples=indexed_samples, |
| | ) |
| |
|
| | plt.savefig(save_filepath, bbox_inches="tight") |
| | plt.close() |
| |
|