Lui3ui3ui commited on
Commit
a855e5a
·
1 Parent(s): 8733c7e

add samples

Browse files
Files changed (7) hide show
  1. 587476.0.png +0 -0
  2. 737083.0.png +0 -0
  3. app.py +51 -17
  4. example1.png +0 -0
  5. example2.png +0 -0
  6. example3.png +0 -0
  7. example4.png +0 -0
587476.0.png ADDED
737083.0.png ADDED
app.py CHANGED
@@ -7,8 +7,20 @@ import requests
7
  from tensorflow.keras.models import load_model
8
  from PIL import Image, ImageDraw
9
 
10
- # Load the trained model
11
- model = load_model('objdet_1_2.h5', custom_objects={'huber_loss': tf.keras.losses.Huber()})
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Constants
14
  IMAGE_SIZE = (256, 256)
@@ -55,20 +67,42 @@ def draw_bounding_box(image, bbox, label):
55
 
56
  return image
57
 
 
 
 
 
 
 
 
 
58
  # Create Gradio interface
59
- iface = gr.Interface(
60
- fn=predict,
61
- inputs=gr.Image(type="pil"),
62
- outputs=[
63
- gr.Text(label="Classification"),
64
- gr.Text(label="Confidence Score"),
65
- gr.Text(label="Bounding Box (x_min, y_min, x_max, y_max)"),
66
- gr.Image(label="Image with Bounding Box")
67
- ],
68
- title="Seamount Detection",
69
- description="Upload an image to classify and detect seamounts. The bounding box is drawn on detected objects."
70
- )
71
-
72
- # Launch the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  if __name__ == "__main__":
74
- iface.launch()
 
7
  from tensorflow.keras.models import load_model
8
  from PIL import Image, ImageDraw
9
 
10
+ # Define model path and URL for dynamic downloading
11
+ MODEL_PATH = "objdet_1_2.h5"
12
+ MODEL_URL = "https://huggingface.co/YOUR_USERNAME/objdet_1_2/resolve/main/objdet_1_2.h5"
13
+
14
+ # Download model if not present
15
+ if not os.path.exists(MODEL_PATH):
16
+ print("Downloading model...")
17
+ response = requests.get(MODEL_URL)
18
+ with open(MODEL_PATH, 'wb') as f:
19
+ f.write(response.content)
20
+ print("Model downloaded successfully.")
21
+
22
+ # Load the model
23
+ model = load_model(MODEL_PATH, custom_objects={'huber_loss': tf.keras.losses.Huber()})
24
 
25
  # Constants
26
  IMAGE_SIZE = (256, 256)
 
67
 
68
  return image
69
 
70
+ # Example images (Replace with actual paths or URLs)
71
+ example_images = [
72
+ "example1.png",
73
+ "example2.png",
74
+ "example3.png",
75
+ "example4.png"
76
+ ]
77
+
78
  # Create Gradio interface
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("# Seamount Detection")
81
+ gr.Markdown("Upload an image or drag one of the example images below to classify and detect seamounts.")
82
+
83
+ with gr.Row():
84
+ input_image = gr.Image(type="pil", label="Upload or Drag an Image")
85
+ output_image = gr.Image(label="Image with Bounding Box")
86
+
87
+ with gr.Row():
88
+ output_class = gr.Text(label="Classification")
89
+ output_confidence = gr.Text(label="Confidence Score")
90
+ output_bbox = gr.Text(label="Bounding Box (x_min, y_min, x_max, y_max)")
91
+
92
+ submit_btn = gr.Button("Predict")
93
+
94
+ # Example image section
95
+ gr.Markdown("### Example Images (Drag one into the input box)")
96
+ with gr.Row():
97
+ examples = [gr.Image(value=img, type="pil", interactive=True, label=f"Example {i+1}") for i, img in enumerate(example_images)]
98
+
99
+ # Connect the prediction function
100
+ submit_btn.click(predict, inputs=input_image, outputs=[output_class, output_confidence, output_bbox, output_image])
101
+
102
+ # Enable dragging example images to input
103
+ for example in examples:
104
+ example.change(fn=lambda img: img, inputs=example, outputs=input_image)
105
+
106
+ # Launch the app
107
  if __name__ == "__main__":
108
+ demo.launch()
example1.png ADDED
example2.png ADDED
example3.png ADDED
example4.png ADDED