Spaces:
Running
Running
| """ | |
| Gradio web application for detecting and measuring objects in images. | |
| Key features: | |
| - Image scaling tool to set measurement reference | |
| - Object detection using YOLOv8 model for scallop/spat detection | |
| - Interactive annotation of detected objects | |
| - Size measurements in mm based on scale reference | |
| - Statistics and histogram visualization of object sizes | |
| - Export results to CSV | |
| """ | |
| # %% #|> Imports | | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| from gradio_image_annotation import image_annotator | |
| import numpy as np | |
| import pandas as pd | |
| import supervision as sv | |
| import logging | |
| import os | |
| import plotly.express as px | |
| from spatstatapp.inference import inference_large | |
| from spatstatapp.plotting import coco_to_detections | |
| from spatstatapp.tile_training_data import load_bboxes | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import cv2 | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # %% load image and detections | | |
| def get_asset_path(relative_path): | |
| if os.getenv('SPACE_ID'): | |
| # Running on HF Spaces | |
| return Path.cwd() / relative_path | |
| # Running locally | |
| return Path(relative_path) | |
| model_path = get_asset_path(Path("models/best.onnx")) | |
| data_dir=get_asset_path(Path("img")) | |
| data_dir.exists() | |
| example_imgs = list(data_dir.glob("*.jpg")) | |
| class PointSelector: | |
| def __init__(self, image=None): | |
| self.points = [] | |
| # self.og_img = image | |
| self.image_path = image | |
| self.line_len_px = None | |
| self.line_len_mm = None | |
| # def reset(self): | |
| # self.points = [] | |
| # return self.og_img, "Points cleared" | |
| def clear_og_img(self, _x_): | |
| # self.og_img = image.copy() | |
| # raise Exception(image) | |
| # self.og_img = None | |
| self.points = [] | |
| self.image_path = None | |
| def reset_og_img(self, image): | |
| # self.og_img = image.copy() | |
| # raise Exception(image) | |
| # self.og_img = None | |
| self.points = [] | |
| self.image_path = image | |
| def add_point(self, image, evt: gr.SelectData): | |
| img_draw = cv2.imread(image)#[:,:,::-1] | |
| if self.image_path is None: | |
| self.image_path = image | |
| if (len(self.points) == 0):# & (self.og_img is None): | |
| # self.image_path = image | |
| img_draw = cv2.imread(self.image_path) | |
| # self.og_img = cv2.imread(image) | |
| # img_draw = self.og_img.copy() | |
| if len(self.points) >= 2: | |
| self.points = [] | |
| img_draw = cv2.imread(self.image_path) | |
| # img_draw = self.og_img.copy() | |
| self.points.append((evt.index[0], evt.index[1])) | |
| # Draw on image | |
| # img_draw = image.copy() | |
| if len(self.points) > 0: | |
| for pt in self.points: | |
| cv2.circle(img_draw, (int(pt[0]), int(pt[1])), 5, (255,0,0), -1) | |
| if len(self.points) == 2: | |
| cv2.line(img_draw, | |
| (int(self.points[0][0]), int(self.points[0][1])), | |
| (int(self.points[1][0]), int(self.points[1][1])), | |
| (0,255,0), 3) | |
| # Calculate distance | |
| dist = np.sqrt((self.points[1][0] - self.points[0][0])**2 + | |
| (self.points[1][1] - self.points[0][1])**2) | |
| msg = f"Distance: {dist:.1f} pixels" | |
| self.line_len_px = dist | |
| # self.image_path = None | |
| self.points = [] | |
| else: | |
| msg = f"Click point {len(self.points)+1}" | |
| return img_draw[:,:,::-1], msg | |
| def set_line_length(self, line_len_mm, button): | |
| self.line_len_mm = line_len_mm | |
| return self.check_scale_set(button) | |
| def check_scale_set(self, button): | |
| if (self.line_len_mm is not None) & (self.line_len_px is not None): | |
| # if True: | |
| return gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False) | |
| def save_scaled_boxes(self, annotator): | |
| try: | |
| json_data = annotator["boxes"] | |
| if len(json_data)==0: | |
| return None | |
| else: | |
| df = pd.DataFrame(json_data).drop(columns=["color"], errors='ignore') | |
| df["xrange"] = ((df["xmax"] - df["xmin"])*(self.line_len_mm/self.line_len_px)).round(2) | |
| df["yrange"] = ((df["ymax"] - df["ymin"])*(self.line_len_mm/self.line_len_px)).round(2) | |
| df["mean_daimeter_mm"] = ((df["yrange"]+df["xrange"])/2).round(2) | |
| return df | |
| except Exception as e: | |
| return None | |
| def detections_to_json(detections:sv.Detections, image:np.ndarray): | |
| """Add predictions to canvas""" | |
| boxes = [] | |
| for xyxy, mask, confidence, class_id, tracker_id, data in detections: | |
| xmin, ymin, xmax, ymax = xyxy | |
| obj = { | |
| "xmin": float(xmin), | |
| "ymin": float(ymin), | |
| "xmax": float(xmax), | |
| "ymax": float(ymax), | |
| "label": "",# data["class_name"], | |
| "color": (255, 0, 0) | |
| } | |
| boxes.append(obj) | |
| annotation = { | |
| "image": image, | |
| "boxes": boxes | |
| } | |
| return annotation | |
| def create_histogram(df): | |
| # print(type(df)) | |
| # print(len(df)) | |
| print() | |
| if df is None or len(df) == 0 or df.iloc[0,1]=="": | |
| return None | |
| fig = px.histogram(df, x="mean_daimeter_mm", | |
| title="Distribution of Shell Sizes", | |
| labels={"mean_daimeter_mm": "Mean Diameter (mm)"}, | |
| nbins=30) | |
| return fig | |
| # def get_boxes_table(annotator): | |
| # json_data = annotator["boxes"] | |
| # if len(json_data)==0: | |
| # return pd.DataFrame() | |
| # else: | |
| # df = pd.DataFrame(json_data).drop(columns=["color"], errors='ignore') | |
| # return df | |
| from ultralytics.utils.ops import xywhn2xyxy | |
| def find_boxes_json(image_path): | |
| # print(annotator) | |
| img = cv2.imread(image_path) | |
| detections = inference_large(img, model_path, sam_path=None, edge_pct=0.01, conf_threshold=0.4, overlap_px=200, tile_px=640) | |
| annotations = detections_to_json(detections, image_path) | |
| annotations["image"] = image_path | |
| # annotator.update(annotations) | |
| annotator = image_annotator( | |
| annotations, | |
| boxes_alpha=0.02, | |
| handle_size=4, | |
| show_label=False, | |
| ) | |
| return annotator, annotations["boxes"] | |
| def load_coco_boxes(image_path, coco_file, class_labels="scallop"): | |
| image = cv2.imread(image_path)[:,:,::-1] | |
| detections = coco_to_detections(coco_file, image) | |
| annotations = detections_to_json(detections, image) | |
| annotations["image"] = image_path | |
| # annotator.update(annotations) | |
| annotator = image_annotator( | |
| annotations, | |
| boxes_alpha=0.02, | |
| handle_size=4, | |
| show_label=False, | |
| ) | |
| return annotator, annotations["boxes"] | |
| selector = PointSelector() | |
| with gr.Blocks( | |
| theme=gr.themes.Default() | |
| ) as demo: | |
| with gr.Tabs() as tabs: | |
| # %% #|> Tab0 | | |
| with gr.TabItem("0. readme", id=3, visible=True): | |
| gr.Markdown( | |
| """ | |
| # Spatstatapp | |
| Spatstatapp is a tool for detecting and measuring objects in images. | |
| ## Features | |
| - Image scaling tool to set measurement reference | |
| - Object detection using YOLOv11 model for scallop/spat detection | |
| - Interactive annotation of detected objects | |
| - Size measurements in mm based on scale reference | |
| - Statistics and histogram visualization of object sizes | |
| - Export results to CSV | |
| ## Note | |
| - This is a demo application, accuracy and speed are not optimized. | |
| - The model is trained on a 5 images only | |
| - accuracy will be low when input data very different from training data. | |
| - accuracy can be improved with additional data. | |
| - Speed is slow due to the large image size and free hosting. | |
| ## Usage | |
| 1. Upload an image | |
| 2. Click two points to create a scale bar. | |
| 3. Set the scale bar length in mm. | |
| 4. Click "find bounding boxes" to detect objects in the image. | |
| 5. Adjust the detected objects by dragging the handles, deleting or create new boxes. | |
| 6. Export boxes and statistics to CSV. | |
| """ | |
| ) | |
| # %% Tab 1 | | |
| with gr.TabItem("1. Image Scale", id=0): | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| image_input = gr.Image(label="Click two points to measure distance", | |
| type="filepath", | |
| interactive=True | |
| ) | |
| # default_image = gr.Dropdown( | |
| # choices=["None"] + list(default_images.keys()), | |
| # label="Use default image?", | |
| # ) | |
| exmples = gr.Examples(example_imgs, image_input, | |
| cache_examples=False, | |
| run_on_click= True, | |
| fn=selector.clear_og_img, | |
| ) | |
| with gr.Column(scale=1, min_width=200): | |
| filename = gr.Textbox(label="Filename") | |
| output_text = gr.Textbox(label="Status", value="Click two points to measure distance") | |
| line_length_mm = gr.Number(label="line length in mm") | |
| # target_select = gr.Radio(label ="select target:", visible=True, choices=["scallop", "spat"]) | |
| button_find = gr.Button("find bounding boxes", visible=False) | |
| load_annot_btn = gr.Button("Load Existing boxes (Optional)", visible=False) | |
| load_annot = gr.File(label="Load Existing boxes (Optional)",file_types=[".txt"], file_count="single", visible=False, height=500) | |
| # test_text = gr.Textbox(label="test") | |
| # %% #|> T1: event handlers | | |
| # image_input.upload() | |
| # default_image.change( | |
| # lambda x: default_images[x] if x in default_images.keys() else None, | |
| # inputs=[default_image], | |
| # outputs=[image_input] | |
| # ) | |
| image_input.upload( | |
| selector.reset_og_img, | |
| inputs=[image_input], | |
| ) | |
| image_input.upload( | |
| lambda x: Path(x).name, | |
| inputs=[image_input], | |
| outputs=[filename] | |
| ) | |
| # Event handlers | |
| image_input.select( | |
| selector.add_point, | |
| inputs=[image_input], | |
| outputs=[image_input, output_text] | |
| ) | |
| line_length_mm.change( | |
| selector.set_line_length, | |
| inputs=[line_length_mm, button_find], | |
| outputs=[button_find] | |
| ) | |
| line_length_mm.change( | |
| selector.check_scale_set, | |
| inputs = load_annot_btn, | |
| outputs=load_annot_btn, | |
| ) | |
| load_annot_btn.click( | |
| lambda: gr.update(visible=True), | |
| outputs=[load_annot] | |
| ) | |
| # load_annot.upload( | |
| # load_bboxes, | |
| # inputs=[load_annot], | |
| # outputs=[test_text] | |
| # ) | |
| # %% #|> Tab2 | | |
| with gr.TabItem("2. Object annotation", id=1, visible=True): | |
| annotator = image_annotator( | |
| boxes_alpha=0.02, | |
| handle_size=4, | |
| show_label=False, | |
| label_list=["scallop", "spat"], | |
| label_colors=[(255, 0, 0), (255, 200, 0)] | |
| ) | |
| # button_get = gr.Button("Get bounding boxes") | |
| download_file = gr.File( | |
| label="Download CSV", | |
| visible=True, | |
| # interactive=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| obj_count = gr.Textbox(label="Object count") | |
| # button_save = gr.Button("save bounding boxes") | |
| with gr.Column(scale=1): | |
| obj_size = gr.Textbox(value = "Has the scale size been set?" ,label="Mean size") | |
| histogram = gr.Plot() | |
| table = gr.DataFrame( | |
| max_height=500, | |
| ) | |
| # table = gr.Textbox(label="Status", value=1) | |
| json_data = gr.JSON(value={}, visible=False) | |
| # %% #|> T2: event handlers | | |
| json_boxes = button_find.click( | |
| fn=find_boxes_json, | |
| inputs=[image_input], | |
| outputs=[annotator, json_data] | |
| ) | |
| button_find.click( | |
| fn=lambda: gr.Tabs(selected=1), | |
| outputs=tabs | |
| ) | |
| json_boxes = load_annot.upload( | |
| fn=load_coco_boxes, | |
| inputs=[image_input, load_annot], | |
| outputs=[annotator, json_data] | |
| ) | |
| # button_find.click( | |
| # fn=change_tab, | |
| # # inputs=[annotator, image_input], | |
| # outputs=tabs | |
| # ) | |
| # annotator.change( | |
| # json_boxes = button_get.click( | |
| json_boxes = annotator.change( | |
| fn=selector.save_scaled_boxes, | |
| inputs= [annotator], | |
| outputs= table | |
| ) | |
| table.change( | |
| fn=create_histogram, | |
| inputs=[table], | |
| outputs=[histogram] | |
| ) | |
| def df_mean_count(df): | |
| try: | |
| mean = df["mean_daimeter_mm"].mean().round(2) | |
| count = len(df) | |
| return mean, count | |
| except Exception as e: | |
| return "Has the scale size been set?", None | |
| table.change( | |
| fn=df_mean_count, | |
| inputs=[table], | |
| outputs=[obj_size, obj_count] | |
| ) | |
| def save_and_download_table(df, img_name): | |
| try: | |
| # Create temporary file with .csv extension | |
| # with NamedTemporaryFile(delete=False, suffix='.csv') as tmp_file: | |
| # csv_path = tmp_file.name | |
| csv_path = Path(img_name).stem +"_boxes.csv" | |
| df.to_csv(csv_path, index=False) | |
| return csv_path | |
| except Exception as e: | |
| return None | |
| table.change( | |
| fn=save_and_download_table, | |
| inputs=[table, filename], | |
| outputs=[download_file] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| # server_name="0.0.0.0", | |
| # server_port=7860, | |
| show_error=True, | |
| debug=True, | |
| ) |