File size: 4,693 Bytes
4ca6be9
 
bb387d5
 
 
 
 
 
e9274bc
 
bb387d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9274bc
bb387d5
 
e9274bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb387d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ca6be9
bb387d5
 
 
 
 
 
 
 
 
e9274bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
from collections import Counter
from PIL import Image
from huggingface_hub import hf_hub_download
import os
import base64
from io import BytesIO

# -------------------------------
# Load YOLO model safely with HF token
# -------------------------------
print("πŸ”§ Loading YOLO model...")
hf_token = os.getenv("HF_TOKEN")  # Make sure your token is in environment variables

try:
    model_path = "best.pt"
    try:
        model = YOLO(model_path)
    except FileNotFoundError:
        if not hf_token:
            raise ValueError("HF_TOKEN not set in environment!")
        print("🌐 Model not found locally β€” downloading from HF Hub with token...")
        model_path = hf_hub_download(
            repo_id="Faethon88/sar",
            filename="best.pt",
            use_auth_token=hf_token
        )
        model = YOLO(model_path)
    print("βœ… Model loaded successfully!")
except Exception as e:
    print(f"❌ Model load failed: {e}")
    model = None

# -------------------------------
# Detection logic
# -------------------------------
def detect_ships(image: Image.Image, confidence: float):
    if model is None:
        return None, "❌ Model not loaded."

    try:
        img_np = np.array(image.convert("RGB"))
        results = model.predict(img_np, conf=confidence, verbose=False)
        result = results[0]

        annotated = img_np.copy()
        boxes = result.boxes.xyxy.cpu().numpy() if result.boxes else []
        confs = result.boxes.conf.cpu().numpy().tolist() if result.boxes else []
        class_ids = result.boxes.cls.cpu().numpy().tolist() if result.boxes else []

        class_names = []
        for (x1, y1, x2, y2), cls_id, conf in zip(boxes, class_ids, confs):
            cls_name = model.names.get(int(cls_id), "ship")
            class_names.append(cls_name)
            cv2.rectangle(annotated, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
            cv2.putText(
                annotated,
                f"{cls_name} {conf:.2f}",
                (int(x1), int(y1) - 10),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.6,
                (255, 255, 0),
                2
            )

        annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
        summary = (
            "Detections:\n" + "\n".join([f"- {cls}: {cnt}" for cls, cnt in Counter(class_names).items()])
            if class_names else "No ships detected."
        )
        summary += f"\nConfidence threshold: {confidence:.2f}\nTotal detections: {len(class_names)}"
        return annotated, summary
    except Exception as e:
        return None, f"❌ Detection failed: {e}"

# -------------------------------
# Gradio API function with wrapper for remote dict input
# -------------------------------
def predict(image, confidence):
    print("DEBUG: predict called")
    print("DEBUG: raw image type:", type(image))
    print("DEBUG: confidence type/value:", type(confidence), confidence)

    # Handle dict input from remote client (Flask sends {"name":..., "data": data_uri})
    if isinstance(image, dict):
        data = image.get("data") or image.get("image") or ""
        if data and isinstance(data, str) and data.startswith("data:image"):
            try:
                header, b64 = data.split(",", 1)
                image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
                print("DEBUG: decoded image from data URI to PIL.Image")
            except Exception as e:
                print("ERROR decoding data URI:", e)
                raise

    if image is None:
        raise ValueError("Received empty image")

    return detect_ships(image, confidence)

# -------------------------------
# Gradio UI + API
# -------------------------------
with gr.Blocks(title="πŸ›°οΈ SAR Ship Detection") as demo:
    gr.Markdown("## πŸ›°οΈ SAR Ship Detection\nUpload a SAR image.")
    with gr.Row():
        image_in = gr.Image(type="pil", label="Upload SAR Image")
        conf = gr.Slider(0.1, 1.0, 0.5, label="Confidence")
    with gr.Row():
        image_out = gr.Image(type="numpy", label="Detection Results")
        text_out = gr.Textbox(label="Summary")
    btn = gr.Button("πŸš€ Run Detection")
    btn.click(predict, [image_in, conf], [image_out, text_out], api_name="predict")

# -------------------------------
# Launch Gradio with verbose errors & debug
# -------------------------------
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,  # Show detailed errors in browser
        debug=True        # Print detailed logs to console
    )