File size: 7,254 Bytes
059e297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from PIL import Image
import cv2
import numpy as np
import gradio as gr

from inference import CoralSegModel, id2label, label2color, create_segmentation_overlay
model = CoralSegModel()

# ---- helpers ----
def _safe_read(cap):
    ok, frame = cap.read()
    return frame if ok and frame is not None else None

def build_annotations(pred_map: np.ndarray, selected: list[str]) -> list[tuple[np.ndarray, str]]:
    """Return [(mask,label), ...] where mask is 0/1 float HxW for AnnotatedImage."""
    if pred_map is None or not selected:
        return []
    
    # Create reverse mapping: label_name -> class_id
    label2id = {label: int(id_str) for id_str, label in id2label.items()}
    
    anns = []
    for label_name in selected:
        if label_name not in label2id:
            continue  # Skip unknown labels
        
        class_id = label2id[label_name]  # Convert label name to class ID
        mask = (pred_map == class_id).astype(np.float32)
        if mask.sum() > 0:
            anns.append((mask, label_name))  # Use the label name for display
    return anns
    
# ==============================
# STREAMING EVENT FUNCTIONS
# ==============================
# IMPORTANT: make the event functions themselves generators.
# Also: include the States as outputs so we can update them every frame.
def remote_start(url: str, n: int, pred_state, base_state):
    if not url:
        return
    cap = cv2.VideoCapture(url)
    if not cap.isOpened():
        return
    idx = 0
    try:
        while True:
            frame = _safe_read(cap)
            if frame is None:
                break
            if n > 1 and (idx % n) != 0:
                idx += 1
                continue
            pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
            # yield live image + updated States' *values*
            yield overlay_rgb, pred_map, base_rgb
            idx += 1
    finally:
        cap.release()

def upload_start(video_file: str, n: int):
    if not video_file:
        return
    cap = cv2.VideoCapture(video_file)
    if not cap.isOpened():
        return
    idx = 0
    try:
        while True:
            ok, frame = cap.read()
            if not ok or frame is None:
                break
            if n > 1 and (idx % n) != 0:
                idx += 1
                continue
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
            yield overlay_rgb, pred_map, base_rgb
            idx += 1
    finally:
        cap.release()

# ==============================
# SNAPSHOT / TOGGLES (non-streaming)
# ==============================
# NOTE: When you pass gr.State as an input, you receive the *value*, not the wrapper.
def make_snapshot(selected_labels, pred_map, base_rgb, alpha=0.25):
    if pred_map is None or base_rgb is None:
        return gr.update()
    # rebuild overlay to match the live look
    overlay = create_segmentation_overlay(pred_map, id2label, label2color, Image.fromarray(base_rgb), alpha=alpha)
    ann = build_annotations(pred_map, selected_labels or [])
    return (overlay, ann)  # (base_image, [(mask,label), ...])

# ==============================
# UI
# ==============================
with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo:
    gr.Markdown("# CoralScapes Streaming Segmentation")
    gr.Markdown(
        "Left: **live stream** (fast). Right: **snapshot** with **hover labels** and **per-class toggles**."
    )

    with gr.Tab("Remote Stream (RTSP/HTTP)"):
        with gr.Row():
            with gr.Column(scale=2):
                
                # States start as None. We'll UPDATE them on every frame by returning them as outputs.
                pred_state_remote = gr.State(None)  # holds last pred_map (HxW np.uint8)
                base_state_remote = gr.State(None)  # holds last base_rgb (HxWx3 uint8)

                live_remote = gr.Image(label="Live segmented stream")

                start_btn = gr.Button("Start")

                snap_btn_remote = gr.Button("📸 Snapshot (hover-able)")
                hover_remote = gr.AnnotatedImage(label="Snapshot (hover to see label)")
                

            with gr.Column(scale=1):
                url  = gr.Textbox(label="Stream URL", placeholder="rtsp://user:pass@ip:port/…")
                skip = gr.Slider(1, 60, value=10, step=1, label="Process every Nth frame")

                toggles_remote = gr.CheckboxGroup(
                    choices=list(id2label.values()), value=list(id2label.values()),
                    label="Toggle classes in snapshot",
                )

            start_btn.click(
                    remote_start,
                    inputs=[url, skip, pred_state_remote, base_state_remote],
                    outputs=[live_remote, pred_state_remote, base_state_remote],
                    queue=True,   # be explicit; required for generator streaming
                )
            
            snap_btn_remote.click(
                make_snapshot,
                inputs=[toggles_remote, pred_state_remote, base_state_remote],
                outputs=[hover_remote],
            )
            toggles_remote.change(
                make_snapshot,
                inputs=[toggles_remote, pred_state_remote, base_state_remote],
                outputs=[hover_remote],
            )

    with gr.Tab("Upload Video"):
        with gr.Row():
            # Left column (now contains toggles, snapshot button, and live output)
            with gr.Column(scale=2):
                # States remain in the same column as live_upload
                pred_state_upload = gr.State(None)
                base_state_upload = gr.State(None)
                
                live_upload = gr.Image(label="Live segmented output")
                start_btn2 = gr.Button("Process")
                
                snap_btn_upload = gr.Button("📸 Snapshot (hover-able)")
                hover_upload = gr.AnnotatedImage(label="Snapshot (hover to see label)")
                
            # Right column (now contains video input and slider)
            with gr.Column(scale=1):
                vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video")
                skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")

                toggles_upload = gr.CheckboxGroup(
                    choices=list(id2label.values()), value=list(id2label.values()),
                    label="Toggle classes in snapshot",
                )
                
        # Event handlers remain the same
        start_btn2.click(
            upload_start,
            inputs=[vid_in, skip2],
            outputs=[live_upload, pred_state_upload, base_state_upload],
            queue=True,
        )

        snap_btn_upload.click(
            make_snapshot,
            inputs=[toggles_upload, pred_state_upload, base_state_upload],
            outputs=[hover_upload],
        )
        
        toggles_upload.change(
            make_snapshot,
            inputs=[toggles_upload, pred_state_upload, base_state_upload],
            outputs=[hover_upload],
        )

if __name__ == "__main__":
    demo.queue().launch(share=True)