Spaces:
Sleeping
Sleeping
File size: 4,693 Bytes
4ca6be9 bb387d5 e9274bc bb387d5 e9274bc bb387d5 e9274bc bb387d5 4ca6be9 bb387d5 e9274bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
from collections import Counter
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import base64
from io import BytesIO
# -------------------------------
# Load YOLO model safely with HF token
# -------------------------------
print("π§ Loading YOLO model...")
hf_token = os.getenv("HF_TOKEN") # Make sure your token is in environment variables
try:
model_path = "best.pt"
try:
model = YOLO(model_path)
except FileNotFoundError:
if not hf_token:
raise ValueError("HF_TOKEN not set in environment!")
print("π Model not found locally β downloading from HF Hub with token...")
model_path = hf_hub_download(
repo_id="Faethon88/sar",
filename="best.pt",
use_auth_token=hf_token
)
model = YOLO(model_path)
print("β
Model loaded successfully!")
except Exception as e:
print(f"β Model load failed: {e}")
model = None
# -------------------------------
# Detection logic
# -------------------------------
def detect_ships(image: Image.Image, confidence: float):
if model is None:
return None, "β Model not loaded."
try:
img_np = np.array(image.convert("RGB"))
results = model.predict(img_np, conf=confidence, verbose=False)
result = results[0]
annotated = img_np.copy()
boxes = result.boxes.xyxy.cpu().numpy() if result.boxes else []
confs = result.boxes.conf.cpu().numpy().tolist() if result.boxes else []
class_ids = result.boxes.cls.cpu().numpy().tolist() if result.boxes else []
class_names = []
for (x1, y1, x2, y2), cls_id, conf in zip(boxes, class_ids, confs):
cls_name = model.names.get(int(cls_id), "ship")
class_names.append(cls_name)
cv2.rectangle(annotated, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
cv2.putText(
annotated,
f"{cls_name} {conf:.2f}",
(int(x1), int(y1) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 0),
2
)
annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
summary = (
"Detections:\n" + "\n".join([f"- {cls}: {cnt}" for cls, cnt in Counter(class_names).items()])
if class_names else "No ships detected."
)
summary += f"\nConfidence threshold: {confidence:.2f}\nTotal detections: {len(class_names)}"
return annotated, summary
except Exception as e:
return None, f"β Detection failed: {e}"
# -------------------------------
# Gradio API function with wrapper for remote dict input
# -------------------------------
def predict(image, confidence):
print("DEBUG: predict called")
print("DEBUG: raw image type:", type(image))
print("DEBUG: confidence type/value:", type(confidence), confidence)
# Handle dict input from remote client (Flask sends {"name":..., "data": data_uri})
if isinstance(image, dict):
data = image.get("data") or image.get("image") or ""
if data and isinstance(data, str) and data.startswith("data:image"):
try:
header, b64 = data.split(",", 1)
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
print("DEBUG: decoded image from data URI to PIL.Image")
except Exception as e:
print("ERROR decoding data URI:", e)
raise
if image is None:
raise ValueError("Received empty image")
return detect_ships(image, confidence)
# -------------------------------
# Gradio UI + API
# -------------------------------
with gr.Blocks(title="π°οΈ SAR Ship Detection") as demo:
gr.Markdown("## π°οΈ SAR Ship Detection\nUpload a SAR image.")
with gr.Row():
image_in = gr.Image(type="pil", label="Upload SAR Image")
conf = gr.Slider(0.1, 1.0, 0.5, label="Confidence")
with gr.Row():
image_out = gr.Image(type="numpy", label="Detection Results")
text_out = gr.Textbox(label="Summary")
btn = gr.Button("π Run Detection")
btn.click(predict, [image_in, conf], [image_out, text_out], api_name="predict")
# -------------------------------
# Launch Gradio with verbose errors & debug
# -------------------------------
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True, # Show detailed errors in browser
debug=True # Print detailed logs to console
) |