Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import requests | |
| from tensorflow.keras.models import load_model | |
| from PIL import Image, ImageDraw | |
| # DEBUG: Print start message | |
| print("Starting Seamount Detection App...") | |
| # Define model path and URL for dynamic downloading | |
| MODEL_PATH = "objdet_1_2.h5" | |
| MODEL_URL = "https://huggingface.co/YOUR_USERNAME/objdet_1_2/resolve/main/objdet_1_2.h5" | |
| # Download model if not present | |
| if not os.path.exists(MODEL_PATH): | |
| print("Model not found. Downloading from:", MODEL_URL) | |
| response = requests.get(MODEL_URL) | |
| with open(MODEL_PATH, 'wb') as f: | |
| f.write(response.content) | |
| print("Model downloaded successfully.") | |
| else: | |
| print("Model found locally.") | |
| # Load the model | |
| print("Loading model...") | |
| model = load_model(MODEL_PATH, custom_objects={'huber_loss': tf.keras.losses.Huber()}) | |
| print("Model loaded.") | |
| # Constants | |
| IMAGE_SIZE = (256, 256) | |
| # Preprocessing function | |
| def preprocess_image(image): | |
| print("Preprocessing image...") | |
| image = image.convert("RGB") | |
| image = np.array(image) / 255.0 | |
| image = cv2.resize(image, IMAGE_SIZE) | |
| return np.expand_dims(image, axis=0) | |
| # Prediction function | |
| def predict(image): | |
| print("Running prediction...") | |
| processed_image = preprocess_image(image) | |
| pred_class, pred_bbox = model.predict(processed_image) | |
| confidence = float(pred_class[0][0]) | |
| classification = "Seamount" if confidence > 0.5 else "No Seamount" | |
| width, height = image.size | |
| bbox = pred_bbox[0] | |
| x_min = int(bbox[0] * width) | |
| y_min = int(bbox[1] * height) | |
| x_max = int(bbox[2] * width) | |
| y_max = int(bbox[3] * height) | |
| print(f"Prediction: {classification} with confidence {confidence:.2f}") | |
| return ( | |
| classification, | |
| confidence, | |
| (x_min, y_min, x_max, y_max), | |
| draw_bounding_box(image, (x_min, y_min, x_max, y_max), classification) | |
| ) | |
| # Function to draw bounding box on the image | |
| def draw_bounding_box(image, bbox, label): | |
| print("Drawing bounding box...") | |
| x_min, y_min, x_max, y_max = bbox | |
| image = image.convert("RGB") | |
| draw = ImageDraw.Draw(image) | |
| draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) | |
| text_position = (x_min, y_min - 10) | |
| draw.text(text_position, label, fill="red") | |
| return image | |
| # List of example image file paths (ensure these images are uploaded to your Space) | |
| example_image_paths = [ | |
| "example1.png", | |
| "example2.png", | |
| "example3.png", | |
| "example4.png" | |
| ] | |
| # Build the Gradio interface using Blocks | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Seamount Detection") | |
| gr.Markdown("**Instructions:** Either upload your own image or select one of the example images below. Clicking an example will populate the input field.") | |
| with gr.Row(): | |
| # Single image input component | |
| uploaded_image = gr.Image(type="pil", label="Upload an Image") | |
| # Examples component displays the images as clickable thumbnails. | |
| gr.Examples( | |
| examples=[[path] for path in example_image_paths], | |
| inputs=uploaded_image, | |
| label="Or select an example image" | |
| ) | |
| submit_btn = gr.Button("Predict") | |
| with gr.Row(): | |
| output_class = gr.Text(label="Classification") | |
| output_confidence = gr.Text(label="Confidence Score") | |
| output_bbox = gr.Text(label="Bounding Box (x_min, y_min, x_max, y_max)") | |
| output_image = gr.Image(label="Image with Bounding Box") | |
| submit_btn.click( | |
| predict, | |
| inputs=uploaded_image, | |
| outputs=[output_class, output_confidence, output_bbox, output_image] | |
| ) | |
| # DEBUG: Launch the app | |
| print("Launching app...") | |
| if __name__ == "__main__": | |
| demo.launch() | |