import os import numpy as np import torch import gradio as gr from mmseg.apis import init_model, inference_model from huggingface_hub import hf_hub_download # Source model repository on the Hub (avoid storing large files in the Space) REPO_ID = "triton7777/SeaIce" CONFIG_FILENAME = "segformer_mit-b5_8xb1-160k_pre-cityscapes_seaicergb0-1024x1024.py" CHECKPOINT_FILENAME = "iter_160000.pth" _model = None _palette = None def _get_model(): global _model, _palette if _model is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Prefer local files if present (for local dev), otherwise download from Hub local_cfg = os.path.join("model", CONFIG_FILENAME) local_ckpt = os.path.join("model", CHECKPOINT_FILENAME) if os.path.exists(local_cfg) and os.path.exists(local_ckpt): cfg_path = local_cfg ckpt_path = local_ckpt else: def _download_any(name: str): candidates = [name, os.path.join("model", name)] last_err = None for cand in candidates: try: return hf_hub_download(repo_id=REPO_ID, filename=cand) except Exception as e: # pragma: no cover (runtime fallback) last_err = e continue raise last_err cfg_path = _download_any(CONFIG_FILENAME) ckpt_path = _download_any(CHECKPOINT_FILENAME) _model = init_model(cfg_path, ckpt_path, device=device) meta = getattr(_model, "dataset_meta", None) or {} _palette = meta.get("palette", None) if _palette is None: num_classes = int(meta.get("classes", []) and len(meta["classes"]) or 21) rng = np.random.RandomState(42) _palette = (rng.randint(0, 255, size=(num_classes, 3))).tolist() return _model, _palette def colorize_mask(mask: np.ndarray, palette): # mask HxW int -> HxWx3 uint8 h, w = mask.shape color = np.zeros((h, w, 3), dtype=np.uint8) max_id = int(mask.max()) if mask.size else 0 if max_id >= len(palette): extra = max_id + 1 - len(palette) rng = np.random.RandomState(123) palette = palette + (rng.randint(0, 255, size=(extra, 3))).tolist() for cls_id, rgb in enumerate(palette): color[mask == cls_id] = rgb return color def overlay_seg(image_np: np.ndarray, seg_rgb: np.ndarray, alpha: float = 0.5): img = image_np if img.dtype != np.uint8: img = np.clip(img, 0, 255).astype(np.uint8) if img.ndim == 2: img = np.stack([img] * 3, axis=-1) if img.shape[2] == 4: img = img[:, :, :3] seg_rgb = seg_rgb.astype(np.uint8) if seg_rgb.shape[:2] != img.shape[:2]: import cv2 seg_rgb = cv2.resize(seg_rgb, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST) blended = (alpha * img + (1 - alpha) * seg_rgb).astype(np.uint8) return blended def predict(image: np.ndarray): model, palette = _get_model() result = inference_model(model, image) seg = result.pred_sem_seg.data if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1: mask = seg[0].to("cpu").numpy().astype(np.int32) else: mask = seg.to("cpu").numpy().astype(np.int32) seg_rgb = colorize_mask(mask, palette) out = overlay_seg(image, seg_rgb, alpha=0.5) return out with gr.Blocks() as demo: gr.Markdown( "# Sea Ice Segmentation (SegFormer-B5)\nUpload an RGB image; the app returns the image overlayed with the segmentation." ) with gr.Row(): inp = gr.Image(type="numpy", label="Input image") out = gr.Image(type="numpy", label="Segmentation overlay") with gr.Row(): alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay alpha") btn = gr.Button("Segment") def _predict_with_alpha(img, a): model, palette = _get_model() result = inference_model(model, img) seg = result.pred_sem_seg.data if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1: mask = seg[0].to("cpu").numpy().astype(np.int32) else: mask = seg.to("cpu").numpy().astype(np.int32) seg_rgb = colorize_mask(mask, palette) return overlay_seg(img, seg_rgb, alpha=a) btn.click(fn=_predict_with_alpha, inputs=[inp, alpha], outputs=out) gr.Examples(examples=[], inputs=inp) if __name__ == "__main__": demo.launch()