SeaIce / app.py
Ajinkya
Use hf_hub to download model at runtime; remove local checkpoint from Space app
16aa079
import os
import numpy as np
import torch
import gradio as gr
from mmseg.apis import init_model, inference_model
from huggingface_hub import hf_hub_download
# Source model repository on the Hub (avoid storing large files in the Space)
REPO_ID = "triton7777/SeaIce"
CONFIG_FILENAME = "segformer_mit-b5_8xb1-160k_pre-cityscapes_seaicergb0-1024x1024.py"
CHECKPOINT_FILENAME = "iter_160000.pth"
_model = None
_palette = None
def _get_model():
global _model, _palette
if _model is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Prefer local files if present (for local dev), otherwise download from Hub
local_cfg = os.path.join("model", CONFIG_FILENAME)
local_ckpt = os.path.join("model", CHECKPOINT_FILENAME)
if os.path.exists(local_cfg) and os.path.exists(local_ckpt):
cfg_path = local_cfg
ckpt_path = local_ckpt
else:
def _download_any(name: str):
candidates = [name, os.path.join("model", name)]
last_err = None
for cand in candidates:
try:
return hf_hub_download(repo_id=REPO_ID, filename=cand)
except Exception as e: # pragma: no cover (runtime fallback)
last_err = e
continue
raise last_err
cfg_path = _download_any(CONFIG_FILENAME)
ckpt_path = _download_any(CHECKPOINT_FILENAME)
_model = init_model(cfg_path, ckpt_path, device=device)
meta = getattr(_model, "dataset_meta", None) or {}
_palette = meta.get("palette", None)
if _palette is None:
num_classes = int(meta.get("classes", []) and len(meta["classes"]) or 21)
rng = np.random.RandomState(42)
_palette = (rng.randint(0, 255, size=(num_classes, 3))).tolist()
return _model, _palette
def colorize_mask(mask: np.ndarray, palette):
# mask HxW int -> HxWx3 uint8
h, w = mask.shape
color = np.zeros((h, w, 3), dtype=np.uint8)
max_id = int(mask.max()) if mask.size else 0
if max_id >= len(palette):
extra = max_id + 1 - len(palette)
rng = np.random.RandomState(123)
palette = palette + (rng.randint(0, 255, size=(extra, 3))).tolist()
for cls_id, rgb in enumerate(palette):
color[mask == cls_id] = rgb
return color
def overlay_seg(image_np: np.ndarray, seg_rgb: np.ndarray, alpha: float = 0.5):
img = image_np
if img.dtype != np.uint8:
img = np.clip(img, 0, 255).astype(np.uint8)
if img.ndim == 2:
img = np.stack([img] * 3, axis=-1)
if img.shape[2] == 4:
img = img[:, :, :3]
seg_rgb = seg_rgb.astype(np.uint8)
if seg_rgb.shape[:2] != img.shape[:2]:
import cv2
seg_rgb = cv2.resize(seg_rgb, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
blended = (alpha * img + (1 - alpha) * seg_rgb).astype(np.uint8)
return blended
def predict(image: np.ndarray):
model, palette = _get_model()
result = inference_model(model, image)
seg = result.pred_sem_seg.data
if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1:
mask = seg[0].to("cpu").numpy().astype(np.int32)
else:
mask = seg.to("cpu").numpy().astype(np.int32)
seg_rgb = colorize_mask(mask, palette)
out = overlay_seg(image, seg_rgb, alpha=0.5)
return out
with gr.Blocks() as demo:
gr.Markdown(
"# Sea Ice Segmentation (SegFormer-B5)\nUpload an RGB image; the app returns the image overlayed with the segmentation."
)
with gr.Row():
inp = gr.Image(type="numpy", label="Input image")
out = gr.Image(type="numpy", label="Segmentation overlay")
with gr.Row():
alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay alpha")
btn = gr.Button("Segment")
def _predict_with_alpha(img, a):
model, palette = _get_model()
result = inference_model(model, img)
seg = result.pred_sem_seg.data
if hasattr(seg, "shape") and seg.dim() == 3 and seg.shape[0] == 1:
mask = seg[0].to("cpu").numpy().astype(np.int32)
else:
mask = seg.to("cpu").numpy().astype(np.int32)
seg_rgb = colorize_mask(mask, palette)
return overlay_seg(img, seg_rgb, alpha=a)
btn.click(fn=_predict_with_alpha, inputs=[inp, alpha], outputs=out)
gr.Examples(examples=[], inputs=inp)
if __name__ == "__main__":
demo.launch()