| import os |
| import tempfile |
| import cv2 |
| from huggingface_hub import snapshot_download |
| from ultralytics import YOLO |
| from PIL import Image |
| import gradio as gr |
|
|
|
|
|
|
| |
| |
| |
| def load_model(repo_id): |
| download_dir = snapshot_download(repo_id) |
| print(download_dir) |
| path = os.path.join(download_dir, "ShrimpandSnail.pt") |
| print(path) |
| detection_model = YOLO(path, task='detect') |
| return detection_model |
|
|
|
|
| |
| |
| |
| def predict_image(pilimg): |
| source = pilimg |
| result = detection_model.predict(source, conf=0.5, iou=0.6) |
| img_bgr = result[0].plot() |
| out_pilimg = Image.fromarray(img_bgr[..., ::-1]) |
| return out_pilimg |
|
|
| |
| |
| |
| def predict_video(video_path): |
| cap = cv2.VideoCapture(video_path) |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
| temp_dir = tempfile.mkdtemp() |
| output_path = os.path.join(temp_dir, "output.mp4") |
|
|
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
|
|
| results = detection_model.predict(frame, conf=0.5, iou=0.6) |
| annotated_frame = results[0].plot() |
| out.write(annotated_frame) |
|
|
| cap.release() |
| out.release() |
|
|
| return output_path |
|
|
|
|
|
|
| REPO_ID = "cllee67/1274287N" |
| detection_model = load_model(REPO_ID) |
|
|
|
|
| |
| |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## Shrimp and Snail Detection – Image & Video Upload") |
| gr.Markdown("Upload an image or video to run object detection.") |
|
|
| with gr.Tab("Image"): |
| img_input = gr.Image(type="pil", label="Upload Image") |
| img_output = gr.Image(type="pil", label="Detected Image") |
| img_btn = gr.Button("Run Detection") |
| img_btn.click(fn=predict_image, inputs=img_input, outputs=img_output) |
|
|
| with gr.Tab("Video"): |
| vid_input = gr.Video(label="Upload Video") |
| vid_output = gr.Video(label="Detected Video") |
| vid_btn = gr.Button("Run Detection") |
| vid_btn.click(fn=predict_video, inputs=vid_input, outputs=vid_output) |
|
|
|
|
| demo.launch() |