Faethon88 commited on
Commit
4ca6be9
Β·
verified Β·
1 Parent(s): da38447

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -1
app.py CHANGED
@@ -1 +1,106 @@
1
- import gradio as gr from ultralytics import YOLO import cv2, numpy as np from collections import Counter from PIL import Image from huggingface_hub import hf_hub_download # ------------------------------- # Load YOLO model safely # ------------------------------- print("πŸ”§ Loading YOLO model...") try: model_path = "best.pt" try: model = YOLO(model_path) except FileNotFoundError: print("🌐 Model not found locally β€” downloading from HF Hub...") model_path = hf_hub_download(repo_id="Faethon88/sar", filename="best.pt") model = YOLO(model_path) print("βœ… Model loaded successfully!") except Exception as e: print(f"❌ Model load failed: {e}") model = None # ------------------------------- # Detection logic # ------------------------------- def detect_ships(image: Image.Image, confidence: float): if model is None: return None, "❌ Model not loaded." img_np = np.array(image.convert("RGB")) results = model.predict(img_np, conf=confidence, verbose=False) result = results[0] annotated = img_np.copy() boxes = result.boxes.xyxy.cpu().numpy() if result.boxes else [] confs = result.boxes.conf.cpu().numpy().tolist() if result.boxes else [] class_ids = result.boxes.cls.cpu().numpy().tolist() if result.boxes else [] class_names = [] for (x1, y1, x2, y2), cls_id, conf in zip(boxes, class_ids, confs): cls_name = model.names.get(int(cls_id), "ship") class_names.append(cls_name) cv2.rectangle(annotated, (int(x1), int(y1)), (int(x2), int(y2)), (0,255,0), 2) cv2.putText(annotated, f"{cls_name} {conf:.2f}", (int(x1), int(y1)-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2) annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) summary = "Detections:\n" + "\n".join([f"- {cls}: {cnt}" for cls, cnt in Counter(class_names).items()]) \ if class_names else "No ships detected." summary += f"\nConfidence threshold: {confidence:.2f}\nTotal detections: {len(class_names)}" return annotated, summary # ------------------------------- # Gradio API function # ------------------------------- def predict(image, confidence): return detect_ships(image, confidence) # ------------------------------- # Gradio UI + API # ------------------------------- with gr.Blocks(title="πŸ›°οΈ SAR Ship Detection") as demo: gr.Markdown("## πŸ›°οΈ SAR Ship Detection\nUpload a SAR image.") with gr.Row(): image_in = gr.Image(type="pil", label="Upload SAR Image") # <-- type="pil" conf = gr.Slider(0.1, 1.0, 0.5, label="Confidence") with gr.Row(): image_out = gr.Image(type="numpy", label="Detection Results") text_out = gr.Textbox(label="Summary") btn = gr.Button("πŸš€ Run Detection") btn.click(predict, [image_in, conf], [image_out, text_out], api_name="predict") if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ import cv2
4
+ import numpy as np
5
+ from collections import Counter
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+
10
+ # -------------------------------
11
+ # Load YOLO model safely with HF token
12
+ # -------------------------------
13
+ print("πŸ”§ Loading YOLO model...")
14
+ hf_token = os.getenv("HF_TOKEN") # Make sure your token is in environment variables
15
+
16
+ try:
17
+ model_path = "best.pt"
18
+ try:
19
+ model = YOLO(model_path)
20
+ except FileNotFoundError:
21
+ if not hf_token:
22
+ raise ValueError("HF_TOKEN not set in environment!")
23
+ print("🌐 Model not found locally β€” downloading from HF Hub with token...")
24
+ model_path = hf_hub_download(
25
+ repo_id="Faethon88/sar",
26
+ filename="best.pt",
27
+ use_auth_token=hf_token
28
+ )
29
+ model = YOLO(model_path)
30
+ print("βœ… Model loaded successfully!")
31
+ except Exception as e:
32
+ print(f"❌ Model load failed: {e}")
33
+ model = None
34
+
35
+ # -------------------------------
36
+ # Detection logic
37
+ # -------------------------------
38
+ def detect_ships(image: Image.Image, confidence: float):
39
+ if model is None:
40
+ return None, "❌ Model not loaded."
41
+
42
+ try:
43
+ img_np = np.array(image.convert("RGB"))
44
+ results = model.predict(img_np, conf=confidence, verbose=False)
45
+ result = results[0]
46
+
47
+ annotated = img_np.copy()
48
+ boxes = result.boxes.xyxy.cpu().numpy() if result.boxes else []
49
+ confs = result.boxes.conf.cpu().numpy().tolist() if result.boxes else []
50
+ class_ids = result.boxes.cls.cpu().numpy().tolist() if result.boxes else []
51
+
52
+ class_names = []
53
+ for (x1, y1, x2, y2), cls_id, conf in zip(boxes, class_ids, confs):
54
+ cls_name = model.names.get(int(cls_id), "ship")
55
+ class_names.append(cls_name)
56
+ cv2.rectangle(annotated, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
57
+ cv2.putText(
58
+ annotated,
59
+ f"{cls_name} {conf:.2f}",
60
+ (int(x1), int(y1) - 10),
61
+ cv2.FONT_HERSHEY_SIMPLEX,
62
+ 0.6,
63
+ (255, 255, 0),
64
+ 2
65
+ )
66
+
67
+ annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
68
+ summary = (
69
+ "Detections:\n" + "\n".join([f"- {cls}: {cnt}" for cls, cnt in Counter(class_names).items()])
70
+ if class_names else "No ships detected."
71
+ )
72
+ summary += f"\nConfidence threshold: {confidence:.2f}\nTotal detections: {len(class_names)}"
73
+ return annotated, summary
74
+ except Exception as e:
75
+ return None, f"❌ Detection failed: {e}"
76
+
77
+ # -------------------------------
78
+ # Gradio API function
79
+ # -------------------------------
80
+ def predict(image, confidence):
81
+ return detect_ships(image, confidence)
82
+
83
+ # -------------------------------
84
+ # Gradio UI + API
85
+ # -------------------------------
86
+ with gr.Blocks(title="πŸ›°οΈ SAR Ship Detection") as demo:
87
+ gr.Markdown("## πŸ›°οΈ SAR Ship Detection\nUpload a SAR image.")
88
+ with gr.Row():
89
+ image_in = gr.Image(type="pil", label="Upload SAR Image")
90
+ conf = gr.Slider(0.1, 1.0, 0.5, label="Confidence")
91
+ with gr.Row():
92
+ image_out = gr.Image(type="numpy", label="Detection Results")
93
+ text_out = gr.Textbox(label="Summary")
94
+ btn = gr.Button("πŸš€ Run Detection")
95
+ btn.click(predict, [image_in, conf], [image_out, text_out], api_name="predict")
96
+
97
+ # -------------------------------
98
+ # Launch Gradio with verbose errors & debug
99
+ # -------------------------------
100
+ if __name__ == "__main__":
101
+ demo.launch(
102
+ server_name="0.0.0.0",
103
+ server_port=7860,
104
+ show_error=True, # Show detailed errors in browser
105
+ debug=True # Print detailed logs to console
106
+ )