Practica3 / app.py
magomerob's picture
Upload 2 files
9368dc8 verified
import gradio as gr
import numpy as np
from PIL import Image
import torch
# Hugging Face Hub helpers
from huggingface_hub import hf_hub_download
# FastAI types (needed if your exported Learner references custom transforms)
from fastai.vision.all import ItemTransform, PILImage, PILMask
try:
import albumentations as A # noqa: F401
except Exception:
# If your learner was exported with Albumentations objects inside, you need this installed.
A = None
# -----------------------------
# Config
# -----------------------------
REPO_ID = "magomerob/resnet-segmentation"
EXAMPLE_IMAGE = "bicho.jpg"
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SegmentationAlbumentationsTransform(ItemTransform):
"""Defined to satisfy pickles exported from notebooks (__main__).
If your exported Learner includes this transform in its pipeline, it must exist
(same name) when loading `model.pkl`.
"""
split_idx = 0
def __init__(self, aug):
self.aug = aug
def encodes(self, x):
img, mask = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
_LEARN = None
def _load_learner_from_hub():
"""Load fastai Learner from Hub with a clearer error if unpickling fails."""
global _LEARN
if _LEARN is not None:
return _LEARN
pkl_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl")
try:
# torch.load will surface the *real* unpickling error (missing class/import, etc.)
learn = torch.load(pkl_path, map_location="cpu")
except Exception as e:
raise RuntimeError(
"Failed to load 'model.pkl'. This is usually caused by missing custom "
"classes/functions used when exporting the Learner (e.g., custom ItemTransform) "
"or missing dependencies (e.g., albumentations).\n\n"
f"Original error: {type(e).__name__}: {e}"
) from e
# Move to device and eval
learn.model.to(_device)
learn.model.eval()
_LEARN = learn
return _LEARN
def _ensure_pil_rgb(x) -> Image.Image:
if isinstance(x, np.ndarray):
return Image.fromarray(x).convert("RGB")
return x.convert("RGB")
def _build_palette(n_classes: int) -> np.ndarray:
"""Deterministic bright-ish palette, class 0 = black."""
rng = np.random.default_rng(0)
pal = rng.integers(0, 256, size=(max(n_classes, 1), 3), dtype=np.uint8)
if len(pal) > 0:
pal[0] = np.array([0, 0, 0], dtype=np.uint8)
return pal
def _colorize_mask(mask: np.ndarray, palette: np.ndarray) -> np.ndarray:
# mask: (H, W) int
h, w = mask.shape
out = np.zeros((h, w, 3), dtype=np.uint8)
m = np.clip(mask, 0, len(palette) - 1)
out[:] = palette[m]
return out
def segment_image(input_image, overlay_alpha=0.5, output_mode="overlay"):
"""Run semantic segmentation and return either overlay or raw mask (colored)."""
try:
learn = _load_learner_from_hub()
except Exception as e:
raise gr.Error(str(e))
img = _ensure_pil_rgb(input_image)
# FastAI predict returns (pred_class, pred, probs)
# For segmentation, pred is usually a TensorMask (H, W)
with torch.no_grad():
_, pred, _ = learn.predict(img)
# pred can be a TensorMask / Tensor
mask = pred.detach().cpu().numpy().astype(np.int64)
# Infer number of classes from vocab/config if possible
n_classes = None
try:
n_classes = len(learn.dls.vocab)
except Exception:
pass
if not n_classes:
n_classes = int(mask.max()) + 1 if mask.size else 1
palette = _build_palette(n_classes)
mask_rgb = _colorize_mask(mask, palette)
if output_mode == "mask":
return Image.fromarray(mask_rgb)
# Overlay
img_np = np.array(img, dtype=np.uint8)
a = float(overlay_alpha)
a = 0.0 if a < 0 else (1.0 if a > 1 else a)
blended = (img_np * (1 - a) + mask_rgb * a).astype(np.uint8)
return Image.fromarray(blended)
with gr.Blocks() as demo:
gr.Markdown(
f"# Semantic Segmentation\nModel: `{REPO_ID}`\n\nUpload an image and get a segmentation mask (or an overlay)."
)
with gr.Row():
inp = gr.Image(type="numpy", label="Input image")
out = gr.Image(type="pil", label="Output")
with gr.Row():
overlay_alpha = gr.Slider(0, 1, step=0.05, value=0.5, label="Overlay alpha")
output_mode = gr.Radio(["overlay", "mask"], value="overlay", label="Output mode")
btn = gr.Button("Run segmentation")
btn.click(
fn=segment_image,
inputs=[inp, overlay_alpha, output_mode],
outputs=[out],
)
gr.Examples(
examples=[[EXAMPLE_IMAGE, 0.5, "overlay"], [EXAMPLE_IMAGE, 0.5, "mask"]],
inputs=[inp, overlay_alpha, output_mode],
)
if __name__ == "__main__":
demo.launch()