Lui3ui3ui's picture
update app.py
7d67687
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
import os
import requests
from tensorflow.keras.models import load_model
from PIL import Image, ImageDraw
# DEBUG: Print start message
print("Starting Seamount Detection App...")
# Define model path and URL for dynamic downloading
MODEL_PATH = "objdet_1_2.h5"
MODEL_URL = "https://huggingface.co/YOUR_USERNAME/objdet_1_2/resolve/main/objdet_1_2.h5"
# Download model if not present
if not os.path.exists(MODEL_PATH):
print("Model not found. Downloading from:", MODEL_URL)
response = requests.get(MODEL_URL)
with open(MODEL_PATH, 'wb') as f:
f.write(response.content)
print("Model downloaded successfully.")
else:
print("Model found locally.")
# Load the model
print("Loading model...")
model = load_model(MODEL_PATH, custom_objects={'huber_loss': tf.keras.losses.Huber()})
print("Model loaded.")
# Constants
IMAGE_SIZE = (256, 256)
# Preprocessing function
def preprocess_image(image):
print("Preprocessing image...")
image = image.convert("RGB")
image = np.array(image) / 255.0
image = cv2.resize(image, IMAGE_SIZE)
return np.expand_dims(image, axis=0)
# Prediction function
def predict(image):
print("Running prediction...")
processed_image = preprocess_image(image)
pred_class, pred_bbox = model.predict(processed_image)
confidence = float(pred_class[0][0])
classification = "Seamount" if confidence > 0.5 else "No Seamount"
width, height = image.size
bbox = pred_bbox[0]
x_min = int(bbox[0] * width)
y_min = int(bbox[1] * height)
x_max = int(bbox[2] * width)
y_max = int(bbox[3] * height)
print(f"Prediction: {classification} with confidence {confidence:.2f}")
return (
classification,
confidence,
(x_min, y_min, x_max, y_max),
draw_bounding_box(image, (x_min, y_min, x_max, y_max), classification)
)
# Function to draw bounding box on the image
def draw_bounding_box(image, bbox, label):
print("Drawing bounding box...")
x_min, y_min, x_max, y_max = bbox
image = image.convert("RGB")
draw = ImageDraw.Draw(image)
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
text_position = (x_min, y_min - 10)
draw.text(text_position, label, fill="red")
return image
# List of example image file paths (ensure these images are uploaded to your Space)
example_image_paths = [
"example1.png",
"example2.png",
"example3.png",
"example4.png"
]
# Build the Gradio interface using Blocks
with gr.Blocks() as demo:
gr.Markdown("# Seamount Detection")
gr.Markdown("**Instructions:** Either upload your own image or select one of the example images below. Clicking an example will populate the input field.")
with gr.Row():
# Single image input component
uploaded_image = gr.Image(type="pil", label="Upload an Image")
# Examples component displays the images as clickable thumbnails.
gr.Examples(
examples=[[path] for path in example_image_paths],
inputs=uploaded_image,
label="Or select an example image"
)
submit_btn = gr.Button("Predict")
with gr.Row():
output_class = gr.Text(label="Classification")
output_confidence = gr.Text(label="Confidence Score")
output_bbox = gr.Text(label="Bounding Box (x_min, y_min, x_max, y_max)")
output_image = gr.Image(label="Image with Bounding Box")
submit_btn.click(
predict,
inputs=uploaded_image,
outputs=[output_class, output_confidence, output_bbox, output_image]
)
# DEBUG: Launch the app
print("Launching app...")
if __name__ == "__main__":
demo.launch()