File size: 4,537 Bytes
61b0d5a
 
 
 
 
 
16aa079
61b0d5a
16aa079
 
 
 
61b0d5a
 
 
 
 
 
 
 
 
16aa079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61b0d5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import numpy as np
import torch
import gradio as gr

from mmseg.apis import init_model, inference_model
from huggingface_hub import hf_hub_download

# Source model repository on the Hub (avoid storing large files in the Space)
REPO_ID = "triton7777/SeaIce"
CONFIG_FILENAME = "segformer_mit-b5_8xb1-160k_pre-cityscapes_seaicergb0-1024x1024.py"
CHECKPOINT_FILENAME = "iter_160000.pth"

_model = None
_palette = None


def _get_model():
    global _model, _palette
    if _model is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        # Prefer local files if present (for local dev), otherwise download from Hub
        local_cfg = os.path.join("model", CONFIG_FILENAME)
        local_ckpt = os.path.join("model", CHECKPOINT_FILENAME)

        if os.path.exists(local_cfg) and os.path.exists(local_ckpt):
            cfg_path = local_cfg
            ckpt_path = local_ckpt
        else:
            def _download_any(name: str):
                candidates = [name, os.path.join("model", name)]
                last_err = None
                for cand in candidates:
                    try:
                        return hf_hub_download(repo_id=REPO_ID, filename=cand)
                    except Exception as e:  # pragma: no cover (runtime fallback)
                        last_err = e
                        continue
                raise last_err

            cfg_path = _download_any(CONFIG_FILENAME)
            ckpt_path = _download_any(CHECKPOINT_FILENAME)

        _model = init_model(cfg_path, ckpt_path, device=device)
        meta = getattr(_model, "dataset_meta", None) or {}
        _palette = meta.get("palette", None)
        if _palette is None:
            num_classes = int(meta.get("classes", []) and len(meta["classes"]) or 21)
            rng = np.random.RandomState(42)
            _palette = (rng.randint(0, 255, size=(num_classes, 3))).tolist()
    return _model, _palette


def colorize_mask(mask: np.ndarray, palette):
    # mask HxW int -> HxWx3 uint8
    h, w = mask.shape
    color = np.zeros((h, w, 3), dtype=np.uint8)
    max_id = int(mask.max()) if mask.size else 0
    if max_id >= len(palette):
        extra = max_id + 1 - len(palette)
        rng = np.random.RandomState(123)
        palette = palette + (rng.randint(0, 255, size=(extra, 3))).tolist()
    for cls_id, rgb in enumerate(palette):
        color[mask == cls_id] = rgb
    return color


def overlay_seg(image_np: np.ndarray, seg_rgb: np.ndarray, alpha: float = 0.5):
    img = image_np
    if img.dtype != np.uint8:
        img = np.clip(img, 0, 255).astype(np.uint8)
    if img.ndim == 2:
        img = np.stack([img] * 3, axis=-1)
    if img.shape[2] == 4:
        img = img[:, :, :3]
    seg_rgb = seg_rgb.astype(np.uint8)
    if seg_rgb.shape[:2] != img.shape[:2]:
        import cv2
        seg_rgb = cv2.resize(seg_rgb, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
    blended = (alpha * img + (1 - alpha) * seg_rgb).astype(np.uint8)
    return blended


def predict(image: np.ndarray):
    model, palette = _get_model()
    result = inference_model(model, image)
    seg = result.pred_sem_seg.data
    if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1:
        mask = seg[0].to("cpu").numpy().astype(np.int32)
    else:
        mask = seg.to("cpu").numpy().astype(np.int32)
    seg_rgb = colorize_mask(mask, palette)
    out = overlay_seg(image, seg_rgb, alpha=0.5)
    return out


with gr.Blocks() as demo:
    gr.Markdown(
        "# Sea Ice Segmentation (SegFormer-B5)\nUpload an RGB image; the app returns the image overlayed with the segmentation."
    )
    with gr.Row():
        inp = gr.Image(type="numpy", label="Input image")
        out = gr.Image(type="numpy", label="Segmentation overlay")
    with gr.Row():
        alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay alpha")
    btn = gr.Button("Segment")

    def _predict_with_alpha(img, a):
        model, palette = _get_model()
        result = inference_model(model, img)
        seg = result.pred_sem_seg.data
        if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1:
            mask = seg[0].to("cpu").numpy().astype(np.int32)
        else:
            mask = seg.to("cpu").numpy().astype(np.int32)
        seg_rgb = colorize_mask(mask, palette)
        return overlay_seg(img, seg_rgb, alpha=a)

    btn.click(fn=_predict_with_alpha, inputs=[inp, alpha], outputs=out)

    gr.Examples(examples=[], inputs=inp)

if __name__ == "__main__":
    demo.launch()