|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import gradio as gr |
|
|
|
|
|
from mmseg.apis import init_model, inference_model |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
REPO_ID = "triton7777/SeaIce" |
|
|
CONFIG_FILENAME = "segformer_mit-b5_8xb1-160k_pre-cityscapes_seaicergb0-1024x1024.py" |
|
|
CHECKPOINT_FILENAME = "iter_160000.pth" |
|
|
|
|
|
_model = None |
|
|
_palette = None |
|
|
|
|
|
|
|
|
def _get_model(): |
|
|
global _model, _palette |
|
|
if _model is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
local_cfg = os.path.join("model", CONFIG_FILENAME) |
|
|
local_ckpt = os.path.join("model", CHECKPOINT_FILENAME) |
|
|
|
|
|
if os.path.exists(local_cfg) and os.path.exists(local_ckpt): |
|
|
cfg_path = local_cfg |
|
|
ckpt_path = local_ckpt |
|
|
else: |
|
|
def _download_any(name: str): |
|
|
candidates = [name, os.path.join("model", name)] |
|
|
last_err = None |
|
|
for cand in candidates: |
|
|
try: |
|
|
return hf_hub_download(repo_id=REPO_ID, filename=cand) |
|
|
except Exception as e: |
|
|
last_err = e |
|
|
continue |
|
|
raise last_err |
|
|
|
|
|
cfg_path = _download_any(CONFIG_FILENAME) |
|
|
ckpt_path = _download_any(CHECKPOINT_FILENAME) |
|
|
|
|
|
_model = init_model(cfg_path, ckpt_path, device=device) |
|
|
meta = getattr(_model, "dataset_meta", None) or {} |
|
|
_palette = meta.get("palette", None) |
|
|
if _palette is None: |
|
|
num_classes = int(meta.get("classes", []) and len(meta["classes"]) or 21) |
|
|
rng = np.random.RandomState(42) |
|
|
_palette = (rng.randint(0, 255, size=(num_classes, 3))).tolist() |
|
|
return _model, _palette |
|
|
|
|
|
|
|
|
def colorize_mask(mask: np.ndarray, palette): |
|
|
|
|
|
h, w = mask.shape |
|
|
color = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
max_id = int(mask.max()) if mask.size else 0 |
|
|
if max_id >= len(palette): |
|
|
extra = max_id + 1 - len(palette) |
|
|
rng = np.random.RandomState(123) |
|
|
palette = palette + (rng.randint(0, 255, size=(extra, 3))).tolist() |
|
|
for cls_id, rgb in enumerate(palette): |
|
|
color[mask == cls_id] = rgb |
|
|
return color |
|
|
|
|
|
|
|
|
def overlay_seg(image_np: np.ndarray, seg_rgb: np.ndarray, alpha: float = 0.5): |
|
|
img = image_np |
|
|
if img.dtype != np.uint8: |
|
|
img = np.clip(img, 0, 255).astype(np.uint8) |
|
|
if img.ndim == 2: |
|
|
img = np.stack([img] * 3, axis=-1) |
|
|
if img.shape[2] == 4: |
|
|
img = img[:, :, :3] |
|
|
seg_rgb = seg_rgb.astype(np.uint8) |
|
|
if seg_rgb.shape[:2] != img.shape[:2]: |
|
|
import cv2 |
|
|
seg_rgb = cv2.resize(seg_rgb, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
blended = (alpha * img + (1 - alpha) * seg_rgb).astype(np.uint8) |
|
|
return blended |
|
|
|
|
|
|
|
|
def predict(image: np.ndarray): |
|
|
model, palette = _get_model() |
|
|
result = inference_model(model, image) |
|
|
seg = result.pred_sem_seg.data |
|
|
if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1: |
|
|
mask = seg[0].to("cpu").numpy().astype(np.int32) |
|
|
else: |
|
|
mask = seg.to("cpu").numpy().astype(np.int32) |
|
|
seg_rgb = colorize_mask(mask, palette) |
|
|
out = overlay_seg(image, seg_rgb, alpha=0.5) |
|
|
return out |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown( |
|
|
"# Sea Ice Segmentation (SegFormer-B5)\nUpload an RGB image; the app returns the image overlayed with the segmentation." |
|
|
) |
|
|
with gr.Row(): |
|
|
inp = gr.Image(type="numpy", label="Input image") |
|
|
out = gr.Image(type="numpy", label="Segmentation overlay") |
|
|
with gr.Row(): |
|
|
alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay alpha") |
|
|
btn = gr.Button("Segment") |
|
|
|
|
|
def _predict_with_alpha(img, a): |
|
|
model, palette = _get_model() |
|
|
result = inference_model(model, img) |
|
|
seg = result.pred_sem_seg.data |
|
|
if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1: |
|
|
mask = seg[0].to("cpu").numpy().astype(np.int32) |
|
|
else: |
|
|
mask = seg.to("cpu").numpy().astype(np.int32) |
|
|
seg_rgb = colorize_mask(mask, palette) |
|
|
return overlay_seg(img, seg_rgb, alpha=a) |
|
|
|
|
|
btn.click(fn=_predict_with_alpha, inputs=[inp, alpha], outputs=out) |
|
|
|
|
|
gr.Examples(examples=[], inputs=inp) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|