import gradio as gr import numpy as np from PIL import Image import torch # Hugging Face Hub helpers from huggingface_hub import hf_hub_download # FastAI types (needed if your exported Learner references custom transforms) from fastai.vision.all import ItemTransform, PILImage, PILMask try: import albumentations as A # noqa: F401 except Exception: # If your learner was exported with Albumentations objects inside, you need this installed. A = None # ----------------------------- # Config # ----------------------------- REPO_ID = "magomerob/resnet-segmentation" EXAMPLE_IMAGE = "bicho.jpg" _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class SegmentationAlbumentationsTransform(ItemTransform): """Defined to satisfy pickles exported from notebooks (__main__). If your exported Learner includes this transform in its pipeline, it must exist (same name) when loading `model.pkl`. """ split_idx = 0 def __init__(self, aug): self.aug = aug def encodes(self, x): img, mask = x aug = self.aug(image=np.array(img), mask=np.array(mask)) return PILImage.create(aug["image"]), PILMask.create(aug["mask"]) _LEARN = None def _load_learner_from_hub(): """Load fastai Learner from Hub with a clearer error if unpickling fails.""" global _LEARN if _LEARN is not None: return _LEARN pkl_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl") try: # torch.load will surface the *real* unpickling error (missing class/import, etc.) learn = torch.load(pkl_path, map_location="cpu") except Exception as e: raise RuntimeError( "Failed to load 'model.pkl'. This is usually caused by missing custom " "classes/functions used when exporting the Learner (e.g., custom ItemTransform) " "or missing dependencies (e.g., albumentations).\n\n" f"Original error: {type(e).__name__}: {e}" ) from e # Move to device and eval learn.model.to(_device) learn.model.eval() _LEARN = learn return _LEARN def _ensure_pil_rgb(x) -> Image.Image: if isinstance(x, np.ndarray): return Image.fromarray(x).convert("RGB") return x.convert("RGB") def _build_palette(n_classes: int) -> np.ndarray: """Deterministic bright-ish palette, class 0 = black.""" rng = np.random.default_rng(0) pal = rng.integers(0, 256, size=(max(n_classes, 1), 3), dtype=np.uint8) if len(pal) > 0: pal[0] = np.array([0, 0, 0], dtype=np.uint8) return pal def _colorize_mask(mask: np.ndarray, palette: np.ndarray) -> np.ndarray: # mask: (H, W) int h, w = mask.shape out = np.zeros((h, w, 3), dtype=np.uint8) m = np.clip(mask, 0, len(palette) - 1) out[:] = palette[m] return out def segment_image(input_image, overlay_alpha=0.5, output_mode="overlay"): """Run semantic segmentation and return either overlay or raw mask (colored).""" try: learn = _load_learner_from_hub() except Exception as e: raise gr.Error(str(e)) img = _ensure_pil_rgb(input_image) # FastAI predict returns (pred_class, pred, probs) # For segmentation, pred is usually a TensorMask (H, W) with torch.no_grad(): _, pred, _ = learn.predict(img) # pred can be a TensorMask / Tensor mask = pred.detach().cpu().numpy().astype(np.int64) # Infer number of classes from vocab/config if possible n_classes = None try: n_classes = len(learn.dls.vocab) except Exception: pass if not n_classes: n_classes = int(mask.max()) + 1 if mask.size else 1 palette = _build_palette(n_classes) mask_rgb = _colorize_mask(mask, palette) if output_mode == "mask": return Image.fromarray(mask_rgb) # Overlay img_np = np.array(img, dtype=np.uint8) a = float(overlay_alpha) a = 0.0 if a < 0 else (1.0 if a > 1 else a) blended = (img_np * (1 - a) + mask_rgb * a).astype(np.uint8) return Image.fromarray(blended) with gr.Blocks() as demo: gr.Markdown( f"# Semantic Segmentation\nModel: `{REPO_ID}`\n\nUpload an image and get a segmentation mask (or an overlay)." ) with gr.Row(): inp = gr.Image(type="numpy", label="Input image") out = gr.Image(type="pil", label="Output") with gr.Row(): overlay_alpha = gr.Slider(0, 1, step=0.05, value=0.5, label="Overlay alpha") output_mode = gr.Radio(["overlay", "mask"], value="overlay", label="Output mode") btn = gr.Button("Run segmentation") btn.click( fn=segment_image, inputs=[inp, overlay_alpha, output_mode], outputs=[out], ) gr.Examples( examples=[[EXAMPLE_IMAGE, 0.5, "overlay"], [EXAMPLE_IMAGE, 0.5, "mask"]], inputs=[inp, overlay_alpha, output_mode], ) if __name__ == "__main__": demo.launch()