| import gradio as gr |
| import numpy as np |
| from PIL import Image |
|
|
| import torch |
|
|
| |
| from huggingface_hub import hf_hub_download |
|
|
| |
| from fastai.vision.all import ItemTransform, PILImage, PILMask |
|
|
| try: |
| import albumentations as A |
| except Exception: |
| |
| A = None |
|
|
| |
| |
| |
| REPO_ID = "magomerob/resnet-segmentation" |
| EXAMPLE_IMAGE = "bicho.jpg" |
|
|
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class SegmentationAlbumentationsTransform(ItemTransform): |
| """Defined to satisfy pickles exported from notebooks (__main__). |
| |
| If your exported Learner includes this transform in its pipeline, it must exist |
| (same name) when loading `model.pkl`. |
| """ |
|
|
| split_idx = 0 |
|
|
| def __init__(self, aug): |
| self.aug = aug |
|
|
| def encodes(self, x): |
| img, mask = x |
| aug = self.aug(image=np.array(img), mask=np.array(mask)) |
| return PILImage.create(aug["image"]), PILMask.create(aug["mask"]) |
|
|
|
|
| _LEARN = None |
|
|
|
|
| def _load_learner_from_hub(): |
| """Load fastai Learner from Hub with a clearer error if unpickling fails.""" |
| global _LEARN |
|
|
| if _LEARN is not None: |
| return _LEARN |
|
|
| pkl_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl") |
|
|
| try: |
| |
| learn = torch.load(pkl_path, map_location="cpu") |
| except Exception as e: |
| raise RuntimeError( |
| "Failed to load 'model.pkl'. This is usually caused by missing custom " |
| "classes/functions used when exporting the Learner (e.g., custom ItemTransform) " |
| "or missing dependencies (e.g., albumentations).\n\n" |
| f"Original error: {type(e).__name__}: {e}" |
| ) from e |
|
|
| |
| learn.model.to(_device) |
| learn.model.eval() |
|
|
| _LEARN = learn |
| return _LEARN |
|
|
|
|
| def _ensure_pil_rgb(x) -> Image.Image: |
| if isinstance(x, np.ndarray): |
| return Image.fromarray(x).convert("RGB") |
| return x.convert("RGB") |
|
|
|
|
| def _build_palette(n_classes: int) -> np.ndarray: |
| """Deterministic bright-ish palette, class 0 = black.""" |
| rng = np.random.default_rng(0) |
| pal = rng.integers(0, 256, size=(max(n_classes, 1), 3), dtype=np.uint8) |
| if len(pal) > 0: |
| pal[0] = np.array([0, 0, 0], dtype=np.uint8) |
| return pal |
|
|
|
|
| def _colorize_mask(mask: np.ndarray, palette: np.ndarray) -> np.ndarray: |
| |
| h, w = mask.shape |
| out = np.zeros((h, w, 3), dtype=np.uint8) |
| m = np.clip(mask, 0, len(palette) - 1) |
| out[:] = palette[m] |
| return out |
|
|
|
|
| def segment_image(input_image, overlay_alpha=0.5, output_mode="overlay"): |
| """Run semantic segmentation and return either overlay or raw mask (colored).""" |
| try: |
| learn = _load_learner_from_hub() |
| except Exception as e: |
| raise gr.Error(str(e)) |
|
|
| img = _ensure_pil_rgb(input_image) |
|
|
| |
| |
| with torch.no_grad(): |
| _, pred, _ = learn.predict(img) |
|
|
| |
| mask = pred.detach().cpu().numpy().astype(np.int64) |
|
|
| |
| n_classes = None |
| try: |
| n_classes = len(learn.dls.vocab) |
| except Exception: |
| pass |
| if not n_classes: |
| n_classes = int(mask.max()) + 1 if mask.size else 1 |
|
|
| palette = _build_palette(n_classes) |
| mask_rgb = _colorize_mask(mask, palette) |
|
|
| if output_mode == "mask": |
| return Image.fromarray(mask_rgb) |
|
|
| |
| img_np = np.array(img, dtype=np.uint8) |
| a = float(overlay_alpha) |
| a = 0.0 if a < 0 else (1.0 if a > 1 else a) |
| blended = (img_np * (1 - a) + mask_rgb * a).astype(np.uint8) |
| return Image.fromarray(blended) |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown( |
| f"# Semantic Segmentation\nModel: `{REPO_ID}`\n\nUpload an image and get a segmentation mask (or an overlay)." |
| ) |
|
|
| with gr.Row(): |
| inp = gr.Image(type="numpy", label="Input image") |
| out = gr.Image(type="pil", label="Output") |
|
|
| with gr.Row(): |
| overlay_alpha = gr.Slider(0, 1, step=0.05, value=0.5, label="Overlay alpha") |
| output_mode = gr.Radio(["overlay", "mask"], value="overlay", label="Output mode") |
|
|
| btn = gr.Button("Run segmentation") |
|
|
| btn.click( |
| fn=segment_image, |
| inputs=[inp, overlay_alpha, output_mode], |
| outputs=[out], |
| ) |
|
|
| gr.Examples( |
| examples=[[EXAMPLE_IMAGE, 0.5, "overlay"], [EXAMPLE_IMAGE, 0.5, "mask"]], |
| inputs=[inp, overlay_alpha, output_mode], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|