Spaces:
Running
Running
| """ | |
| Gradio app to showcase the pyronear model for salmon vision. | |
| """ | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Any, Tuple | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from ultralytics import YOLO | |
| def bgr_to_rgb(a: np.ndarray) -> np.ndarray: | |
| """ | |
| Turn a BGR numpy array into a RGB numpy array when the array `a` represents | |
| an image. | |
| """ | |
| return a[:, :, ::-1] | |
| def has_values(maybe_tensor: torch.Tensor | None) -> bool: | |
| """ | |
| Check whether the `maybe_tensor` contains items. | |
| """ | |
| if maybe_tensor is None: | |
| return False | |
| elif isinstance(maybe_tensor, torch.Tensor): | |
| if len(maybe_tensor) == 0: | |
| return False | |
| else: | |
| return True | |
| def analyze_predictions(yolo_predictions) -> dict[str, Any]: | |
| """ | |
| Analyze the raw `yolo_predictions` and outputs a dict containg information. | |
| Args: | |
| yolo_predictions: result of calling model.track() on a video | |
| Returns: | |
| counts (int): number of distinct identifiers. | |
| ids (set[int]): all the assigned identifiers. | |
| detected_species (dict[int, int]): mapping from identifier to instance class | |
| names (list[str]): the class names used by the model | |
| """ | |
| if len(yolo_predictions) == 0: | |
| return { | |
| "counts": 0, | |
| "ids": set(), | |
| "detected_species": {}, | |
| "names": None, | |
| } | |
| else: | |
| names = yolo_predictions[0].names | |
| ids = set() | |
| for prediction in yolo_predictions: | |
| if has_values(prediction.boxes.id): | |
| for id in prediction.boxes.id.numpy().astype("int"): | |
| ids.add(id.item()) | |
| detected_species = {} | |
| for id in ids: | |
| counter = Counter() | |
| for prediction in yolo_predictions: | |
| if has_values(prediction.boxes.id): | |
| for idd, klass in zip( | |
| prediction.boxes.id.numpy().astype("int"), | |
| prediction.boxes.cls.numpy().astype("int"), | |
| ): | |
| if idd.item() == id: | |
| counter[klass.item()] += 1 | |
| selected_class = counter.most_common(1)[0][0] | |
| detected_species[id] = selected_class | |
| return { | |
| "counts": len(ids), | |
| "ids": ids, | |
| "detected_species": detected_species, | |
| "names": names, | |
| } | |
| def prediction_to_str(yolo_predictions) -> str: | |
| """ | |
| Turn the yolo_predictions into a human friendly string. | |
| """ | |
| if len(yolo_predictions) == 0: | |
| return "No prediction" | |
| else: | |
| result = analyze_predictions(yolo_predictions=yolo_predictions) | |
| names = result["names"] | |
| detected_species = result["detected_species"] | |
| ids = result["ids"] | |
| summary_str = "\n".join( | |
| [ | |
| f"- The fish with id {id} is a {names.get(klass, 'Unknown')}" | |
| for id, klass in detected_species.items() | |
| ] | |
| ) | |
| print(summary_str) | |
| return f"Detected {len(ids)} salmons in the video clip with ids {ids}:\n{summary_str}" | |
| def interface_fn(model: YOLO, video_filepath: Path) -> Tuple[Path, str]: | |
| """ | |
| Main interface function that runs the model on the provided pil_image and | |
| returns the exepected tuple to populate the gradio interface. | |
| Args: | |
| model (YOLO): Loaded ultralytics YOLO model. | |
| pil_image (PIL): image to run inference on. | |
| Returns: | |
| pil_image_with_prediction (PIL): image with prediction from the model. | |
| raw_prediction_str (str): string representing the raw prediction from the | |
| model. | |
| """ | |
| project = "runs/track/" | |
| name = video_filepath.stem | |
| predictions = model.track( | |
| source=video_filepath, | |
| save=True, | |
| tracker="bytetrack.yaml", | |
| exist_ok=True, | |
| project=project, | |
| name=name, | |
| ) | |
| filepath_video_prediction = Path(f"{project}/{name}/{name}.avi") | |
| raw_prediction_str = prediction_to_str(yolo_predictions=predictions) | |
| return (filepath_video_prediction, raw_prediction_str) | |
| def examples(dir_examples: Path) -> list[Path]: | |
| """ | |
| List the images from the dir_examples directory. | |
| Returns: | |
| filepaths (list[Path]): list of image filepaths. | |
| """ | |
| return list(dir_examples.glob("*.mp4")) | |
| def load_model(filepath_weights: Path) -> YOLO: | |
| """ | |
| Load the YOLO model given the filepath_weights. | |
| """ | |
| return YOLO(filepath_weights) | |
| # Main Gradio interface | |
| MODEL_FILEPATH_WEIGHTS = Path("data/model/weights.pt") | |
| DIR_EXAMPLES = Path("data/videos/") | |
| DEFAULT_IMAGE_INDEX = 0 | |
| with gr.Blocks() as demo: | |
| model = load_model(MODEL_FILEPATH_WEIGHTS) | |
| videos_filepaths = examples(dir_examples=DIR_EXAMPLES) | |
| print(f"videos_filepaths: {videos_filepaths}") | |
| default_value_input = videos_filepaths[DEFAULT_IMAGE_INDEX] | |
| input = gr.Video( | |
| value=default_value_input, | |
| format="mp4", | |
| autoplay=True, | |
| loop=True, | |
| label="input video", | |
| sources=["upload"], | |
| ) | |
| output_video = gr.Video( | |
| format="mp4", | |
| label="model prediction", | |
| autoplay=True, | |
| loop=True, | |
| ) | |
| output_raw = gr.Text(label="raw prediction") | |
| fn = lambda video_filepath: interface_fn( | |
| model=model, video_filepath=Path(video_filepath) | |
| ) | |
| gr.Interface( | |
| title="ML model for wild salmon migration monitoring 🐟", | |
| fn=fn, | |
| inputs=input, | |
| outputs=[output_video, output_raw], | |
| examples=videos_filepaths, | |
| flagging_mode="never", | |
| ) | |
| demo.launch() | |