Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from mmseg.apis import init_model, inference_model | |
| from huggingface_hub import hf_hub_download | |
| # Source model repository on the Hub (avoid storing large files in the Space) | |
| REPO_ID = "triton7777/SeaIce" | |
| CONFIG_FILENAME = "segformer_mit-b5_8xb1-160k_pre-cityscapes_seaicergb0-1024x1024.py" | |
| CHECKPOINT_FILENAME = "iter_160000.pth" | |
| _model = None | |
| _palette = None | |
| def _get_model(): | |
| global _model, _palette | |
| if _model is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Prefer local files if present (for local dev), otherwise download from Hub | |
| local_cfg = os.path.join("model", CONFIG_FILENAME) | |
| local_ckpt = os.path.join("model", CHECKPOINT_FILENAME) | |
| if os.path.exists(local_cfg) and os.path.exists(local_ckpt): | |
| cfg_path = local_cfg | |
| ckpt_path = local_ckpt | |
| else: | |
| def _download_any(name: str): | |
| candidates = [name, os.path.join("model", name)] | |
| last_err = None | |
| for cand in candidates: | |
| try: | |
| return hf_hub_download(repo_id=REPO_ID, filename=cand) | |
| except Exception as e: # pragma: no cover (runtime fallback) | |
| last_err = e | |
| continue | |
| raise last_err | |
| cfg_path = _download_any(CONFIG_FILENAME) | |
| ckpt_path = _download_any(CHECKPOINT_FILENAME) | |
| _model = init_model(cfg_path, ckpt_path, device=device) | |
| meta = getattr(_model, "dataset_meta", None) or {} | |
| _palette = meta.get("palette", None) | |
| if _palette is None: | |
| num_classes = int(meta.get("classes", []) and len(meta["classes"]) or 21) | |
| rng = np.random.RandomState(42) | |
| _palette = (rng.randint(0, 255, size=(num_classes, 3))).tolist() | |
| return _model, _palette | |
| def colorize_mask(mask: np.ndarray, palette): | |
| # mask HxW int -> HxWx3 uint8 | |
| h, w = mask.shape | |
| color = np.zeros((h, w, 3), dtype=np.uint8) | |
| max_id = int(mask.max()) if mask.size else 0 | |
| if max_id >= len(palette): | |
| extra = max_id + 1 - len(palette) | |
| rng = np.random.RandomState(123) | |
| palette = palette + (rng.randint(0, 255, size=(extra, 3))).tolist() | |
| for cls_id, rgb in enumerate(palette): | |
| color[mask == cls_id] = rgb | |
| return color | |
| def overlay_seg(image_np: np.ndarray, seg_rgb: np.ndarray, alpha: float = 0.5): | |
| img = image_np | |
| if img.dtype != np.uint8: | |
| img = np.clip(img, 0, 255).astype(np.uint8) | |
| if img.ndim == 2: | |
| img = np.stack([img] * 3, axis=-1) | |
| if img.shape[2] == 4: | |
| img = img[:, :, :3] | |
| seg_rgb = seg_rgb.astype(np.uint8) | |
| if seg_rgb.shape[:2] != img.shape[:2]: | |
| import cv2 | |
| seg_rgb = cv2.resize(seg_rgb, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| blended = (alpha * img + (1 - alpha) * seg_rgb).astype(np.uint8) | |
| return blended | |
| def predict(image: np.ndarray): | |
| model, palette = _get_model() | |
| result = inference_model(model, image) | |
| seg = result.pred_sem_seg.data | |
| if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1: | |
| mask = seg[0].to("cpu").numpy().astype(np.int32) | |
| else: | |
| mask = seg.to("cpu").numpy().astype(np.int32) | |
| seg_rgb = colorize_mask(mask, palette) | |
| out = overlay_seg(image, seg_rgb, alpha=0.5) | |
| return out | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# Sea Ice Segmentation (SegFormer-B5)\nUpload an RGB image; the app returns the image overlayed with the segmentation." | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(type="numpy", label="Input image") | |
| out = gr.Image(type="numpy", label="Segmentation overlay") | |
| with gr.Row(): | |
| alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay alpha") | |
| btn = gr.Button("Segment") | |
| def _predict_with_alpha(img, a): | |
| model, palette = _get_model() | |
| result = inference_model(model, img) | |
| seg = result.pred_sem_seg.data | |
| if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1: | |
| mask = seg[0].to("cpu").numpy().astype(np.int32) | |
| else: | |
| mask = seg.to("cpu").numpy().astype(np.int32) | |
| seg_rgb = colorize_mask(mask, palette) | |
| return overlay_seg(img, seg_rgb, alpha=a) | |
| btn.click(fn=_predict_with_alpha, inputs=[inp, alpha], outputs=out) | |
| gr.Examples(examples=[], inputs=inp) | |
| if __name__ == "__main__": | |
| demo.launch() | |