Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,270 +1,154 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import torch
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
-
import
|
| 6 |
-
from
|
|
|
|
| 7 |
from segment_anything import sam_model_registry, SamPredictor
|
| 8 |
import supervision as sv
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if grounding_dino_model is None:
|
| 23 |
-
print("π¦ Loading Grounding DINO model...")
|
| 24 |
-
grounding_dino_model = GroundingDINOModel(
|
| 25 |
-
model_config_path="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
| 26 |
-
model_checkpoint_path="weights/groundingdino_swint_ogc.pth",
|
| 27 |
-
device=device
|
| 28 |
-
)
|
| 29 |
-
print("β
Grounding DINO loaded!")
|
| 30 |
-
|
| 31 |
-
if sam_predictor is None:
|
| 32 |
-
print("π¦ Loading SAM model...")
|
| 33 |
-
sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
|
| 34 |
-
model_type = "vit_h"
|
| 35 |
-
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
| 36 |
-
sam.to(device=device)
|
| 37 |
-
sam_predictor = SamPredictor(sam)
|
| 38 |
-
print("β
SAM loaded!")
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"""
|
| 42 |
-
Detect
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
image_pil: PIL Image
|
| 46 |
-
text_prompt: Text prompt for detection (default: "fish")
|
| 47 |
-
box_threshold: Confidence threshold for boxes
|
| 48 |
-
text_threshold: Confidence threshold for text matching
|
| 49 |
-
|
| 50 |
-
Returns:
|
| 51 |
-
mask: Binary mask of detected fish
|
| 52 |
-
metadata: Detection metadata
|
| 53 |
"""
|
| 54 |
-
load_models()
|
| 55 |
-
|
| 56 |
-
# Convert PIL to numpy
|
| 57 |
-
image_np = np.array(image_pil)
|
| 58 |
-
|
| 59 |
-
# 1. Grounding DINO: Detect fish boxes
|
| 60 |
-
print(f"π Detecting '{text_prompt}' with Grounding DINO...")
|
| 61 |
-
detections = grounding_dino_model.predict_with_classes(
|
| 62 |
-
image=image_np,
|
| 63 |
-
classes=[text_prompt],
|
| 64 |
-
box_threshold=box_threshold,
|
| 65 |
-
text_threshold=text_threshold
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
print(f"π¦ Found {len(detections.xyxy)} boxes")
|
| 69 |
-
|
| 70 |
-
if len(detections.xyxy) == 0:
|
| 71 |
-
print("β No fish detected!")
|
| 72 |
-
return None, {
|
| 73 |
-
"success": False,
|
| 74 |
-
"mode": "grounded_sam",
|
| 75 |
-
"detection_method": "grounding_dino",
|
| 76 |
-
"fish_detected": False,
|
| 77 |
-
"reason": "No fish found in image"
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
# Select best detection (highest confidence)
|
| 81 |
-
best_idx = np.argmax(detections.confidence)
|
| 82 |
-
best_box = detections.xyxy[best_idx]
|
| 83 |
-
best_conf = float(detections.confidence[best_idx])
|
| 84 |
-
|
| 85 |
-
print(f"π― Best detection: Confidence={best_conf:.2f}, Box={best_box}")
|
| 86 |
-
|
| 87 |
-
# 2. SAM: Segment the detected fish
|
| 88 |
-
print("βοΈ Segmenting with SAM...")
|
| 89 |
-
sam_predictor.set_image(image_np)
|
| 90 |
-
|
| 91 |
-
# Convert box to SAM format
|
| 92 |
-
box_np = best_box.reshape(1, 4)
|
| 93 |
-
|
| 94 |
-
masks, scores, _ = sam_predictor.predict(
|
| 95 |
-
box=box_np,
|
| 96 |
-
multimask_output=False
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
mask = masks[0] # Get best mask
|
| 100 |
-
|
| 101 |
-
# Calculate statistics
|
| 102 |
-
mask_area = int(np.sum(mask))
|
| 103 |
-
total_pixels = mask.shape[0] * mask.shape[1]
|
| 104 |
-
mask_percentage = (mask_area / total_pixels) * 100
|
| 105 |
-
|
| 106 |
-
# Get contours
|
| 107 |
-
contours, _ = cv2.findContours(
|
| 108 |
-
mask.astype(np.uint8),
|
| 109 |
-
cv2.RETR_EXTERNAL,
|
| 110 |
-
cv2.CHAIN_APPROX_SIMPLE
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
# Get fish center
|
| 114 |
-
if len(contours) > 0:
|
| 115 |
-
largest_contour = max(contours, key=cv2.contourArea)
|
| 116 |
-
M = cv2.moments(largest_contour)
|
| 117 |
-
if M["m00"] != 0:
|
| 118 |
-
cx = int(M["m10"] / M["m00"])
|
| 119 |
-
cy = int(M["m01"] / M["m00"])
|
| 120 |
-
else:
|
| 121 |
-
cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
|
| 122 |
-
else:
|
| 123 |
-
cx, cy = int(best_box[0] + best_box[2]) // 2, int(best_box[1] + best_box[3]) // 2
|
| 124 |
-
|
| 125 |
-
# Convert contours to list format
|
| 126 |
-
contour_points = []
|
| 127 |
-
if len(contours) > 0:
|
| 128 |
-
for point in contours[0][:100]: # Limit to 100 points
|
| 129 |
-
contour_points.append({
|
| 130 |
-
"x": int(point[0][0]),
|
| 131 |
-
"y": int(point[0][1])
|
| 132 |
-
})
|
| 133 |
-
|
| 134 |
-
metadata = {
|
| 135 |
-
"success": True,
|
| 136 |
-
"mode": "grounded_sam",
|
| 137 |
-
"detection_method": "grounding_dino_sam",
|
| 138 |
-
"fish_detected": True,
|
| 139 |
-
"grounding_dino": {
|
| 140 |
-
"confidence": best_conf,
|
| 141 |
-
"bounding_box": [int(x) for x in best_box],
|
| 142 |
-
"text_prompt": text_prompt,
|
| 143 |
-
"total_detections": len(detections.xyxy)
|
| 144 |
-
},
|
| 145 |
-
"mask_area": mask_area,
|
| 146 |
-
"mask_percentage": mask_percentage,
|
| 147 |
-
"num_contours": len(contours),
|
| 148 |
-
"fish_center": [cx, cy],
|
| 149 |
-
"image_size": list(mask.shape),
|
| 150 |
-
"device": device,
|
| 151 |
-
"contours": contour_points
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
print(f"β
Segmentation complete! Mask: {mask_percentage:.2f}%")
|
| 155 |
-
|
| 156 |
-
return mask, metadata
|
| 157 |
-
|
| 158 |
-
def process_image(image, quality="high"):
|
| 159 |
-
"""Main processing function for Gradio interface"""
|
| 160 |
-
|
| 161 |
-
if image is None:
|
| 162 |
-
return None, "β No image provided"
|
| 163 |
-
|
| 164 |
try:
|
| 165 |
-
#
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
#
|
| 172 |
-
|
| 173 |
-
image_pil.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Create visualization
|
| 182 |
-
|
| 183 |
|
| 184 |
-
#
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
| 188 |
|
| 189 |
-
# Draw bounding
|
| 190 |
-
box
|
| 191 |
-
|
|
|
|
| 192 |
|
| 193 |
-
#
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
Detections: {metadata['grounding_dino']['total_detections']}
|
| 205 |
-
|
| 206 |
-
βοΈ SAM Segmentation
|
| 207 |
-
Mask Area: {metadata['mask_percentage']:.2f}%
|
| 208 |
-
Fish Center: {metadata['fish_center']}
|
| 209 |
-
Contours: {metadata['num_contours']}
|
| 210 |
-
|
| 211 |
-
βοΈ System
|
| 212 |
-
Device: {metadata['device']}
|
| 213 |
-
Image Size: {metadata['image_size']}
|
| 214 |
-
"""
|
| 215 |
|
| 216 |
-
return
|
| 217 |
|
| 218 |
except Exception as e:
|
| 219 |
-
|
| 220 |
-
import traceback
|
| 221 |
-
traceback.print_exc()
|
| 222 |
-
return None, f"β Error: {str(e)}"
|
| 223 |
|
| 224 |
-
# Gradio
|
| 225 |
-
with gr.Blocks(title="
|
| 226 |
-
gr.Markdown(""
|
| 227 |
-
|
| 228 |
-
### Powered by Grounding DINO + SAM
|
| 229 |
-
|
| 230 |
-
Upload an image with a fish and watch the AI detect and segment it!
|
| 231 |
-
|
| 232 |
-
β οΈ **CPU Mode**: First run downloads ~680MB models (2-3 min). Processing: ~30-60 sec per image.
|
| 233 |
-
""")
|
| 234 |
|
| 235 |
with gr.Row():
|
| 236 |
with gr.Column():
|
| 237 |
-
input_image = gr.Image(type="pil", label="
|
|
|
|
| 238 |
quality = gr.Radio(
|
| 239 |
-
choices=["
|
| 240 |
-
value="
|
| 241 |
-
label="
|
| 242 |
-
info="High = 1024px, Medium = 768px (faster)"
|
| 243 |
)
|
| 244 |
-
|
| 245 |
|
| 246 |
with gr.Column():
|
| 247 |
-
output_image = gr.Image(label="
|
| 248 |
-
|
| 249 |
|
| 250 |
-
|
| 251 |
-
fn=
|
| 252 |
-
inputs=[input_image, quality],
|
| 253 |
-
outputs=[output_image,
|
| 254 |
)
|
| 255 |
|
| 256 |
-
gr.
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
### π Model Info
|
| 264 |
-
- Grounding DINO: Text-prompted object detection
|
| 265 |
-
- SAM (ViT-H): High-quality segmentation
|
| 266 |
-
- Total Model Size: ~680MB
|
| 267 |
-
""")
|
| 268 |
|
| 269 |
if __name__ == "__main__":
|
| 270 |
-
demo.launch(
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import pipeline
|
| 6 |
+
from groundingdino.util.inference import load_model, load_image, predict
|
| 7 |
from segment_anything import sam_model_registry, SamPredictor
|
| 8 |
import supervision as sv
|
| 9 |
+
import cv2
|
| 10 |
+
import os
|
| 11 |
|
| 12 |
+
# Download models on startup
|
| 13 |
+
print("Loading models...")
|
| 14 |
|
| 15 |
+
# Load Grounding DINO model from Hugging Face
|
| 16 |
+
# Using a different approach that doesn't require local config files
|
| 17 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
| 18 |
|
| 19 |
+
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
| 20 |
+
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
| 21 |
|
| 22 |
+
# Load SAM model
|
| 23 |
+
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
| 24 |
+
model_type = "vit_h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# Download SAM weights if not present
|
| 27 |
+
if not os.path.exists(sam_checkpoint):
|
| 28 |
+
os.system(f"wget https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}")
|
| 29 |
+
|
| 30 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
| 31 |
+
sam_predictor = SamPredictor(sam)
|
| 32 |
+
|
| 33 |
+
print("Models loaded successfully!")
|
| 34 |
+
|
| 35 |
+
def detect_and_segment(image, text_prompt="fish", quality="Medium (512px)"):
|
| 36 |
"""
|
| 37 |
+
Detect objects using Grounding DINO and segment using SAM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
try:
|
| 40 |
+
# Resize image based on quality setting
|
| 41 |
+
quality_map = {
|
| 42 |
+
"Low (256px)": 256,
|
| 43 |
+
"Medium (512px)": 512,
|
| 44 |
+
"High (1024px)": 1024
|
| 45 |
+
}
|
| 46 |
+
target_size = quality_map.get(quality, 512)
|
| 47 |
+
|
| 48 |
+
# Convert PIL to numpy
|
| 49 |
+
image_np = np.array(image)
|
| 50 |
+
h, w = image_np.shape[:2]
|
| 51 |
+
|
| 52 |
+
# Resize maintaining aspect ratio
|
| 53 |
+
scale = min(target_size / w, target_size / h)
|
| 54 |
+
new_w, new_h = int(w * scale), int(h * scale)
|
| 55 |
+
image_resized = cv2.resize(image_np, (new_w, new_h))
|
| 56 |
|
| 57 |
+
# Prepare image for Grounding DINO
|
| 58 |
+
inputs = dino_processor(images=image_resized, text=text_prompt, return_tensors="pt")
|
|
|
|
| 59 |
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
outputs = dino_model(**inputs)
|
| 62 |
|
| 63 |
+
# Post-process results
|
| 64 |
+
results = dino_processor.post_process_grounded_object_detection(
|
| 65 |
+
outputs,
|
| 66 |
+
inputs.input_ids,
|
| 67 |
+
box_threshold=0.25,
|
| 68 |
+
text_threshold=0.25,
|
| 69 |
+
target_sizes=[(new_h, new_w)]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if len(results) == 0 or len(results[0]["boxes"]) == 0:
|
| 73 |
+
return image, {"error": "No fish detected", "detections": 0}
|
| 74 |
+
|
| 75 |
+
# Get boxes and scores
|
| 76 |
+
boxes = results[0]["boxes"].cpu().numpy()
|
| 77 |
+
scores = results[0]["scores"].cpu().numpy()
|
| 78 |
+
|
| 79 |
+
# Use SAM to segment
|
| 80 |
+
sam_predictor.set_image(image_resized)
|
| 81 |
+
|
| 82 |
+
# Convert boxes to SAM format
|
| 83 |
+
masks = []
|
| 84 |
+
for box in boxes:
|
| 85 |
+
box_sam = np.array([box[0], box[1], box[2], box[3]])
|
| 86 |
+
mask, _, _ = sam_predictor.predict(box=box_sam, multimask_output=False)
|
| 87 |
+
masks.append(mask[0])
|
| 88 |
|
| 89 |
# Create visualization
|
| 90 |
+
annotated_image = image_resized.copy()
|
| 91 |
|
| 92 |
+
# Draw masks
|
| 93 |
+
for mask in masks:
|
| 94 |
+
color_mask = np.zeros_like(annotated_image)
|
| 95 |
+
color_mask[mask] = [0, 255, 0] # Green mask
|
| 96 |
+
annotated_image = cv2.addWeighted(annotated_image, 1, color_mask, 0.5, 0)
|
| 97 |
|
| 98 |
+
# Draw bounding boxes
|
| 99 |
+
for box in boxes:
|
| 100 |
+
x1, y1, x2, y2 = map(int, box)
|
| 101 |
+
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
| 102 |
|
| 103 |
+
# Calculate metadata
|
| 104 |
+
total_pixels = new_w * new_h
|
| 105 |
+
mask_pixels = sum(np.sum(mask) for mask in masks)
|
| 106 |
+
mask_percentage = (mask_pixels / total_pixels) * 100
|
| 107 |
|
| 108 |
+
metadata = {
|
| 109 |
+
"detections": len(boxes),
|
| 110 |
+
"avg_confidence": float(np.mean(scores)),
|
| 111 |
+
"image_size": f"{new_w}x{new_h}",
|
| 112 |
+
"mask_percentage": f"{mask_percentage:.2f}%"
|
| 113 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
return Image.fromarray(annotated_image), metadata
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
+
return image, {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
# Create Gradio interface
|
| 121 |
+
with gr.Blocks(title="Grounded SAM - Fish Detection") as demo:
|
| 122 |
+
gr.Markdown("# π Grounded SAM: Fish Detection & Segmentation")
|
| 123 |
+
gr.Markdown("Upload an image and detect fish using Grounding DINO + Segment Anything Model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
with gr.Row():
|
| 126 |
with gr.Column():
|
| 127 |
+
input_image = gr.Image(type="pil", label="Upload Image")
|
| 128 |
+
text_prompt = gr.Textbox(value="fish", label="Detection Prompt")
|
| 129 |
quality = gr.Radio(
|
| 130 |
+
choices=["Low (256px)", "Medium (512px)", "High (1024px)"],
|
| 131 |
+
value="Medium (512px)",
|
| 132 |
+
label="Processing Quality"
|
|
|
|
| 133 |
)
|
| 134 |
+
submit_btn = gr.Button("Process Image", variant="primary")
|
| 135 |
|
| 136 |
with gr.Column():
|
| 137 |
+
output_image = gr.Image(label="Detection Result")
|
| 138 |
+
output_metadata = gr.JSON(label="Detection Metadata")
|
| 139 |
|
| 140 |
+
submit_btn.click(
|
| 141 |
+
fn=detect_and_segment,
|
| 142 |
+
inputs=[input_image, text_prompt, quality],
|
| 143 |
+
outputs=[output_image, output_metadata]
|
| 144 |
)
|
| 145 |
|
| 146 |
+
gr.Examples(
|
| 147 |
+
examples=[
|
| 148 |
+
["fish_angler.jpg", "fish", "High (1024px)"],
|
| 149 |
+
],
|
| 150 |
+
inputs=[input_image, text_prompt, quality]
|
| 151 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
if __name__ == "__main__":
|
| 154 |
+
demo.launch()
|