| | """ |
| | Gradio app to showcase the bear face recognition model. |
| | """ |
| |
|
| | import uuid |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import gradio as gr |
| | from PIL import Image |
| |
|
| | from ui import bearid_ui |
| | from utils import load_models, run_pipeline, setup |
| |
|
| |
|
| | def prediction_to_str(bear_ids: list[str], indexed_k_nearest_individuals) -> str: |
| | """ |
| | Turn the prediction into a human friendly string. |
| | """ |
| | nearest_individuals = [] |
| |
|
| | for j in range(len(bear_ids)): |
| | bear_id = bear_ids[j] |
| | nearest_individual = indexed_k_nearest_individuals[bear_id][0] |
| | distance = nearest_individual["distance"] |
| | nearest_individuals.append( |
| | {"bear_id": bear_id, "distance": distance, "data": nearest_individual} |
| | ) |
| |
|
| | distance_str = "\n".join( |
| | [ |
| | f"- {n['bear_id']} at distance {n['distance']:.2f} in the embedding space" |
| | for n in nearest_individuals |
| | ] |
| | ) |
| | return f"The model found that the closest individual is {bear_ids[0]}:\n\n{distance_str}" |
| |
|
| |
|
| | def examples(dir_examples: Path) -> list[Path]: |
| | """ |
| | List the images from the dir_examples directory. |
| | |
| | Returns: |
| | filepaths (list[Path]): list of image filepaths. |
| | """ |
| | return list(dir_examples.glob("*.jpg")) |
| |
|
| |
|
| | |
| | INPUT_PACKAGED_PIPELINE = Path("./data/09_external/artifacts/packaged_pipeline.zip") |
| | PIPELINE_INSTALL_PATH = Path("./data/06_models/pipeline/metriclearning/") |
| | METRIC_LEARNING_MODEL_FILEPATH = Path( |
| | "./data/06_models/pipeline/metriclearning/bearidentification/model.pt" |
| | ) |
| | METRIC_LEARNING_KNN_INDEX_FILEPATH = Path( |
| | "./data/06_models/pipeline/metriclearning/bearidentification/knn.index" |
| | ) |
| | INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH = Path( |
| | "./data/06_models/pipeline/metriclearning/bearfacesegmentation/model.pt" |
| | ) |
| | CACHE_PIPELINE_RUN = {} |
| |
|
| |
|
| | def pipeline_run_fn( |
| | loaded_models: dict[str, Any], |
| | pil_image: Image.Image, |
| | param_square_dim: int, |
| | param_k: int, |
| | param_n_samples_per_individual: int, |
| | knn_index_filepath: Path, |
| | cache: dict, |
| | ) -> dict: |
| | """ |
| | A simple cached version of the pipeline.run function. |
| | __Note__: It keeps the cache in memory and can grow unbounded. |
| | """ |
| | bytes_image = pil_image.tobytes() |
| | if bytes_image in cache: |
| | print("CACHE HIT") |
| | return cache[bytes_image] |
| | else: |
| | print("CACHE MISS") |
| | results = run_pipeline( |
| | loaded_models=loaded_models, |
| | pil_image=pil_image, |
| | param_square_dim=param_square_dim, |
| | param_k=param_k, |
| | param_n_samples_per_individual=param_n_samples_per_individual, |
| | knn_index_filepath=knn_index_filepath, |
| | ) |
| | cache[bytes_image] = results |
| | return results |
| |
|
| |
|
| | setup( |
| | input_packaged_pipeline=INPUT_PACKAGED_PIPELINE, |
| | install_path=PIPELINE_INSTALL_PATH, |
| | ) |
| |
|
| | |
| | DIR_EXAMPLES = Path("data/images/") |
| | DEFAULT_IMAGE_INDEX = 0 |
| | PARAM_K = 5 |
| | PARAM_N_SAMPLES_PER_INDIVIDUAL = 4 |
| | PARAM_SQUARE_DIM = 300 |
| | OUTPUT_DIR_PREDICTION = Path("./data/07_model_output/predict/") |
| |
|
| | BEAR_ID_TO_NAME = { |
| | "bf_480": "David Martinez", |
| | "bf_813": "Michael Thompson", |
| | "bf_503": "Jessica Brown", |
| | } |
| |
|
| | loaded_models = load_models( |
| | filepath_metric_learning_weights=METRIC_LEARNING_MODEL_FILEPATH, |
| | filepath_segmentation_weights=INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH, |
| | ) |
| | image_filepaths = examples(dir_examples=DIR_EXAMPLES) |
| | default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX]) |
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | with gr.Column(): |
| | image_input = gr.Image( |
| | type="pil", |
| | value=default_value_input, |
| | label="input image", |
| | sources=["upload", "clipboard"], |
| | ) |
| | gr.Examples( |
| | examples=image_filepaths, |
| | inputs=image_input, |
| | ) |
| | submit_btn = gr.Button(value="Identify", variant="primary") |
| |
|
| | with gr.Column(): |
| | with gr.Tab("Prediction"): |
| | with gr.Row(): |
| | output_bear_id = gr.Text(label="Predicted Individual") |
| | output_bear_name = gr.Text(label="Name", visible=False) |
| | output_bear_id_samples = gr.Gallery( |
| | label="Similar faces of the same bear", |
| | preview=True, |
| | visible=False, |
| | ) |
| |
|
| | with gr.Tab("Details", visible=False) as tab_details: |
| | output_segmentation_image = gr.Image( |
| | type="pil", label="model prediction" |
| | ) |
| | output_cropped_image = gr.Image(type="pil", label="cropped bear face") |
| | output_identification_prediction_image = gr.Image( |
| | type="pil", |
| | label="prediction", |
| | ) |
| | output_raw = gr.Text(label="raw prediction") |
| |
|
| | def submit_fn( |
| | loaded_models: dict[str, Any], |
| | pil_image: Image.Image, |
| | param_square_dim: int, |
| | param_k: int, |
| | param_n_samples_per_individual: int, |
| | ) -> dict: |
| | """ |
| | Main function used for the Gradio interface. |
| | |
| | Args: |
| | loaded_models (dict[str, YOLO]): loaded models. |
| | pil_image (PIL): original image picked by the user |
| | |
| | Returns: |
| | bear_id (str): Identification string for the bear. |
| | output_bear_id_samples (list[PIL]): random images of the identified bear. |
| | output_segmentation_image (PIL): image of the segmented bear head. |
| | output_cropped_image (PIL): image of the cropped bear head. |
| | output_identification_prediction_image (PIL): image with prediction from the model. |
| | output_raw (str): string representing the raw prediction from the |
| | model. |
| | tab_details (gr.Column): Gradio Tab column. |
| | """ |
| | result = pipeline_run_fn( |
| | loaded_models=loaded_models, |
| | pil_image=pil_image, |
| | param_square_dim=param_square_dim, |
| | param_k=param_k, |
| | param_n_samples_per_individual=param_n_samples_per_individual, |
| | knn_index_filepath=METRIC_LEARNING_KNN_INDEX_FILEPATH, |
| | cache=CACHE_PIPELINE_RUN, |
| | ) |
| | pil_image_segmented_head = result["stages"]["segmentation"]["output"][ |
| | "pil_image" |
| | ] |
| | pil_image_cropped_head = result["stages"]["crop"]["output"]["pil_images"][ |
| | "resized" |
| | ] |
| |
|
| | bear_ids = result["stages"]["identification"]["output"]["bear_ids"] |
| | indexed_samples = result["stages"]["identification"]["output"][ |
| | "indexed_samples" |
| | ] |
| | indexed_k_nearest_individuals = result["stages"]["identification"]["output"][ |
| | "indexed_k_nearest_individuals" |
| | ] |
| | bear_id = bear_ids[0] |
| |
|
| | output_filepath = OUTPUT_DIR_PREDICTION / f"prediction_{uuid.uuid4()}.png" |
| | output_filepath.parent.mkdir(parents=True, exist_ok=True) |
| | if output_filepath.exists(): |
| | output_filepath.unlink() |
| |
|
| | bearid_ui( |
| | pil_image_chip=pil_image_cropped_head, |
| | indexed_k_nearest_individuals=indexed_k_nearest_individuals, |
| | indexed_samples=indexed_samples, |
| | save_filepath=output_filepath, |
| | k_closest_neighbors=param_k, |
| | ) |
| |
|
| | pil_image_identification_prediction = Image.open(output_filepath) |
| |
|
| | raw_prediction_str = prediction_to_str( |
| | bear_ids=bear_ids, |
| | indexed_k_nearest_individuals=indexed_k_nearest_individuals, |
| | ) |
| |
|
| | return { |
| | output_bear_id: bear_id, |
| | output_bear_name: gr.Text( |
| | BEAR_ID_TO_NAME.get(bear_id, "Laura Anderson"), visible=True |
| | ), |
| | output_bear_id_samples: gr.Gallery( |
| | [Image.open(fp) for fp in indexed_samples[bear_id]], visible=True |
| | ), |
| | output_segmentation_image: pil_image_segmented_head, |
| | output_cropped_image: pil_image_cropped_head, |
| | output_identification_prediction_image: pil_image_identification_prediction, |
| | output_raw: gr.Text( |
| | raw_prediction_str, lines=len(raw_prediction_str.split("\n")) |
| | ), |
| | tab_details: gr.Column(visible=True), |
| | } |
| |
|
| | submit_btn.click( |
| | fn=lambda pil_image: submit_fn( |
| | loaded_models=loaded_models, |
| | pil_image=pil_image, |
| | param_square_dim=PARAM_SQUARE_DIM, |
| | param_k=PARAM_K, |
| | param_n_samples_per_individual=PARAM_N_SAMPLES_PER_INDIVIDUAL, |
| | ), |
| | inputs=image_input, |
| | outputs=[ |
| | output_bear_id, |
| | output_bear_name, |
| | output_bear_id_samples, |
| | output_segmentation_image, |
| | output_cropped_image, |
| | output_identification_prediction_image, |
| | output_raw, |
| | tab_details, |
| | ], |
| | ) |
| |
|
| | demo.launch() |
| |
|