File size: 10,094 Bytes
dbced4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1457065
dbced4f
 
1457065
 
 
 
 
 
 
dbced4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
inference.py β€” Video Segmentation Inference Engine
Extracted from U-Net + DeepLabV3 notebook.

Loads DeepLabV3-ResNet50 once at startup and exposes:
  - segment_frame(frame_bgr) -> (seg_rgb, blend_bgr, detected_classes)
  - process_video(input_path, output_path, progress_cb) -> None
"""

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
import warnings
import logging
import os
import subprocess
import tempfile

warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)

# ─── PASCAL VOC 21 Classes ───────────────────────────────────────────────────

VOC_CLASSES = [
    "background", "aeroplane", "bicycle", "bird", "boat",
    "bottle", "bus", "car", "cat", "chair",
    "cow", "diningtable", "dog", "horse", "motorbike",
    "person", "potted plant", "sheep", "sofa", "train",
    "tv/monitor",
]

# Vibrant perceptually distinct colours (RGB)
PALETTE = np.array([
    [  0,   0,   0],   # 0  background
    [135, 206, 235],   # 1  aeroplane
    [255, 165,   0],   # 2  bicycle
    [255, 215,   0],   # 3  bird
    [  0, 191, 255],   # 4  boat
    [148,   0, 211],   # 5  bottle
    [255,  20, 147],   # 6  bus
    [220,  20,  60],   # 7  car
    [255, 140,   0],   # 8  cat
    [139,  69,  19],   # 9  chair
    [255, 255,   0],   # 10 cow
    [210, 105,  30],   # 11 dining table
    [186,  85, 211],   # 12 dog
    [255, 105, 180],   # 13 horse
    [  0, 255, 127],   # 14 motorbike
    [255,  69,   0],   # 15 person
    [ 34, 139,  34],   # 16 potted plant
    [240, 230, 140],   # 17 sheep
    [  0, 206, 209],   # 18 sofa
    [  0,   0, 255],   # 19 train
    [127, 255, 212],   # 20 tv/monitor
], dtype=np.uint8)

# ─── Model Singleton ─────────────────────────────────────────────────────────

_model = None
_preprocess = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def get_ffmpeg() -> str:
    """Return path to ffmpeg β€” uses bundled imageio-ffmpeg if system ffmpeg not found."""
    import shutil
    sys_ffmpeg = shutil.which("ffmpeg")
    if sys_ffmpeg:
        return sys_ffmpeg
    try:
        import imageio_ffmpeg
        return imageio_ffmpeg.get_ffmpeg_exe()
    except ImportError:
        raise RuntimeError(
            "ffmpeg not found. Install it: brew install ffmpeg  "
            "or: pip install imageio-ffmpeg"
        )


def get_model():
    """Load and cache the model (called once at startup)."""
    global _model, _preprocess
    if _model is None:
        logger.info(f"Loading DeepLabV3-ResNet50 on {DEVICE}...")
        weights = DeepLabV3_ResNet50_Weights.DEFAULT
        _model = deeplabv3_resnet50(weights=weights).to(DEVICE)
        _model.eval()
        _preprocess = weights.transforms()
        logger.info("Model loaded successfully.")
    return _model, _preprocess


# ─── Core Inference Helpers ───────────────────────────────────────────────────

def decode_segmap(seg_mask: np.ndarray) -> np.ndarray:
    """Convert (H,W) class index map β†’ (H,W,3) RGB colour image."""
    return PALETTE[seg_mask % len(PALETTE)]


def segment_frame(frame_bgr: np.ndarray, alpha: float = 0.55):
    """
    Segment a single BGR frame.

    Returns:
        seg_rgb   : pure colour mask (H,W,3) uint8
        blend_bgr : original blended with mask (H,W,3) uint8
        detected  : set of detected class IDs (excluding background)
    """
    model, preprocess = get_model()
    h, w = frame_bgr.shape[:2]
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    pil_img = Image.fromarray(frame_rgb)

    inp = preprocess(pil_img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        out = model(inp)["out"]
        pred = out.argmax(dim=1).squeeze().cpu().numpy()

    pred_resized = cv2.resize(
        pred.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST
    )

    seg_rgb = decode_segmap(pred_resized)
    seg_bgr = cv2.cvtColor(seg_rgb, cv2.COLOR_RGB2BGR)
    blend_bgr = cv2.addWeighted(frame_bgr, 1 - alpha, seg_bgr, alpha, 0)
    detected = set(np.unique(pred_resized).tolist()) - {0}

    return seg_rgb, blend_bgr, detected


def make_legend_bar(class_ids: set, bar_w: int, bar_h: int = 40) -> np.ndarray:
    """Render a colour legend strip for detected classes."""
    bar = np.zeros((bar_h, bar_w, 3), dtype=np.uint8)
    classes = sorted(class_ids)
    if not classes:
        return bar
    sw = bar_w // max(len(classes), 1)
    for i, cid in enumerate(classes):
        x0, x1 = i * sw, min((i + 1) * sw, bar_w)
        color = PALETTE[cid % len(PALETTE)].tolist()
        bar[:, x0:x1] = color
        label = VOC_CLASSES[cid] if cid < len(VOC_CLASSES) else str(cid)
        cv2.putText(
            bar, label, (x0 + 3, bar_h - 8),
            cv2.FONT_HERSHEY_SIMPLEX, 0.38, (255, 255, 255), 1, cv2.LINE_AA,
        )
    return bar


# ─── Video Processing ─────────────────────────────────────────────────────────

def _reencode_h264(raw_path: str, final_path: str, fps: float):
    """
    Re-encode a raw opencv-written video to H.264 MP4 using ffmpeg.
    H.264 is required for browser <video> playback.
    """
    ffmpeg = get_ffmpeg()
    cmd = [
        ffmpeg, "-y",
        "-r", str(fps),          # Set input frame rate
        "-i", raw_path,
        "-vcodec", "libx264",
        "-pix_fmt", "yuv420p",   # required for QuickTime / Safari
        "-preset", "medium",
        "-crf", "23",            # quality
        "-profile:v", "high",    # high compatibility profile
        "-level", "4.0",
        "-movflags", "+faststart",
        "-an",                   # no audio
        final_path,
    ]
    logger.info(f"Re-encoding to H.264: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        logger.error(f"ffmpeg error: {result.stderr[-500:]}")
        raise RuntimeError(f"ffmpeg re-encoding failed: {result.stderr[-300:]}")
    logger.info("H.264 re-encoding complete.")


def process_video(
    input_path: str,
    output_path: str,
    progress_callback=None,
    alpha: float = 0.55,
    max_dim: int = 640,
):
    """
    Process a video file frame-by-frame and write browser-compatible H.264 MP4.

    Args:
        input_path: path to input video
        output_path: path to write final H.264 MP4 (browser-playable)
        progress_callback: callable(pct: float, detected_names: list) or None
        alpha: blend alpha for overlay (0=original, 1=mask)
        max_dim: resize longest edge to this before inference (for speed)
    """
    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        raise ValueError(f"Cannot open video: {input_path}")

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Resize to max_dim on longest edge (keeps aspect ratio)
    scale = min(max_dim / orig_w, max_dim / orig_h, 1.0)
    out_w = int(orig_w * scale)
    out_h = int(orig_h * scale)

    # H.264 requires even dimensions
    out_w = out_w if out_w % 2 == 0 else out_w - 1
    out_h = out_h if out_h % 2 == 0 else out_h - 1

    combined_w = out_w * 2
    combined_h = out_h + 44  # +44px for legend bar
    # also ensure combined dims are even
    combined_w = combined_w if combined_w % 2 == 0 else combined_w - 1
    combined_h = combined_h if combined_h % 2 == 0 else combined_h - 1

    # Write raw frames to a temp file first (mp4v is fastest for write)
    # then re-encode to H.264 for browser compatibility
    tmp_fd, tmp_path = tempfile.mkstemp(suffix="_raw.mp4")
    os.close(tmp_fd)

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(tmp_path, fourcc, fps, (combined_w, combined_h))
    if not writer.isOpened():
        raise RuntimeError(f"Failed to open VideoWriter for {tmp_path}")

    frame_idx = 0
    all_detected = set()

    logger.info(f"Processing {total_frames} frames @ {fps:.1f} fps β€” output {combined_w}x{combined_h}")

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Resize frame for inference
        if scale < 1.0 or frame.shape[1] != out_w or frame.shape[0] != out_h:
            frame = cv2.resize(frame, (out_w, out_h), interpolation=cv2.INTER_AREA)

        seg_rgb, blend_bgr, detected = segment_frame(frame, alpha=alpha)
        all_detected.update(detected)

        # Legend bar (colour + label per class)
        legend = make_legend_bar(all_detected, combined_w, bar_h=44)
        legend_bgr = cv2.cvtColor(legend, cv2.COLOR_RGB2BGR)

        # Side-by-side: original left | segmented overlay right
        side_by_side = np.hstack([frame, blend_bgr])
        combined = np.vstack([side_by_side, legend_bgr])

        writer.write(combined)
        frame_idx += 1

        if progress_callback and total_frames > 0:
            pct = round(frame_idx / total_frames * 100, 1)
            detected_names = [
                VOC_CLASSES[c] for c in sorted(all_detected) if c < len(VOC_CLASSES)
            ]
            progress_callback(pct, detected_names)

    cap.release()
    writer.release()
    logger.info(f"Raw frames written to temp: {tmp_path}")

    # Re-encode raw mp4v β†’ H.264 for browser playback
    try:
        _reencode_h264(tmp_path, output_path, fps)
    finally:
        # Always clean up temp file
        try:
            os.unlink(tmp_path)
        except OSError:
            pass

    logger.info(f"Final H.264 output: {output_path}")
    return all_detected