LhatMjnk commited on
Commit
059e297
·
verified ·
1 Parent(s): d1130ec

Upload folder using huggingface_hub

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Coral Streaming
3
- emoji: 💻
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.49.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Coral_Streaming
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.49.0
 
 
6
  ---
 
 
__pycache__/inference.cpython-311.pyc ADDED
Binary file (9.77 kB). View file
 
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+
6
+ from inference import CoralSegModel, id2label, label2color, create_segmentation_overlay
7
+ model = CoralSegModel()
8
+
9
+ # ---- helpers ----
10
+ def _safe_read(cap):
11
+ ok, frame = cap.read()
12
+ return frame if ok and frame is not None else None
13
+
14
+ def build_annotations(pred_map: np.ndarray, selected: list[str]) -> list[tuple[np.ndarray, str]]:
15
+ """Return [(mask,label), ...] where mask is 0/1 float HxW for AnnotatedImage."""
16
+ if pred_map is None or not selected:
17
+ return []
18
+
19
+ # Create reverse mapping: label_name -> class_id
20
+ label2id = {label: int(id_str) for id_str, label in id2label.items()}
21
+
22
+ anns = []
23
+ for label_name in selected:
24
+ if label_name not in label2id:
25
+ continue # Skip unknown labels
26
+
27
+ class_id = label2id[label_name] # Convert label name to class ID
28
+ mask = (pred_map == class_id).astype(np.float32)
29
+ if mask.sum() > 0:
30
+ anns.append((mask, label_name)) # Use the label name for display
31
+ return anns
32
+
33
+ # ==============================
34
+ # STREAMING EVENT FUNCTIONS
35
+ # ==============================
36
+ # IMPORTANT: make the event functions themselves generators.
37
+ # Also: include the States as outputs so we can update them every frame.
38
+ def remote_start(url: str, n: int, pred_state, base_state):
39
+ if not url:
40
+ return
41
+ cap = cv2.VideoCapture(url)
42
+ if not cap.isOpened():
43
+ return
44
+ idx = 0
45
+ try:
46
+ while True:
47
+ frame = _safe_read(cap)
48
+ if frame is None:
49
+ break
50
+ if n > 1 and (idx % n) != 0:
51
+ idx += 1
52
+ continue
53
+ pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
54
+ # yield live image + updated States' *values*
55
+ yield overlay_rgb, pred_map, base_rgb
56
+ idx += 1
57
+ finally:
58
+ cap.release()
59
+
60
+ def upload_start(video_file: str, n: int):
61
+ if not video_file:
62
+ return
63
+ cap = cv2.VideoCapture(video_file)
64
+ if not cap.isOpened():
65
+ return
66
+ idx = 0
67
+ try:
68
+ while True:
69
+ ok, frame = cap.read()
70
+ if not ok or frame is None:
71
+ break
72
+ if n > 1 and (idx % n) != 0:
73
+ idx += 1
74
+ continue
75
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
+ pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
77
+ yield overlay_rgb, pred_map, base_rgb
78
+ idx += 1
79
+ finally:
80
+ cap.release()
81
+
82
+ # ==============================
83
+ # SNAPSHOT / TOGGLES (non-streaming)
84
+ # ==============================
85
+ # NOTE: When you pass gr.State as an input, you receive the *value*, not the wrapper.
86
+ def make_snapshot(selected_labels, pred_map, base_rgb, alpha=0.25):
87
+ if pred_map is None or base_rgb is None:
88
+ return gr.update()
89
+ # rebuild overlay to match the live look
90
+ overlay = create_segmentation_overlay(pred_map, id2label, label2color, Image.fromarray(base_rgb), alpha=alpha)
91
+ ann = build_annotations(pred_map, selected_labels or [])
92
+ return (overlay, ann) # (base_image, [(mask,label), ...])
93
+
94
+ # ==============================
95
+ # UI
96
+ # ==============================
97
+ with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo:
98
+ gr.Markdown("# CoralScapes Streaming Segmentation")
99
+ gr.Markdown(
100
+ "Left: **live stream** (fast). Right: **snapshot** with **hover labels** and **per-class toggles**."
101
+ )
102
+
103
+ with gr.Tab("Remote Stream (RTSP/HTTP)"):
104
+ with gr.Row():
105
+ with gr.Column(scale=2):
106
+
107
+ # States start as None. We'll UPDATE them on every frame by returning them as outputs.
108
+ pred_state_remote = gr.State(None) # holds last pred_map (HxW np.uint8)
109
+ base_state_remote = gr.State(None) # holds last base_rgb (HxWx3 uint8)
110
+
111
+ live_remote = gr.Image(label="Live segmented stream")
112
+
113
+ start_btn = gr.Button("Start")
114
+
115
+ snap_btn_remote = gr.Button("📸 Snapshot (hover-able)")
116
+ hover_remote = gr.AnnotatedImage(label="Snapshot (hover to see label)")
117
+
118
+
119
+ with gr.Column(scale=1):
120
+ url = gr.Textbox(label="Stream URL", placeholder="rtsp://user:pass@ip:port/…")
121
+ skip = gr.Slider(1, 60, value=10, step=1, label="Process every Nth frame")
122
+
123
+ toggles_remote = gr.CheckboxGroup(
124
+ choices=list(id2label.values()), value=list(id2label.values()),
125
+ label="Toggle classes in snapshot",
126
+ )
127
+
128
+ start_btn.click(
129
+ remote_start,
130
+ inputs=[url, skip, pred_state_remote, base_state_remote],
131
+ outputs=[live_remote, pred_state_remote, base_state_remote],
132
+ queue=True, # be explicit; required for generator streaming
133
+ )
134
+
135
+ snap_btn_remote.click(
136
+ make_snapshot,
137
+ inputs=[toggles_remote, pred_state_remote, base_state_remote],
138
+ outputs=[hover_remote],
139
+ )
140
+ toggles_remote.change(
141
+ make_snapshot,
142
+ inputs=[toggles_remote, pred_state_remote, base_state_remote],
143
+ outputs=[hover_remote],
144
+ )
145
+
146
+ with gr.Tab("Upload Video"):
147
+ with gr.Row():
148
+ # Left column (now contains toggles, snapshot button, and live output)
149
+ with gr.Column(scale=2):
150
+ # States remain in the same column as live_upload
151
+ pred_state_upload = gr.State(None)
152
+ base_state_upload = gr.State(None)
153
+
154
+ live_upload = gr.Image(label="Live segmented output")
155
+ start_btn2 = gr.Button("Process")
156
+
157
+ snap_btn_upload = gr.Button("📸 Snapshot (hover-able)")
158
+ hover_upload = gr.AnnotatedImage(label="Snapshot (hover to see label)")
159
+
160
+ # Right column (now contains video input and slider)
161
+ with gr.Column(scale=1):
162
+ vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video")
163
+ skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
164
+
165
+ toggles_upload = gr.CheckboxGroup(
166
+ choices=list(id2label.values()), value=list(id2label.values()),
167
+ label="Toggle classes in snapshot",
168
+ )
169
+
170
+ # Event handlers remain the same
171
+ start_btn2.click(
172
+ upload_start,
173
+ inputs=[vid_in, skip2],
174
+ outputs=[live_upload, pred_state_upload, base_state_upload],
175
+ queue=True,
176
+ )
177
+
178
+ snap_btn_upload.click(
179
+ make_snapshot,
180
+ inputs=[toggles_upload, pred_state_upload, base_state_upload],
181
+ outputs=[hover_upload],
182
+ )
183
+
184
+ toggles_upload.change(
185
+ make_snapshot,
186
+ inputs=[toggles_upload, pred_state_upload, base_state_upload],
187
+ outputs=[hover_upload],
188
+ )
189
+
190
+ if __name__ == "__main__":
191
+ demo.queue().launch(share=True)
gradio_demo.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from diffusers import DiffusionPipeline
3
+
4
+ pipe = DiffusionPipeline.from_pretrained(...)
5
+ pipe.to('cuda')
6
+
7
+ @spaces.GPU
8
+ def generate(prompt):
9
+ return pipe(prompt).images
10
+
11
+ gr.Interface(
12
+ fn=generate,
13
+ inputs=gr.Text(),
14
+ outputs=gr.Gallery(),
15
+ ).launch()
inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ import json
6
+ import urllib.request
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+ from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
11
+
12
+ id2label = json.load(urllib.request.urlopen(
13
+ "https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/id2label.json"))
14
+ label2color = json.load(urllib.request.urlopen(
15
+ "https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/label2color.json"))
16
+
17
+ # Load model from HF (swap this with your own if you want)
18
+ HF_MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024"
19
+
20
+ def create_segmentation_overlay(pred, id2label, label2color, image, alpha=0.25):
21
+ """
22
+ Colorizes the segmentation prediction and creates an overlay image.
23
+
24
+ Args:
25
+ pred: The segmentation prediction (numpy array).
26
+ id2label: Dictionary mapping class IDs to labels.
27
+ label2color: Dictionary mapping labels to colors.
28
+ image: The original PIL Image.
29
+
30
+ Returns:
31
+ A PIL Image representing the overlay of the original image and the colorized segmentation mask.
32
+ """
33
+ H, W = pred.shape
34
+ rgb = np.zeros((H, W, 3), dtype=np.uint8)
35
+
36
+ # Get unique class IDs present in the prediction
37
+ unique_classes = np.unique(pred)
38
+
39
+ # Create a mapping from class ID to color
40
+ id2color = {int(id): label2color[label] for id, label in id2label.items()}
41
+
42
+ # Define a default color for unknown classes (e.g., black)
43
+ default_color = [0, 0, 0]
44
+
45
+ # Iterate through unique class IDs and colorize the image
46
+ for class_id in unique_classes:
47
+ # Get the color for the current class ID, use default_color if not found
48
+ rgb_c = id2color.get(int(class_id), default_color)
49
+ # Assign the color to the pixels with the current class ID
50
+ rgb[pred == class_id] = rgb_c
51
+
52
+ mask_rgb = Image.fromarray(rgb)
53
+
54
+ # 4) Alpha overlay
55
+ overlay = Image.blend(image.convert("RGBA"), mask_rgb.convert("RGBA"), alpha=alpha)
56
+
57
+ return overlay
58
+
59
+ def resize_image(image, target_size=1024):
60
+ """
61
+ Used to resize the image such that the smaller side equals 1024
62
+ """
63
+ h_img, w_img = image.size
64
+ if h_img < w_img:
65
+ new_h, new_w = target_size, int(w_img * (target_size / h_img))
66
+ else:
67
+ new_h, new_w = int(h_img * (target_size / w_img)), target_size
68
+ resized_img = image.resize((new_h, new_w))
69
+ return resized_img
70
+
71
+ class CoralSegModel:
72
+ def __init__(self, device=None):
73
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
74
+
75
+ self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID)
76
+
77
+ self.model = SegformerForSemanticSegmentation.from_pretrained(
78
+ HF_MODEL_ID,
79
+ dtype=torch.bfloat16
80
+ ).to(self.device)
81
+
82
+ self.model.eval()
83
+
84
+ @torch.inference_mode()
85
+ def segment_image(self, image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40, batch_size=4) -> np.ndarray:
86
+ """
87
+ Batched sliding window inference for improved GPU utilization.
88
+ """
89
+ h_crop, w_crop = crop_size
90
+
91
+ img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0)
92
+ img = img.to(self.device, torch.bfloat16)
93
+ _, _, h_img, w_img = img.size()
94
+
95
+ h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
96
+ w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
97
+
98
+ h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
99
+ w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
100
+
101
+ preds = img.new_zeros((1, num_classes, h_img, w_img))
102
+ count_mat = img.new_zeros((1, 1, h_img, w_img))
103
+
104
+ # Collect all crops and their coordinates
105
+ crops = []
106
+ coords = []
107
+ for h_idx in range(h_grids):
108
+ for w_idx in range(w_grids):
109
+ y1 = h_idx * h_stride
110
+ x1 = w_idx * w_stride
111
+ y2 = min(y1 + h_crop, h_img)
112
+ x2 = min(x1 + w_crop, w_img)
113
+ y1 = max(y2 - h_crop, 0)
114
+ x1 = max(x2 - w_crop, 0)
115
+
116
+ crop_img = img[:, :, y1:y2, x1:x2]
117
+ crops.append(crop_img)
118
+ coords.append((x1, x2, y1, y2))
119
+
120
+ # Process crops in batches
121
+ for i in range(0, len(crops), batch_size):
122
+ batch_crops = crops[i:i+batch_size]
123
+ batch_coords = coords[i:i+batch_size]
124
+
125
+ # Stack crops into a batch
126
+ batch_tensor = torch.cat(batch_crops, dim=0)
127
+
128
+ if preprocessor:
129
+ inputs = preprocessor(batch_tensor, return_tensors="pt", device=self.device)
130
+ inputs["pixel_values"] = inputs["pixel_values"].to(self.device, torch.bfloat16)
131
+ else:
132
+ inputs = {"pixel_values": batch_tensor}
133
+
134
+ outputs = model(**inputs)
135
+
136
+ # Process each output in the batch
137
+ for j, (x1, x2, y1, y2) in enumerate(batch_coords):
138
+ resized_logits = F.interpolate(
139
+ outputs.logits[j].unsqueeze(dim=0),
140
+ size=(y2-y1, x2-x1),
141
+ mode="bilinear",
142
+ align_corners=False
143
+ )
144
+ preds[:, :, y1:y2, x1:x2] += resized_logits
145
+ count_mat[:, :, y1:y2, x1:x2] += 1
146
+
147
+ assert (count_mat == 0).sum() == 0
148
+ preds = preds / count_mat
149
+ preds = preds.argmax(dim=1)
150
+ preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
151
+ label_pred = preds.squeeze().cpu().numpy()
152
+ return label_pred
153
+
154
+ @torch.inference_mode()
155
+ def predict_map_and_overlay(self, frame_bgr: np.ndarray):
156
+ """
157
+ Returns:
158
+ pred_map: HxW (uint8/int) with class indices in [0..C-1]
159
+ overlay: HxWx3 RGB uint8 (blended color mask over original)
160
+ rgb: HxWx3 RGB uint8 original frame (for AnnotatedImage base)
161
+ """
162
+ rgb = frame_bgr
163
+
164
+ pil = Image.fromarray(rgb)
165
+ pred = self.segment_image(pil, self.processor, self.model)
166
+ overlay_rgb = create_segmentation_overlay(pred, id2label, label2color, pil, 0.45)
167
+
168
+ return pred, overlay_rgb, rgb