Spaces:
Sleeping
Sleeping
File size: 7,254 Bytes
059e297 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
from PIL import Image
import cv2
import numpy as np
import gradio as gr
from inference import CoralSegModel, id2label, label2color, create_segmentation_overlay
model = CoralSegModel()
# ---- helpers ----
def _safe_read(cap):
ok, frame = cap.read()
return frame if ok and frame is not None else None
def build_annotations(pred_map: np.ndarray, selected: list[str]) -> list[tuple[np.ndarray, str]]:
"""Return [(mask,label), ...] where mask is 0/1 float HxW for AnnotatedImage."""
if pred_map is None or not selected:
return []
# Create reverse mapping: label_name -> class_id
label2id = {label: int(id_str) for id_str, label in id2label.items()}
anns = []
for label_name in selected:
if label_name not in label2id:
continue # Skip unknown labels
class_id = label2id[label_name] # Convert label name to class ID
mask = (pred_map == class_id).astype(np.float32)
if mask.sum() > 0:
anns.append((mask, label_name)) # Use the label name for display
return anns
# ==============================
# STREAMING EVENT FUNCTIONS
# ==============================
# IMPORTANT: make the event functions themselves generators.
# Also: include the States as outputs so we can update them every frame.
def remote_start(url: str, n: int, pred_state, base_state):
if not url:
return
cap = cv2.VideoCapture(url)
if not cap.isOpened():
return
idx = 0
try:
while True:
frame = _safe_read(cap)
if frame is None:
break
if n > 1 and (idx % n) != 0:
idx += 1
continue
pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
# yield live image + updated States' *values*
yield overlay_rgb, pred_map, base_rgb
idx += 1
finally:
cap.release()
def upload_start(video_file: str, n: int):
if not video_file:
return
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
return
idx = 0
try:
while True:
ok, frame = cap.read()
if not ok or frame is None:
break
if n > 1 and (idx % n) != 0:
idx += 1
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
yield overlay_rgb, pred_map, base_rgb
idx += 1
finally:
cap.release()
# ==============================
# SNAPSHOT / TOGGLES (non-streaming)
# ==============================
# NOTE: When you pass gr.State as an input, you receive the *value*, not the wrapper.
def make_snapshot(selected_labels, pred_map, base_rgb, alpha=0.25):
if pred_map is None or base_rgb is None:
return gr.update()
# rebuild overlay to match the live look
overlay = create_segmentation_overlay(pred_map, id2label, label2color, Image.fromarray(base_rgb), alpha=alpha)
ann = build_annotations(pred_map, selected_labels or [])
return (overlay, ann) # (base_image, [(mask,label), ...])
# ==============================
# UI
# ==============================
with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo:
gr.Markdown("# CoralScapes Streaming Segmentation")
gr.Markdown(
"Left: **live stream** (fast). Right: **snapshot** with **hover labels** and **per-class toggles**."
)
with gr.Tab("Remote Stream (RTSP/HTTP)"):
with gr.Row():
with gr.Column(scale=2):
# States start as None. We'll UPDATE them on every frame by returning them as outputs.
pred_state_remote = gr.State(None) # holds last pred_map (HxW np.uint8)
base_state_remote = gr.State(None) # holds last base_rgb (HxWx3 uint8)
live_remote = gr.Image(label="Live segmented stream")
start_btn = gr.Button("Start")
snap_btn_remote = gr.Button("📸 Snapshot (hover-able)")
hover_remote = gr.AnnotatedImage(label="Snapshot (hover to see label)")
with gr.Column(scale=1):
url = gr.Textbox(label="Stream URL", placeholder="rtsp://user:pass@ip:port/…")
skip = gr.Slider(1, 60, value=10, step=1, label="Process every Nth frame")
toggles_remote = gr.CheckboxGroup(
choices=list(id2label.values()), value=list(id2label.values()),
label="Toggle classes in snapshot",
)
start_btn.click(
remote_start,
inputs=[url, skip, pred_state_remote, base_state_remote],
outputs=[live_remote, pred_state_remote, base_state_remote],
queue=True, # be explicit; required for generator streaming
)
snap_btn_remote.click(
make_snapshot,
inputs=[toggles_remote, pred_state_remote, base_state_remote],
outputs=[hover_remote],
)
toggles_remote.change(
make_snapshot,
inputs=[toggles_remote, pred_state_remote, base_state_remote],
outputs=[hover_remote],
)
with gr.Tab("Upload Video"):
with gr.Row():
# Left column (now contains toggles, snapshot button, and live output)
with gr.Column(scale=2):
# States remain in the same column as live_upload
pred_state_upload = gr.State(None)
base_state_upload = gr.State(None)
live_upload = gr.Image(label="Live segmented output")
start_btn2 = gr.Button("Process")
snap_btn_upload = gr.Button("📸 Snapshot (hover-able)")
hover_upload = gr.AnnotatedImage(label="Snapshot (hover to see label)")
# Right column (now contains video input and slider)
with gr.Column(scale=1):
vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video")
skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
toggles_upload = gr.CheckboxGroup(
choices=list(id2label.values()), value=list(id2label.values()),
label="Toggle classes in snapshot",
)
# Event handlers remain the same
start_btn2.click(
upload_start,
inputs=[vid_in, skip2],
outputs=[live_upload, pred_state_upload, base_state_upload],
queue=True,
)
snap_btn_upload.click(
make_snapshot,
inputs=[toggles_upload, pred_state_upload, base_state_upload],
outputs=[hover_upload],
)
toggles_upload.change(
make_snapshot,
inputs=[toggles_upload, pred_state_upload, base_state_upload],
outputs=[hover_upload],
)
if __name__ == "__main__":
demo.queue().launch(share=True)
|