sar_imaging / app.py
Faethon88's picture
Update app.py
e9274bc verified
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
from collections import Counter
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import base64
from io import BytesIO
# -------------------------------
# Load YOLO model safely with HF token
# -------------------------------
print("πŸ”§ Loading YOLO model...")
hf_token = os.getenv("HF_TOKEN") # Make sure your token is in environment variables
try:
model_path = "best.pt"
try:
model = YOLO(model_path)
except FileNotFoundError:
if not hf_token:
raise ValueError("HF_TOKEN not set in environment!")
print("🌐 Model not found locally β€” downloading from HF Hub with token...")
model_path = hf_hub_download(
repo_id="Faethon88/sar",
filename="best.pt",
use_auth_token=hf_token
)
model = YOLO(model_path)
print("βœ… Model loaded successfully!")
except Exception as e:
print(f"❌ Model load failed: {e}")
model = None
# -------------------------------
# Detection logic
# -------------------------------
def detect_ships(image: Image.Image, confidence: float):
if model is None:
return None, "❌ Model not loaded."
try:
img_np = np.array(image.convert("RGB"))
results = model.predict(img_np, conf=confidence, verbose=False)
result = results[0]
annotated = img_np.copy()
boxes = result.boxes.xyxy.cpu().numpy() if result.boxes else []
confs = result.boxes.conf.cpu().numpy().tolist() if result.boxes else []
class_ids = result.boxes.cls.cpu().numpy().tolist() if result.boxes else []
class_names = []
for (x1, y1, x2, y2), cls_id, conf in zip(boxes, class_ids, confs):
cls_name = model.names.get(int(cls_id), "ship")
class_names.append(cls_name)
cv2.rectangle(annotated, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
cv2.putText(
annotated,
f"{cls_name} {conf:.2f}",
(int(x1), int(y1) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 0),
2
)
annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
summary = (
"Detections:\n" + "\n".join([f"- {cls}: {cnt}" for cls, cnt in Counter(class_names).items()])
if class_names else "No ships detected."
)
summary += f"\nConfidence threshold: {confidence:.2f}\nTotal detections: {len(class_names)}"
return annotated, summary
except Exception as e:
return None, f"❌ Detection failed: {e}"
# -------------------------------
# Gradio API function with wrapper for remote dict input
# -------------------------------
def predict(image, confidence):
print("DEBUG: predict called")
print("DEBUG: raw image type:", type(image))
print("DEBUG: confidence type/value:", type(confidence), confidence)
# Handle dict input from remote client (Flask sends {"name":..., "data": data_uri})
if isinstance(image, dict):
data = image.get("data") or image.get("image") or ""
if data and isinstance(data, str) and data.startswith("data:image"):
try:
header, b64 = data.split(",", 1)
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
print("DEBUG: decoded image from data URI to PIL.Image")
except Exception as e:
print("ERROR decoding data URI:", e)
raise
if image is None:
raise ValueError("Received empty image")
return detect_ships(image, confidence)
# -------------------------------
# Gradio UI + API
# -------------------------------
with gr.Blocks(title="πŸ›°οΈ SAR Ship Detection") as demo:
gr.Markdown("## πŸ›°οΈ SAR Ship Detection\nUpload a SAR image.")
with gr.Row():
image_in = gr.Image(type="pil", label="Upload SAR Image")
conf = gr.Slider(0.1, 1.0, 0.5, label="Confidence")
with gr.Row():
image_out = gr.Image(type="numpy", label="Detection Results")
text_out = gr.Textbox(label="Summary")
btn = gr.Button("πŸš€ Run Detection")
btn.click(predict, [image_in, conf], [image_out, text_out], api_name="predict")
# -------------------------------
# Launch Gradio with verbose errors & debug
# -------------------------------
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True, # Show detailed errors in browser
debug=True # Print detailed logs to console
)