eloise54's picture
update documentation
0541bcc
import gradio as gr
import numpy as np
import torch
import torchvision
from torch import nn
from torchvision import transforms
import typing as tp
from huggingface_hub import list_repo_files, hf_hub_download
from ultralytics import YOLO
import cv2
# ---------------------------------
# 0. Get dataset file names
# ---------------------------------
repo_type = "dataset"
repo_id = "eloise54/cots_yolo_dataset"
files = list_repo_files(repo_id, repo_type=repo_type)
def get_dataset_splits(files):
train_images = []
val_images = []
test_images = []
train_labels = []
val_labels = []
test_labels = []
for x in files:
if ".jpg" in x:
l = x.replace("images/", "labels/")
l = l.replace(".jpg", ".txt")
if "train/" in x:
train_images.append(x)
train_labels.append(l)
elif "val/" in x:
val_images.append(x)
val_labels.append(l)
elif "test/" in x:
test_images.append(x)
test_labels.append(l)
return train_images, val_images, test_images, train_labels, val_labels, test_labels
train_images, val_images, test_images, train_labels, val_labels, test_labels = get_dataset_splits(files)
# ---------------------------------
# 1. Load model
# ---------------------------------
model = YOLO('runs/detect/yolov11m_1920p/weights/best.pt').to("cpu")
model.eval()
# ---------------------------------
# 2. Define function to read labels and draw boxes
# ---------------------------------
def read_ground_truth(label_file_path, img_width, img_height):
ground_truth_boxes = []
try:
with open(label_file_path) as f:
for line in f:
cls, xc, yc, w, h = map(float, line.split())
print(cls, xc, yc, w, h)
xc = xc * img_width
yc = yc * img_height
w = w * img_width
h = h * img_height
x0 = xc - 0.5 * w
y0 = yc - 0.5 * h
x1 = xc + 0.5 * w
y1 = yc + 0.5 * h
ground_truth_boxes.append({
"class_id": int(cls),
"box": [x0, y0, x1, y1]
})
except:
pass#no label txt files means no COTS in image
return ground_truth_boxes
def draw_rectangle(img, box, color, thickness):
start_point = (int(box[0]), int(box[1]))
end_point = (int(box[2]), int(box[3]))
overlay = img.copy()
alpha = 0.5
overlay = cv2.rectangle(overlay, start_point, end_point, color, thickness)
img = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0)
return img
# ---------------------------------
# 3. Prediction function
# ---------------------------------
def get_sample(index: int, dataset_choice: str):
images = []
labels = []
if dataset_choice == "train":
images = train_images
labels = train_labels
elif dataset_choice == "val":
images = val_images
labels = val_labels
elif dataset_choice == "test":
images = test_images
labels = test_labels
index = max(0, min(index, len(images) - 1)) # clamp index
downloaded_path = hf_hub_download(repo_id=repo_id,repo_type=repo_type,filename=images[index],local_dir=".")
try:
downloaded_path = hf_hub_download(repo_id=repo_id,repo_type=repo_type,filename=labels[index],local_dir=".")
except:
pass #no label txt files means no COTS in image
pred_color = (0, 0, 255)
gt_color = (0, 255, 0)
thickness = 15
img = cv2.imread(images[index])
with torch.no_grad():
results = model(images[index], imgsz=1920)
gt = read_ground_truth(labels[index], img.shape[1], img.shape[0])
for res in results:
boxes = res.boxes.xyxy
for box in boxes:
img = draw_rectangle(img, box, pred_color, thickness)
for box_dict in gt:
img = draw_rectangle(img, box_dict['box'], gt_color, thickness)
img = img[...,::-1] # BGR to RGB
return img, index, index, dataset_choice
# ---------------------------------
# 4. Navigation functions
# ---------------------------------
def next_sample(index: int, dataset_choice: str):
return get_sample(index + 1, dataset_choice)
def prev_sample(index: int, dataset_choice: str):
return get_sample(index - 1, dataset_choice)
# ---------------------------------
# 5. UI elements
# ---------------------------------
dataset_information= """
## Dataset overview
[![Hugging Face Dataset](https://img.shields.io/badge/huggingface-dataset-blue?logo=huggingface)](https://huggingface.co/datasets/eloise54/cots_yolo_dataset)
This dataset is a **modified version** of the [CSIRO COTS and COTS Scars Dataset](https://data.csiro.au/collection/csiro:64235), originally released under the [Creative Commons Attribution 4.0 License (CC BY 4.0)](https://creativecommons.org/licenses/by/4.0/).
The original dataset contains images and annotations for **Crown-of-Thorns Starfish (COTS)** and **COTS scars**, collected to support coral reef monitoring and control efforts on the Great Barrier Reef (GBR).
These starfish are coral predators, and their outbreaks can severely damage reef ecosystems.
**PCSIRO COTS and COTS Scars Dataset reference:**
```bibtex
@dataset{csiro_cots_2024,
author = {Armin, Ali and Bainbridge, Scott and Page, Geoff and Tychsen-Smith, Lachlan and Coleman, Greg and Oorloff, Jeremy and Harvey, De'vereux and Do, Brendan and Marsh, Benjamin and Lawrence, Emma and Kusy, Brano and Hayder, Zeeshan and Bonin, Mary},
title = {COTS and COTS scar dataset},
year = {2024},
publisher = {CSIRO},
version = {v1},
doi = {10.25919/03a7-hn83},
url = {https://data.csiro.au/collection/csiro:64235}
}
```
"""
with gr.Blocks() as demo:
gr.Markdown("## 🪸 Crown of thorns starfish detection - protect the great barrier reef")
gr.Markdown("Use **Next** or **Previous** to browse samples and see model predictions vs ground truth.")
state = gr.State(0) # holds current index
with gr.Row():
dropdown = gr.Dropdown( ["train", "val", "test"], label="Dataset split to use", value="train")
dataset_choice = gr.Text(label="Using Dataset")
with gr.Row(equal_height=True):
index_input = gr.Number(label="Enter image number to display: ", value=0, precision=0)
go_btn = gr.Button("Apply")
with gr.Row():
image_output = gr.Image(label="Image")
with gr.Row():
gr.Markdown("Green is ground truth, Red is model prediction")
with gr.Row():
index = gr.Text(label="Current Image Number", interactive=False)
with gr.Row():
prev_btn = gr.Button("⬅️ Prev image")
next_btn = gr.Button("Next image➡️")
with gr.Row():
gr.Markdown(dataset_information)
# Connect navigation
prev_btn.click(fn=prev_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice])
next_btn.click(fn=next_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice])
go_btn.click(fn=get_sample, inputs=[index_input, dropdown], outputs=[image_output, state, index, dataset_choice])
# Load initial image
demo.load(fn=get_sample, inputs=[state, dropdown], outputs=[image_output, state, index, dataset_choice])
# ---------------------------------
# 6. Run
# ---------------------------------
if __name__ == "__main__":
demo.launch(show_api=False)