Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from ultralytics import YOLO | |
| import cv2 | |
| import numpy as np | |
| from collections import Counter | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| # ------------------------------- | |
| # Load YOLO model safely with HF token | |
| # ------------------------------- | |
| print("π§ Loading YOLO model...") | |
| hf_token = os.getenv("HF_TOKEN") # Make sure your token is in environment variables | |
| try: | |
| model_path = "best.pt" | |
| try: | |
| model = YOLO(model_path) | |
| except FileNotFoundError: | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN not set in environment!") | |
| print("π Model not found locally β downloading from HF Hub with token...") | |
| model_path = hf_hub_download( | |
| repo_id="Faethon88/sar", | |
| filename="best.pt", | |
| use_auth_token=hf_token | |
| ) | |
| model = YOLO(model_path) | |
| print("β Model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Model load failed: {e}") | |
| model = None | |
| # ------------------------------- | |
| # Detection logic | |
| # ------------------------------- | |
| def detect_ships(image: Image.Image, confidence: float): | |
| if model is None: | |
| return None, "β Model not loaded." | |
| try: | |
| img_np = np.array(image.convert("RGB")) | |
| results = model.predict(img_np, conf=confidence, verbose=False) | |
| result = results[0] | |
| annotated = img_np.copy() | |
| boxes = result.boxes.xyxy.cpu().numpy() if result.boxes else [] | |
| confs = result.boxes.conf.cpu().numpy().tolist() if result.boxes else [] | |
| class_ids = result.boxes.cls.cpu().numpy().tolist() if result.boxes else [] | |
| class_names = [] | |
| for (x1, y1, x2, y2), cls_id, conf in zip(boxes, class_ids, confs): | |
| cls_name = model.names.get(int(cls_id), "ship") | |
| class_names.append(cls_name) | |
| cv2.rectangle(annotated, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) | |
| cv2.putText( | |
| annotated, | |
| f"{cls_name} {conf:.2f}", | |
| (int(x1), int(y1) - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (255, 255, 0), | |
| 2 | |
| ) | |
| annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) | |
| summary = ( | |
| "Detections:\n" + "\n".join([f"- {cls}: {cnt}" for cls, cnt in Counter(class_names).items()]) | |
| if class_names else "No ships detected." | |
| ) | |
| summary += f"\nConfidence threshold: {confidence:.2f}\nTotal detections: {len(class_names)}" | |
| return annotated, summary | |
| except Exception as e: | |
| return None, f"β Detection failed: {e}" | |
| # ------------------------------- | |
| # Gradio API function with wrapper for remote dict input | |
| # ------------------------------- | |
| def predict(image, confidence): | |
| print("DEBUG: predict called") | |
| print("DEBUG: raw image type:", type(image)) | |
| print("DEBUG: confidence type/value:", type(confidence), confidence) | |
| # Handle dict input from remote client (Flask sends {"name":..., "data": data_uri}) | |
| if isinstance(image, dict): | |
| data = image.get("data") or image.get("image") or "" | |
| if data and isinstance(data, str) and data.startswith("data:image"): | |
| try: | |
| header, b64 = data.split(",", 1) | |
| image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") | |
| print("DEBUG: decoded image from data URI to PIL.Image") | |
| except Exception as e: | |
| print("ERROR decoding data URI:", e) | |
| raise | |
| if image is None: | |
| raise ValueError("Received empty image") | |
| return detect_ships(image, confidence) | |
| # ------------------------------- | |
| # Gradio UI + API | |
| # ------------------------------- | |
| with gr.Blocks(title="π°οΈ SAR Ship Detection") as demo: | |
| gr.Markdown("## π°οΈ SAR Ship Detection\nUpload a SAR image.") | |
| with gr.Row(): | |
| image_in = gr.Image(type="pil", label="Upload SAR Image") | |
| conf = gr.Slider(0.1, 1.0, 0.5, label="Confidence") | |
| with gr.Row(): | |
| image_out = gr.Image(type="numpy", label="Detection Results") | |
| text_out = gr.Textbox(label="Summary") | |
| btn = gr.Button("π Run Detection") | |
| btn.click(predict, [image_in, conf], [image_out, text_out], api_name="predict") | |
| # ------------------------------- | |
| # Launch Gradio with verbose errors & debug | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, # Show detailed errors in browser | |
| debug=True # Print detailed logs to console | |
| ) |