import gradio as gr from ultralytics import YOLO import cv2 import numpy as np from collections import Counter from PIL import Image from huggingface_hub import hf_hub_download import os import base64 from io import BytesIO # ------------------------------- # Load YOLO model safely with HF token # ------------------------------- print("🔧 Loading YOLO model...") hf_token = os.getenv("HF_TOKEN") # Make sure your token is in environment variables try: model_path = "best.pt" try: model = YOLO(model_path) except FileNotFoundError: if not hf_token: raise ValueError("HF_TOKEN not set in environment!") print("🌐 Model not found locally — downloading from HF Hub with token...") model_path = hf_hub_download( repo_id="Faethon88/sar", filename="best.pt", use_auth_token=hf_token ) model = YOLO(model_path) print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Model load failed: {e}") model = None # ------------------------------- # Detection logic # ------------------------------- def detect_ships(image: Image.Image, confidence: float): if model is None: return None, "❌ Model not loaded." try: img_np = np.array(image.convert("RGB")) results = model.predict(img_np, conf=confidence, verbose=False) result = results[0] annotated = img_np.copy() boxes = result.boxes.xyxy.cpu().numpy() if result.boxes else [] confs = result.boxes.conf.cpu().numpy().tolist() if result.boxes else [] class_ids = result.boxes.cls.cpu().numpy().tolist() if result.boxes else [] class_names = [] for (x1, y1, x2, y2), cls_id, conf in zip(boxes, class_ids, confs): cls_name = model.names.get(int(cls_id), "ship") class_names.append(cls_name) cv2.rectangle(annotated, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) cv2.putText( annotated, f"{cls_name} {conf:.2f}", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2 ) annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) summary = ( "Detections:\n" + "\n".join([f"- {cls}: {cnt}" for cls, cnt in Counter(class_names).items()]) if class_names else "No ships detected." ) summary += f"\nConfidence threshold: {confidence:.2f}\nTotal detections: {len(class_names)}" return annotated, summary except Exception as e: return None, f"❌ Detection failed: {e}" # ------------------------------- # Gradio API function with wrapper for remote dict input # ------------------------------- def predict(image, confidence): print("DEBUG: predict called") print("DEBUG: raw image type:", type(image)) print("DEBUG: confidence type/value:", type(confidence), confidence) # Handle dict input from remote client (Flask sends {"name":..., "data": data_uri}) if isinstance(image, dict): data = image.get("data") or image.get("image") or "" if data and isinstance(data, str) and data.startswith("data:image"): try: header, b64 = data.split(",", 1) image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") print("DEBUG: decoded image from data URI to PIL.Image") except Exception as e: print("ERROR decoding data URI:", e) raise if image is None: raise ValueError("Received empty image") return detect_ships(image, confidence) # ------------------------------- # Gradio UI + API # ------------------------------- with gr.Blocks(title="🛰️ SAR Ship Detection") as demo: gr.Markdown("## 🛰️ SAR Ship Detection\nUpload a SAR image.") with gr.Row(): image_in = gr.Image(type="pil", label="Upload SAR Image") conf = gr.Slider(0.1, 1.0, 0.5, label="Confidence") with gr.Row(): image_out = gr.Image(type="numpy", label="Detection Results") text_out = gr.Textbox(label="Summary") btn = gr.Button("🚀 Run Detection") btn.click(predict, [image_in, conf], [image_out, text_out], api_name="predict") # ------------------------------- # Launch Gradio with verbose errors & debug # ------------------------------- if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, # Show detailed errors in browser debug=True # Print detailed logs to console )