Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from torch import nn | |
| from torchvision import transforms | |
| import typing as tp | |
| from huggingface_hub import list_repo_files, hf_hub_download | |
| from ultralytics import YOLO | |
| import cv2 | |
| # --------------------------------- | |
| # 0. Get dataset file names | |
| # --------------------------------- | |
| repo_type = "dataset" | |
| repo_id = "eloise54/cots_yolo_dataset" | |
| files = list_repo_files(repo_id, repo_type=repo_type) | |
| def get_dataset_splits(files): | |
| train_images = [] | |
| val_images = [] | |
| test_images = [] | |
| train_labels = [] | |
| val_labels = [] | |
| test_labels = [] | |
| for x in files: | |
| if ".jpg" in x: | |
| l = x.replace("images/", "labels/") | |
| l = l.replace(".jpg", ".txt") | |
| if "train/" in x: | |
| train_images.append(x) | |
| train_labels.append(l) | |
| elif "val/" in x: | |
| val_images.append(x) | |
| val_labels.append(l) | |
| elif "test/" in x: | |
| test_images.append(x) | |
| test_labels.append(l) | |
| return train_images, val_images, test_images, train_labels, val_labels, test_labels | |
| train_images, val_images, test_images, train_labels, val_labels, test_labels = get_dataset_splits(files) | |
| # --------------------------------- | |
| # 1. Load model | |
| # --------------------------------- | |
| model = YOLO('runs/detect/yolov11m_1920p/weights/best.pt').to("cpu") | |
| model.eval() | |
| # --------------------------------- | |
| # 2. Define function to read labels and draw boxes | |
| # --------------------------------- | |
| def read_ground_truth(label_file_path, img_width, img_height): | |
| ground_truth_boxes = [] | |
| try: | |
| with open(label_file_path) as f: | |
| for line in f: | |
| cls, xc, yc, w, h = map(float, line.split()) | |
| print(cls, xc, yc, w, h) | |
| xc = xc * img_width | |
| yc = yc * img_height | |
| w = w * img_width | |
| h = h * img_height | |
| x0 = xc - 0.5 * w | |
| y0 = yc - 0.5 * h | |
| x1 = xc + 0.5 * w | |
| y1 = yc + 0.5 * h | |
| ground_truth_boxes.append({ | |
| "class_id": int(cls), | |
| "box": [x0, y0, x1, y1] | |
| }) | |
| except: | |
| pass#no label txt files means no COTS in image | |
| return ground_truth_boxes | |
| def draw_rectangle(img, box, color, thickness): | |
| start_point = (int(box[0]), int(box[1])) | |
| end_point = (int(box[2]), int(box[3])) | |
| overlay = img.copy() | |
| alpha = 0.5 | |
| overlay = cv2.rectangle(overlay, start_point, end_point, color, thickness) | |
| img = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) | |
| return img | |
| # --------------------------------- | |
| # 3. Prediction function | |
| # --------------------------------- | |
| def get_sample(index: int, dataset_choice: str): | |
| images = [] | |
| labels = [] | |
| if dataset_choice == "train": | |
| images = train_images | |
| labels = train_labels | |
| elif dataset_choice == "val": | |
| images = val_images | |
| labels = val_labels | |
| elif dataset_choice == "test": | |
| images = test_images | |
| labels = test_labels | |
| index = max(0, min(index, len(images) - 1)) # clamp index | |
| downloaded_path = hf_hub_download(repo_id=repo_id,repo_type=repo_type,filename=images[index],local_dir=".") | |
| try: | |
| downloaded_path = hf_hub_download(repo_id=repo_id,repo_type=repo_type,filename=labels[index],local_dir=".") | |
| except: | |
| pass #no label txt files means no COTS in image | |
| pred_color = (0, 0, 255) | |
| gt_color = (0, 255, 0) | |
| thickness = 15 | |
| img = cv2.imread(images[index]) | |
| with torch.no_grad(): | |
| results = model(images[index], imgsz=1920) | |
| gt = read_ground_truth(labels[index], img.shape[1], img.shape[0]) | |
| for res in results: | |
| boxes = res.boxes.xyxy | |
| for box in boxes: | |
| img = draw_rectangle(img, box, pred_color, thickness) | |
| for box_dict in gt: | |
| img = draw_rectangle(img, box_dict['box'], gt_color, thickness) | |
| img = img[...,::-1] # BGR to RGB | |
| return img, index, index, dataset_choice | |
| # --------------------------------- | |
| # 4. Navigation functions | |
| # --------------------------------- | |
| def next_sample(index: int, dataset_choice: str): | |
| return get_sample(index + 1, dataset_choice) | |
| def prev_sample(index: int, dataset_choice: str): | |
| return get_sample(index - 1, dataset_choice) | |
| # --------------------------------- | |
| # 5. UI elements | |
| # --------------------------------- | |
| dataset_information= """ | |
| ## Dataset overview | |
| [](https://huggingface.co/datasets/eloise54/cots_yolo_dataset) | |
| This dataset is a **modified version** of the [CSIRO COTS and COTS Scars Dataset](https://data.csiro.au/collection/csiro:64235), originally released under the [Creative Commons Attribution 4.0 License (CC BY 4.0)](https://creativecommons.org/licenses/by/4.0/). | |
| The original dataset contains images and annotations for **Crown-of-Thorns Starfish (COTS)** and **COTS scars**, collected to support coral reef monitoring and control efforts on the Great Barrier Reef (GBR). | |
| These starfish are coral predators, and their outbreaks can severely damage reef ecosystems. | |
| **PCSIRO COTS and COTS Scars Dataset reference:** | |
| ```bibtex | |
| @dataset{csiro_cots_2024, | |
| author = {Armin, Ali and Bainbridge, Scott and Page, Geoff and Tychsen-Smith, Lachlan and Coleman, Greg and Oorloff, Jeremy and Harvey, De'vereux and Do, Brendan and Marsh, Benjamin and Lawrence, Emma and Kusy, Brano and Hayder, Zeeshan and Bonin, Mary}, | |
| title = {COTS and COTS scar dataset}, | |
| year = {2024}, | |
| publisher = {CSIRO}, | |
| version = {v1}, | |
| doi = {10.25919/03a7-hn83}, | |
| url = {https://data.csiro.au/collection/csiro:64235} | |
| } | |
| ``` | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🪸 Crown of thorns starfish detection - protect the great barrier reef") | |
| gr.Markdown("Use **Next** or **Previous** to browse samples and see model predictions vs ground truth.") | |
| state = gr.State(0) # holds current index | |
| with gr.Row(): | |
| dropdown = gr.Dropdown( ["train", "val", "test"], label="Dataset split to use", value="train") | |
| dataset_choice = gr.Text(label="Using Dataset") | |
| with gr.Row(equal_height=True): | |
| index_input = gr.Number(label="Enter image number to display: ", value=0, precision=0) | |
| go_btn = gr.Button("Apply") | |
| with gr.Row(): | |
| image_output = gr.Image(label="Image") | |
| with gr.Row(): | |
| gr.Markdown("Green is ground truth, Red is model prediction") | |
| with gr.Row(): | |
| index = gr.Text(label="Current Image Number", interactive=False) | |
| with gr.Row(): | |
| prev_btn = gr.Button("⬅️ Prev image") | |
| next_btn = gr.Button("Next image➡️") | |
| with gr.Row(): | |
| gr.Markdown(dataset_information) | |
| # Connect navigation | |
| prev_btn.click(fn=prev_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice]) | |
| next_btn.click(fn=next_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice]) | |
| go_btn.click(fn=get_sample, inputs=[index_input, dropdown], outputs=[image_output, state, index, dataset_choice]) | |
| # Load initial image | |
| demo.load(fn=get_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice]) | |
| # --------------------------------- | |
| # 6. Run | |
| # --------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch(show_api=False) |